diff --git a/docs/spark_expressions_support.md b/docs/spark_expressions_support.md index 5474894108..8b4d80a2bf 100644 --- a/docs/spark_expressions_support.md +++ b/docs/spark_expressions_support.md @@ -400,7 +400,7 @@ ### string_funcs - [x] ascii -- [ ] base64 +- [x] base64 - [x] bit_length - [x] btrim - [x] char diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 9d13ccd9ed..b7514c05b0 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -148,6 +148,7 @@ object QueryPlanSerde extends Logging with CometExprShim { private val stringExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map( classOf[Ascii] -> CometScalarFunction("ascii"), + classOf[Base64] -> CometBase64, classOf[BitLength] -> CometScalarFunction("bit_length"), classOf[Chr] -> CometScalarFunction("char"), classOf[ConcatWs] -> CometConcatWs, diff --git a/spark/src/main/scala/org/apache/comet/serde/statics.scala b/spark/src/main/scala/org/apache/comet/serde/statics.scala index 0737644ab9..372f209eff 100644 --- a/spark/src/main/scala/org/apache/comet/serde/statics.scala +++ b/spark/src/main/scala/org/apache/comet/serde/statics.scala @@ -19,11 +19,13 @@ package org.apache.comet.serde -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{Attribute, Base64, Literal} import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils +import org.apache.spark.sql.types.BooleanType import org.apache.comet.CometSparkSessionExtensions.withInfo +import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto} object CometStaticInvoke extends CometExpressionSerde[StaticInvoke] { @@ -34,7 +36,8 @@ object CometStaticInvoke extends CometExpressionSerde[StaticInvoke] { : Map[(String, Class[_]), CometExpressionSerde[StaticInvoke]] = Map( ("readSidePadding", classOf[CharVarcharCodegenUtils]) -> CometScalarFunction( - "read_side_padding")) + "read_side_padding"), + ("encode", classOf[Base64]) -> CometBase64Encode) override def convert( expr: StaticInvoke, @@ -52,3 +55,27 @@ object CometStaticInvoke extends CometExpressionSerde[StaticInvoke] { } } } + +/** + * Handler for Base64.encode StaticInvoke (Spark 3.5+, where Base64 is RuntimeReplaceable). Maps + * to DataFusion's built-in encode(input, 'base64') function. + */ +private object CometBase64Encode extends CometExpressionSerde[StaticInvoke] { + + override def convert( + expr: StaticInvoke, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + // Check if chunked mode is requested (2nd argument, Spark 3.5+) + expr.arguments match { + case Seq(_, Literal(true, BooleanType)) => + withInfo(expr, "base64 with chunk encoding is not supported") + return None + case _ => // OK: either no chunkBase64 param (Spark 3.4) or chunkBase64=false + } + val inputExpr = exprToProtoInternal(expr.arguments.head, inputs, binding) + val encodingExpr = exprToProtoInternal(Literal("base64"), inputs, binding) + val optExpr = scalarFunctionExprToProto("encode", inputExpr, encodingExpr) + optExprWithInfo(optExpr, expr, expr.arguments.head) + } +} diff --git a/spark/src/main/scala/org/apache/comet/serde/strings.scala b/spark/src/main/scala/org/apache/comet/serde/strings.scala index 64ba644048..a10c65519b 100644 --- a/spark/src/main/scala/org/apache/comet/serde/strings.scala +++ b/spark/src/main/scala/org/apache/comet/serde/strings.scala @@ -21,7 +21,7 @@ package org.apache.comet.serde import java.util.Locale -import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Concat, ConcatWs, Expression, If, InitCap, IsNull, Left, Length, Like, Literal, Lower, RegExpReplace, Right, RLike, StringLPad, StringRepeat, StringRPad, StringSplit, Substring, Upper} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Base64, Cast, Concat, ConcatWs, Expression, If, InitCap, IsNull, Left, Length, Like, Literal, Lower, RegExpReplace, Right, RLike, StringLPad, StringRepeat, StringRPad, StringSplit, Substring, Upper} import org.apache.spark.sql.types.{BinaryType, DataTypes, LongType, StringType} import org.apache.spark.unsafe.types.UTF8String @@ -31,6 +31,23 @@ import org.apache.comet.expressions.{CometCast, CometEvalMode, RegExp} import org.apache.comet.serde.ExprOuterClass.Expr import org.apache.comet.serde.QueryPlanSerde.{createBinaryExpr, exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto, scalarFunctionExprToProtoWithReturnType} +/** + * Handler for Base64 as a direct expression (Spark 3.4 where it is not RuntimeReplaceable). In + * Spark 3.5+, Base64 is RuntimeReplaceable and handled via CometBase64Encode in statics.scala. + */ +object CometBase64 extends CometExpressionSerde[Base64] { + + override def convert( + expr: Base64, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val inputExpr = exprToProtoInternal(expr.child, inputs, binding) + val encodingExpr = exprToProtoInternal(Literal("base64"), inputs, binding) + val optExpr = scalarFunctionExprToProto("encode", inputExpr, encodingExpr) + optExprWithInfo(optExpr, expr, expr.child) + } +} + object CometStringRepeat extends CometExpressionSerde[StringRepeat] { override def convert( diff --git a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala index 121d7f7d5a..c591739527 100644 --- a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala @@ -148,6 +148,31 @@ class CometStringExpressionSuite extends CometTestBase { } } + test("base64") { + withSQLConf("spark.sql.chunkBase64String.enabled" -> "false") { + val data = Seq( + Array[Byte](72, 101, 108, 108, 111), // "Hello" + Array[Byte](83, 112, 97, 114, 107, 32, 83, 81, 76), // "Spark SQL" + Array[Byte](), // empty + null).map(Tuple1(_)) + withParquetTable(data, "tbl") { + checkSparkAnswerAndOperator("SELECT base64(_1) FROM tbl") + checkSparkAnswerAndOperator("SELECT base64(NULL) FROM tbl") + } + } + } + + test("base64 with chunk encoding falls back") { + withSQLConf("spark.sql.chunkBase64String.enabled" -> "true") { + val data = Seq(Array[Byte](72, 101, 108, 108, 111)).map(Tuple1(_)) + withParquetTable(data, "tbl") { + checkSparkAnswerAndFallbackReason( + "SELECT base64(_1) FROM tbl", + "base64 with chunk encoding is not supported") + } + } + } + test("split string basic") { withSQLConf("spark.comet.expression.StringSplit.allowIncompatible" -> "true") { withParquetTable((0 until 5).map(i => (s"value$i,test$i", i)), "tbl") {