diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala index b65576403e9d8..b3b2288d2299a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, LogicalPlan} import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.catalyst.trees.UnaryLike @@ -87,10 +88,11 @@ case class DynamicPruningSubquery( override def toString: String = s"dynamicpruning#${exprId.id} $conditionString" override lazy val canonicalized: DynamicPruning = { + val buildOutput = buildQuery.output copy( pruningKey = pruningKey.canonicalized, buildQuery = buildQuery.canonicalized, - buildKeys = buildKeys.map(_.canonicalized), + buildKeys = buildKeys.map(QueryPlan.normalizeExpressions(_, buildOutput)), exprId = ExprId(0)) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala index e0d3a176b1a43..6997076334526 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala @@ -22,7 +22,7 @@ import java.util.TimeZone import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.plans.logical.Range +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Range} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.MULTI_COMMUTATIVE_OP_OPT_THRESHOLD import org.apache.spark.sql.types.{BooleanType, Decimal, DecimalType, DoubleType, IntegerType, LongType, StringType, StructField, StructType, TimestampNTZType, TimestampType} @@ -352,6 +352,23 @@ class CanonicalizeSuite extends SparkFunSuite { SQLConf.get.setConfString(MULTI_COMMUTATIVE_OP_OPT_THRESHOLD.key, default.toString) } + test("SPARK-45658: DynamicPruningSubquery canonicalization build keys not canonicalized" + + " relative to build query output") { + val pruneExprId = NamedExpression.newExprId + val pruneKey = AttributeReference("dummy", IntegerType)(pruneExprId) + val testRelation = LocalRelation($"a".int, $"b".int, $"c".int) + + val buildQueryPlan1 = testRelation.where("a".attr > 10).select($"b".attr * Literal(5)).analyze + val buildKeys1 = Seq(buildQueryPlan1.output.head) + val dps1 = DynamicPruningSubquery(pruneKey, buildQueryPlan1, buildKeys1, Seq(0), true) + + val buildQueryPlan2 = testRelation.where("a".attr > 10).select($"b".attr * Literal(5)).analyze + val buildKeys2 = Seq(buildQueryPlan2.output.head) + val dps2 = DynamicPruningSubquery(pruneKey, buildQueryPlan2, buildKeys2, Seq(0), true) + + assert(dps1.canonicalized == dps2.canonicalized) + } + test("canonicalization of With expressions with one common expression") { val expr = Divide(Literal.create(1, IntegerType), AttributeReference("a", IntegerType)()) val common1 = IsNull(With(expr.copy()) { case Seq(expr) =>