From c4b7980d1e6f41c1e57f557234142446a452cc96 Mon Sep 17 00:00:00 2001 From: Eric Yang Date: Tue, 16 Jun 2026 23:28:55 -0700 Subject: [PATCH] [SPARK-57500][SQL] Escape backslash in MySQL JDBC pushdown for string literals in comparison/IN predicates --- .../apache/spark/sql/jdbc/JdbcDialects.scala | 24 ++++++++++++++++++- .../apache/spark/sql/jdbc/MySQLDialect.scala | 11 +++------ .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 20 ++++++++++++++++ 3 files changed, 46 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index b0b10f1d09f27..c130191b9dc36 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.connector.catalog.{Identifier, TableChange} import org.apache.spark.sql.connector.catalog.TableChange._ import org.apache.spark.sql.connector.catalog.functions.UnboundFunction import org.apache.spark.sql.connector.catalog.index.TableIndex -import org.apache.spark.sql.connector.expressions.{Expression, Literal, NamedReference} +import org.apache.spark.sql.connector.expressions.{Expression, GeneralScalarExpression, Literal, NamedReference} import org.apache.spark.sql.connector.expressions.aggregate.AggregateFunc import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue, Predicate} import org.apache.spark.sql.connector.util.V2ExpressionSQLBuilder @@ -408,9 +408,31 @@ abstract class JdbcDialect extends Serializable with Logging { override def build(expr: Expression): String = expr match { case _: AlwaysTrue => "(1 = 1)" case _: AlwaysFalse => "(1 = 0)" + case e: GeneralScalarExpression if isLikeStringPredicate(e.name()) => + buildLikeStringPredicate(e).getOrElse(super.build(expr)) case _ => super.build(expr) } + private def isLikeStringPredicate(name: String): Boolean = + name == "STARTS_WITH" || name == "ENDS_WITH" || name == "CONTAINS" + + private def buildLikeStringPredicate(e: GeneralScalarExpression): Option[String] = { + val children = e.children() + if (children.length != 2) return None + children(1) match { + case lit: Literal[_] if lit.value() != null && lit.dataType().isInstanceOf[StringType] => + val escaped = escapeSpecialCharsForLikePattern(lit.value().toString) + val pattern = e.name() match { + case "STARTS_WITH" => escaped + "%" + case "ENDS_WITH" => "%" + escaped + case _ => "%" + escaped + "%" + } + val col = build(children(0)) + Some(col + " LIKE '" + escapeSql(pattern) + "' ESCAPE '" + escapeSql("\\") + "'") + case _ => None + } + } + // Some dialects do not support boolean type and this convenient util function is // provided to generate SQL string without boolean values. protected def inputToSQLNoBool(input: Expression): String = input match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala index b301c0c0bd5bc..6698cbbeb6dce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala @@ -99,14 +99,6 @@ private case class MySQLDialect() extends JdbcDialect with SQLConfHelper with No } } - // MySQL treats backslash as an escape character inside string literals, so every backslash in - // a LIKE pattern (and the ESCAPE character) must be doubled to survive string-literal parsing - // before the LIKE engine applies its own escaping. The base STARTS_WITH/ENDS_WITH/CONTAINS - // pattern building is otherwise shared, so only this hook is overridden. - override def escapeStringLiteralForLikePattern(str: String): String = { - str.replace("\\", "\\\\") - } - override def visitAggregateFunction( funcName: String, isDistinct: Boolean, inputs: Array[String]): String = if (isDistinct && distinctUnsupportedAggregateFunctions.contains(funcName)) { @@ -120,6 +112,9 @@ private case class MySQLDialect() extends JdbcDialect with SQLConfHelper with No } } + override protected[jdbc] def escapeSql(value: String): String = + if (value == null) null else super.escapeSql(value).replace("\\", "\\\\") + override def compileExpression(expr: Expression): Option[String] = { val mysqlSQLBuilder = new MySQLSQLBuilder() try { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 3a83a2549f924..86b1f6a422e30 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -1000,6 +1000,26 @@ class JDBCSuite extends SharedSparkSession { assert(mySQLSQL(StringStartsWith("c", "a%b_")) === """`c` LIKE 'a\\%b\\_%' ESCAPE '\\'""") } + test("SPARK-57500: escape backslash in pushed-down string literals for comparison/IN") { + val defaultDialect = JdbcDialects.get("jdbc:") + val mySQLDialect = JdbcDialects.get("jdbc:mysql://127.0.0.1/db") + def d(f: Filter): String = defaultDialect.compileExpression(f.toV2).getOrElse("") + def m(f: Filter): String = mySQLDialect.compileExpression(f.toV2).getOrElse("") + + assert(d(EqualTo("c", "a\\b")) === """"c" = 'a\b'""") + assert(m(EqualTo("c", "a\\b")) === """`c` = 'a\\b'""") + assert(d(LessThan("c", "a\\b")) === """"c" < 'a\b'""") + assert(m(LessThan("c", "a\\b")) === """`c` < 'a\\b'""") + assert(d(In("c", Array[Any]("a\\b", "x\\y"))) === """"c" IN ('a\b', 'x\y')""") + assert(m(In("c", Array[Any]("a\\b", "x\\y"))) === """`c` IN ('a\\b', 'x\\y')""") + assert(m(EqualTo("c", "a'\\b")) === """`c` = 'a''\\b'""") + + assert(d(StringStartsWith("c", "a\\b")) === """"c" LIKE 'a\\b%' ESCAPE '\'""") + assert(m(StringStartsWith("c", "a\\b")) === """`c` LIKE 'a\\\\b%' ESCAPE '\\'""") + assert(d(StringContains("c", "a\\b")) === """"c" LIKE '%a\\b%' ESCAPE '\'""") + assert(m(StringContains("c", "a\\b")) === """`c` LIKE '%a\\\\b%' ESCAPE '\\'""") + } + test("SPARK-57446: escape single quotes in JDBC comment queries") { val defaultDialect = JdbcDialects.get("jdbc:") assert(defaultDialect.getTableCommentQuery("t", "a'b") ===