Spark 3.2 String 类型无法隐式转换为 Timestamp 类型

太长不看版:Spark 3.2 之后为了兼容 ANSI SQL 标准,修改了 Interval 的数据类型(CalendarIntervalType → YearMonthIntervalType),代码中只写了 String → CalendarIntervalType的隐式转换规则,这里新版需要配置 spark.sql.legacy.interval.enabled=true 来让 Interval 使用旧的数据类型(CalendarIntervalType)。

题图 by DALL·E 3

问题描述

最近将项目从 Spark 2.4 升级到 Spark 3.2.2 之后发现一个 SQL 报错:

1
2
3
4
scala> spark.sql("select '2022-11-12 23:33:55' - INTERVAL 3 YEAR")
org.apache.spark.sql.AnalysisException: cannot resolve '(CAST('2022-11-12 23:33:55' AS DOUBLE) - INTERVAL '3' YEAR)' due to data type mismatch: differing types in '(CAST('2022-11-12 23:33:55' AS DOUBLE) - INTERVAL '3' YEAR)' (double and interval year).; line 1 pos 7;
'Project [unresolvedalias((cast(2022-11-12 23:33:55 as double) - INTERVAL '3' YEAR), None)]
+- OneRowRelation

而在 Spark 2.4 中可以正常执行:

1
2
3
4
5
6
scala> spark.sql("select '2022-11-12 23:33:55' - INTERVAL 3 YEAR").show
+-------------------------------------------------------------------------+
|CAST(CAST(2022-11-12 23:33:55 AS TIMESTAMP) - interval 3 years AS STRING)|
+-------------------------------------------------------------------------+
| 2019-11-12 23:33:55|
+-------------------------------------------------------------------------+

这里可以看到,Spark 2.4 中 StringInterval 类型做运算时,自动将 String 类型转为了 Timestamp 类型,而在 Spark 3.2 中却把 String 类型转为了 Double 类型,导致计算报错。

问题分析

Spark 3.2

我们去 Spark 3.2 源码中寻找一下,是在什么地方把 String 转为了 Double 类型。

首先,我们要知道 Spark 执行计划的优化规则在 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala 中定义了若干的批次:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
override def batches: Seq[Batch] = Seq(
Batch("Substitution", fixedPoint,
OptimizeUpdateFields,
CTESubstitution,
WindowsSubstitution,
EliminateUnions,
SubstituteUnresolvedOrdinals),
Batch("Disable Hints", Once,
new ResolveHints.DisableHints),
Batch("Hints", fixedPoint,
ResolveHints.ResolveJoinStrategyHints,
ResolveHints.ResolveCoalesceHints),
// ... 略
Batch("HandleAnalysisOnlyCommand", Once,
HandleAnalysisOnlyCommand)
)

执行这些批次的代码位于 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala 类中:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def execute(plan: TreeType): TreeType = {
var curPlan = plan
// ... 略
batches.foreach { batch =>
val batchStartPlan = curPlan
var iteration = 1
var lastPlan = curPlan
var continue = true
while (continue) {
curPlan = batch.rules.foldLeft(curPlan) {
case (plan, rule) =>
val startTime = System.nanoTime()
val result = rule(plan)
// 这里使用 rule 对 plan 进行优化,这里打印一下日志,方便对比优化效果
logWarning(s"apply rule ${rule.ruleName} to plan:\\n$plan\\n### to\\n$result")
// ... 略
result
}
iteration += 1
if (iteration > batch.strategy.maxIterations) {
// ... 略
continue = false
}
if (curPlan.fastEquals(lastPlan)) {
logTrace(
s"Fixed point reached for batch ${batch.name} after ${iteration - 1} iterations.")
continue = false
}
lastPlan = curPlan
}
}
curPlan
}

可以看到这里是对批次进行循环遍历,直到优化没有效果为止,即: curPlan.fastEquals(lastPlan)

1
2
3
4
5
6
7
8
9
10
11
12
15:02:09.102 WARN org.apache.spark.sql.internal.BaseSessionStateBuilder$$anon$1: apply rule org.apache.spark.sql.catalyst.analysis.TypeCoercionBase$CombinedTypeCoercionRule to plan: 
'Project [unresolvedalias((time#218 - INTERVAL '3' MONTH), None)]
+- SubqueryAlias tmp
+- Project [2022-12-12 12:00:59 AS time#218]
+- OneRowRelation

### to

'Project [unresolvedalias((cast(time#218 as double) - INTERVAL '3' MONTH), None)]
+- SubqueryAlias tmp
+- Project [2022-12-12 12:00:59 AS time#218]
+- OneRowRelation

这里通过添加的日志输出可以看到在经过 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala 类的优化之后,String 变成了 Double,代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
object TypeCoercion extends TypeCoercionBase {

override def typeCoercionRules: List[Rule[LogicalPlan]] =
WidenSetOperationTypes ::
CombinedTypeCoercionRule(
InConversion ::
PromoteStrings ::
// ... 略
StringLiteralCoercion :: Nil) :: Nil

case class CombinedTypeCoercionRule(rules: Seq[TypeCoercionRule]) extends TypeCoercionRule {
override def transform: PartialFunction[Expression, Expression] = {
val transforms = rules.map(_.transform)
Function.unlift { e: Expression =>
val result = transforms.foldLeft(e) {
case (current, transform) =>
val p = transform.applyOrElse(current, identity[Expression])
// 这里也加上日志方便确定是具体的 rule
logWarning(s"transform = ${transform.getClass.getName} from \\n $current to \\n$p")
p
}
if (result ne e) {
Some(result)
} else {
None
}
}
}
}

这里的 CombinedTypeCoercionRule 内是一个列表,里面包含里若干规则,这里依然通过添加日志的方式, 判断是哪个规则把 String 更改为 Double:

1
2
3
4
15:05:33.865 WARN org.apache.spark.sql.catalyst.analysis.TypeCoercionBase$CombinedTypeCoercionRule: transform = org.apache.spark.sql.catalyst.analysis.TypeCoercion$PromoteStrings$$anonfun$transform$6 from 
(time#218 - INTERVAL '3' MONTH)
to
(cast(time#218 as double) - INTERVAL '3' MONTH)

这里看到是 PromoteStrings 这个规则的问题,代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
object PromoteStrings extends TypeCoercionRule {
override def transform: PartialFunction[Expression, Expression] = {

case a @ BinaryArithmetic(left @ StringType(), right)
if right.dataType != CalendarIntervalType =>
// logWarning(s"\\nString to Double ${left.getClass.getName} right = ${right.dataType.getClass.getName}")
a.makeCopy(Array(Cast(left, DoubleType), right))
case a @ BinaryArithmetic(left, right @ StringType())
if left.dataType != CalendarIntervalType =>
a.makeCopy(Array(left, Cast(right, DoubleType)))

// ... 略
}
}

这里第一个 case 就是我们的规则,一个二元运算,并且左面是 String 类型,右侧如果不是 CalendarIntervalType,就将 String 转为 Double 类型。

这里的 select '2022-11-12 23:33:55' - INTERVAL 3 YEAR的右侧是个 YearMonthIntervalType 类型,这是 Spark 3.2.0 新增的一个类型,为了兼容 SQL 2016

Spark 2.4

再来看看 Spark 2.4.6 中是如何处理 String 和 Interval 类型的。

通过前面类似的方法,加日志找规则:

1
2
3
4
5
6
7
8
9
10
11
12
02:45:22.491 WARN org.apache.spark.sql.internal.BaseSessionStateBuilder$$anon$1: apply rule org.apache.spark.sql.catalyst.analysis.TypeCoercion$ImplicitTypeCasts to plan: 
'Project [unresolvedalias(cast(time#175 - interval 3 months as string), None)]
+- SubqueryAlias `tmp`
+- Project [2022-12-12 12:00:59 AS time#175]
+- OneRowRelation

### to

'Project [unresolvedalias(cast(cast(time#175 as timestamp) - interval 3 months as string), None)]
+- SubqueryAlias `tmp`
+- Project [2022-12-12 12:00:59 AS time#175]
+- OneRowRelation

可以看到是 ImplicitTypeCastsString 转为了 Timestamp

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala 中找到 ImplicitTypeCasts 代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
object ImplicitTypeCasts extends TypeCoercionRule {
override protected def coerceTypes(
plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
case e: ImplicitCastInputTypes if e.inputTypes.nonEmpty =>
val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) =>
// If we cannot do the implicit cast, just use the original input.
implicitCast(in, expected).getOrElse(in)
}
e.withNewChildren(children)
}

def implicitCast(e: Expression, expectedType: AbstractDataType): Option[Expression] = {
implicitCast(e.dataType, expectedType).map { dt =>
if (dt == e.dataType) e else Cast(e, dt)
}
}

private def implicitCast(inType: DataType, expectedType: AbstractDataType): Option[DataType] = {
// Note that ret is nullable to avoid typing a lot of Some(...) in this local scope.
// We wrap immediately an Option after this.
@Nullable val ret: DataType = (inType, expectedType) match {
// If the expected type is already a parent of the input type, no need to cast.
case _ if expectedType.acceptsType(inType) => inType
// ... 略
// Implicit cast from/to string
case (StringType, DecimalType) => DecimalType.SYSTEM_DEFAULT
case (StringType, target: NumericType) => target
case (StringType, DateType) => DateType
// 下面这行就是我们要找的规则
case (StringType, TimestampType) => TimestampType
case (StringType, BinaryType) => BinaryType
// Cast any atomic type to string.
case (any: AtomicType, StringType) if any != StringType => StringType
// ... 略
case _ => null
}
Option(ret)
}
}

这里的 implicitCast 方法写了很多转换的规则,其中就包括我们要找的规则 case (StringType, TimestampType) => TimestampType

差异

这时候我们发现其实在 Spark 3.2 中也是有 implicitCast 方法的,只不过因为右侧的 Interval 类型被改变了,导致这两个 Spark 版本的行为不一致:

  • Spark 2.4 中 interval 3 month 的类型是 CalendarIntervalType
  • Spark 3.2 中 interval 3 month 的类型是 YearMonthIntervalType

因此最终的优化规则不一样,从而发生报错。

另外spark.sql.ansi.enabled 默认是关闭的,这里的行为和此模式是否开启无关。

结论

后来通过 Spark Jira: Support ANSI SQL INTERVAL types 里面看到,可以通过配置改变 Interval 的底层类型,sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 中搜索了一下,找到这个配置 spark.sql.legacy.interval.enabled 默认关闭状态,开启之后可以正常执行上面的 SQL。此配置的部分相关代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
// Antlr 解析代码中 legacyIntervalEnabled 行为如下
// sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala

override def visitTypeConstructor(ctx: TypeConstructorContext): Literal = withOrigin(ctx) {
// ... 略
try {
valueType match {
case "DATE" =>
// ... 略
case "TIMESTAMP" =>
// ... 略
case "INTERVAL" =>
val interval = try {
IntervalUtils.stringToInterval(UTF8String.fromString(value))
} catch {
case e: IllegalArgumentException =>
val ex = QueryParsingErrors.cannotParseIntervalValueError(value, ctx)
ex.setStackTrace(e.getStackTrace)
throw ex
}
if (!conf.legacyIntervalEnabled) {
val units = value
.split("\\\\s")
.map(_.toLowerCase(Locale.ROOT).stripSuffix("s"))
.filter(s => s != "interval" && s.matches("[a-z]+"))
constructMultiUnitsIntervalLiteral(ctx, interval, units)
} else {
Literal(interval, CalendarIntervalType)
}
case "X" =>
val padding = if (value.length % 2 != 0) "0" else ""
Literal(DatatypeConverter.parseHexBinary(padding + value))
case other =>
throw QueryParsingErrors.literalValueTypeUnsupportedError(other, ctx)
}
} catch {
case e: IllegalArgumentException =>
throw QueryParsingErrors.parsingValueTypeError(e, valueType, ctx)
}
}

// datetimeExpressions 中 legacyInterval 行为如下
// sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
case class SubtractDates(
left: Expression,
right: Expression,
legacyInterval: Boolean)
extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {

def this(left: Expression, right: Expression) =
this(left, right, SQLConf.get.legacyIntervalEnabled)

override def inputTypes: Seq[AbstractDataType] = Seq(DateType, DateType)
override def dataType: DataType = {
if (legacyInterval) CalendarIntervalType else DayTimeIntervalType(DAY)
}
// ... 略
}

参考资料

Spark 中 ANSI SQL 相关可以参考如下资料: