Skip to content

Commit

Permalink
[SPARK-46240][SQL] Support inject executed plan prep rules in SparkSe…
Browse files Browse the repository at this point in the history
…ssionExtensions
  • Loading branch information
jiang13021 committed Dec 8, 2023
1 parent 105eee7 commit d9dbc88
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ import org.apache.spark.sql.execution.{ColumnarRule, SparkPlan}
* <li>Customized Parser.</li>
* <li>(External) Catalog listeners.</li>
* <li>Columnar Rules.</li>
* <li>Executed Plan Preparation Rules.</li>
* <li>Adaptive Query Post Planner Strategy Rules.</li>
* <li>Adaptive Query Stage Preparation Rules.</li>
* <li>Adaptive Query Execution Runtime Optimizer Rules.</li>
Expand Down Expand Up @@ -113,11 +114,13 @@ class SparkSessionExtensions {
type FunctionDescription = (FunctionIdentifier, ExpressionInfo, FunctionBuilder)
type TableFunctionDescription = (FunctionIdentifier, ExpressionInfo, TableFunctionBuilder)
type ColumnarRuleBuilder = SparkSession => ColumnarRule
type ExecutedPlanPrepRuleBuilder = SparkSession => Rule[SparkPlan]
type QueryPostPlannerStrategyBuilder = SparkSession => Rule[SparkPlan]
type QueryStagePrepRuleBuilder = SparkSession => Rule[SparkPlan]
type QueryStageOptimizerRuleBuilder = SparkSession => Rule[SparkPlan]

private[this] val columnarRuleBuilders = mutable.Buffer.empty[ColumnarRuleBuilder]
private[this] val executedPlanPrepRuleBuilders = mutable.Buffer.empty[ExecutedPlanPrepRuleBuilder]
private[this] val queryPostPlannerStrategyRuleBuilders =
mutable.Buffer.empty[QueryPostPlannerStrategyBuilder]
private[this] val queryStagePrepRuleBuilders = mutable.Buffer.empty[QueryStagePrepRuleBuilder]
Expand All @@ -132,6 +135,13 @@ class SparkSessionExtensions {
columnarRuleBuilders.map(_.apply(session)).toSeq
}

/**
* Build the override rules for the for the executed plan preparation phase.
*/
private[sql] def buildExecutedPlanPrepRules(session: SparkSession): Seq[Rule[SparkPlan]] = {
executedPlanPrepRuleBuilders.map(_.apply(session)).toSeq
}

/**
* Build the override rules for the query post planner strategy phase of adaptive query execution.
*/
Expand Down Expand Up @@ -168,6 +178,13 @@ class SparkSessionExtensions {
columnarRuleBuilders += builder
}

/**
* Inject a rule that applied between `plannerStrategy` and executed plan preparation
*/
def injectExecutedPlanPrepRule(builder: ExecutedPlanPrepRuleBuilder): Unit = {
executedPlanPrepRuleBuilders += builder
}

/**
* Inject a rule that applied between `plannerStrategy` and `queryStagePrepRules`, so
* it can get the whole plan before injecting exchanges.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ object QueryExecution {
Nil
} else {
Seq(ReuseExchangeAndSubquery)
})
}) ++ sparkSession.sessionState.executedPlanPrepRules
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.catalog.CatalogManager
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.{ColumnarRule, CommandExecutionMode, QueryExecution, SparkOptimizer, SparkPlanner, SparkSqlParser}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.adaptive.AdaptiveRulesHolder
import org.apache.spark.sql.execution.aggregate.{ResolveEncodersInScalaAgg, ScalaUDAF}
import org.apache.spark.sql.execution.analysis.DetectAmbiguousSelfJoin
Expand Down Expand Up @@ -328,6 +329,10 @@ abstract class BaseSessionStateBuilder(
extensions.buildColumnarRules(session)
}

protected def executedPlanPrepRules: Seq[Rule[SparkPlan]] = {
extensions.buildExecutedPlanPrepRules(session)
}

protected def adaptiveRulesHolder: AdaptiveRulesHolder = {
new AdaptiveRulesHolder(
extensions.buildQueryStagePrepRules(session),
Expand Down Expand Up @@ -403,6 +408,7 @@ abstract class BaseSessionStateBuilder(
createQueryExecution,
createClone,
columnarRules,
executedPlanPrepRules,
adaptiveRulesHolder,
planNormalizationRules,
() => artifactManager)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ private[sql] class SessionState(
createQueryExecution: (LogicalPlan, CommandExecutionMode.Value) => QueryExecution,
createClone: (SparkSession, SessionState) => SessionState,
val columnarRules: Seq[ColumnarRule],
val executedPlanPrepRules: Seq[Rule[SparkPlan]],
val adaptiveRulesHolder: AdaptiveRulesHolder,
val planNormalizationRules: Seq[Rule[LogicalPlan]],
val artifactManagerBuilder: () => ArtifactManager) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package org.apache.spark.sql

import java.util.{Locale, UUID}

import scala.collection.mutable
import scala.collection.mutable.ListBuffer
import scala.concurrent.Future

Expand Down Expand Up @@ -543,6 +544,22 @@ class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper with Adapt
}
}
}

test("SPARK-46240: Support inject executed plan prep rules in SparkSessionExtensions") {
val extensions = create {
extensions =>
extensions.injectExecutedPlanPrepRule(_ => MyExecutedPlanPrepRule)
}
withSession(extensions) { session =>
session.sql(
"""
| select k, v
| from values ("key1", 1), ("key2", 2)
| as tbl(k, v)
|""".stripMargin).queryExecution.executedPlan
assert(MyExecutedPlanPrepRule.collectMap.get("LocalTableScan").contains(1))
}
}
}

case class MyRule(spark: SparkSession) extends Rule[LogicalPlan] {
Expand Down Expand Up @@ -1228,3 +1245,17 @@ object MyQueryPostPlannerStrategyRule extends Rule[SparkPlan] {
}
}
}

object MyExecutedPlanPrepRule extends Rule[SparkPlan] {
var collectMap: mutable.HashMap[String, Long] = mutable.HashMap.empty

override def apply(plan: SparkPlan): SparkPlan = {
plan foreachUp {
node => {
val key = node.nodeName
collectMap.update(key, collectMap.getOrElseUpdate(key, 0) + 1)
}
}
plan
}
}

0 comments on commit d9dbc88

Please sign in to comment.