如何不修改 Spark 源码对 Spark Catalyst 引擎扩展

书接上回,上篇文章 概述了一下 Spark SQL 源码中的基本概念和解析流程,这篇文章介绍一种扩展方法。

圆角表示为 Catalyst 部分

从 Spark 2.2 之后,Spark 支持扩展 Catalyst 引擎。扩展点如下表:

Stage Extension description
Parser injectParser 负责 SQL 解析
Analyzer injectResolutionRule
injectPostHocResolutionRule
injectCheckRule
负责逻辑执行计划生成,catalog 绑定,以及进行各种检查
Optimizer injectOptimizerRule 负责逻辑执行计划的优化
Planner injectPlannerStrategy 负责物理执行计划的生成

在 Spark 3.x 之后,又额外提供了一些其他扩展点:

  • e.injectColumnar:底层读写文件相关
  • e.injectFunction:增加内置函数
  • e.injectQueryStagePrepRule:优化 AQE

注意这里只能新增扩展,无法修改已有的规则,实在不行的还得去改 Spark 源码。

一个小 Demo

这里演示一个比较简单的规则:禁止在没有 limit 的情况下 select * from table

StrictParser.scala

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
case class StrictParser(session: SparkSession, parser: ParserInterface) extends ParserInterface with Logging {

override def parsePlan(sqlText: String): LogicalPlan = {
val plan = parser.parseQuery(sqlText)
logger.debug(s"parsePlan: $sqlText\n$plan")
// 这里只检查了最顶层的 SQL,没有检查子查询
val hasLimit = plan.isInstanceOf[GlobalLimit]
var hasStar = false
plan transform {
case project @ Project(projectList, _) =>
// 检查有没出现 *
hasStar = projectList.exists(_.isInstanceOf[UnresolvedStar])
project
}
if (!hasLimit && hasStar) {
throw new RuntimeException(s"can't select * without limit: $sqlText")
}
plan
}
// ... 略
}

StrictParserTest.scala

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
class StrictParserTest extends AnyFunSuite {
private val spark = SparkSession.builder()
.config(SparkExtension.sparkConf)
// 这里有两种注入的方式,一种是配置 spark.sql.extensions,另外一种是调用 withExtensions 方法
// .config("spark.sql.extensions", "com.package.class")
.withExtensions { extensions =>
extensions.injectParser({ (session, parser) => StrictParser(session, parser) })
}
.getOrCreate()
spark.createDataFrame(Seq((1, "scala"), (2, "java")))
.toDF("id", "name")
.createOrReplaceTempView("test")

test("parsePlan") {
assertThrows[RuntimeException]{
spark.sql("select * from test")
}
spark.sql("select * from test limit 10")
spark.sql("select * from (select * from test) t1 limit 1")
}
}

执行计划

1
2
3
4
5
6
select * from test limit 10
---
'GlobalLimit 10
+- 'LocalLimit 10
+- 'Project [*]
+- 'UnresolvedRelation [test], [], false

源码浅析

基于 Spark 3.2.2 版本

SparkSessionExtensions.scala

1
2
3
4
5
6
7
8
9
private[this] val optimizerRules = mutable.Buffer.empty[RuleBuilder]

private[sql] def buildOptimizerRules(session: SparkSession): Seq[Rule[LogicalPlan]] = {
optimizerRules.map(_.apply(session)).toSeq
}
// 这里把我们自定义的规则加到了 optimizerRules 列表中
def injectOptimizerRule(builder: RuleBuilder): Unit = {
optimizerRules += builder
}

BaseSessionStateBuilder.scala

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
protected def optimizer: Optimizer = {
new SparkOptimizer(catalogManager, catalog, experimentalMethods) {
override def earlyScanPushDownRules: Seq[Rule[LogicalPlan]] =
super.earlyScanPushDownRules ++ customEarlyScanPushDownRules

override def preCBORules: Seq[Rule[LogicalPlan]] =
super.preCBORules ++ customPreCBORules

override def extendedOperatorOptimizationRules: Seq[Rule[LogicalPlan]] =
super.extendedOperatorOptimizationRules ++ customOperatorOptimizationRules
}
}
// 上面这行调用下面的方法应用自定义规则
protected def customOperatorOptimizationRules: Seq[Rule[LogicalPlan]] = {
extensions.buildOptimizerRules(session)
}
def build(): SessionState = {
new SessionState(
session.sharedState,
conf,
experimentalMethods,
functionRegistry,
tableFunctionRegistry,
udfRegistration,
() => catalog,
sqlParser,
() => analyzer,
() => optimizer,
planner,
() => streamingQueryManager,
listenerManager,
() => resourceLoader,
createQueryExecution,
createClone,
columnarRules,
queryStagePrepRules)
}

SparkSession.scala

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
// 注入扩展
def withExtensions(f: SparkSessionExtensions => Unit): Builder = synchronized {
f(extensions)
this
}
lazy val sessionState: SessionState = {
parentSessionState
.map(_.clone(this))
.getOrElse {
val state = SparkSession.instantiateSessionState(
SparkSession.sessionStateClassName(sharedState.conf),
self)
state
}
}
// 初始化 SessionState
private def instantiateSessionState(
className: String,
sparkSession: SparkSession): SessionState = {
try {
val clazz = Utils.classForName(className)
val ctor = clazz.getConstructors.head
ctor.newInstance(sparkSession, None).asInstanceOf[BaseSessionStateBuilder].build()
} catch {
case NonFatal(e) =>
throw new IllegalArgumentException(s"Error while instantiating '$className':", e)
}
}

QueryExection.scala

前面都是 Spark 初始化过程添加自定义规则,下面的代码是在 Spark SQL 执行过程中如何应用这些规则:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
lazy val optimizedPlan: LogicalPlan = {
// We need to materialize the commandExecuted here because optimizedPlan is also tracked under
// the optimizing phase
assertCommandExecuted()
executePhase(QueryPlanningTracker.OPTIMIZATION) {
// clone the plan to avoid sharing the plan instance between different stages like analyzing,
// optimizing and planning.
val plan =
sparkSession.sessionState.optimizer.executeAndTrack(withCachedData.clone(), tracker)
// We do not want optimized plans to be re-analyzed as literals that have been constant
// folded and such can cause issues during analysis. While `clone` should maintain the
// `analyzed` state of the LogicalPlan, we set the plan as analyzed here as well out of
// paranoia.
plan.setAnalyzed()
plan
}
}

RuleExecutor.scala

上面的 optimizer.executeAndTrack 就是调用下面的这些代码,真正开始执行优化逻辑:

这里可以看到优化规则都是按批次执行的,每个批次内的优化规则都会执行多次,结束的标准有两个:

  • 执行次数达到上限,也就是下面代码中的 batch.strategy.maxIterations 防止有些优化规则有问题, 无限执行;

  • plan 不再变化,也就是说这些规则都没用了;

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
def executeAndTrack(plan: TreeType, tracker: QueryPlanningTracker): TreeType = {
QueryPlanningTracker.withTracker(tracker) {
execute(plan)
}
}

def execute(plan: TreeType): TreeType = {
var curPlan = plan
val queryExecutionMetrics = RuleExecutor.queryExecutionMeter
val planChangeLogger = new PlanChangeLogger[TreeType]()
val tracker: Option[QueryPlanningTracker] = QueryPlanningTracker.get
val beforeMetrics = RuleExecutor.getCurrentMetrics()

// Run the structural integrity checker against the initial input
if (!isPlanIntegral(plan, plan)) {
throw QueryExecutionErrors.structuralIntegrityOfInputPlanIsBrokenInClassError(
this.getClass.getName.stripSuffix("$"))
}

batches.foreach { batch =>
val batchStartPlan = curPlan
var iteration = 1
var lastPlan = curPlan
var continue = true

// Run until fix point (or the max number of iterations as specified in the strategy.
while (continue) {
curPlan = batch.rules.foldLeft(curPlan) {
case (plan, rule) =>
val startTime = System.nanoTime()
// 这一行对 plan 执行优化规则,得到 result 是转换之后的 plan
val result = rule(plan)
val runTime = System.nanoTime() - startTime
val effective = !result.fastEquals(plan)

if (effective) {
queryExecutionMetrics.incNumEffectiveExecution(rule.ruleName)
queryExecutionMetrics.incTimeEffectiveExecutionBy(rule.ruleName, runTime)
planChangeLogger.logRule(rule.ruleName, plan, result)
}
queryExecutionMetrics.incExecutionTimeBy(rule.ruleName, runTime)
queryExecutionMetrics.incNumExecution(rule.ruleName)

// Record timing information using QueryPlanningTracker
tracker.foreach(_.recordRuleInvocation(rule.ruleName, runTime, effective))

// Run the structural integrity checker against the plan after each rule.
if (effective && !isPlanIntegral(plan, result)) {
throw QueryExecutionErrors.structuralIntegrityIsBrokenAfterApplyingRuleError(
rule.ruleName, batch.name)
}

result
}
iteration += 1
// 结束的标准之一
if (iteration > batch.strategy.maxIterations) {
// Only log if this is a rule that is supposed to run more than once.
if (iteration != 2) {
val endingMsg = if (batch.strategy.maxIterationsSetting == null) {
"."
} else {
s", please set '${batch.strategy.maxIterationsSetting}' to a larger value."
}
val message = s"Max iterations (${iteration - 1}) reached for batch ${batch.name}" +
s"$endingMsg"
if (Utils.isTesting || batch.strategy.errorOnExceed) {
throw new RuntimeException(message)
} else {
logWarning(message)
}
}
// Check idempotence for Once batches.
if (batch.strategy == Once &&
Utils.isTesting && !excludedOnceBatches.contains(batch.name)) {
checkBatchIdempotence(batch, curPlan)
}
continue = false
}
// 结束的标准之二
if (curPlan.fastEquals(lastPlan)) {
logTrace(
s"Fixed point reached for batch ${batch.name} after ${iteration - 1} iterations.")
continue = false
}
lastPlan = curPlan
}

planChangeLogger.logBatch(batch.name, batchStartPlan, curPlan)
}
planChangeLogger.logMetrics(RuleExecutor.getCurrentMetrics() - beforeMetrics)

curPlan
}

参考资料