Spark 自定义逻辑优化器之分区下推优化

书接上回,如何不修改 Spark 源码对 Spark Catalyst 引擎扩展 中介绍了如何自定义 Spark Catalyst 引擎,这次再来一个实践中的场景:在数据量很大的接入任务中,对表进行分区,此分区对用户隐藏,当用户根据指定的时间字段过滤查询的时候,自动加上分区过滤条件。

题图 by DALL·E 3

概述

这里优化的一个前提是用户在配置接入任务过程中,需要标记一个时间字段,用于分区存储优化,然后再接入过程中对此字段进行解析,保存到分区 year=yyyy/month=MM/day=dd 中。

比如时间字段是 ctime,用户 SQL 是 where ctime > '2022-11-30 12:00:00' 进过我们的优化规则之后,会变成 where ctime > '2022-11-30 12:00:00' AND concat(year, month, day) >= '20221130';甚至用户可以使用 substring 函数,比如 where substring(ctime, 1, 10) >= '2021-12-12' 也可以优化为 where substring(ctime, 1, 10) >= '2021-12-12' AND concat(year, month, day) >= '20211212',当然这里其实不是修改 SQL,而是修改的逻辑执行计划。

一个测试 SQL:select * from test where dt > '2022-11-30 00:00:00'

未经优化的逻辑计划:

1
2
3
Project [id#0, dt#1, year#2, month#3, day#4]
+- Filter (isnotnull(dt#1) AND (dt#1 > 2022-11-30 00:00:00))
+- FileScan parquet xx.test[id#0,dt#1,year#2,month#3,day#4] Batched: true, DataFilters: [isnotnull(dt#1), (dt#1 > 2022-11-30 00:00:00)], Format: Parquet, Location: CatalogFileIndex(1 paths)[hdfs://xx.db/test], PartitionFilters: [], PushedFilters: [IsNotNull(dt), GreaterThan(dt,2022-11-30 00:00:00)], ReadSchema: struct<id:string,dt:string>

优化后:PartitionFilters 部分正是我们添加的分区过滤条件

1
2
3
Project [id#0, dt#1, year#2, month#3, day#4]
+- Filter (isnotnull(dt#1) AND (dt#1 > 2022-11-30 00:00:00))
+- FileScan parquet xx.test[id#0,dt#1,year#2,month#3,day#4] Batched: true, DataFilters: [isnotnull(dt#1), (dt#1 > 2022-11-30 00:00:00)], Format: Parquet, Location: InMemoryFileIndex(2 paths)[hdfs://xx.db/test/year=2022/month=12/day..., PartitionFilters: [(concat(year#2, month#3, day#4) >= 20221130)], PushedFilters: [IsNotNull(dt), GreaterThan(dt,2022-11-30 00:00:00)], ReadSchema: struct<id:string,dt:string>

1、 创建优化规则

我们首先创建一个类,比如叫 PartitionPushDownRule

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
package com.zhangnew.demo

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.datasources.{CatalogFileIndex, HadoopFsRelation, LogicalRelation}
import org.apache.spark.sql.types.StringType
import org.joda.time.format.DateTimeFormat

case class PartitionPushDownRule(session: SparkSession) extends Rule[LogicalPlan] {

def apply(plan: LogicalPlan): LogicalPlan = {
// 1、读取配置,是否开启此优化规则
if (!session.conf.get("spark.zhangnew.partition.push.down", "false").toBoolean) {
return plan
}
// 后面几步见下文标题
// 2、检查是否为分区表,获取时间字段和分区字段
// 3、检查逻辑计划是否已经经过优化
// 4、应用分区下推,添加分区过滤添加
}
}

使用这个优化器:

1
2
3
4
5
6
7
8
9
10
11
12
val sparkConf: SparkConf = new SparkConf()
.set("spark.zhangnew.partition.push.down", "true") // 开启优化规则
.setMaster("local[1]")

private val spark = SparkSession.builder()
.config(sparkConf)
.enableHiveSupport()
.withExtensions { extensions =>
// 注入优化规则
extensions.injectOptimizerRule { session => PartitionPushDownRule(session) }
}
.getOrCreate()

2、检查是否为分区表

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
var tbName = "" // 表名
var tablePartitionKey = "" // 要过滤的时间字段
var tablePartitionKeyFormat = "yyyy-MM-dd HH:mm:ss" // 时间字段的格式
val partitionFields = mutable.Set[String]() // 分区字段列表,可能是 year month day

plan foreach {
case l: LogicalRelation =>
l.relation match {
// Hive 表
case hadoopFsRelation: HadoopFsRelation =>
val (keyFromProperties, formatFromProperties) = hadoopFsRelation.location match {
case index: CatalogFileIndex =>
tbName = index.table.qualifiedName // 获取表名
(index.table.properties.getOrElse(TABLE_PARTITION_KEY, ""), // 获取时间字段和格式
index.table.properties.getOrElse(TABLE_PARTITION_KEY_FORMAT, tablePartitionKeyFormat))
case _ => ("", tablePartitionKeyFormat)
}
tablePartitionKeyFormat = formatFromProperties
tablePartitionKey = hadoopFsRelation.options.getOrElse(TABLE_PARTITION_KEY, keyFromProperties)
hadoopFsRelation.partitionSchema.foreach(col => partitionFields.add(col.name)) // 获取分区字段
case _ =>
}
case _ =>
}
// 如果不满足要求,直接返回原始 plan
if (tablePartitionKey.isEmpty || partitionFields.isEmpty) return plan

3、检查是否已经优化过

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
var resolved = false
// 分区字段列表,前面值拿到了分区字段名称,仅依靠名称无法准确判断,这里是逻辑计划中的真实字段,带有 id,比如下面的 year#2
val partitionReferences = mutable.Set[AttributeReference]()

// 优化后的标志是包含如下分区过滤条件: concat(year#2, month#3, day#4) > 20221131
def checkResolved(expr: Expression): Unit = {
def check(leftExpr: Expression, rightExpr: Expression): Unit = {
// 一侧是常量字符串
leftExpr match {
case Literal(_, dataType) =>
if (!dataType.isInstanceOf[StringType]) return
// 另外一侧是 CONCAT(分区字段列表)
rightExpr match {
case Concat(children) if children.map(_.isInstanceOf[AttributeReference]).distinct.equals(Seq(true)) =>
resolved = children.map(_.asInstanceOf[AttributeReference].name).toSet
.forall(col => partitionFields.contains(col))
case _ =>
}
case _ =>
}
}

expr match {
case BinaryExpression(leftLeft, leftRight) =>
// 两边都要检查一下,一侧是常量字符串,另外一侧是优化后的 concat 语句
check(leftLeft, leftRight)
check(leftRight, leftLeft)
case _ =>
}
}


// check whether the filter is resolved && get partition key references
plan transformAllExpressions {
case ar: AttributeReference if partitionFields.contains(ar.name) =>
if (!partitionReferences.map(_.exprId).contains(ar.exprId)) {
partitionReferences.add(ar)
}
ar
case And(left, right) =>
checkResolved(left)
checkResolved(right)
And(left, right) // 这里只是检查,不修改,所以原样返回
}
if (resolved) return plan

4、应用分区下推优化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86

val parseFormat = DateTimeFormat.forPattern(tablePartitionKeyFormat)
val resolvedExpr = mutable.Set[Expression]()

def pushDown(left: Expression,
right: Expression,
factory: (Expression, Expression) => BinaryExpression): Expression = {
val originPlan = factory(left, right)

def buildExpr(left: Expression, right: Expression, reverse: Boolean = false): Expression = {

def buildExprByRightExpr(rightExpr: Expression): Expression = {
rightExpr match {
case literal: Literal =>
val date = try {
parseFormat.parseDateTime(literal.value.toString)
} catch {
case _: Throwable => return originPlan
}
val partitionValueList = date.toString("yyyy-MM-dd").split("-")
val partitionKeyList = partitionReferences.toSeq.sortBy(_.name).reverse
// concat(year#2, month#3, day#4) > 20221131
val expr1 = Concat(partitionKeyList)
val expr2 = Literal(partitionValueList.take(partitionKeyList.size).mkString(""))
// 注意判断左右,位置搞错的话,过滤条件就错了
val (exprLeft, exprRight) = if (reverse) {
(expr2, expr1)
} else {
(expr1, expr2)
}
// 分区的过滤范围要大于等于时间字段的日期 change > to >=, < to <=
val partitionExpr = originPlan match {
case GreaterThan(_, _) =>
GreaterThanOrEqual(exprLeft, exprRight)
case LessThan(_, _) =>
LessThanOrEqual(exprLeft, exprRight)
case _ => factory(exprLeft, exprRight)
}
// ctime > '2022-11-31 12:00:00' AND concat(year, month, day) >= '20221131'
val expr = And(originPlan, partitionExpr)
resolvedExpr.add(originPlan)
expr
case _ => originPlan
}
}

left match {
// match ctime > '2022-11-31 12:00:00'
case reference: AttributeReference if reference.name == tablePartitionKey =>
return buildExprByRightExpr(right)
// match substring(ctime, 1, 10) >= '2021-12-12'
case sub: Substring if sub.pos.isInstanceOf[Literal] && sub.pos.asInstanceOf[Literal].value == 1
&& sub.len.isInstanceOf[Literal] && sub.len.asInstanceOf[Literal].value == 10
&& sub.str.isInstanceOf[AttributeReference]
&& sub.str.asInstanceOf[AttributeReference].name == tablePartitionKey =>
return buildExprByRightExpr(right)
case _ =>
}
originPlan
}

// 先检查一下是否满足条件
if (partitionReferences.isEmpty || tablePartitionKey.isEmpty || partitionFields.isEmpty
|| resolvedExpr.contains(originPlan)) {
return originPlan
}

val expr = buildExpr(left, right)
// 交换左右再检查一次
if (!originPlan.fastEquals(expr)) expr else buildExpr(right, left, reverse = true)
}

plan transformAllExpressions {
case GreaterThan(left, right) =>
pushDown(left, right, (left: Expression, right: Expression) => GreaterThan(left, right))
case GreaterThanOrEqual(left, right) =>
pushDown(left, right, (left: Expression, right: Expression) => GreaterThanOrEqual(left, right))
case LessThan(left, right) =>
pushDown(left, right, (left: Expression, right: Expression) => LessThan(left, right))
case LessThanOrEqual(left, right) =>
pushDown(left, right, (left: Expression, right: Expression) => LessThanOrEqual(left, right))
case EqualTo(left, right) =>
pushDown(left, right, (left: Expression, right: Expression) => EqualTo(left, right))
case x =>
x
}

附言

上述举例非真实场景和代码,只是提供一种思路,真实系统中还有一些细节需要考虑,比如时间过滤的时候,对比的一个字段是另外一个字段或者函数的情况,而不是只针对 Literal