diff --git a/spark/src/main/scala/org/apache/comet/CometConf.scala b/spark/src/main/scala/org/apache/comet/CometConf.scala index af8fbb5815..3b54498e63 100644 --- a/spark/src/main/scala/org/apache/comet/CometConf.scala +++ b/spark/src/main/scala/org/apache/comet/CometConf.scala @@ -275,6 +275,16 @@ object CometConf extends ShimCometConf { val COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED: ConfigEntry[Boolean] = createExecEnabledConfig("localTableScan", defaultValue = false) + val COMET_EXEC_IN_MEMORY_CACHE_ENABLED: ConfigEntry[Boolean] = + conf("spark.comet.exec.inMemoryCache.enabled") + .category(CATEGORY_EXEC) + .doc( + "Whether to enable Comet native execution for in-memory cached tables. " + + "When disabled or when spark.comet.enabled=false, Spark's default cache " + + "serializer and execution path will be used.") + .booleanConf + .createWithDefault(false) + val COMET_NATIVE_COLUMNAR_TO_ROW_ENABLED: ConfigEntry[Boolean] = conf(s"$COMET_EXEC_CONFIG_PREFIX.columnarToRow.native.enabled") .category(CATEGORY_EXEC) diff --git a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala index d116d2f407..114304492e 100644 --- a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala +++ b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala @@ -29,11 +29,13 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.catalyst.util.sideBySide import org.apache.spark.sql.comet._ +import org.apache.spark.sql.comet.CometInMemoryTableScanExec import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle, CometShuffleExchangeExec} import org.apache.spark.sql.comet.util.Utils import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AQEShuffleReadExec, BroadcastQueryStageExec, ShuffleQueryStageExec} import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec} +import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.command.{DataWritingCommandExec, ExecutedCommandExec} import org.apache.spark.sql.execution.datasources.WriteFilesExec import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat @@ -85,6 +87,7 @@ object CometExecRule { classOf[SortMergeJoinExec] -> CometSortMergeJoinExec, classOf[SortExec] -> CometSortExec, classOf[LocalTableScanExec] -> CometLocalTableScanExec, + classOf[InMemoryTableScanExec] -> CometInMemoryTableScanExec, classOf[WindowExec] -> CometWindowExec) /** @@ -282,6 +285,47 @@ case class CometExecRule(session: SparkSession) case op if isCometScan(op) => convertToComet(op, CometScanWrapper).getOrElse(op) + case scan: InMemoryTableScanExec => + val cachedBuffers = scan.relation.cacheBuilder.cachedColumnBuffers + val firstBatchOpt = cachedBuffers.take(1).headOption + val expectedBatchClass = + "org.apache.spark.sql.comet.execution.arrow.CometCachedBatch" + + if (CometConf.COMET_EXEC_IN_MEMORY_CACHE_ENABLED.get(conf)) { + firstBatchOpt match { + case Some(firstBatch) if firstBatch.getClass.getName == expectedBatchClass => + convertToComet(scan, CometInMemoryTableScanExec).getOrElse(scan) + + case Some(firstBatch) => + withFallbackReason( + scan, + s"Comet in-memory cache requires $expectedBatchClass, " + + s"but found ${firstBatch.getClass.getName}") + scan + + case None => + withFallbackReason( + scan, + "Comet in-memory cache rewrite skipped because cached buffer is empty") + scan + } + } else { + firstBatchOpt match { + case Some(firstBatch) if firstBatch.getClass.getName == expectedBatchClass => + withFallbackReason( + scan, + s"Native support for operator InMemoryTableScanExec is disabled. " + + s"Set ${CometConf.COMET_EXEC_IN_MEMORY_CACHE_ENABLED.key}=true to enable it.") + case _ => + } + + if (shouldApplySparkToColumnar(conf, scan)) { + convertToComet(scan, CometSparkToColumnarExec).getOrElse(scan) + } else { + scan + } + } + case op if shouldApplySparkToColumnar(conf, op) => convertToComet(op, CometSparkToColumnarExec).getOrElse(op) diff --git a/spark/src/main/scala/org/apache/spark/Plugins.scala b/spark/src/main/scala/org/apache/spark/Plugins.scala index 7290ab436a..b849a481d9 100644 --- a/spark/src/main/scala/org/apache/spark/Plugins.scala +++ b/spark/src/main/scala/org/apache/spark/Plugins.scala @@ -28,6 +28,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config.{EXECUTOR_MEMORY, EXECUTOR_MEMORY_OVERHEAD, EXECUTOR_MEMORY_OVERHEAD_FACTOR} import org.apache.spark.sql.internal.StaticSQLConf +import org.apache.comet.CometConf import org.apache.comet.CometConf.{COMET_METRICS_ENABLED, COMET_ONHEAP_ENABLED} import org.apache.comet.CometSparkSessionExtensions @@ -54,6 +55,10 @@ class CometDriverPlugin extends DriverPlugin with Logging with ShimCometDriverPl return Collections.emptyMap[String, String] } + val extraConfs = new ju.HashMap[String, String]() + + CometDriverPlugin.maybeSetCacheSerializer(sc.conf, extraConfs) + // register CometSparkSessionExtensions if it isn't already registered CometDriverPlugin.registerCometSessionExtension(sc.conf) @@ -87,7 +92,7 @@ class CometDriverPlugin extends DriverPlugin with Logging with ShimCometDriverPl logInfo("Comet is running in unified memory mode and sharing off-heap memory with Spark") } - Collections.emptyMap[String, String] + extraConfs } override def receive(message: Any): AnyRef = super.receive(message) @@ -104,6 +109,29 @@ class CometDriverPlugin extends DriverPlugin with Logging with ShimCometDriverPl } object CometDriverPlugin extends Logging { + // Use Comet's cache serializer only for the native in-memory cache path. + // If the application already set spark.sql.cache.serializer, leave that value + // unchanged so Comet does not replace a user-selected cache format. + private[apache] def maybeSetCacheSerializer( + conf: SparkConf, + extraConfs: ju.HashMap[String, String]): Unit = { + if (conf.getBoolean(CometConf.COMET_EXEC_IN_MEMORY_CACHE_ENABLED.key, false)) { + val serializerKey = StaticSQLConf.SPARK_CACHE_SERIALIZER.key + val serializerValue = + "org.apache.spark.sql.comet.execution.arrow.ArrowCachedBatchSerializer" + val defaultSerializer = StaticSQLConf.SPARK_CACHE_SERIALIZER.defaultValueString + val currentSerializer = conf.get(serializerKey, defaultSerializer) + + if (currentSerializer == defaultSerializer) { + extraConfs.put(serializerKey, serializerValue) + conf.set(serializerKey, serializerValue) + logInfo(s"Auto-set $serializerKey=$serializerValue") + } else { + logInfo(s"Not overriding user-provided $serializerKey=$currentSerializer") + } + } + } + def registerCometMetrics(sc: SparkContext): Unit = { if (sc.getConf.getBoolean( COMET_METRICS_ENABLED.key, diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometInMemoryTableScanExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometInMemoryTableScanExec.scala new file mode 100644 index 0000000000..cb12f2a3ac --- /dev/null +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometInMemoryTableScanExec.scala @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet + +import scala.collection.JavaConverters._ + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.columnar.{CachedBatch, CachedBatchSerializer} +import org.apache.spark.sql.execution.LeafExecNode +import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.vectorized.ColumnarBatch + +import org.apache.comet.CometConf +import org.apache.comet.serde.CometOperatorSerde +import org.apache.comet.serde.OperatorOuterClass +import org.apache.comet.serde.OperatorOuterClass.Operator +import org.apache.comet.serde.QueryPlanSerde.serializeDataType + +/** + * Reads Spark cached table data when the cache was written by Comet's cache serializer. + * + * Spark stores cached data through `CachedBatchSerializer`. This node keeps the scan inside Comet + * by asking the serializer to decode cached batches directly into `ColumnarBatch` output, + * avoiding the extra Spark columnar-to-Comet columnar conversion used by the default path. + * + * `relationOutput` is the full schema stored in the cache. `scanOutput` is the subset requested + * by this scan after pruning. + */ +case class CometInMemoryTableScanExec( + originalPlan: InMemoryTableScanExec, + serializer: CachedBatchSerializer, + cachedBuffers: RDD[CachedBatch], + relationOutput: Seq[Attribute], + scanOutput: Seq[Attribute]) + extends CometExec + with LeafExecNode { + + override lazy val metrics: Map[String, SQLMetric] = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) + + override def output: Seq[Attribute] = originalPlan.output + + // Use the serializer's vector types because the cached batch layout is owned by the serializer. + override def vectorTypes: Option[Seq[String]] = + serializer.vectorTypes(scanOutput, conf) + + // Decode only the requested columns from the cached batches and update scan output metrics. + override def doExecuteColumnar(): RDD[ColumnarBatch] = { + val numOutputRows = longMetric("numOutputRows") + + serializer + .convertCachedBatchToColumnarBatch(cachedBuffers, relationOutput, scanOutput, conf) + .map { cb => + numOutputRows += cb.numRows() + cb + } + } +} + +object CometInMemoryTableScanExec extends CometOperatorSerde[InMemoryTableScanExec] { + + override def enabledConfig: Option[org.apache.comet.ConfigEntry[Boolean]] = + Some(CometConf.COMET_EXEC_IN_MEMORY_CACHE_ENABLED) + + override def convert( + op: InMemoryTableScanExec, + builder: OperatorOuterClass.Operator.Builder, + childOp: Operator*): Option[Operator] = { + + // Empty-output scans still need a schema for native planning, so fall back to the cache schema. + val actualOutput = + if (op.output.nonEmpty) op.output + else op.relation.output + + val scanTypes = actualOutput.flatMap(attr => serializeDataType(attr.dataType)) + + val scanBuilder = OperatorOuterClass.Scan + .newBuilder() + .setSource(op.getClass.getSimpleName) + .addAllFields(scanTypes.asJava) + // Cached batches are decoded on the JVM side; the native scan only receives Spark batches. + .setArrowFfiSafe(false) + + Some(builder.setScan(scanBuilder).build()) + } + + // Reuse Spark's InMemoryRelation metadata so cache materialization, pruning, and storage + // behavior remain controlled by Spark's cache manager. + override def createExec(nativeOp: Operator, op: InMemoryTableScanExec): CometNativeExec = { + val relation = op.relation + + val actualOutput = + if (op.output.nonEmpty) op.output + else relation.output + + CometScanWrapper( + nativeOp, + CometInMemoryTableScanExec( + op, + relation.cacheBuilder.serializer, + relation.cacheBuilder.cachedColumnBuffers, + relation.output, + actualOutput)) + } +} diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/ArrowCachedBatchSerializer.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/ArrowCachedBatchSerializer.scala new file mode 100644 index 0000000000..9a183480e7 --- /dev/null +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/ArrowCachedBatchSerializer.scala @@ -0,0 +1,353 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.execution.arrow + +import scala.collection.JavaConverters._ + +import org.apache.spark.TaskContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericInternalRow, UnsafeProjection} +import org.apache.spark.sql.columnar.{CachedBatch, SimpleMetricsCachedBatch, SimpleMetricsCachedBatchSerializer} +import org.apache.spark.sql.comet.util.Utils +import org.apache.spark.sql.execution.columnar.{DefaultCachedBatch, DefaultCachedBatchSerializer} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} +import org.apache.spark.storage.StorageLevel +import org.apache.spark.unsafe.types.{ByteArray, UTF8String} +import org.apache.spark.util.io.ChunkedByteBuffer + +import org.apache.comet.CometConf + +/** + * Cached batch format used when Comet writes Spark in-memory cache data. + * + * `bytes` contains compressed Arrow stream data produced by `Utils.serializeBatches`. The cache + * manager still owns storage and eviction; this class only changes the cached payload. + */ +private case class CometCachedBatch( + override val numRows: Int, + override val sizeInBytes: Long, + override val stats: InternalRow, + bytes: ChunkedByteBuffer) + extends SimpleMetricsCachedBatch + +/** + * Cache serializer that stores Comet-compatible Arrow batches in Spark's in-memory cache. + * + * When Comet cache support is disabled, row-based cache writes and default cache reads are + * delegated to Spark's `DefaultCachedBatchSerializer`. + */ +class ArrowCachedBatchSerializer extends SimpleMetricsCachedBatchSerializer { + + private val fallback = new DefaultCachedBatchSerializer() + + // Cache writes use Comet format only when both Comet and the in-memory cache scan are enabled. + private def enabled(conf: SQLConf): Boolean = { + CometConf.COMET_ENABLED.get(conf) && + CometConf.COMET_EXEC_IN_MEMORY_CACHE_ENABLED.get(conf) + } + + // Row-to-Arrow conversion needs a StructType, while cache APIs pass attributes. + private def toStructType(schema: Seq[Attribute]): StructType = { + StructType(schema.map { attr => + StructField(attr.name, attr.dataType, attr.nullable, attr.metadata) + }) + } + + // Build the statistics row expected by SimpleMetricsCachedBatchSerializer. + // For each cached column Spark expects five values in this order: + // lower bound, upper bound, null count, row count, and size in bytes. + private def computeStats(batch: ColumnarBatch, attrs: Seq[Attribute]): InternalRow = { + val numCols = attrs.length + val lower = new Array[Any](numCols) + val upper = new Array[Any](numCols) + val nulls = Array.fill[Int](numCols)(0) + val numRows = batch.numRows() + + var c = 0 + while (c < numCols) { + val dt = attrs(c).dataType + val col = batch.column(c) + var r = 0 + while (r < numRows) { + if (col.isNullAt(r)) { + nulls(c) += 1 + } else if (tracksBounds(dt)) { + val value = readValue(col, dt, r) + if (lower(c) == null || compare(dt, value, lower(c)) < 0) { + lower(c) = value + } + if (upper(c) == null || compare(dt, value, upper(c)) > 0) { + upper(c) = value + } + } + r += 1 + } + c += 1 + } + + val values = new Array[Any](numCols * 5) + c = 0 + while (c < numCols) { + val base = c * 5 + values(base) = lower(c) + values(base + 1) = upper(c) + values(base + 2) = nulls(c) + values(base + 3) = numRows + values(base + 4) = 0L + c += 1 + } + + new GenericInternalRow(values) + } + + // Spark can prune cache batches only for types whose bounds can be compared. + // Other types still report null count and row count but leave bounds as null. + private def tracksBounds(dt: DataType): Boolean = dt match { + case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | + _: DecimalType | StringType | DateType | TimestampType | TimestampNTZType => + true + case _ => false + } + + // Read a non-null value from a ColumnVector using Spark's internal value type + // for the corresponding DataType. + private def readValue(col: ColumnVector, dt: DataType, rowId: Int): Any = dt match { + case BooleanType => col.getBoolean(rowId) + case ByteType => col.getByte(rowId) + case ShortType => col.getShort(rowId) + case IntegerType | DateType => col.getInt(rowId) + case LongType | TimestampType | TimestampNTZType => col.getLong(rowId) + case FloatType => col.getFloat(rowId) + case DoubleType => col.getDouble(rowId) + case d: DecimalType => col.getDecimal(rowId, d.precision, d.scale) + case StringType => col.getUTF8String(rowId).copy() + case _ => null + } + + // Compare values using the same physical representation used in the stats row. + private def compare(dt: DataType, left: Any, right: Any): Int = dt match { + case BooleanType => + java.lang.Boolean.compare(left.asInstanceOf[Boolean], right.asInstanceOf[Boolean]) + case ByteType => + java.lang.Byte.compare(left.asInstanceOf[Byte], right.asInstanceOf[Byte]) + case ShortType => + java.lang.Short.compare(left.asInstanceOf[Short], right.asInstanceOf[Short]) + case IntegerType | DateType => + java.lang.Integer.compare(left.asInstanceOf[Int], right.asInstanceOf[Int]) + case LongType | TimestampType | TimestampNTZType => + java.lang.Long.compare(left.asInstanceOf[Long], right.asInstanceOf[Long]) + case FloatType => + java.lang.Float.compare(left.asInstanceOf[Float], right.asInstanceOf[Float]) + case DoubleType => + java.lang.Double.compare(left.asInstanceOf[Double], right.asInstanceOf[Double]) + case _: DecimalType => + left.asInstanceOf[Decimal].compare(right.asInstanceOf[Decimal]) + case StringType => + ByteArray.compareBinary( + left.asInstanceOf[UTF8String].getBytes, + right.asInstanceOf[UTF8String].getBytes) + case other => + throw new IllegalStateException(s"compare called for unsupported type $other") + } + + // Compute Spark-compatible cache stats before serializing each batch to Arrow. + // The stats are stored beside the Arrow bytes so Spark's cache filter can prune + // CometCachedBatch without decoding the batch first. + private def encodeBatches( + batches: Iterator[ColumnarBatch], + attrs: Seq[Attribute]): Iterator[CachedBatch] = { + batches.flatMap { batch => + val stats = computeStats(batch, attrs) + + Utils.serializeBatches(Iterator.single(batch)).map { case (rows, buffer) => + CometCachedBatch( + numRows = rows.toInt, + sizeInBytes = buffer.size, + stats = stats, + bytes = buffer) + } + } + } + + override def supportsColumnarInput(schema: Seq[Attribute]): Boolean = { + val activeConf = SQLConf.get + activeConf != null && enabled(activeConf) + } + override def supportsColumnarOutput(schema: StructType): Boolean = true + + // Columnar Comet output is stored as compressed Arrow stream bytes. + override def convertColumnarBatchToCachedBatch( + input: RDD[ColumnarBatch], + schema: Seq[Attribute], + storageLevel: StorageLevel, + conf: SQLConf): RDD[CachedBatch] = { + + input.mapPartitions { batches => + encodeBatches(batches, schema) + } + } + + override def convertCachedBatchToColumnarBatch( + input: RDD[CachedBatch], + cacheAttributes: Seq[Attribute], + selectedAttributes: Seq[Attribute], + conf: SQLConf): RDD[ColumnarBatch] = { + + // Resolve requested columns by exprId, not by name, because aliases may reuse names. + val selectedIndices = + if (selectedAttributes.isEmpty) { + cacheAttributes.indices.toArray + } else { + val byExprId = cacheAttributes.zipWithIndex.map { case (attr, idx) => + attr.exprId -> idx + }.toMap + + selectedAttributes.map { attr => + byExprId.getOrElse( + attr.exprId, + throw new IllegalStateException( + s"Could not resolve selected attribute ${attr.name} from cache attributes")) + }.toArray + } + + val batchTypes = input.map(_.getClass.getName).distinct().collect() + + if (batchTypes.isEmpty) { + input.sparkContext.emptyRDD[ColumnarBatch] + } else if (batchTypes.length > 1) { + throw new IllegalStateException( + s"Mixed cached batch types are not supported: ${batchTypes.mkString(", ")}") + } else if (batchTypes.head == classOf[CometCachedBatch].getName) { + input.mapPartitions { it => + it.flatMap { + case cb: CometCachedBatch => + Utils.decodeBatches(cb.bytes, "CometCache").map { batch => + if (selectedIndices.length == batch.numCols()) { + batch + } else { + val cols = + selectedIndices.map(i => batch.column(i).asInstanceOf[ColumnVector]) + new ColumnarBatch(cols, batch.numRows()) + } + } + + case other => + throw new IllegalStateException( + s"Expected CometCachedBatch, got ${other.getClass.getName}") + } + } + } else if (batchTypes.head == classOf[DefaultCachedBatch].getName) { + fallback.convertCachedBatchToColumnarBatch(input, cacheAttributes, selectedAttributes, conf) + } else { + throw new IllegalStateException(s"Unsupported cached batch type: ${batchTypes.head}") + } + } + + // Row input can still be cached in Comet format by converting rows to Arrow batches first. + override def convertInternalRowToCachedBatch( + input: RDD[InternalRow], + schema: Seq[Attribute], + storageLevel: StorageLevel, + conf: SQLConf): RDD[CachedBatch] = { + + if (!enabled(conf)) { + fallback.convertInternalRowToCachedBatch(input, schema, storageLevel, conf) + } else { + val batchSize = conf.columnBatchSize + val sessionTz = conf.sessionLocalTimeZone + + input.mapPartitions { rows => + val iter = CometArrowConverters.rowToArrowBatchIter( + rows, + toStructType(schema), + batchSize, + sessionTz, + TaskContext.get()) + + encodeBatches(iter, schema) + } + } + } + + override def convertCachedBatchToInternalRow( + input: RDD[CachedBatch], + cacheAttributes: Seq[Attribute], + selectedAttributes: Seq[Attribute], + conf: SQLConf): RDD[InternalRow] = { + + // Resolve requested columns by exprId, not by name, because aliases may reuse names. + val selectedIndices = + if (selectedAttributes.isEmpty) { + cacheAttributes.indices.toArray + } else { + val byExprId = cacheAttributes.zipWithIndex.map { case (attr, idx) => + attr.exprId -> idx + }.toMap + + selectedAttributes.map { attr => + byExprId.getOrElse( + attr.exprId, + throw new IllegalStateException( + s"Could not resolve selected attribute ${attr.name} from cache attributes")) + }.toArray + } + + val batchTypes = input.map(_.getClass.getName).distinct().collect() + + if (batchTypes.isEmpty) { + input.sparkContext.emptyRDD[InternalRow] + } else if (batchTypes.length > 1) { + throw new IllegalStateException( + s"Mixed cached batch types are not supported: ${batchTypes.mkString(", ")}") + } else if (batchTypes.head == classOf[DefaultCachedBatch].getName) { + fallback.convertCachedBatchToInternalRow(input, cacheAttributes, selectedAttributes, conf) + } else if (batchTypes.head == classOf[CometCachedBatch].getName) { + input.mapPartitions { it => + it.flatMap { + case cb: CometCachedBatch => + Utils.decodeBatches(cb.bytes, "CometCache").flatMap { batch => + val projectedBatch = + if (selectedIndices.length == batch.numCols()) { + batch + } else { + val cols = + selectedIndices.map(i => batch.column(i).asInstanceOf[ColumnVector]) + new ColumnarBatch(cols, batch.numRows()) + } + + // Spark's row collect path expects UnsafeRow, not ColumnarBatchRow wrappers. + val toUnsafe = UnsafeProjection.create(selectedAttributes, selectedAttributes) + projectedBatch.rowIterator().asScala.map(row => toUnsafe(row).copy()) + } + + case other => + throw new IllegalStateException( + s"Expected CometCachedBatch, got ${other.getClass.getName}") + } + } + } else { + throw new IllegalStateException(s"Unsupported cached batch type: ${batchTypes.head}") + } + } + +} diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 8cbf7c9189..eb506ddc6e 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -612,7 +612,8 @@ abstract class CometNativeExec extends CometExec { _: AQEShuffleReadExec | _: CometShuffleExchangeExec | _: CometUnionExec | _: CometTakeOrderedAndProjectExec | _: CometCoalesceExec | _: ReusedExchangeExec | _: CometBroadcastExchangeExec | _: BroadcastQueryStageExec | - _: CometSparkToColumnarExec | _: CometLocalTableScanExec => + _: CometSparkToColumnarExec | _: CometLocalTableScanExec | + _: CometInMemoryTableScanExec => func(plan) case _: CometPlan => // Other Comet operators, continue to traverse the tree. diff --git a/spark/src/test/scala/org/apache/comet/exec/CometInMemoryCacheSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometInMemoryCacheSuite.scala new file mode 100644 index 0000000000..21c5334cd0 --- /dev/null +++ b/spark/src/test/scala/org/apache/comet/exec/CometInMemoryCacheSuite.scala @@ -0,0 +1,365 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.exec + +import java.{util => ju} + +import org.apache.spark.CometDriverPlugin +import org.apache.spark.SparkConf +import org.apache.spark.sql.CometTestBase +import org.apache.spark.sql.catalyst.expressions.{And, Expression, GreaterThanOrEqual, LessThan, Literal} +import org.apache.spark.sql.columnar.SimpleMetricsCachedBatch +import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} + +import org.apache.comet.CometConf + +class CometInMemoryCacheSuite extends CometTestBase { + override protected def sparkConf: SparkConf = { + val conf = new SparkConf() + conf.set("spark.driver.memory", "1G") + conf.set("spark.executor.memory", "1G") + conf.set("spark.executor.memoryOverhead", "2G") + conf.set("spark.plugins", "org.apache.spark.CometPlugin") + conf.set( + "spark.shuffle.manager", + "org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager") + conf.set("spark.comet.enabled", "true") + conf.set("spark.comet.exec.enabled", "true") + conf.set("spark.comet.exec.onHeap.enabled", "true") + conf.set("spark.comet.metrics.enabled", "true") + conf.set( + "spark.sql.cache.serializer", + "org.apache.spark.sql.comet.execution.arrow.ArrowCachedBatchSerializer") + conf + } + + private def cachedBatchTypes(table: String): Array[String] = { + val ds = spark.table(table).asInstanceOf[org.apache.spark.sql.classic.Dataset[_]] + val cached = spark.sharedState.cacheManager.lookupCachedData(ds).get + cached.cachedRepresentation.cacheBuilder.cachedColumnBuffers + .map(_.getClass.getName) + .distinct() + .collect() + } + + test("CometInMemoryTableScan over CometCachedBatch") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false", + CometConf.COMET_SHUFFLE_MODE.key -> "jvm", + SQLConf.CACHE_VECTORIZED_READER_ENABLED.key -> "true", + CometConf.COMET_EXEC_IN_MEMORY_CACHE_ENABLED.key -> "true", + "spark.comet.sparkToColumnar.enabled" -> "true") { + + spark.catalog.clearCache() + + spark + .range(1000) + .selectExpr("id as key", "id % 8 as value") + .createOrReplaceTempView("abc") + + spark.catalog.cacheTable("abc") + spark.table("abc").count() + + assert( + cachedBatchTypes("abc").sameElements( + Array("org.apache.spark.sql.comet.execution.arrow.CometCachedBatch"))) + + val df = spark.sql("SELECT key, count(*) FROM abc GROUP BY key") + checkSparkAnswer(df) + + val plan = df.queryExecution.executedPlan.toString() + assert(plan.contains("CometInMemoryTableScan")) + assert(!plan.contains("CometSparkColumnarToColumnar")) + + spark.catalog.clearCache() + } + } + + test("Comet in-memory cache disabled keeps SparkToColumnar fallback path") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false", + CometConf.COMET_SHUFFLE_MODE.key -> "jvm", + SQLConf.CACHE_VECTORIZED_READER_ENABLED.key -> "true", + CometConf.COMET_EXEC_IN_MEMORY_CACHE_ENABLED.key -> "true", + "spark.comet.sparkToColumnar.enabled" -> "true") { + + spark.catalog.clearCache() + + spark + .range(1000) + .selectExpr("id as key", "id % 8 as value") + .createOrReplaceTempView("comet_cache_disabled") + + spark.catalog.cacheTable("comet_cache_disabled") + spark.table("comet_cache_disabled").count() + + assert( + cachedBatchTypes("comet_cache_disabled").sameElements( + Array("org.apache.spark.sql.comet.execution.arrow.CometCachedBatch"))) + } + + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false", + CometConf.COMET_SHUFFLE_MODE.key -> "jvm", + SQLConf.CACHE_VECTORIZED_READER_ENABLED.key -> "true", + CometConf.COMET_EXEC_IN_MEMORY_CACHE_ENABLED.key -> "false", + "spark.comet.sparkToColumnar.enabled" -> "true") { + + val df = spark.sql("SELECT key, count(*) FROM comet_cache_disabled GROUP BY key") + checkSparkAnswer(df) + + val plan = df.queryExecution.executedPlan.toString() + assert(!plan.contains("CometInMemoryTableScan")) + assert(plan.contains("CometSparkColumnarToColumnar")) + + spark.catalog.clearCache() + } + } + + test("Comet in-memory cache handles multi-partition cache") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false", + CometConf.COMET_SHUFFLE_MODE.key -> "jvm", + SQLConf.CACHE_VECTORIZED_READER_ENABLED.key -> "true", + CometConf.COMET_EXEC_IN_MEMORY_CACHE_ENABLED.key -> "true", + "spark.comet.sparkToColumnar.enabled" -> "true") { + + spark.catalog.clearCache() + + val multiPartition = + spark.range(0, 1000, 1, 5).toDF("id").cache() + multiPartition.createOrReplaceTempView("multi_partition_cache") + multiPartition.count() + + assert( + cachedBatchTypes("multi_partition_cache").sameElements( + Array("org.apache.spark.sql.comet.execution.arrow.CometCachedBatch"))) + + val grouped = spark.sql(""" + SELECT id % 100, count(*) + FROM multi_partition_cache + GROUP BY id % 100 + """) + checkSparkAnswer(grouped) + + val groupedPlan = grouped.queryExecution.executedPlan.toString() + assert(groupedPlan.contains("CometInMemoryTableScan")) + + multiPartition.unpersist() + spark.catalog.clearCache() + } + } + + test("Comet in-memory cache handles empty cache") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false", + CometConf.COMET_SHUFFLE_MODE.key -> "jvm", + SQLConf.CACHE_VECTORIZED_READER_ENABLED.key -> "true", + CometConf.COMET_EXEC_IN_MEMORY_CACHE_ENABLED.key -> "true", + "spark.comet.sparkToColumnar.enabled" -> "true") { + + spark.catalog.clearCache() + + val empty = spark.range(0).toDF("id").cache() + empty.createOrReplaceTempView("empty_cache") + empty.count() + + val emptyDf = spark.sql("SELECT * FROM empty_cache") + checkSparkAnswer(emptyDf) + + val emptyPlan = emptyDf.queryExecution.executedPlan.toString() + assert(!emptyPlan.contains("CometInMemoryTableScan")) + + empty.unpersist() + spark.catalog.clearCache() + } + } + + test("Comet in-memory cache supports projection-only read") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false", + CometConf.COMET_SHUFFLE_MODE.key -> "jvm", + SQLConf.CACHE_VECTORIZED_READER_ENABLED.key -> "true", + CometConf.COMET_EXEC_IN_MEMORY_CACHE_ENABLED.key -> "true", + "spark.comet.sparkToColumnar.enabled" -> "true") { + + spark.catalog.clearCache() + + spark + .range(1000) + .selectExpr("id as key", "id % 8 as value", "id + 1 as key_plus_1") + .createOrReplaceTempView("project_cache") + + spark.catalog.cacheTable("project_cache") + spark.table("project_cache").count() + + assert( + cachedBatchTypes("project_cache").sameElements( + Array("org.apache.spark.sql.comet.execution.arrow.CometCachedBatch"))) + + val df = spark.sql("SELECT key FROM project_cache") + checkSparkAnswer(df) + + val plan = df.queryExecution.executedPlan.toString() + assert(plan.contains("CometInMemoryTableScan")) + assert(plan.contains("CometNativeColumnarToRow")) + + spark.catalog.clearCache() + } + } + + test("Comet in-memory cache supports shuffle after cache read") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false", + CometConf.COMET_SHUFFLE_MODE.key -> "jvm", + SQLConf.CACHE_VECTORIZED_READER_ENABLED.key -> "true", + CometConf.COMET_EXEC_IN_MEMORY_CACHE_ENABLED.key -> "true", + "spark.comet.sparkToColumnar.enabled" -> "true") { + + spark.catalog.clearCache() + + spark + .range(1000) + .selectExpr("id as key", "id % 100 as group") + .createOrReplaceTempView("shuffle_cache") + + spark.catalog.cacheTable("shuffle_cache") + spark.table("shuffle_cache").count() + + assert( + cachedBatchTypes("shuffle_cache").sameElements( + Array("org.apache.spark.sql.comet.execution.arrow.CometCachedBatch"))) + + val df = spark.sql("SELECT group, count(*) FROM shuffle_cache GROUP BY group") + checkSparkAnswer(df) + + val plan = df.queryExecution.executedPlan.toString() + assert(plan.contains("CometInMemoryTableScan")) + assert(plan.contains("CometHashAggregate")) + + spark.catalog.clearCache() + } + } + + test("Comet in-memory cache supports stats-based batch pruning") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false", + CometConf.COMET_SHUFFLE_MODE.key -> "jvm", + SQLConf.CACHE_VECTORIZED_READER_ENABLED.key -> "true", + CometConf.COMET_EXEC_IN_MEMORY_CACHE_ENABLED.key -> "true", + "spark.comet.sparkToColumnar.enabled" -> "true", + "spark.sql.inMemoryColumnarStorage.batchSize" -> "100") { + + spark.catalog.clearCache() + + spark + .range(0, 1000, 1, 10) + .selectExpr("id as key", "id % 7 as value") + .createOrReplaceTempView("prune_cache") + + spark.catalog.cacheTable("prune_cache") + spark.table("prune_cache").count() + + assert( + cachedBatchTypes("prune_cache").sameElements( + Array("org.apache.spark.sql.comet.execution.arrow.CometCachedBatch"))) + + val ds = spark.table("prune_cache").asInstanceOf[org.apache.spark.sql.classic.Dataset[_]] + val cached = spark.sharedState.cacheManager.lookupCachedData(ds).get + val relation = cached.cachedRepresentation + val cachedBuffers = relation.cacheBuilder.cachedColumnBuffers + + // Spark's cache pruning reads statistics through SimpleMetricsCachedBatch. + // CometCachedBatch must expose the same five statistics per column: + // lower bound, upper bound, null count, row count, and size in bytes. + val firstBatch = cachedBuffers.take(1).head + assert(firstBatch.isInstanceOf[SimpleMetricsCachedBatch]) + assert( + firstBatch.asInstanceOf[SimpleMetricsCachedBatch].stats.numFields == + relation.output.length * 5) + + val keyAttr = relation.output.find(_.name == "key").get + + // Call the serializer filter directly so the test fails if buildFilter is + // accidentally changed back to a no-op. + def prunedCount(predicate: Expression): Long = { + val filter = relation.cacheBuilder.serializer.buildFilter(Seq(predicate), relation.output) + cachedBuffers.mapPartitionsWithIndex(filter).count() + } + + val totalBatches = cachedBuffers.count() + assert(totalBatches > 1) + + val targetPredicate = + And(GreaterThanOrEqual(keyAttr, Literal(900L)), LessThan(keyAttr, Literal(905L))) + assert(prunedCount(targetPredicate) == 1) + + val outsidePredicate = LessThan(keyAttr, Literal(0L)) + assert(prunedCount(outsidePredicate) == 0) + + val allPredicate = + And(GreaterThanOrEqual(keyAttr, Literal(0L)), LessThan(keyAttr, Literal(1000L))) + assert(prunedCount(allPredicate) == totalBatches) + + val df = spark.sql(""" + SELECT key, value + FROM prune_cache + WHERE key >= 900 AND key < 905 + """) + checkSparkAnswer(df) + + val plan = df.queryExecution.executedPlan.toString() + assert(plan.contains("CometInMemoryTableScan")) + assert(!plan.contains("CometSparkColumnarToColumnar")) + + spark.catalog.clearCache() + } + } + + test("Comet plugin respects user-provided cache serializer") { + val serializerKey = StaticSQLConf.SPARK_CACHE_SERIALIZER.key + val cometSerializer = + "org.apache.spark.sql.comet.execution.arrow.ArrowCachedBatchSerializer" + val userSerializer = "com.example.CustomCachedBatchSerializer" + + val defaultConf = new SparkConf() + .set(CometConf.COMET_EXEC_IN_MEMORY_CACHE_ENABLED.key, "true") + val defaultExtraConfs = new ju.HashMap[String, String]() + + // With no user serializer configured, the plugin should install Comet's + // serializer and also return it through extraConfs for executors. + CometDriverPlugin.maybeSetCacheSerializer(defaultConf, defaultExtraConfs) + + assert(defaultConf.get(serializerKey) == cometSerializer) + assert(defaultExtraConfs.get(serializerKey) == cometSerializer) + + val userConf = new SparkConf() + .set(CometConf.COMET_EXEC_IN_MEMORY_CACHE_ENABLED.key, "true") + .set(serializerKey, userSerializer) + val userExtraConfs = new ju.HashMap[String, String]() + + // If the user already configured a cache serializer, keep it and do not + // send a replacement serializer through extraConfs. + CometDriverPlugin.maybeSetCacheSerializer(userConf, userExtraConfs) + + assert(userConf.get(serializerKey) == userSerializer) + assert(!userExtraConfs.containsKey(serializerKey)) + } +}