diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index 0193f3012c..22f3791a1f 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -55,6 +55,7 @@ use datafusion_spark::function::math::hex::SparkHex; use datafusion_spark::function::math::width_bucket::SparkWidthBucket; use datafusion_spark::function::string::char::CharFunc; use datafusion_spark::function::string::concat::SparkConcat; +use datafusion_spark::function::url::url_decode::UrlDecode; use futures::poll; use futures::stream::StreamExt; use jni::objects::JByteBuffer; @@ -400,6 +401,7 @@ fn register_datafusion_spark_function(session_ctx: &SessionContext) { session_ctx.register_udf(ScalarUDF::new_from_impl(SparkWidthBucket::default())); session_ctx.register_udf(ScalarUDF::new_from_impl(MapFromEntries::default())); session_ctx.register_udf(ScalarUDF::new_from_impl(SparkCrc32::default())); + session_ctx.register_udf(ScalarUDF::new_from_impl(UrlDecode::default())); } /// Prepares arrow arrays for output. diff --git a/spark/src/main/scala/org/apache/comet/serde/statics.scala b/spark/src/main/scala/org/apache/comet/serde/statics.scala index 0737644ab9..df4aaca7bc 100644 --- a/spark/src/main/scala/org/apache/comet/serde/statics.scala +++ b/spark/src/main/scala/org/apache/comet/serde/statics.scala @@ -19,11 +19,13 @@ package org.apache.comet.serde -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{Attribute, UrlCodec} import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils import org.apache.comet.CometSparkSessionExtensions.withInfo +import org.apache.comet.serde.ExprOuterClass.Expr +import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProtoWithReturnType} object CometStaticInvoke extends CometExpressionSerde[StaticInvoke] { @@ -34,7 +36,8 @@ object CometStaticInvoke extends CometExpressionSerde[StaticInvoke] { : Map[(String, Class[_]), CometExpressionSerde[StaticInvoke]] = Map( ("readSidePadding", classOf[CharVarcharCodegenUtils]) -> CometScalarFunction( - "read_side_padding")) + "read_side_padding"), + ("decode", UrlCodec.getClass) -> CometUrlDecode) override def convert( expr: StaticInvoke, @@ -52,3 +55,21 @@ object CometStaticInvoke extends CometExpressionSerde[StaticInvoke] { } } } + +/** + * Handler for UrlCodec.decode StaticInvoke (Spark 3.4+). Maps to datafusion-spark's built-in + * url_decode function. + */ +private object CometUrlDecode extends CometExpressionSerde[StaticInvoke] { + + override def convert( + expr: StaticInvoke, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + // StaticInvoke args: [child, Literal("UTF-8")] — only serialize the first + val childExpr = exprToProtoInternal(expr.arguments.head, inputs, binding) + val optExpr = + scalarFunctionExprToProtoWithReturnType("url_decode", expr.dataType, false, childExpr) + optExprWithInfo(optExpr, expr, expr.arguments.head) + } +} diff --git a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala index 121d7f7d5a..d6d408502a 100644 --- a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala @@ -478,4 +478,35 @@ class CometStringExpressionSuite extends CometTestBase { } } + test("url_decode") { + val data = Seq( + "https%3A%2F%2Fspark.apache.org", // percent-encoded URL + "hello+world", // plus as space + "%E4%B8%AD%E6%96%87", // multi-byte UTF-8 (Chinese) + "no+encoding+needed", // spaces only + "%e4%b8%ad", // lowercase hex digits + "abc%20def%21%40%23", // mixed encoded/unencoded + "%F0%9F%94%A5", // 4-byte UTF-8 (emoji) + "already+decoded+%2B+literal+plus", // encoded plus sign (%2B) + "").map(Tuple1(_)) + withParquetTable(data, "tbl") { + checkSparkAnswerAndOperator("SELECT url_decode(_1) FROM tbl") + } + } + + test("url_decode - null handling") { + withParquetTable( + Seq(Some("hello+world"), None, Some("%E4%B8%AD")).map(v => Tuple1(v.orNull)), + "tbl") { + checkSparkAnswerAndOperator("SELECT url_decode(_1) FROM tbl") + } + } + + test("url_decode - literals") { + withParquetTable(Seq(Tuple1(1)), "tbl") { + checkSparkAnswerAndOperator("SELECT url_decode('hello%20world') FROM tbl") + checkSparkAnswerAndOperator("SELECT url_decode('%E4%B8%AD%E6%96%87') FROM tbl") + } + } + }