diff --git a/common/src/main/java/org/apache/comet/exceptions/CometQueryExecutionException.java b/common/src/main/java/org/apache/comet/exceptions/CometQueryExecutionException.java
new file mode 100644
index 0000000000..5ff19ea398
--- /dev/null
+++ b/common/src/main/java/org/apache/comet/exceptions/CometQueryExecutionException.java
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.exceptions;
+
+import org.apache.comet.CometNativeException;
+
+/**
+ * Exception thrown from Comet native execution containing JSON-encoded error information. The
+ * message contains a JSON object with the following structure:
+ *
+ *
+ * {
+ * "errorType": "DivideByZero",
+ * "errorClass": "DIVIDE_BY_ZERO",
+ * "params": { ... },
+ * "context": { "sqlText": "...", "startOffset": 0, "stopOffset": 10 },
+ * "hint": "Use `try_divide` to tolerate divisor being 0"
+ * }
+ *
+ *
+ * CometExecIterator parses this JSON and converts it to the appropriate Spark exception by calling
+ * the corresponding QueryExecutionErrors.* method.
+ */
+public final class CometQueryExecutionException extends CometNativeException {
+
+ /**
+ * Creates a new CometQueryExecutionException with a JSON-encoded error message.
+ *
+ * @param jsonMessage JSON string containing error information
+ */
+ public CometQueryExecutionException(String jsonMessage) {
+ super(jsonMessage);
+ }
+
+ /**
+ * Returns true if the message appears to be JSON-formatted. This is used to distinguish between
+ * JSON-encoded errors and legacy error messages.
+ *
+ * @return true if message starts with '{' and ends with '}'
+ */
+ public boolean isJsonMessage() {
+ String msg = getMessage();
+ return msg != null && msg.trim().startsWith("{") && msg.trim().endsWith("}");
+ }
+}
diff --git a/native/Cargo.lock b/native/Cargo.lock
index 3cd6ea29b7..450089b04f 100644
--- a/native/Cargo.lock
+++ b/native/Cargo.lock
@@ -1948,6 +1948,7 @@ dependencies = [
"num",
"rand 0.10.0",
"regex",
+ "serde",
"serde_json",
"thiserror 2.0.18",
"tokio",
diff --git a/native/core/src/errors.rs b/native/core/src/errors.rs
index ecac7af94e..1329614cd7 100644
--- a/native/core/src/errors.rs
+++ b/native/core/src/errors.rs
@@ -39,7 +39,7 @@ use jni::sys::{jboolean, jbyte, jchar, jdouble, jfloat, jint, jlong, jobject, js
use crate::execution::operators::ExecutionError;
use datafusion_comet_spark_expr::SparkError;
-use jni::objects::{GlobalRef, JThrowable, JValue};
+use jni::objects::{GlobalRef, JThrowable};
use jni::JNIEnv;
use lazy_static::lazy_static;
use parquet::errors::ParquetError;
@@ -223,9 +223,9 @@ impl jni::errors::ToException for CometError {
class: "java/lang/NullPointerException".to_string(),
msg: self.to_string(),
},
- CometError::Spark { .. } => Exception {
- class: "org/apache/spark/SparkException".to_string(),
- msg: self.to_string(),
+ CometError::Spark(spark_err) => Exception {
+ class: spark_err.exception_class().to_string(),
+ msg: spark_err.to_string(),
},
CometError::NumberIntFormat { source: s } => Exception {
class: "java/lang/NumberFormatException".to_string(),
@@ -392,33 +392,37 @@ fn throw_exception(env: &mut JNIEnv, error: &CometError, backtrace: Option env.throw(<&JThrowable>::from(throwable.as_obj())),
+ // Handle DataFusion errors containing SparkError - serialize to JSON
CometError::DataFusion {
msg: _,
source: DataFusionError::External(e),
- } if matches!(e.downcast_ref(), Some(SparkError::CastOverFlow { .. })) => {
- match e.downcast_ref() {
- Some(SparkError::CastOverFlow {
- value,
- from_type,
- to_type,
- }) => {
- let throwable: JThrowable = env
- .new_object(
- "org/apache/spark/sql/comet/CastOverflowException",
- "(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;)V",
- &[
- JValue::Object(&env.new_string(value).unwrap()),
- JValue::Object(&env.new_string(from_type).unwrap()),
- JValue::Object(&env.new_string(to_type).unwrap()),
- ],
- )
- .unwrap()
- .into();
- env.throw(throwable)
+ } => {
+ // Try SparkErrorWithContext first (includes context)
+ if let Some(spark_error_with_ctx) =
+ e.downcast_ref::()
+ {
+ let json_message = spark_error_with_ctx.to_json();
+ env.throw_new(
+ "org/apache/comet/exceptions/CometQueryExecutionException",
+ json_message,
+ )
+ } else if let Some(spark_error) = e.downcast_ref::() {
+ // Fall back to plain SparkError (no context)
+ throw_spark_error_as_json(env, spark_error)
+ } else {
+ // Not a SparkError, use generic exception
+ let exception = error.to_exception();
+ match backtrace {
+ Some(backtrace_string) => env.throw_new(
+ exception.class,
+ to_stacktrace_string(exception.msg, backtrace_string).unwrap(),
+ ),
+ _ => env.throw_new(exception.class, exception.msg),
}
- _ => unreachable!(),
}
}
+ // Handle direct SparkError - serialize to JSON
+ CometError::Spark(spark_error) => throw_spark_error_as_json(env, spark_error),
_ => {
let exception = error.to_exception();
match backtrace {
@@ -434,6 +438,21 @@ fn throw_exception(env: &mut JNIEnv, error: &CometError, backtrace: Option jni::errors::Result<()> {
+ // Serialize error to JSON
+ let json_message = spark_error.to_json();
+
+ // Throw CometQueryExecutionException with JSON message
+ env.throw_new(
+ "org/apache/comet/exceptions/CometQueryExecutionException",
+ json_message,
+ )
+}
+
#[derive(Debug, Error)]
enum StacktraceError {
#[error("Unable to initialize message: {0}")]
diff --git a/native/core/src/execution/expressions/arithmetic.rs b/native/core/src/execution/expressions/arithmetic.rs
index a9749678db..320532d773 100644
--- a/native/core/src/execution/expressions/arithmetic.rs
+++ b/native/core/src/execution/expressions/arithmetic.rs
@@ -17,6 +17,122 @@
//! Arithmetic expression builders
+use std::any::Any;
+use std::fmt::{Display, Formatter};
+use std::hash::{Hash, Hasher};
+
+use arrow::datatypes::{DataType, Schema};
+use arrow::record_batch::RecordBatch;
+use datafusion::common::DataFusionError;
+use datafusion::logical_expr::ColumnarValue;
+use datafusion::physical_expr::PhysicalExpr;
+use datafusion_comet_spark_expr::{QueryContext, SparkError, SparkErrorWithContext};
+
+/// Wrapper expression that catches and wraps SparkError with QueryContext
+/// for binary arithmetic operations.
+#[derive(Debug)]
+pub struct CheckedBinaryExpr {
+ /// The underlying physical expression (typically a ScalarFunctionExpr)
+ child: Arc,
+ /// Optional query context to attach to errors
+ query_context: Option>,
+}
+
+impl CheckedBinaryExpr {
+ pub fn new(child: Arc, query_context: Option>) -> Self {
+ Self {
+ child,
+ query_context,
+ }
+ }
+}
+
+impl Display for CheckedBinaryExpr {
+ fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+ write!(f, "CheckedBinaryExpr({})", self.child)
+ }
+}
+
+impl PartialEq for CheckedBinaryExpr {
+ fn eq(&self, other: &Self) -> bool {
+ self.child.eq(&other.child)
+ }
+}
+
+impl Eq for CheckedBinaryExpr {}
+
+impl PartialEq for CheckedBinaryExpr {
+ fn eq(&self, other: &dyn Any) -> bool {
+ other
+ .downcast_ref::()
+ .map(|x| self.eq(x))
+ .unwrap_or(false)
+ }
+}
+
+impl Hash for CheckedBinaryExpr {
+ fn hash(&self, state: &mut H) {
+ self.child.hash(state);
+ }
+}
+
+impl PhysicalExpr for CheckedBinaryExpr {
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+ self.child.fmt_sql(f)
+ }
+
+ fn data_type(&self, input_schema: &Schema) -> datafusion::common::Result {
+ self.child.data_type(input_schema)
+ }
+
+ fn nullable(&self, input_schema: &Schema) -> datafusion::common::Result {
+ self.child.nullable(input_schema)
+ }
+
+ fn evaluate(&self, batch: &RecordBatch) -> datafusion::common::Result {
+ let result = self.child.evaluate(batch);
+
+ // If there's an error and we have query_context, wrap it
+ match result {
+ Err(DataFusionError::External(e)) if self.query_context.is_some() => {
+ if let Some(spark_err) = e.downcast_ref::() {
+ let wrapped = SparkErrorWithContext::with_context(
+ spark_err.clone(),
+ Arc::clone(self.query_context.as_ref().unwrap()),
+ );
+ Err(DataFusionError::External(Box::new(wrapped)))
+ } else {
+ Err(DataFusionError::External(e))
+ }
+ }
+ other => other,
+ }
+ }
+
+ fn children(&self) -> Vec<&Arc> {
+ vec![&self.child]
+ }
+
+ fn with_new_children(
+ self: Arc,
+ children: Vec>,
+ ) -> datafusion::common::Result> {
+ match children.len() {
+ 1 => Ok(Arc::new(CheckedBinaryExpr::new(
+ Arc::clone(&children[0]),
+ self.query_context.clone(),
+ ))),
+ _ => Err(DataFusionError::Internal(
+ "CheckedBinaryExpr should have exactly one child".to_string(),
+ )),
+ }
+ }
+}
+
/// Macro to generate arithmetic expression builders that need eval_mode handling
#[macro_export]
macro_rules! arithmetic_expr_builder {
@@ -37,6 +153,7 @@ macro_rules! arithmetic_expr_builder {
let eval_mode =
$crate::execution::planner::from_protobuf_eval_mode(expr.eval_mode)?;
planner.create_binary_expr(
+ spark_expr, // Pass the full spark_expr for query_context lookup
expr.left.as_ref().unwrap(),
expr.right.as_ref().unwrap(),
expr.return_type.as_ref(),
@@ -53,7 +170,6 @@ use std::sync::Arc;
use arrow::datatypes::SchemaRef;
use datafusion::logical_expr::Operator as DataFusionOperator;
-use datafusion::physical_expr::PhysicalExpr;
use datafusion_comet_proto::spark_expression::Expr;
use datafusion_comet_spark_expr::{create_modulo_expr, create_negate_expr, EvalMode};
@@ -95,6 +211,7 @@ impl ExpressionBuilder for IntegralDivideBuilder {
let expr = extract_expr!(spark_expr, IntegralDivide);
let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
planner.create_binary_expr_with_options(
+ spark_expr, // Pass the full spark_expr for query_context lookup
expr.left.as_ref().unwrap(),
expr.right.as_ref().unwrap(),
expr.return_type.as_ref(),
diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs
index fb3d35c512..a9d0d4370c 100644
--- a/native/core/src/execution/planner.rs
+++ b/native/core/src/execution/planner.rs
@@ -163,6 +163,7 @@ pub struct PhysicalPlanner {
exec_context_id: i64,
partition: i32,
session_ctx: Arc,
+ query_context_registry: Arc,
}
impl Default for PhysicalPlanner {
@@ -177,6 +178,7 @@ impl PhysicalPlanner {
exec_context_id: TEST_EXEC_CONTEXT_ID,
session_ctx,
partition,
+ query_context_registry: datafusion_comet_spark_expr::create_query_context_map(),
}
}
@@ -185,6 +187,7 @@ impl PhysicalPlanner {
exec_context_id,
partition: self.partition,
session_ctx: Arc::clone(&self.session_ctx),
+ query_context_registry: Arc::clone(&self.query_context_registry),
}
}
@@ -251,6 +254,26 @@ impl PhysicalPlanner {
spark_expr: &Expr,
input_schema: SchemaRef,
) -> Result, ExecutionError> {
+ // Register QueryContext if present
+ if let (Some(expr_id), Some(ctx_proto)) =
+ (spark_expr.expr_id, spark_expr.query_context.as_ref())
+ {
+ // Deserialize QueryContext from protobuf
+ let query_ctx = datafusion_comet_spark_expr::QueryContext::new(
+ ctx_proto.sql_text.clone(),
+ ctx_proto.start_index,
+ ctx_proto.stop_index,
+ ctx_proto.object_type.clone(),
+ ctx_proto.object_name.clone(),
+ ctx_proto.line,
+ ctx_proto.start_position,
+ );
+
+ // Register query context for error reporting
+ let registry = &self.query_context_registry;
+ registry.register(expr_id, query_ctx);
+ }
+
// Try to use the modular registry first - this automatically handles any registered expression types
if ExpressionRegistry::global().can_handle(spark_expr) {
return ExpressionRegistry::global().create_expr(spark_expr, input_schema, self);
@@ -369,10 +392,19 @@ impl PhysicalPlanner {
let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?;
let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap());
let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
+
+ // Look up query context from registry if expr_id is present
+ let query_context = spark_expr.expr_id.and_then(|expr_id| {
+ let registry = &self.query_context_registry;
+ registry.get(expr_id)
+ });
+
Ok(Arc::new(Cast::new(
child,
datatype,
SparkCastOptions::new(eval_mode, &expr.timezone, expr.allow_incompat),
+ spark_expr.expr_id,
+ query_context,
)))
}
ExprStruct::CheckOverflow(expr) => {
@@ -380,10 +412,18 @@ impl PhysicalPlanner {
let data_type = to_arrow_datatype(expr.datatype.as_ref().unwrap());
let fail_on_error = expr.fail_on_error;
+ // Look up query context from registry if expr_id is present
+ let query_context = spark_expr.expr_id.and_then(|expr_id| {
+ let registry = &self.query_context_registry;
+ registry.get(expr_id)
+ });
+
Ok(Arc::new(CheckOverflow::new(
child,
data_type,
fail_on_error,
+ spark_expr.expr_id,
+ query_context,
)))
}
ExprStruct::ScalarFunc(expr) => {
@@ -397,6 +437,8 @@ impl PhysicalPlanner {
None,
true,
false,
+ None, // No expr_id for internal map_extract wrapper
+ Arc::clone(&self.query_context_registry),
))),
// DataFusion 49 hardcodes return type for MD5 built in function as UTF8View
// which is not yet supported in Comet
@@ -405,6 +447,8 @@ impl PhysicalPlanner {
func?,
DataType::Utf8,
SparkCastOptions::new_without_timezone(EvalMode::Try, true),
+ None,
+ None,
))),
_ => func,
}
@@ -519,6 +563,8 @@ impl PhysicalPlanner {
Arc::clone(&child),
DataType::Utf8,
spark_cast_options,
+ None,
+ None,
));
Ok(Arc::new(IfExpr::new(
Arc::new(IsNullExpr::new(child)),
@@ -544,6 +590,8 @@ impl PhysicalPlanner {
default_value,
expr.one_based,
expr.fail_on_error,
+ spark_expr.expr_id,
+ Arc::clone(&self.query_context_registry),
)))
}
ExprStruct::GetArrayStructFields(expr) => {
@@ -634,8 +682,10 @@ impl PhysicalPlanner {
}
}
+ #[allow(clippy::too_many_arguments)]
pub fn create_binary_expr(
&self,
+ spark_expr: &Expr,
left: &Expr,
right: &Expr,
return_type: Option<&spark_expression::DataType>,
@@ -644,6 +694,7 @@ impl PhysicalPlanner {
eval_mode: EvalMode,
) -> Result, ExecutionError> {
self.create_binary_expr_with_options(
+ spark_expr,
left,
right,
return_type,
@@ -657,6 +708,7 @@ impl PhysicalPlanner {
#[allow(clippy::too_many_arguments)]
pub fn create_binary_expr_with_options(
&self,
+ spark_expr: &Expr,
left: &Expr,
right: &Expr,
return_type: Option<&spark_expression::DataType>,
@@ -665,6 +717,12 @@ impl PhysicalPlanner {
options: BinaryExprOptions,
eval_mode: EvalMode,
) -> Result, ExecutionError> {
+ // Look up query context from registry if expr_id is present
+ let query_context = spark_expr.expr_id.and_then(|expr_id| {
+ let registry = &self.query_context_registry;
+ registry.get(expr_id)
+ });
+
let left = self.create_expr(left, Arc::clone(&input_schema))?;
let right = self.create_expr(right, Arc::clone(&input_schema))?;
match (
@@ -688,17 +746,23 @@ impl PhysicalPlanner {
left,
DataType::Decimal256(p1, s1),
SparkCastOptions::new_without_timezone(EvalMode::Legacy, false),
+ None,
+ None,
));
let right = Arc::new(Cast::new(
right,
DataType::Decimal256(p2, s2),
SparkCastOptions::new_without_timezone(EvalMode::Legacy, false),
+ None,
+ None,
));
let child = Arc::new(BinaryExpr::new(left, op, right));
Ok(Arc::new(Cast::new(
child,
data_type,
SparkCastOptions::new_without_timezone(EvalMode::Legacy, false),
+ None,
+ None,
)))
}
(
@@ -787,13 +851,17 @@ impl PhysicalPlanner {
None,
eval_mode,
)?;
- Ok(Arc::new(ScalarFunctionExpr::new(
+ let scalar_expr = Arc::new(ScalarFunctionExpr::new(
op_str,
fun_expr,
vec![left, right],
Arc::new(Field::new(op_str, data_type, true)),
Arc::new(ConfigOptions::default()),
- )))
+ ));
+
+ // Wrap with CheckedBinaryExpr to add query_context to errors
+ use crate::execution::expressions::arithmetic::CheckedBinaryExpr;
+ Ok(Arc::new(CheckedBinaryExpr::new(scalar_expr, query_context)))
} else {
Ok(Arc::new(BinaryExpr::new(left, op, right)))
}
@@ -1804,6 +1872,26 @@ impl PhysicalPlanner {
spark_expr: &AggExpr,
schema: SchemaRef,
) -> Result {
+ // Register QueryContext if present
+ if let (Some(expr_id), Some(ctx_proto)) =
+ (spark_expr.expr_id, spark_expr.query_context.as_ref())
+ {
+ // Deserialize QueryContext from protobuf
+ let query_ctx = datafusion_comet_spark_expr::QueryContext::new(
+ ctx_proto.sql_text.clone(),
+ ctx_proto.start_index,
+ ctx_proto.stop_index,
+ ctx_proto.object_type.clone(),
+ ctx_proto.object_name.clone(),
+ ctx_proto.line,
+ ctx_proto.start_position,
+ );
+
+ // Register query context for error reporting
+ let registry = &self.query_context_registry;
+ registry.register(expr_id, query_ctx);
+ }
+
match spark_expr.expr_struct.as_ref().unwrap() {
AggExprStruct::Count(expr) => {
assert!(!expr.children.is_empty());
@@ -1854,8 +1942,12 @@ impl PhysicalPlanner {
let builder = match datatype {
DataType::Decimal128(_, _) => {
let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
- let func =
- AggregateUDF::new_from_impl(SumDecimal::try_new(datatype, eval_mode)?);
+ let func = AggregateUDF::new_from_impl(SumDecimal::try_new(
+ datatype,
+ eval_mode,
+ spark_expr.expr_id,
+ Arc::clone(&self.query_context_registry),
+ )?);
AggregateExprBuilder::new(Arc::new(func), vec![child])
}
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => {
@@ -1887,8 +1979,14 @@ impl PhysicalPlanner {
let builder = match datatype {
DataType::Decimal128(_, _) => {
- let func =
- AggregateUDF::new_from_impl(AvgDecimal::new(datatype, input_datatype));
+ let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
+ let func = AggregateUDF::new_from_impl(AvgDecimal::new(
+ datatype,
+ input_datatype,
+ eval_mode,
+ spark_expr.expr_id,
+ Arc::clone(&self.query_context_registry),
+ ));
AggregateExprBuilder::new(Arc::new(func), vec![child])
}
_ => {
@@ -3023,14 +3121,21 @@ fn create_case_expr(
Arc::clone(&x.1),
coerce_type.clone(),
cast_options.clone(),
+ None,
+ None,
));
(Arc::clone(&x.0), t)
})
.collect::, Arc)>>();
let else_phy_expr: Option> = else_expr.clone().map(|x| {
- Arc::new(Cast::new(x, coerce_type.clone(), cast_options.clone()))
- as Arc
+ Arc::new(Cast::new(
+ x,
+ coerce_type.clone(),
+ cast_options.clone(),
+ None,
+ None,
+ )) as Arc
});
Ok(Arc::new(CaseExpr::try_new(
None,
@@ -3717,9 +3822,13 @@ mod tests {
type_info: None,
}),
})),
+ query_context: None,
+ expr_id: None,
};
let right = spark_expression::Expr {
expr_struct: Some(Literal(lit)),
+ query_context: None,
+ expr_id: None,
};
let expr = spark_expression::Expr {
@@ -3727,6 +3836,8 @@ mod tests {
left: Some(Box::new(left)),
right: Some(Box::new(right)),
}))),
+ query_context: None,
+ expr_id: None,
};
Operator {
@@ -3782,6 +3893,8 @@ mod tests {
index,
datatype: Some(create_proto_datatype()),
})),
+ query_context: None,
+ expr_id: None,
}
}
@@ -3847,6 +3960,8 @@ mod tests {
type_info: None,
}),
})),
+ query_context: None,
+ expr_id: None,
};
let array_col_1 = spark_expression::Expr {
@@ -3857,6 +3972,8 @@ mod tests {
type_info: None,
}),
})),
+ query_context: None,
+ expr_id: None,
};
let projection = Operator {
@@ -3870,6 +3987,8 @@ mod tests {
return_type: None,
fail_on_error: false,
})),
+ query_context: None,
+ expr_id: None,
}],
})),
};
@@ -3925,6 +4044,140 @@ mod tests {
});
}
+ #[test]
+ fn test_array_repeat() {
+ let session_ctx = SessionContext::new();
+ let task_ctx = session_ctx.task_ctx();
+ let planner = PhysicalPlanner::new(Arc::from(session_ctx), 0);
+
+ // Mock scan operator with 3 INT32 columns
+ let op_scan = Operator {
+ plan_id: 0,
+ children: vec![],
+ op_struct: Some(OpStruct::Scan(spark_operator::Scan {
+ fields: vec![
+ spark_expression::DataType {
+ type_id: 3, // Int32
+ type_info: None,
+ },
+ spark_expression::DataType {
+ type_id: 3, // Int32
+ type_info: None,
+ },
+ spark_expression::DataType {
+ type_id: 3, // Int32
+ type_info: None,
+ },
+ ],
+ source: "".to_string(),
+ arrow_ffi_safe: false,
+ })),
+ };
+
+ // Mock expression to read a INT32 column with position 0
+ let array_col = spark_expression::Expr {
+ expr_struct: Some(Bound(spark_expression::BoundReference {
+ index: 0,
+ datatype: Some(spark_expression::DataType {
+ type_id: 3,
+ type_info: None,
+ }),
+ })),
+ query_context: None,
+ expr_id: None,
+ };
+
+ // Mock expression to read a INT32 column with position 1
+ let array_col_1 = spark_expression::Expr {
+ expr_struct: Some(Bound(spark_expression::BoundReference {
+ index: 1,
+ datatype: Some(spark_expression::DataType {
+ type_id: 3,
+ type_info: None,
+ }),
+ })),
+ query_context: None,
+ expr_id: None,
+ };
+
+ // Make a projection operator with array_repeat(array_col, array_col_1)
+ let projection = Operator {
+ children: vec![op_scan],
+ plan_id: 0,
+ op_struct: Some(OpStruct::Projection(spark_operator::Projection {
+ project_list: vec![spark_expression::Expr {
+ expr_struct: Some(ExprStruct::ScalarFunc(spark_expression::ScalarFunc {
+ func: "array_repeat".to_string(),
+ args: vec![array_col, array_col_1],
+ return_type: None,
+ fail_on_error: false,
+ })),
+ query_context: None,
+ expr_id: None,
+ }],
+ })),
+ };
+
+ // Create a physical plan
+ let (mut scans, datafusion_plan) =
+ planner.create_plan(&projection, &mut vec![], 1).unwrap();
+
+ // Start executing the plan in a separate thread
+ // The plan waits for incoming batches and emitting result as input comes
+ let mut stream = datafusion_plan.native_plan.execute(0, task_ctx).unwrap();
+
+ let runtime = tokio::runtime::Runtime::new().unwrap();
+ // create async channel
+ let (tx, mut rx) = mpsc::channel(1);
+
+ // Send data as input to the plan being executed in a separate thread
+ runtime.spawn(async move {
+ // create data batch
+ // 0, 1, 2
+ // 3, 4, 5
+ // 6, null, null
+ let a = Int32Array::from(vec![Some(0), Some(3), Some(6)]);
+ let b = Int32Array::from(vec![Some(1), Some(4), None]);
+ let c = Int32Array::from(vec![Some(2), Some(5), None]);
+ let input_batch1 = InputBatch::Batch(vec![Arc::new(a), Arc::new(b), Arc::new(c)], 3);
+ let input_batch2 = InputBatch::EOF;
+
+ let batches = vec![input_batch1, input_batch2];
+
+ for batch in batches.into_iter() {
+ tx.send(batch).await.unwrap();
+ }
+ });
+
+ // Wait for the plan to finish executing and assert the result
+ runtime.block_on(async move {
+ loop {
+ let batch = rx.recv().await.unwrap();
+ scans[0].set_input_batch(batch);
+ match poll!(stream.next()) {
+ Poll::Ready(Some(batch)) => {
+ assert!(batch.is_ok(), "got error {}", batch.unwrap_err());
+ let batch = batch.unwrap();
+ let expected = [
+ "+--------------+",
+ "| col_0 |",
+ "+--------------+",
+ "| [0] |",
+ "| [3, 3, 3, 3] |",
+ "| [] |",
+ "+--------------+",
+ ];
+ assert_batches_eq!(expected, &[batch]);
+ }
+ Poll::Ready(None) => {
+ break;
+ }
+ _ => {}
+ }
+ }
+ });
+ }
+
/// Executes a `test_data_query` SQL query
/// and saves the result into a temp folder using parquet format
/// Read the file back to the memory using a custom schema
@@ -4313,6 +4566,8 @@ mod tests {
type_info: None,
}),
})),
+ expr_id: None,
+ query_context: None,
};
// Create bound reference for the INT8 column (index 1)
@@ -4324,6 +4579,8 @@ mod tests {
type_info: None,
}),
})),
+ expr_id: None,
+ query_context: None,
};
// Create a Subtract expression: date_col - int8_col
@@ -4339,6 +4596,8 @@ mod tests {
}),
eval_mode: 0, // Legacy mode
}))),
+ expr_id: None,
+ query_context: None,
};
// Create a projection operator with the subtract expression
diff --git a/native/core/src/parquet/schema_adapter.rs b/native/core/src/parquet/schema_adapter.rs
index 42f0e7fc61..0ad61df426 100644
--- a/native/core/src/parquet/schema_adapter.rs
+++ b/native/core/src/parquet/schema_adapter.rs
@@ -349,7 +349,13 @@ impl SparkPhysicalExprAdapter {
cast_options.allow_cast_unsigned_ints = self.parquet_options.allow_cast_unsigned_ints;
cast_options.is_adapting_schema = true;
- let spark_cast = Arc::new(Cast::new(child, target_type.clone(), cast_options));
+ let spark_cast = Arc::new(Cast::new(
+ child,
+ target_type.clone(),
+ cast_options,
+ None,
+ None,
+ ));
return Ok(Transformed::yes(spark_cast as Arc));
}
diff --git a/native/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/native/org/apache/spark/sql/errors/QueryExecutionErrors.scala
new file mode 100644
index 0000000000..7981468394
--- /dev/null
+++ b/native/org/apache/spark/sql/errors/QueryExecutionErrors.scala
@@ -0,0 +1,2718 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.errors
+
+import java.io.{FileNotFoundException, IOException}
+import java.lang.reflect.InvocationTargetException
+import java.net.{URISyntaxException, URL}
+import java.time.DateTimeException
+import java.util.concurrent.TimeoutException
+
+import com.fasterxml.jackson.core.{JsonParser, JsonToken}
+import org.apache.hadoop.fs.{FileAlreadyExistsException, FileStatus, Path}
+import org.apache.hadoop.fs.permission.FsPermission
+import org.codehaus.commons.compiler.{CompileException, InternalCompilerException}
+
+import org.apache.spark._
+import org.apache.spark.launcher.SparkLauncher
+import org.apache.spark.memory.SparkOutOfMemoryError
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.ScalaReflection.Schema
+import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.catalyst.analysis.UnresolvedGenerator
+import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogTable}
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression}
+import org.apache.spark.sql.catalyst.parser.ParseException
+import org.apache.spark.sql.catalyst.plans.JoinType
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.ValueInterval
+import org.apache.spark.sql.catalyst.trees.{Origin, SQLQueryContext, TreeNode}
+import org.apache.spark.sql.catalyst.util.{sideBySide, BadRecordException, DateTimeUtils, FailFastMode}
+import org.apache.spark.sql.connector.catalog.{CatalogNotFoundException, Table, TableProvider}
+import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
+import org.apache.spark.sql.connector.expressions.Transform
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.StaticSQLConf.GLOBAL_TEMP_DATABASE
+import org.apache.spark.sql.streaming.OutputMode
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.array.ByteArrayMethods
+import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.util.{CircularBuffer, Utils}
+
+/**
+ * Object for grouping error messages from (most) exceptions thrown during query execution.
+ * This does not include exceptions thrown during the eager execution of commands, which are
+ * grouped into [[QueryCompilationErrors]].
+ */
+private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionErrors {
+
+ def cannotEvaluateExpressionError(expression: Expression): Throwable = {
+ SparkException.internalError(s"Cannot evaluate expression: $expression")
+ }
+
+ def cannotGenerateCodeForExpressionError(expression: Expression): Throwable = {
+ SparkException.internalError(s"Cannot generate code for expression: $expression")
+ }
+
+ def cannotTerminateGeneratorError(generator: UnresolvedGenerator): Throwable = {
+ SparkException.internalError(s"Cannot terminate expression: $generator")
+ }
+
+ def castingCauseOverflowError(t: Any, from: DataType, to: DataType): ArithmeticException = {
+ new SparkArithmeticException(
+ errorClass = "CAST_OVERFLOW",
+ messageParameters = Map(
+ "value" -> toSQLValue(t, from),
+ "sourceType" -> toSQLType(from),
+ "targetType" -> toSQLType(to),
+ "ansiConfig" -> toSQLConf(SQLConf.ANSI_ENABLED.key)),
+ context = Array.empty,
+ summary = "")
+ }
+
+ def castingCauseOverflowErrorInTableInsert(
+ from: DataType,
+ to: DataType,
+ columnName: String): ArithmeticException = {
+ new SparkArithmeticException(
+ errorClass = "CAST_OVERFLOW_IN_TABLE_INSERT",
+ messageParameters = Map(
+ "sourceType" -> toSQLType(from),
+ "targetType" -> toSQLType(to),
+ "columnName" -> toSQLId(columnName)),
+ context = Array.empty,
+ summary = ""
+ )
+ }
+
+ def cannotChangeDecimalPrecisionError(
+ value: Decimal,
+ decimalPrecision: Int,
+ decimalScale: Int,
+ context: SQLQueryContext = null): ArithmeticException = {
+ new SparkArithmeticException(
+ errorClass = "NUMERIC_VALUE_OUT_OF_RANGE",
+ messageParameters = Map(
+ "value" -> value.toPlainString,
+ "precision" -> decimalPrecision.toString,
+ "scale" -> decimalScale.toString,
+ "config" -> toSQLConf(SQLConf.ANSI_ENABLED.key)),
+ context = getQueryContext(context),
+ summary = getSummary(context))
+ }
+
+ def invalidInputSyntaxForBooleanError(
+ s: UTF8String,
+ context: SQLQueryContext): SparkRuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "CAST_INVALID_INPUT",
+ messageParameters = Map(
+ "expression" -> toSQLValue(s, StringType),
+ "sourceType" -> toSQLType(StringType),
+ "targetType" -> toSQLType(BooleanType),
+ "ansiConfig" -> toSQLConf(SQLConf.ANSI_ENABLED.key)),
+ context = getQueryContext(context),
+ summary = getSummary(context))
+ }
+
+ def invalidInputInCastToNumberError(
+ to: DataType,
+ s: UTF8String,
+ context: SQLQueryContext): SparkNumberFormatException = {
+ new SparkNumberFormatException(
+ errorClass = "CAST_INVALID_INPUT",
+ messageParameters = Map(
+ "expression" -> toSQLValue(s, StringType),
+ "sourceType" -> toSQLType(StringType),
+ "targetType" -> toSQLType(to),
+ "ansiConfig" -> toSQLConf(SQLConf.ANSI_ENABLED.key)),
+ context = getQueryContext(context),
+ summary = getSummary(context))
+ }
+
+ def invalidInputInConversionError(
+ to: DataType,
+ s: UTF8String,
+ fmt: UTF8String,
+ hint: String): SparkIllegalArgumentException = {
+ new SparkIllegalArgumentException(
+ errorClass = "CONVERSION_INVALID_INPUT",
+ messageParameters = Map(
+ "str" -> toSQLValue(s, StringType),
+ "fmt" -> toSQLValue(fmt, StringType),
+ "targetType" -> toSQLType(to),
+ "suggestion" -> toSQLId(hint)))
+ }
+
+ def cannotCastFromNullTypeError(to: DataType): Throwable = {
+ new SparkException(
+ errorClass = "CANNOT_CAST_DATATYPE",
+ messageParameters = Map(
+ "sourceType" -> NullType.typeName,
+ "targetType" -> to.typeName),
+ cause = null)
+ }
+
+ def cannotCastError(from: DataType, to: DataType): Throwable = {
+ new SparkException(
+ errorClass = "CANNOT_CAST_DATATYPE",
+ messageParameters = Map(
+ "sourceType" -> from.typeName,
+ "targetType" -> to.typeName),
+ cause = null)
+ }
+
+ def cannotParseDecimalError(): Throwable = {
+ new SparkRuntimeException(
+ errorClass = "CANNOT_PARSE_DECIMAL",
+ messageParameters = Map.empty)
+ }
+
+ def dataTypeUnsupportedError(dataType: String, failure: String): Throwable = {
+ DataTypeErrors.dataTypeUnsupportedError(dataType, failure)
+ }
+
+ def failedExecuteUserDefinedFunctionError(functionName: String, inputTypes: String,
+ outputType: String, e: Throwable): Throwable = {
+ new SparkException(
+ errorClass = "FAILED_EXECUTE_UDF",
+ messageParameters = Map(
+ "functionName" -> toSQLId(functionName),
+ "signature" -> inputTypes,
+ "result" -> outputType),
+ cause = e)
+ }
+
+ def divideByZeroError(context: SQLQueryContext): ArithmeticException = {
+ new SparkArithmeticException(
+ errorClass = "DIVIDE_BY_ZERO",
+ messageParameters = Map("config" -> toSQLConf(SQLConf.ANSI_ENABLED.key)),
+ context = getQueryContext(context),
+ summary = getSummary(context))
+ }
+
+ def intervalDividedByZeroError(context: SQLQueryContext): ArithmeticException = {
+ new SparkArithmeticException(
+ errorClass = "INTERVAL_DIVIDED_BY_ZERO",
+ messageParameters = Map.empty,
+ context = getQueryContext(context),
+ summary = getSummary(context))
+ }
+
+ def invalidArrayIndexError(
+ index: Int,
+ numElements: Int,
+ context: SQLQueryContext): ArrayIndexOutOfBoundsException = {
+ new SparkArrayIndexOutOfBoundsException(
+ errorClass = "INVALID_ARRAY_INDEX",
+ messageParameters = Map(
+ "indexValue" -> toSQLValue(index, IntegerType),
+ "arraySize" -> toSQLValue(numElements, IntegerType),
+ "ansiConfig" -> toSQLConf(SQLConf.ANSI_ENABLED.key)),
+ context = getQueryContext(context),
+ summary = getSummary(context))
+ }
+
+ def invalidElementAtIndexError(
+ index: Int,
+ numElements: Int,
+ context: SQLQueryContext): ArrayIndexOutOfBoundsException = {
+ new SparkArrayIndexOutOfBoundsException(
+ errorClass = "INVALID_ARRAY_INDEX_IN_ELEMENT_AT",
+ messageParameters = Map(
+ "indexValue" -> toSQLValue(index, IntegerType),
+ "arraySize" -> toSQLValue(numElements, IntegerType),
+ "ansiConfig" -> toSQLConf(SQLConf.ANSI_ENABLED.key)),
+ context = getQueryContext(context),
+ summary = getSummary(context))
+ }
+
+ def invalidBitmapPositionError(bitPosition: Long,
+ bitmapNumBytes: Long): ArrayIndexOutOfBoundsException = {
+ new SparkArrayIndexOutOfBoundsException(
+ errorClass = "INVALID_BITMAP_POSITION",
+ messageParameters = Map(
+ "bitPosition" -> s"$bitPosition",
+ "bitmapNumBytes" -> s"$bitmapNumBytes",
+ "bitmapNumBits" -> s"${bitmapNumBytes * 8}"),
+ context = Array.empty,
+ summary = "")
+ }
+
+ def invalidFractionOfSecondError(): DateTimeException = {
+ new SparkDateTimeException(
+ errorClass = "INVALID_FRACTION_OF_SECOND",
+ messageParameters = Map(
+ "ansiConfig" -> toSQLConf(SQLConf.ANSI_ENABLED.key)
+ ),
+ context = Array.empty,
+ summary = "")
+ }
+
+ def ansiDateTimeParseError(e: Exception): SparkDateTimeException = {
+ new SparkDateTimeException(
+ errorClass = "CANNOT_PARSE_TIMESTAMP",
+ messageParameters = Map(
+ "message" -> e.getMessage,
+ "ansiConfig" -> toSQLConf(SQLConf.ANSI_ENABLED.key)),
+ context = Array.empty,
+ summary = "")
+ }
+
+ def ansiDateTimeError(e: Exception): SparkDateTimeException = {
+ new SparkDateTimeException(
+ errorClass = "_LEGACY_ERROR_TEMP_2000",
+ messageParameters = Map(
+ "message" -> e.getMessage,
+ "ansiConfig" -> toSQLConf(SQLConf.ANSI_ENABLED.key)),
+ context = Array.empty,
+ summary = "")
+ }
+
+ def ansiIllegalArgumentError(message: String): SparkIllegalArgumentException = {
+ new SparkIllegalArgumentException(
+ errorClass = "_LEGACY_ERROR_TEMP_2000",
+ messageParameters = Map(
+ "message" -> message,
+ "ansiConfig" -> toSQLConf(SQLConf.ANSI_ENABLED.key)))
+ }
+
+ def ansiIllegalArgumentError(e: IllegalArgumentException): IllegalArgumentException = {
+ ansiIllegalArgumentError(e.getMessage)
+ }
+
+ def overflowInSumOfDecimalError(context: SQLQueryContext): ArithmeticException = {
+ arithmeticOverflowError("Overflow in sum of decimals", context = context)
+ }
+
+ def overflowInIntegralDivideError(context: SQLQueryContext): ArithmeticException = {
+ arithmeticOverflowError("Overflow in integral divide", "try_divide", context)
+ }
+
+ def overflowInConvError(context: SQLQueryContext): ArithmeticException = {
+ arithmeticOverflowError("Overflow in function conv()", context = context)
+ }
+
+ def mapSizeExceedArraySizeWhenZipMapError(size: Int): SparkRuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "_LEGACY_ERROR_TEMP_2003",
+ messageParameters = Map(
+ "size" -> size.toString(),
+ "maxRoundedArrayLength" -> ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString()))
+ }
+
+ def literalTypeUnsupportedError(v: Any): RuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "UNSUPPORTED_FEATURE.LITERAL_TYPE",
+ messageParameters = Map(
+ "value" -> v.toString,
+ "type" -> v.getClass.toString))
+ }
+
+ def pivotColumnUnsupportedError(v: Any, dataType: DataType): RuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "UNSUPPORTED_FEATURE.PIVOT_TYPE",
+ messageParameters = Map(
+ "value" -> v.toString,
+ "type" -> toSQLType(dataType)))
+ }
+
+ def noDefaultForDataTypeError(dataType: DataType): SparkException = {
+ SparkException.internalError(s"No default value for type: ${toSQLType(dataType)}.")
+ }
+
+ def orderedOperationUnsupportedByDataTypeError(
+ dataType: DataType): SparkIllegalArgumentException = {
+ new SparkIllegalArgumentException(
+ errorClass = "_LEGACY_ERROR_TEMP_2005",
+ messageParameters = Map("dataType" -> dataType.toString()))
+ }
+
+ def orderedOperationUnsupportedByDataTypeError(
+ dataType: String): SparkIllegalArgumentException = {
+ new SparkIllegalArgumentException(
+ errorClass = "_LEGACY_ERROR_TEMP_2005",
+ messageParameters = Map("dataType" -> dataType))
+ }
+
+ def invalidRegexGroupIndexError(
+ funcName: String,
+ groupCount: Int,
+ groupIndex: Int): RuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "INVALID_PARAMETER_VALUE.REGEX_GROUP_INDEX",
+ messageParameters = Map(
+ "parameter" -> toSQLId("idx"),
+ "functionName" -> toSQLId(funcName),
+ "groupCount" -> groupCount.toString(),
+ "groupIndex" -> groupIndex.toString()))
+ }
+
+ def invalidUrlError(url: UTF8String, e: URISyntaxException): SparkIllegalArgumentException = {
+ new SparkIllegalArgumentException(
+ errorClass = "INVALID_URL",
+ messageParameters = Map(
+ "url" -> url.toString,
+ "ansiConfig" -> toSQLConf(SQLConf.ANSI_ENABLED.key)),
+ cause = e)
+ }
+
+ def illegalUrlError(url: UTF8String, e: IllegalArgumentException): Throwable = {
+ new SparkIllegalArgumentException(
+ errorClass = "CANNOT_DECODE_URL",
+ messageParameters = Map("url" -> url.toString),
+ cause = e
+ )
+ }
+
+ def mergeUnsupportedByWindowFunctionError(funcName: String): Throwable = {
+ SparkException.internalError(
+ s"The aggregate window function ${toSQLId(funcName)} does not support merging.")
+ }
+
+ def dataTypeUnexpectedError(dataType: DataType): SparkUnsupportedOperationException = {
+ new SparkUnsupportedOperationException(
+ errorClass = "_LEGACY_ERROR_TEMP_2011",
+ messageParameters = Map("dataType" -> dataType.catalogString))
+ }
+
+ def typeUnsupportedError(dataType: DataType): SparkIllegalArgumentException = {
+ new SparkIllegalArgumentException(
+ errorClass = "_LEGACY_ERROR_TEMP_2011",
+ messageParameters = Map("dataType" -> dataType.toString()))
+ }
+
+ def negativeValueUnexpectedError(
+ frequencyExpression : Expression): SparkIllegalArgumentException = {
+ new SparkIllegalArgumentException(
+ errorClass = "_LEGACY_ERROR_TEMP_2013",
+ messageParameters = Map("frequencyExpression" -> frequencyExpression.sql))
+ }
+
+ def addNewFunctionMismatchedWithFunctionError(funcName: String): Throwable = {
+ SparkException.internalError(
+ "Cannot add new function to generated class, " +
+ s"failed to match ${toSQLId(funcName)} at `addNewFunction`.")
+ }
+
+ def cannotGenerateCodeForIncomparableTypeError(
+ codeType: String, dataType: DataType): SparkIllegalArgumentException = {
+ new SparkIllegalArgumentException(
+ errorClass = "_LEGACY_ERROR_TEMP_2015",
+ messageParameters = Map(
+ "codeType" -> codeType,
+ "dataType" -> dataType.catalogString))
+ }
+
+ def cannotInterpolateClassIntoCodeBlockError(arg: Any): SparkIllegalArgumentException = {
+ new SparkIllegalArgumentException(
+ errorClass = "_LEGACY_ERROR_TEMP_2016",
+ messageParameters = Map("arg" -> arg.getClass.getName))
+ }
+
+ def customCollectionClsNotResolvedError(): SparkUnsupportedOperationException = {
+ new SparkUnsupportedOperationException(
+ errorClass = "_LEGACY_ERROR_TEMP_2017",
+ messageParameters = Map.empty)
+ }
+
+ def classUnsupportedByMapObjectsError(cls: Class[_]): SparkRuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "_LEGACY_ERROR_TEMP_2018",
+ messageParameters = Map("cls" -> cls.getName))
+ }
+
+ def nullAsMapKeyNotAllowedError(): SparkRuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "NULL_MAP_KEY",
+ messageParameters = Map.empty)
+ }
+
+ def methodNotDeclaredError(name: String): Throwable = {
+ SparkException.internalError(
+ s"""A method named "$name" is not declared in any enclosing class nor any supertype""")
+ }
+
+ def methodNotFoundError(
+ cls: Class[_],
+ functionName: String,
+ argClasses: Seq[Class[_]]): Throwable = {
+ SparkException.internalError(
+ s"Couldn't find method $functionName with arguments " +
+ s"${argClasses.mkString("(", ", ", ")")} on $cls.")
+ }
+
+ def constructorNotFoundError(cls: String): SparkRuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "_LEGACY_ERROR_TEMP_2020",
+ messageParameters = Map("cls" -> cls))
+ }
+
+ def unsupportedNaturalJoinTypeError(joinType: JoinType): SparkException = {
+ SparkException.internalError(
+ s"Unsupported natural join type ${joinType.toString}")
+ }
+
+ def notExpectedUnresolvedEncoderError(attr: AttributeReference): SparkRuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "_LEGACY_ERROR_TEMP_2023",
+ messageParameters = Map("attr" -> attr.toString()))
+ }
+
+ def unsupportedEncoderError(): SparkRuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "_LEGACY_ERROR_TEMP_2024",
+ messageParameters = Map.empty)
+ }
+
+ def notOverrideExpectedMethodsError(
+ className: String, m1: String, m2: String): SparkRuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "_LEGACY_ERROR_TEMP_2025",
+ messageParameters = Map("className" -> className, "m1" -> m1, "m2" -> m2))
+ }
+
+ def failToConvertValueToJsonError(
+ value: AnyRef, cls: Class[_], dataType: DataType): SparkRuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "_LEGACY_ERROR_TEMP_2026",
+ messageParameters = Map(
+ "value" -> value.toString(),
+ "cls" -> cls.toString(),
+ "dataType" -> dataType.toString()))
+ }
+
+ def unexpectedOperatorInCorrelatedSubquery(
+ op: LogicalPlan, pos: String = ""): SparkRuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "_LEGACY_ERROR_TEMP_2027",
+ messageParameters = Map("op" -> op.toString(), "pos" -> pos))
+ }
+
+ def unsupportedRoundingMode(roundMode: BigDecimal.RoundingMode.Value): SparkException = {
+ DataTypeErrors.unsupportedRoundingMode(roundMode)
+ }
+
+ def resolveCannotHandleNestedSchema(plan: LogicalPlan): SparkRuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "_LEGACY_ERROR_TEMP_2030",
+ messageParameters = Map("plan" -> plan.toString()))
+ }
+
+ def inputExternalRowCannotBeNullError(): SparkRuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "_LEGACY_ERROR_TEMP_2031",
+ messageParameters = Map.empty)
+ }
+
+ def fieldCannotBeNullMsg(index: Int, fieldName: String): String = {
+ s"The ${index}th field '$fieldName' of input row cannot be null."
+ }
+
+ def fieldCannotBeNullError(index: Int, fieldName: String): SparkRuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "_LEGACY_ERROR_TEMP_2032",
+ messageParameters = Map("fieldCannotBeNullMsg" -> fieldCannotBeNullMsg(index, fieldName)))
+ }
+
+ def unableToCreateDatabaseAsFailedToCreateDirectoryError(
+ dbDefinition: CatalogDatabase, e: IOException): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2033",
+ messageParameters = Map(
+ "name" -> dbDefinition.name,
+ "locationUri" -> dbDefinition.locationUri.toString()),
+ cause = e)
+ }
+
+ def unableToDropDatabaseAsFailedToDeleteDirectoryError(
+ dbDefinition: CatalogDatabase, e: IOException): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2034",
+ messageParameters = Map(
+ "name" -> dbDefinition.name,
+ "locationUri" -> dbDefinition.locationUri.toString()),
+ cause = e)
+ }
+
+ def unableToCreateTableAsFailedToCreateDirectoryError(
+ table: String, defaultTableLocation: Path, e: IOException): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2035",
+ messageParameters = Map(
+ "table" -> table,
+ "defaultTableLocation" -> defaultTableLocation.toString()),
+ cause = e)
+ }
+
+ def unableToDeletePartitionPathError(partitionPath: Path, e: IOException): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2036",
+ messageParameters = Map("partitionPath" -> partitionPath.toString()),
+ cause = e)
+ }
+
+ def unableToDropTableAsFailedToDeleteDirectoryError(
+ table: String, dir: Path, e: IOException): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2037",
+ messageParameters = Map("table" -> table, "dir" -> dir.toString()),
+ cause = e)
+ }
+
+ def unableToRenameTableAsFailedToRenameDirectoryError(
+ oldName: String, newName: String, oldDir: Path, e: IOException): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2038",
+ messageParameters = Map(
+ "oldName" -> oldName,
+ "newName" -> newName,
+ "oldDir" -> oldDir.toString()),
+ cause = e)
+ }
+
+ def unableToCreatePartitionPathError(partitionPath: Path, e: IOException): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2039",
+ messageParameters = Map("partitionPath" -> partitionPath.toString()),
+ cause = e)
+ }
+
+ def unableToRenamePartitionPathError(oldPartPath: Path, e: IOException): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2040",
+ messageParameters = Map("oldPartPath" -> oldPartPath.toString()),
+ cause = e)
+ }
+
+ def methodNotImplementedError(methodName: String): SparkUnsupportedOperationException = {
+ new SparkUnsupportedOperationException(
+ errorClass = "_LEGACY_ERROR_TEMP_2041",
+ messageParameters = Map("methodName" -> methodName))
+ }
+
+ def arithmeticOverflowError(e: ArithmeticException): SparkArithmeticException = {
+ new SparkArithmeticException(
+ errorClass = "_LEGACY_ERROR_TEMP_2042",
+ messageParameters = Map(
+ "message" -> e.getMessage,
+ "ansiConfig" -> toSQLConf(SQLConf.ANSI_ENABLED.key)),
+ context = Array.empty,
+ summary = "")
+ }
+
+ def unaryMinusCauseOverflowError(originValue: Int): SparkArithmeticException = {
+ new SparkArithmeticException(
+ errorClass = "_LEGACY_ERROR_TEMP_2043",
+ messageParameters = Map("sqlValue" -> toSQLValue(originValue, IntegerType)),
+ context = Array.empty,
+ summary = "")
+ }
+
+ def binaryArithmeticCauseOverflowError(
+ eval1: Short, symbol: String, eval2: Short): SparkArithmeticException = {
+ new SparkArithmeticException(
+ errorClass = "BINARY_ARITHMETIC_OVERFLOW",
+ messageParameters = Map(
+ "value1" -> toSQLValue(eval1, ShortType),
+ "symbol" -> symbol,
+ "value2" -> toSQLValue(eval2, ShortType)),
+ context = Array.empty,
+ summary = "")
+ }
+
+ def intervalArithmeticOverflowError(
+ message: String,
+ hint: String = "",
+ context: SQLQueryContext): ArithmeticException = {
+ val alternative = if (hint.nonEmpty) {
+ s" Use '$hint' to tolerate overflow and return NULL instead."
+ } else ""
+ new SparkArithmeticException(
+ errorClass = "INTERVAL_ARITHMETIC_OVERFLOW",
+ messageParameters = Map(
+ "message" -> message,
+ "alternative" -> alternative),
+ context = getQueryContext(context),
+ summary = getSummary(context))
+ }
+
+ def failedToCompileMsg(e: Exception): String = {
+ s"failed to compile: $e"
+ }
+
+ def internalCompilerError(e: InternalCompilerException): Throwable = {
+ new InternalCompilerException(failedToCompileMsg(e), e)
+ }
+
+ def compilerError(e: CompileException): Throwable = {
+ new CompileException(failedToCompileMsg(e), e.getLocation)
+ }
+
+ def unsupportedTableChangeError(e: IllegalArgumentException): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2045",
+ messageParameters = Map("message" -> e.getMessage),
+ cause = e)
+ }
+
+ def notADatasourceRDDPartitionError(split: Partition): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2046",
+ messageParameters = Map("split" -> split.toString()),
+ cause = null)
+ }
+
+ def dataPathNotSpecifiedError(): SparkIllegalArgumentException = {
+ new SparkIllegalArgumentException(
+ errorClass = "_LEGACY_ERROR_TEMP_2047",
+ messageParameters = Map.empty)
+ }
+
+ def createStreamingSourceNotSpecifySchemaError(): SparkIllegalArgumentException = {
+ new SparkIllegalArgumentException(
+ errorClass = "_LEGACY_ERROR_TEMP_2048",
+ messageParameters = Map.empty)
+ }
+
+ def streamedOperatorUnsupportedByDataSourceError(
+ className: String, operator: String): SparkUnsupportedOperationException = {
+ new SparkUnsupportedOperationException(
+ errorClass = "_LEGACY_ERROR_TEMP_2049",
+ messageParameters = Map("className" -> className, "operator" -> operator))
+ }
+
+ def nonTimeWindowNotSupportedInStreamingError(
+ windowFuncList: Seq[String],
+ columnNameList: Seq[String],
+ windowSpecList: Seq[String],
+ origin: Origin): AnalysisException = {
+ new AnalysisException(
+ errorClass = "NON_TIME_WINDOW_NOT_SUPPORTED_IN_STREAMING",
+ messageParameters = Map(
+ "windowFunc" -> windowFuncList.map(toSQLStmt(_)).mkString(","),
+ "columnName" -> columnNameList.map(toSQLId(_)).mkString(","),
+ "windowSpec" -> windowSpecList.map(toSQLStmt(_)).mkString(",")),
+ origin = origin)
+ }
+
+ def multiplePathsSpecifiedError(allPaths: Seq[String]): SparkIllegalArgumentException = {
+ new SparkIllegalArgumentException(
+ errorClass = "_LEGACY_ERROR_TEMP_2050",
+ messageParameters = Map("paths" -> allPaths.mkString(", ")))
+ }
+
+ def dataSourceNotFoundError(
+ provider: String, error: Throwable): SparkClassNotFoundException = {
+ new SparkClassNotFoundException(
+ errorClass = "DATA_SOURCE_NOT_FOUND",
+ messageParameters = Map("provider" -> provider),
+ cause = error)
+ }
+
+ def removedClassInSpark2Error(className: String, e: Throwable): SparkClassNotFoundException = {
+ new SparkClassNotFoundException(
+ errorClass = "_LEGACY_ERROR_TEMP_2052",
+ messageParameters = Map("className" -> className),
+ cause = e)
+ }
+
+ def incompatibleDataSourceRegisterError(e: Throwable): Throwable = {
+ new SparkClassNotFoundException(
+ errorClass = "INCOMPATIBLE_DATASOURCE_REGISTER",
+ messageParameters = Map("message" -> e.getMessage),
+ cause = e)
+ }
+
+ def sparkUpgradeInReadingDatesError(
+ format: String, config: String, option: String): SparkUpgradeException = {
+ new SparkUpgradeException(
+ errorClass = "INCONSISTENT_BEHAVIOR_CROSS_VERSION.READ_ANCIENT_DATETIME",
+ messageParameters = Map(
+ "format" -> format,
+ "config" -> toSQLConf(config),
+ "option" -> toDSOption(option)),
+ cause = null
+ )
+ }
+
+ def sparkUpgradeInWritingDatesError(format: String, config: String): SparkUpgradeException = {
+ new SparkUpgradeException(
+ errorClass = "INCONSISTENT_BEHAVIOR_CROSS_VERSION.WRITE_ANCIENT_DATETIME",
+ messageParameters = Map(
+ "format" -> format,
+ "config" -> toSQLConf(config)),
+ cause = null
+ )
+ }
+
+ def buildReaderUnsupportedForFileFormatError(
+ format: String): SparkUnsupportedOperationException = {
+ new SparkUnsupportedOperationException(
+ errorClass = "_LEGACY_ERROR_TEMP_2053",
+ messageParameters = Map("format" -> format))
+ }
+
+ def taskFailedWhileWritingRowsError(path: String, cause: Throwable): Throwable = {
+ new SparkException(
+ errorClass = "TASK_WRITE_FAILED",
+ messageParameters = Map("path" -> path),
+ cause = cause)
+ }
+
+ def readCurrentFileNotFoundError(e: FileNotFoundException): SparkFileNotFoundException = {
+ new SparkFileNotFoundException(
+ errorClass = "_LEGACY_ERROR_TEMP_2055",
+ messageParameters = Map("message" -> e.getMessage))
+ }
+
+ def saveModeUnsupportedError(saveMode: Any, pathExists: Boolean): Throwable = {
+ val errorSubClass = if (pathExists) "EXISTENT_PATH" else "NON_EXISTENT_PATH"
+ new SparkIllegalArgumentException(
+ errorClass = s"UNSUPPORTED_SAVE_MODE.$errorSubClass",
+ messageParameters = Map("saveMode" -> toSQLValue(saveMode, StringType)))
+ }
+
+ def cannotClearOutputDirectoryError(staticPrefixPath: Path): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2056",
+ messageParameters = Map("staticPrefixPath" -> staticPrefixPath.toString()),
+ cause = null)
+ }
+
+ def cannotClearPartitionDirectoryError(path: Path): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2057",
+ messageParameters = Map("path" -> path.toString()),
+ cause = null)
+ }
+
+ def failedToCastValueToDataTypeForPartitionColumnError(
+ value: String, dataType: DataType, columnName: String): SparkRuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "_LEGACY_ERROR_TEMP_2058",
+ messageParameters = Map(
+ "value" -> value,
+ "dataType" -> dataType.toString(),
+ "columnName" -> columnName))
+ }
+
+ def endOfStreamError(): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2059",
+ messageParameters = Map.empty,
+ cause = null)
+ }
+
+ def fallbackV1RelationReportsInconsistentSchemaError(
+ v2Schema: StructType, v1Schema: StructType): SparkIllegalArgumentException = {
+ new SparkIllegalArgumentException(
+ errorClass = "_LEGACY_ERROR_TEMP_2060",
+ messageParameters = Map("v2Schema" -> v2Schema.toString(), "v1Schema" -> v1Schema.toString()))
+ }
+
+ def noRecordsFromEmptyDataReaderError(): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2061",
+ messageParameters = Map.empty,
+ cause = null)
+ }
+
+ def fileNotFoundError(e: FileNotFoundException): SparkFileNotFoundException = {
+ new SparkFileNotFoundException(
+ errorClass = "_LEGACY_ERROR_TEMP_2062",
+ messageParameters = Map("message" -> e.getMessage))
+ }
+
+ def unsupportedSchemaColumnConvertError(
+ filePath: String,
+ column: String,
+ logicalType: String,
+ physicalType: String,
+ e: Exception): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2063",
+ messageParameters = Map(
+ "filePath" -> filePath,
+ "column" -> column,
+ "logicalType" -> logicalType,
+ "physicalType" -> physicalType),
+ cause = e)
+ }
+
+ def cannotReadFilesError(
+ e: Throwable,
+ path: String): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2064",
+ messageParameters = Map("path" -> path),
+ cause = e)
+ }
+
+ def cannotCreateColumnarReaderError(): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2065",
+ messageParameters = Map.empty,
+ cause = null)
+ }
+
+ def invalidNamespaceNameError(namespace: Array[String]): SparkIllegalArgumentException = {
+ new SparkIllegalArgumentException(
+ errorClass = "_LEGACY_ERROR_TEMP_2066",
+ messageParameters = Map("namespace" -> namespace.quoted))
+ }
+
+ def unsupportedPartitionTransformError(
+ transform: Transform): SparkUnsupportedOperationException = {
+ new SparkUnsupportedOperationException(
+ errorClass = "_LEGACY_ERROR_TEMP_2067",
+ messageParameters = Map("transform" -> transform.toString()))
+ }
+
+ def missingDatabaseLocationError(): SparkIllegalArgumentException = {
+ new SparkIllegalArgumentException(
+ errorClass = "_LEGACY_ERROR_TEMP_2068",
+ messageParameters = Map.empty)
+ }
+
+ def cannotRemoveReservedPropertyError(property: String): SparkUnsupportedOperationException = {
+ new SparkUnsupportedOperationException(
+ errorClass = "_LEGACY_ERROR_TEMP_2069",
+ messageParameters = Map("property" -> property))
+ }
+
+ def writingJobFailedError(cause: Throwable): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2070",
+ messageParameters = Map.empty,
+ cause = cause)
+ }
+
+ def commitDeniedError(
+ partId: Int, taskId: Long, attemptId: Int, stageId: Int, stageAttempt: Int): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2071",
+ messageParameters = Map(
+ "partId" -> partId.toString(),
+ "taskId" -> taskId.toString(),
+ "attemptId" -> attemptId.toString(),
+ "stageId" -> stageId.toString(),
+ "stageAttempt" -> stageAttempt.toString()),
+ cause = null)
+ }
+
+ def cannotCreateJDBCTableWithPartitionsError(): SparkUnsupportedOperationException = {
+ new SparkUnsupportedOperationException(
+ errorClass = "_LEGACY_ERROR_TEMP_2073",
+ messageParameters = Map.empty)
+ }
+
+ def unsupportedUserSpecifiedSchemaError(): SparkUnsupportedOperationException = {
+ new SparkUnsupportedOperationException(
+ errorClass = "_LEGACY_ERROR_TEMP_2074",
+ messageParameters = Map.empty)
+ }
+
+ def writeUnsupportedForBinaryFileDataSourceError(): SparkUnsupportedOperationException = {
+ new SparkUnsupportedOperationException(
+ errorClass = "_LEGACY_ERROR_TEMP_2075",
+ messageParameters = Map.empty)
+ }
+
+ def fileLengthExceedsMaxLengthError(status: FileStatus, maxLength: Int): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2076",
+ messageParameters = Map(
+ "path" -> status.getPath.toString(),
+ "len" -> status.getLen.toString(),
+ "maxLength" -> maxLength.toString()),
+ cause = null)
+ }
+
+ def unsupportedFieldNameError(fieldName: String): SparkRuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "_LEGACY_ERROR_TEMP_2077",
+ messageParameters = Map("fieldName" -> fieldName))
+ }
+
+ def cannotSpecifyBothJdbcTableNameAndQueryError(
+ jdbcTableName: String, jdbcQueryString: String): SparkIllegalArgumentException = {
+ new SparkIllegalArgumentException(
+ errorClass = "_LEGACY_ERROR_TEMP_2078",
+ messageParameters = Map(
+ "jdbcTableName" -> jdbcTableName,
+ "jdbcQueryString" -> jdbcQueryString))
+ }
+
+ def missingJdbcTableNameAndQueryError(
+ jdbcTableName: String, jdbcQueryString: String): SparkIllegalArgumentException = {
+ new SparkIllegalArgumentException(
+ errorClass = "_LEGACY_ERROR_TEMP_2079",
+ messageParameters = Map(
+ "jdbcTableName" -> jdbcTableName,
+ "jdbcQueryString" -> jdbcQueryString))
+ }
+
+ def emptyOptionError(optionName: String): SparkIllegalArgumentException = {
+ new SparkIllegalArgumentException(
+ errorClass = "_LEGACY_ERROR_TEMP_2080",
+ messageParameters = Map("optionName" -> optionName))
+ }
+
+ def invalidJdbcTxnIsolationLevelError(
+ jdbcTxnIsolationLevel: String, value: String): SparkIllegalArgumentException = {
+ new SparkIllegalArgumentException(
+ errorClass = "_LEGACY_ERROR_TEMP_2081",
+ messageParameters = Map("value" -> value, "jdbcTxnIsolationLevel" -> jdbcTxnIsolationLevel))
+ }
+
+ def cannotGetJdbcTypeError(dt: DataType): SparkIllegalArgumentException = {
+ new SparkIllegalArgumentException(
+ errorClass = "_LEGACY_ERROR_TEMP_2082",
+ messageParameters = Map("catalogString" -> dt.catalogString))
+ }
+
+ def unrecognizedSqlTypeError(jdbcTypeId: String, typeName: String): Throwable = {
+ new SparkSQLException(
+ errorClass = "UNRECOGNIZED_SQL_TYPE",
+ messageParameters = Map("typeName" -> typeName, "jdbcType" -> jdbcTypeId))
+ }
+
+ def unsupportedJdbcTypeError(content: String): SparkSQLException = {
+ new SparkSQLException(
+ errorClass = "_LEGACY_ERROR_TEMP_2083",
+ messageParameters = Map("content" -> content))
+ }
+
+ def unsupportedArrayElementTypeBasedOnBinaryError(dt: DataType): SparkIllegalArgumentException = {
+ new SparkIllegalArgumentException(
+ errorClass = "_LEGACY_ERROR_TEMP_2084",
+ messageParameters = Map("catalogString" -> dt.catalogString))
+ }
+
+ def nestedArraysUnsupportedError(): SparkIllegalArgumentException = {
+ new SparkIllegalArgumentException(
+ errorClass = "_LEGACY_ERROR_TEMP_2085",
+ messageParameters = Map.empty)
+ }
+
+ def cannotTranslateNonNullValueForFieldError(pos: Int): SparkIllegalArgumentException = {
+ new SparkIllegalArgumentException(
+ errorClass = "_LEGACY_ERROR_TEMP_2086",
+ messageParameters = Map("pos" -> pos.toString()))
+ }
+
+ def invalidJdbcNumPartitionsError(
+ n: Int, jdbcNumPartitions: String): SparkIllegalArgumentException = {
+ new SparkIllegalArgumentException(
+ errorClass = "_LEGACY_ERROR_TEMP_2087",
+ messageParameters = Map("n" -> n.toString(), "jdbcNumPartitions" -> jdbcNumPartitions))
+ }
+
+ def multiActionAlterError(tableName: String): Throwable = {
+ new SparkSQLFeatureNotSupportedException(
+ errorClass = "UNSUPPORTED_FEATURE.MULTI_ACTION_ALTER",
+ messageParameters = Map("tableName" -> tableName))
+ }
+
+ def dataTypeUnsupportedYetError(dataType: DataType): SparkUnsupportedOperationException = {
+ new SparkUnsupportedOperationException(
+ errorClass = "_LEGACY_ERROR_TEMP_2088",
+ messageParameters = Map("dataType" -> dataType.toString()))
+ }
+
+ def unsupportedOperationForDataTypeError(
+ dataType: DataType): SparkUnsupportedOperationException = {
+ new SparkUnsupportedOperationException(
+ errorClass = "_LEGACY_ERROR_TEMP_2089",
+ messageParameters = Map("catalogString" -> dataType.catalogString))
+ }
+
+ def inputFilterNotFullyConvertibleError(owner: String): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2090",
+ messageParameters = Map("owner" -> owner),
+ cause = null)
+ }
+
+ def cannotReadFooterForFileError(file: Path, e: Exception): Throwable = {
+ new SparkException(
+ errorClass = "CANNOT_READ_FILE_FOOTER",
+ messageParameters = Map("file" -> file.toString()),
+ cause = e)
+ }
+
+ def foundDuplicateFieldInCaseInsensitiveModeError(
+ requiredFieldName: String, matchedOrcFields: String): SparkRuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "_LEGACY_ERROR_TEMP_2093",
+ messageParameters = Map(
+ "requiredFieldName" -> requiredFieldName,
+ "matchedOrcFields" -> matchedOrcFields))
+ }
+
+ def foundDuplicateFieldInFieldIdLookupModeError(
+ requiredId: Int, matchedFields: String): SparkRuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "_LEGACY_ERROR_TEMP_2094",
+ messageParameters = Map(
+ "requiredId" -> requiredId.toString(),
+ "matchedFields" -> matchedFields))
+ }
+
+ def failedToMergeIncompatibleSchemasError(
+ left: StructType, right: StructType, e: Throwable): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2095",
+ messageParameters = Map("left" -> left.toString(), "right" -> right.toString()),
+ cause = e)
+ }
+
+ def ddlUnsupportedTemporarilyError(ddl: String): SparkUnsupportedOperationException = {
+ new SparkUnsupportedOperationException(
+ errorClass = "_LEGACY_ERROR_TEMP_2096",
+ messageParameters = Map("ddl" -> ddl))
+ }
+
+ def executeBroadcastTimeoutError(timeout: Long, ex: Option[TimeoutException]): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2097",
+ messageParameters = Map(
+ "timeout" -> timeout.toString(),
+ "broadcastTimeout" -> toSQLConf(SQLConf.BROADCAST_TIMEOUT.key),
+ "autoBroadcastJoinThreshold" -> toSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key)),
+ cause = ex.orNull)
+ }
+
+ def cannotCompareCostWithTargetCostError(cost: String): SparkIllegalArgumentException = {
+ new SparkIllegalArgumentException(
+ errorClass = "_LEGACY_ERROR_TEMP_2098",
+ messageParameters = Map("cost" -> cost))
+ }
+
+ def notSupportTypeError(dataType: DataType): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2100",
+ messageParameters = Map("dataType" -> dataType.toString()),
+ cause = null)
+ }
+
+ def notSupportNonPrimitiveTypeError(): SparkRuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "_LEGACY_ERROR_TEMP_2101",
+ messageParameters = Map.empty)
+ }
+
+ def unsupportedTypeError(dataType: DataType): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2102",
+ messageParameters = Map("catalogString" -> dataType.catalogString),
+ cause = null)
+ }
+
+ def useDictionaryEncodingWhenDictionaryOverflowError(): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2103",
+ messageParameters = Map.empty,
+ cause = null)
+ }
+
+ def endOfIteratorError(): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2104",
+ messageParameters = Map.empty,
+ cause = null)
+ }
+
+ def cannotAllocateMemoryToGrowBytesToBytesMapError(): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2105",
+ messageParameters = Map.empty,
+ cause = null)
+ }
+
+ def cannotAcquireMemoryToBuildLongHashedRelationError(size: Long, got: Long): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2106",
+ messageParameters = Map("size" -> size.toString(), "got" -> got.toString()),
+ cause = null)
+ }
+
+ def cannotAcquireMemoryToBuildUnsafeHashedRelationError(): Throwable = {
+ new SparkOutOfMemoryError(
+ "_LEGACY_ERROR_TEMP_2107")
+ }
+
+ def rowLargerThan256MUnsupportedError(): SparkUnsupportedOperationException = {
+ new SparkUnsupportedOperationException(
+ errorClass = "_LEGACY_ERROR_TEMP_2108",
+ messageParameters = Map.empty)
+ }
+
+ def cannotBuildHashedRelationWithUniqueKeysExceededError(): SparkUnsupportedOperationException = {
+ new SparkUnsupportedOperationException(
+ errorClass = "_LEGACY_ERROR_TEMP_2109",
+ messageParameters = Map.empty)
+ }
+
+ def cannotBuildHashedRelationLargerThan8GError(): SparkUnsupportedOperationException = {
+ new SparkUnsupportedOperationException(
+ errorClass = "_LEGACY_ERROR_TEMP_2110",
+ messageParameters = Map.empty)
+ }
+
+ def failedToPushRowIntoRowQueueError(rowQueue: String): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2111",
+ messageParameters = Map("rowQueue" -> rowQueue),
+ cause = null)
+ }
+
+ def unexpectedWindowFunctionFrameError(frame: String): SparkRuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "_LEGACY_ERROR_TEMP_2112",
+ messageParameters = Map("frame" -> frame))
+ }
+
+ def cannotParseStatisticAsPercentileError(
+ stats: String, e: NumberFormatException): SparkIllegalArgumentException = {
+ new SparkIllegalArgumentException(
+ errorClass = "_LEGACY_ERROR_TEMP_2113",
+ messageParameters = Map("stats" -> stats),
+ cause = e)
+ }
+
+ def statisticNotRecognizedError(stats: String): SparkIllegalArgumentException = {
+ new SparkIllegalArgumentException(
+ errorClass = "_LEGACY_ERROR_TEMP_2114",
+ messageParameters = Map("stats" -> stats))
+ }
+
+ def unknownColumnError(unknownColumn: String): SparkIllegalArgumentException = {
+ new SparkIllegalArgumentException(
+ errorClass = "_LEGACY_ERROR_TEMP_2115",
+ messageParameters = Map("unknownColumn" -> unknownColumn))
+ }
+
+ def unexpectedAccumulableUpdateValueError(o: Any): SparkIllegalArgumentException = {
+ new SparkIllegalArgumentException(
+ errorClass = "_LEGACY_ERROR_TEMP_2116",
+ messageParameters = Map("o" -> o.toString()))
+ }
+
+ def unscaledValueTooLargeForPrecisionError(
+ value: Decimal,
+ decimalPrecision: Int,
+ decimalScale: Int,
+ context: SQLQueryContext = null): ArithmeticException = {
+ DataTypeErrors.unscaledValueTooLargeForPrecisionError(
+ value, decimalPrecision, decimalScale, context)
+ }
+
+ def decimalPrecisionExceedsMaxPrecisionError(
+ precision: Int, maxPrecision: Int): SparkArithmeticException = {
+ DataTypeErrors.decimalPrecisionExceedsMaxPrecisionError(precision, maxPrecision)
+ }
+
+ def outOfDecimalTypeRangeError(str: UTF8String): SparkArithmeticException = {
+ new SparkArithmeticException(
+ errorClass = "NUMERIC_OUT_OF_SUPPORTED_RANGE",
+ messageParameters = Map(
+ "value" -> str.toString),
+ context = Array.empty,
+ summary = "")
+ }
+
+ def unsupportedArrayTypeError(clazz: Class[_]): SparkRuntimeException = {
+ DataTypeErrors.unsupportedJavaTypeError(clazz)
+ }
+
+ def unsupportedJavaTypeError(clazz: Class[_]): SparkRuntimeException = {
+ DataTypeErrors.unsupportedJavaTypeError(clazz)
+ }
+
+ def failedParsingStructTypeError(raw: String): SparkRuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "FAILED_PARSE_STRUCT_TYPE",
+ messageParameters = Map("raw" -> toSQLValue(raw, StringType)))
+ }
+
+ def cannotMergeDecimalTypesWithIncompatibleScaleError(
+ leftScale: Int, rightScale: Int): Throwable = {
+ DataTypeErrors.cannotMergeDecimalTypesWithIncompatibleScaleError(leftScale, rightScale)
+ }
+
+ def cannotMergeIncompatibleDataTypesError(left: DataType, right: DataType): Throwable = {
+ DataTypeErrors.cannotMergeIncompatibleDataTypesError(left, right)
+ }
+
+ def exceedMapSizeLimitError(size: Int): SparkRuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "_LEGACY_ERROR_TEMP_2126",
+ messageParameters = Map(
+ "size" -> size.toString(),
+ "maxRoundedArrayLength" -> ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString()))
+ }
+
+ def duplicateMapKeyFoundError(key: Any): SparkRuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "DUPLICATED_MAP_KEY",
+ messageParameters = Map(
+ "key" -> key.toString(),
+ "mapKeyDedupPolicy" -> toSQLConf(SQLConf.MAP_KEY_DEDUP_POLICY.key)))
+ }
+
+ def mapDataKeyArrayLengthDiffersFromValueArrayLengthError(): SparkRuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "_LEGACY_ERROR_TEMP_2128",
+ messageParameters = Map.empty)
+ }
+
+ def registeringStreamingQueryListenerError(e: Exception): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2131",
+ messageParameters = Map.empty,
+ cause = e)
+ }
+
+ def concurrentQueryInstanceError(): Throwable = {
+ new SparkConcurrentModificationException(
+ errorClass = "CONCURRENT_QUERY",
+ messageParameters = Map.empty[String, String])
+ }
+
+ def concurrentStreamLogUpdate(batchId: Long): Throwable = {
+ new SparkException(
+ errorClass = "CONCURRENT_STREAM_LOG_UPDATE",
+ messageParameters = Map("batchId" -> batchId.toString),
+ cause = null)
+ }
+
+ def cannotParseJsonArraysAsStructsError(recordStr: String): SparkRuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "MALFORMED_RECORD_IN_PARSING.CANNOT_PARSE_JSON_ARRAYS_AS_STRUCTS",
+ messageParameters = Map(
+ "badRecord" -> recordStr,
+ "failFastMode" -> FailFastMode.name))
+ }
+
+ def cannotParseStringAsDataTypeError(parser: JsonParser, token: JsonToken, dataType: DataType)
+ : SparkRuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "_LEGACY_ERROR_TEMP_2133",
+ messageParameters = Map(
+ "fieldName" -> parser.getCurrentName,
+ "fieldValue" -> parser.getText,
+ "token" -> token.toString(),
+ "dataType" -> dataType.toString()))
+ }
+
+ def emptyJsonFieldValueError(dataType: DataType): SparkRuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "EMPTY_JSON_FIELD_VALUE",
+ messageParameters = Map("dataType" -> toSQLType(dataType)))
+ }
+
+ def cannotParseJSONFieldError(parser: JsonParser, jsonType: JsonToken, dataType: DataType)
+ : SparkRuntimeException = {
+ cannotParseJSONFieldError(parser.getCurrentName, parser.getText, jsonType, dataType)
+ }
+
+ def cannotParseJSONFieldError(
+ fieldName: String,
+ fieldValue: String,
+ jsonType: JsonToken,
+ dataType: DataType): SparkRuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "CANNOT_PARSE_JSON_FIELD",
+ messageParameters = Map(
+ "fieldName" -> toSQLValue(fieldName, StringType),
+ "fieldValue" -> fieldValue,
+ "jsonType" -> jsonType.toString(),
+ "dataType" -> toSQLType(dataType)))
+ }
+
+ def rootConverterReturnNullError(): SparkRuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "INVALID_JSON_ROOT_FIELD",
+ messageParameters = Map.empty)
+ }
+
+ def attributesForTypeUnsupportedError(schema: Schema): SparkUnsupportedOperationException = {
+ new SparkUnsupportedOperationException(
+ errorClass = "_LEGACY_ERROR_TEMP_2142",
+ messageParameters = Map(
+ "schema" -> schema.toString()))
+ }
+
+ def paramExceedOneCharError(paramName: String): SparkRuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "_LEGACY_ERROR_TEMP_2145",
+ messageParameters = Map(
+ "paramName" -> paramName))
+ }
+
+ def paramIsNotIntegerError(paramName: String, value: String): SparkRuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "_LEGACY_ERROR_TEMP_2146",
+ messageParameters = Map(
+ "paramName" -> paramName,
+ "value" -> value))
+ }
+
+ def paramIsNotBooleanValueError(paramName: String): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2147",
+ messageParameters = Map(
+ "paramName" -> paramName),
+ cause = null)
+ }
+
+ def foundNullValueForNotNullableFieldError(name: String): SparkRuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "_LEGACY_ERROR_TEMP_2148",
+ messageParameters = Map(
+ "name" -> name))
+ }
+
+ def malformedCSVRecordError(badRecord: String): SparkRuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "MALFORMED_CSV_RECORD",
+ messageParameters = Map("badRecord" -> badRecord))
+ }
+
+ def elementsOfTupleExceedLimitError(): SparkUnsupportedOperationException = {
+ new SparkUnsupportedOperationException(
+ errorClass = "_LEGACY_ERROR_TEMP_2150",
+ messageParameters = Map.empty)
+ }
+
+ def expressionDecodingError(e: Exception, expressions: Seq[Expression]): SparkRuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "_LEGACY_ERROR_TEMP_2151",
+ messageParameters = Map(
+ "e" -> e.toString(),
+ "expressions" -> expressions.map(
+ _.simpleString(SQLConf.get.maxToStringFields)).mkString("\n")),
+ cause = e)
+ }
+
+ def expressionEncodingError(e: Exception, expressions: Seq[Expression]): SparkRuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "_LEGACY_ERROR_TEMP_2152",
+ messageParameters = Map(
+ "e" -> e.toString(),
+ "expressions" -> expressions.map(
+ _.simpleString(SQLConf.get.maxToStringFields)).mkString("\n")),
+ cause = e)
+ }
+
+ def classHasUnexpectedSerializerError(
+ clsName: String, objSerializer: Expression): SparkRuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "_LEGACY_ERROR_TEMP_2153",
+ messageParameters = Map(
+ "clsName" -> clsName,
+ "objSerializer" -> objSerializer.toString()))
+ }
+
+ def unsupportedOperandTypeForSizeFunctionError(
+ dataType: DataType): SparkUnsupportedOperationException = {
+ new SparkUnsupportedOperationException(
+ errorClass = "_LEGACY_ERROR_TEMP_2156",
+ messageParameters = Map(
+ "dataType" -> dataType.getClass.getCanonicalName))
+ }
+
+ def unexpectedValueForStartInFunctionError(prettyName: String): SparkRuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "_LEGACY_ERROR_TEMP_2157",
+ messageParameters = Map(
+ "prettyName" -> prettyName))
+ }
+
+ def unexpectedValueForLengthInFunctionError(prettyName: String): SparkRuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "_LEGACY_ERROR_TEMP_2158",
+ messageParameters = Map(
+ "prettyName" -> prettyName))
+ }
+
+ def invalidIndexOfZeroError(context: SQLQueryContext): RuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "INVALID_INDEX_OF_ZERO",
+ cause = null,
+ messageParameters = Map.empty,
+ context = getQueryContext(context),
+ summary = getSummary(context))
+ }
+
+ def concatArraysWithElementsExceedLimitError(numberOfElements: Long): SparkRuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "_LEGACY_ERROR_TEMP_2159",
+ messageParameters = Map(
+ "numberOfElements" -> numberOfElements.toString(),
+ "maxRoundedArrayLength" -> ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString()))
+ }
+
+ def flattenArraysWithElementsExceedLimitError(numberOfElements: Long): SparkRuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "_LEGACY_ERROR_TEMP_2160",
+ messageParameters = Map(
+ "numberOfElements" -> numberOfElements.toString(),
+ "maxRoundedArrayLength" -> ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString()))
+ }
+
+ def createArrayWithElementsExceedLimitError(count: Any): SparkRuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "_LEGACY_ERROR_TEMP_2161",
+ messageParameters = Map(
+ "count" -> count.toString(),
+ "maxRoundedArrayLength" -> ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString()))
+ }
+
+ def unionArrayWithElementsExceedLimitError(length: Int): SparkRuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "_LEGACY_ERROR_TEMP_2162",
+ messageParameters = Map(
+ "length" -> length.toString(),
+ "maxRoundedArrayLength" -> ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString()))
+ }
+
+ def initialTypeNotTargetDataTypeError(
+ dataType: DataType, target: String): SparkUnsupportedOperationException = {
+ new SparkUnsupportedOperationException(
+ errorClass = "_LEGACY_ERROR_TEMP_2163",
+ messageParameters = Map(
+ "dataType" -> dataType.catalogString,
+ "target" -> target))
+ }
+
+ def initialTypeNotTargetDataTypesError(dataType: DataType): SparkUnsupportedOperationException = {
+ new SparkUnsupportedOperationException(
+ errorClass = "_LEGACY_ERROR_TEMP_2164",
+ messageParameters = Map(
+ "dataType" -> dataType.catalogString,
+ "arrayType" -> ArrayType.simpleString,
+ "structType" -> StructType.simpleString,
+ "mapType" -> MapType.simpleString))
+ }
+
+ def malformedRecordsDetectedInSchemaInferenceError(e: Throwable): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2165",
+ messageParameters = Map(
+ "failFastMode" -> FailFastMode.name),
+ cause = e)
+ }
+
+ def malformedJSONError(): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2166",
+ messageParameters = Map.empty,
+ cause = null)
+ }
+
+ def malformedRecordsDetectedInSchemaInferenceError(dataType: DataType): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2167",
+ messageParameters = Map(
+ "failFastMode" -> FailFastMode.name,
+ "dataType" -> dataType.catalogString),
+ cause = null)
+ }
+
+ def decorrelateInnerQueryThroughPlanUnsupportedError(
+ plan: LogicalPlan): SparkUnsupportedOperationException = {
+ new SparkUnsupportedOperationException(
+ errorClass = "_LEGACY_ERROR_TEMP_2168",
+ messageParameters = Map(
+ "plan" -> plan.nodeName))
+ }
+
+ def methodCalledInAnalyzerNotAllowedError(): SparkRuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "_LEGACY_ERROR_TEMP_2169",
+ messageParameters = Map.empty)
+ }
+
+ def cannotSafelyMergeSerdePropertiesError(
+ props1: Map[String, String],
+ props2: Map[String, String],
+ conflictKeys: Set[String]): SparkUnsupportedOperationException = {
+ new SparkUnsupportedOperationException(
+ errorClass = "_LEGACY_ERROR_TEMP_2170",
+ messageParameters = Map(
+ "props1" -> props1.map { case (k, v) => s"$k=$v" }.mkString("{", ",", "}"),
+ "props2" -> props2.map { case (k, v) => s"$k=$v" }.mkString("{", ",", "}"),
+ "conflictKeys" -> conflictKeys.mkString(", ")))
+ }
+
+ def pairUnsupportedAtFunctionError(
+ r1: ValueInterval,
+ r2: ValueInterval,
+ function: String): SparkUnsupportedOperationException = {
+ new SparkUnsupportedOperationException(
+ errorClass = "_LEGACY_ERROR_TEMP_2171",
+ messageParameters = Map(
+ "r1" -> r1.toString(),
+ "r2" -> r2.toString(),
+ "function" -> function))
+ }
+
+ def onceStrategyIdempotenceIsBrokenForBatchError[TreeType <: TreeNode[_]](
+ batchName: String, plan: TreeType, reOptimized: TreeType): SparkRuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "_LEGACY_ERROR_TEMP_2172",
+ messageParameters = Map(
+ "batchName" -> batchName,
+ "plan" -> sideBySide(plan.treeString, reOptimized.treeString).mkString("\n")))
+ }
+
+ def ruleIdNotFoundForRuleError(ruleName: String): Throwable = {
+ new SparkException(
+ errorClass = "RULE_ID_NOT_FOUND",
+ messageParameters = Map("ruleName" -> ruleName),
+ cause = null)
+ }
+
+ def cannotCreateArrayWithElementsExceedLimitError(
+ numElements: Long, additionalErrorMessage: String): SparkRuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "_LEGACY_ERROR_TEMP_2176",
+ messageParameters = Map(
+ "numElements" -> numElements.toString(),
+ "maxRoundedArrayLength"-> ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString(),
+ "additionalErrorMessage" -> additionalErrorMessage))
+ }
+
+ def malformedRecordsDetectedInRecordParsingError(
+ badRecord: String, e: BadRecordException): Throwable = {
+ new SparkException(
+ errorClass = "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION",
+ messageParameters = Map(
+ "badRecord" -> badRecord,
+ "failFastMode" -> FailFastMode.name),
+ cause = e)
+ }
+
+ def remoteOperationsUnsupportedError(): SparkRuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "_LEGACY_ERROR_TEMP_2178",
+ messageParameters = Map.empty)
+ }
+
+ def invalidKerberosConfigForHiveServer2Error(): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2179",
+ messageParameters = Map.empty,
+ cause = null)
+ }
+
+ def parentSparkUIToAttachTabNotFoundError(): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2180",
+ messageParameters = Map.empty,
+ cause = null)
+ }
+
+ def inferSchemaUnsupportedForHiveError(): SparkUnsupportedOperationException = {
+ new SparkUnsupportedOperationException(
+ errorClass = "_LEGACY_ERROR_TEMP_2181",
+ messageParameters = Map.empty)
+ }
+
+ def requestedPartitionsMismatchTablePartitionsError(
+ table: CatalogTable, partition: Map[String, Option[String]]): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2182",
+ messageParameters = Map(
+ "tableIdentifier" -> table.identifier.table,
+ "partitionKeys" -> partition.keys.mkString(","),
+ "partitionColumnNames" -> table.partitionColumnNames.mkString(",")),
+ cause = null)
+ }
+
+ def dynamicPartitionKeyNotAmongWrittenPartitionPathsError(key: String): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2183",
+ messageParameters = Map(
+ "key" -> toSQLValue(key, StringType)),
+ cause = null)
+ }
+
+ def cannotRemovePartitionDirError(partitionPath: Path): SparkRuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "_LEGACY_ERROR_TEMP_2184",
+ messageParameters = Map(
+ "partitionPath" -> partitionPath.toString()))
+ }
+
+ def cannotCreateStagingDirError(message: String, e: IOException): SparkRuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "_LEGACY_ERROR_TEMP_2185",
+ messageParameters = Map(
+ "message" -> message),
+ cause = e)
+ }
+
+ def serDeInterfaceNotFoundError(e: NoClassDefFoundError): SparkClassNotFoundException = {
+ new SparkClassNotFoundException(
+ errorClass = "_LEGACY_ERROR_TEMP_2186",
+ messageParameters = Map.empty,
+ cause = e)
+ }
+
+ def convertHiveTableToCatalogTableError(
+ e: SparkException, dbName: String, tableName: String): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2187",
+ messageParameters = Map(
+ "message" -> e.getMessage,
+ "dbName" -> dbName,
+ "tableName" -> tableName),
+ cause = e)
+ }
+
+ def cannotRecognizeHiveTypeError(
+ e: ParseException, fieldType: String, fieldName: String): Throwable = {
+ new SparkException(
+ errorClass = "CANNOT_RECOGNIZE_HIVE_TYPE",
+ messageParameters = Map(
+ "fieldType" -> toSQLType(fieldType),
+ "fieldName" -> toSQLId(fieldName)),
+ cause = e)
+ }
+
+ def getTablesByTypeUnsupportedByHiveVersionError(): SparkUnsupportedOperationException = {
+ new SparkUnsupportedOperationException(
+ errorClass = "_LEGACY_ERROR_TEMP_2189",
+ messageParameters = Map.empty)
+ }
+
+ def dropTableWithPurgeUnsupportedError(): SparkUnsupportedOperationException = {
+ new SparkUnsupportedOperationException(
+ errorClass = "_LEGACY_ERROR_TEMP_2190",
+ messageParameters = Map.empty)
+ }
+
+ def alterTableWithDropPartitionAndPurgeUnsupportedError(): SparkUnsupportedOperationException = {
+ new SparkUnsupportedOperationException(
+ errorClass = "_LEGACY_ERROR_TEMP_2191",
+ messageParameters = Map.empty)
+ }
+
+ def invalidPartitionFilterError(): SparkUnsupportedOperationException = {
+ new SparkUnsupportedOperationException(
+ errorClass = "_LEGACY_ERROR_TEMP_2192",
+ messageParameters = Map.empty)
+ }
+
+ def getPartitionMetadataByFilterError(e: InvocationTargetException): SparkRuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "_LEGACY_ERROR_TEMP_2193",
+ messageParameters = Map(
+ "hiveMetastorePartitionPruningFallbackOnException" ->
+ SQLConf.HIVE_METASTORE_PARTITION_PRUNING_FALLBACK_ON_EXCEPTION.key),
+ cause = e)
+ }
+
+ def unsupportedHiveMetastoreVersionError(
+ version: String, key: String): SparkUnsupportedOperationException = {
+ new SparkUnsupportedOperationException(
+ errorClass = "_LEGACY_ERROR_TEMP_2194",
+ messageParameters = Map(
+ "version" -> version,
+ "key" -> key))
+ }
+
+ def loadHiveClientCausesNoClassDefFoundError(
+ cnf: NoClassDefFoundError,
+ execJars: Seq[URL],
+ key: String,
+ e: InvocationTargetException): SparkClassNotFoundException = {
+ new SparkClassNotFoundException(
+ errorClass = "_LEGACY_ERROR_TEMP_2195",
+ messageParameters = Map(
+ "cnf" -> cnf.toString(),
+ "execJars" -> execJars.mkString(", "),
+ "key" -> key),
+ cause = e)
+ }
+
+ def cannotFetchTablesOfDatabaseError(dbName: String, e: Exception): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2196",
+ messageParameters = Map(
+ "dbName" -> dbName),
+ cause = e)
+ }
+
+ def illegalLocationClauseForViewPartitionError(): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2197",
+ messageParameters = Map.empty,
+ cause = null)
+ }
+
+ def renamePathAsExistsPathError(srcPath: Path, dstPath: Path): Throwable = {
+ new SparkFileAlreadyExistsException(
+ errorClass = "FAILED_RENAME_PATH",
+ messageParameters = Map(
+ "sourcePath" -> srcPath.toString,
+ "targetPath" -> dstPath.toString))
+ }
+
+ def renameAsExistsPathError(dstPath: Path): SparkFileAlreadyExistsException = {
+ new SparkFileAlreadyExistsException(
+ errorClass = "_LEGACY_ERROR_TEMP_2198",
+ messageParameters = Map(
+ "dstPath" -> dstPath.toString()))
+ }
+
+ def renameSrcPathNotFoundError(srcPath: Path): Throwable = {
+ new SparkFileNotFoundException(
+ errorClass = "RENAME_SRC_PATH_NOT_FOUND",
+ messageParameters = Map("sourcePath" -> srcPath.toString))
+ }
+
+ def failedRenameTempFileError(srcPath: Path, dstPath: Path): Throwable = {
+ new SparkException(
+ errorClass = "FAILED_RENAME_TEMP_FILE",
+ messageParameters = Map(
+ "srcPath" -> srcPath.toString(),
+ "dstPath" -> dstPath.toString()),
+ cause = null)
+ }
+
+ def legacyMetadataPathExistsError(metadataPath: Path, legacyMetadataPath: Path): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2200",
+ messageParameters = Map(
+ "metadataPath" -> metadataPath.toString(),
+ "legacyMetadataPath" -> legacyMetadataPath.toString(),
+ "StreamingCheckpointEscaptedPathCheckEnabled" ->
+ SQLConf.STREAMING_CHECKPOINT_ESCAPED_PATH_CHECK_ENABLED.key),
+ cause = null)
+ }
+
+ def partitionColumnNotFoundInSchemaError(
+ col: String, schema: StructType): SparkRuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "_LEGACY_ERROR_TEMP_2201",
+ messageParameters = Map(
+ "col" -> col,
+ "schema" -> schema.toString()))
+ }
+
+ def stateNotDefinedOrAlreadyRemovedError(): Throwable = {
+ new NoSuchElementException("State is either not defined or has already been removed")
+ }
+
+ def cannotSetTimeoutDurationError(): SparkUnsupportedOperationException = {
+ new SparkUnsupportedOperationException(
+ errorClass = "_LEGACY_ERROR_TEMP_2203",
+ messageParameters = Map.empty)
+ }
+
+ def cannotGetEventTimeWatermarkError(): SparkUnsupportedOperationException = {
+ new SparkUnsupportedOperationException(
+ errorClass = "_LEGACY_ERROR_TEMP_2204",
+ messageParameters = Map.empty)
+ }
+
+ def cannotSetTimeoutTimestampError(): SparkUnsupportedOperationException = {
+ new SparkUnsupportedOperationException(
+ errorClass = "_LEGACY_ERROR_TEMP_2205",
+ messageParameters = Map.empty)
+ }
+
+ def batchMetadataFileNotFoundError(batchMetadataFile: Path): SparkFileNotFoundException = {
+ new SparkFileNotFoundException(
+ errorClass = "BATCH_METADATA_NOT_FOUND",
+ messageParameters = Map(
+ "batchMetadataFile" -> batchMetadataFile.toString()))
+ }
+
+ def multiStreamingQueriesUsingPathConcurrentlyError(
+ path: String, e: FileAlreadyExistsException): SparkConcurrentModificationException = {
+ new SparkConcurrentModificationException(
+ errorClass = "_LEGACY_ERROR_TEMP_2207",
+ messageParameters = Map(
+ "path" -> path),
+ cause = e)
+ }
+
+ def addFilesWithAbsolutePathUnsupportedError(
+ commitProtocol: String): SparkUnsupportedOperationException = {
+ new SparkUnsupportedOperationException(
+ errorClass = "_LEGACY_ERROR_TEMP_2208",
+ messageParameters = Map(
+ "commitProtocol" -> commitProtocol))
+ }
+
+ def microBatchUnsupportedByDataSourceError(
+ srcName: String,
+ disabledSources: String,
+ table: Table): SparkUnsupportedOperationException = {
+ new SparkUnsupportedOperationException(
+ errorClass = "_LEGACY_ERROR_TEMP_2209",
+ messageParameters = Map(
+ "srcName" -> srcName,
+ "disabledSources" -> disabledSources,
+ "table" -> table.toString()))
+ }
+
+ def cannotExecuteStreamingRelationExecError(): SparkUnsupportedOperationException = {
+ new SparkUnsupportedOperationException(
+ errorClass = "_LEGACY_ERROR_TEMP_2210",
+ messageParameters = Map.empty)
+ }
+
+ def invalidStreamingOutputModeError(
+ outputMode: Option[OutputMode]): SparkUnsupportedOperationException = {
+ new SparkUnsupportedOperationException(
+ errorClass = "_LEGACY_ERROR_TEMP_2211",
+ messageParameters = Map(
+ "outputMode" -> outputMode.toString()))
+ }
+
+ def invalidCatalogNameError(name: String): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2212",
+ messageParameters = Map(
+ "name" -> name),
+ cause = null)
+ }
+
+ def catalogPluginClassNotFoundError(name: String): Throwable = {
+ new CatalogNotFoundException(
+ s"Catalog '$name' plugin class not found: spark.sql.catalog.$name is not defined")
+ }
+
+ def catalogPluginClassNotImplementedError(name: String, pluginClassName: String): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2214",
+ messageParameters = Map(
+ "name" -> name,
+ "pluginClassName" -> pluginClassName),
+ cause = null)
+ }
+
+ def catalogPluginClassNotFoundForCatalogError(
+ name: String,
+ pluginClassName: String,
+ e: Exception): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2215",
+ messageParameters = Map(
+ "name" -> name,
+ "pluginClassName" -> pluginClassName),
+ cause = e)
+ }
+
+ def catalogFailToFindPublicNoArgConstructorError(
+ name: String,
+ pluginClassName: String,
+ e: Exception): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2216",
+ messageParameters = Map(
+ "name" -> name,
+ "pluginClassName" -> pluginClassName),
+ cause = e)
+ }
+
+ def catalogFailToCallPublicNoArgConstructorError(
+ name: String,
+ pluginClassName: String,
+ e: Exception): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2217",
+ messageParameters = Map(
+ "name" -> name,
+ "pluginClassName" -> pluginClassName),
+ cause = e)
+ }
+
+ def cannotInstantiateAbstractCatalogPluginClassError(
+ name: String,
+ pluginClassName: String,
+ e: Exception): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2218",
+ messageParameters = Map(
+ "name" -> name,
+ "pluginClassName" -> pluginClassName),
+ cause = e.getCause)
+ }
+
+ def failedToInstantiateConstructorForCatalogError(
+ name: String,
+ pluginClassName: String,
+ e: Exception): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2219",
+ messageParameters = Map(
+ "name" -> name,
+ "pluginClassName" -> pluginClassName),
+ cause = e.getCause)
+ }
+
+ def noSuchElementExceptionError(): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2220",
+ messageParameters = Map.empty,
+ cause = null)
+ }
+
+ def sqlConfigNotFoundError(key: String): SparkNoSuchElementException = {
+ new SparkNoSuchElementException(
+ errorClass = "SQL_CONF_NOT_FOUND",
+ messageParameters = Map("sqlConf" -> toSQLConf(key)))
+ }
+
+ def cannotMutateReadOnlySQLConfError(): SparkUnsupportedOperationException = {
+ new SparkUnsupportedOperationException(
+ errorClass = "_LEGACY_ERROR_TEMP_2222",
+ messageParameters = Map.empty)
+ }
+
+ def cannotCloneOrCopyReadOnlySQLConfError(): SparkUnsupportedOperationException = {
+ new SparkUnsupportedOperationException(
+ errorClass = "_LEGACY_ERROR_TEMP_2223",
+ messageParameters = Map.empty)
+ }
+
+ def cannotGetSQLConfInSchedulerEventLoopThreadError(): SparkRuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "_LEGACY_ERROR_TEMP_2224",
+ messageParameters = Map.empty,
+ cause = null)
+ }
+
+ def unsupportedOperationExceptionError(): SparkUnsupportedOperationException = {
+ DataTypeErrors.unsupportedOperationExceptionError()
+ }
+
+ def nullLiteralsCannotBeCastedError(name: String): SparkUnsupportedOperationException = {
+ DataTypeErrors.nullLiteralsCannotBeCastedError(name)
+ }
+
+ def notUserDefinedTypeError(name: String, userClass: String): Throwable = {
+ DataTypeErrors.notUserDefinedTypeError(name, userClass)
+ }
+
+ def cannotLoadUserDefinedTypeError(name: String, userClass: String): Throwable = {
+ DataTypeErrors.cannotLoadUserDefinedTypeError(name, userClass)
+ }
+
+ def notPublicClassError(name: String): SparkUnsupportedOperationException = {
+ new SparkUnsupportedOperationException(
+ errorClass = "_LEGACY_ERROR_TEMP_2229",
+ messageParameters = Map(
+ "name" -> name))
+ }
+
+ def primitiveTypesNotSupportedError(): SparkUnsupportedOperationException = {
+ new SparkUnsupportedOperationException(
+ errorClass = "_LEGACY_ERROR_TEMP_2230",
+ messageParameters = Map.empty)
+ }
+
+ def fieldIndexOnRowWithoutSchemaError(): SparkUnsupportedOperationException = {
+ DataTypeErrors.fieldIndexOnRowWithoutSchemaError()
+ }
+
+ def valueIsNullError(index: Int): Throwable = {
+ DataTypeErrors.valueIsNullError(index)
+ }
+
+ def onlySupportDataSourcesProvidingFileFormatError(providingClass: String): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2233",
+ messageParameters = Map(
+ "providingClass" -> providingClass),
+ cause = null)
+ }
+
+ def cannotRestorePermissionsForPathError(permission: FsPermission, path: Path): Throwable = {
+ new SparkSecurityException(
+ errorClass = "CANNOT_RESTORE_PERMISSIONS_FOR_PATH",
+ messageParameters = Map(
+ "permission" -> permission.toString,
+ "path" -> path.toString))
+ }
+
+ def failToSetOriginalACLBackError(
+ aclEntries: String, path: Path, e: Throwable): SparkSecurityException = {
+ new SparkSecurityException(
+ errorClass = "_LEGACY_ERROR_TEMP_2234",
+ messageParameters = Map(
+ "aclEntries" -> aclEntries,
+ "path" -> path.toString(),
+ "message" -> e.getMessage))
+ }
+
+ def multiFailuresInStageMaterializationError(error: Throwable): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2235",
+ messageParameters = Map.empty,
+ cause = error)
+ }
+
+ def unrecognizedCompressionSchemaTypeIDError(typeId: Int): SparkUnsupportedOperationException = {
+ new SparkUnsupportedOperationException(
+ errorClass = "_LEGACY_ERROR_TEMP_2236",
+ messageParameters = Map(
+ "typeId" -> typeId.toString()))
+ }
+
+ def getParentLoggerNotImplementedError(
+ className: String): SparkSQLFeatureNotSupportedException = {
+ new SparkSQLFeatureNotSupportedException(
+ errorClass = "_LEGACY_ERROR_TEMP_2237",
+ messageParameters = Map(
+ "className" -> className))
+ }
+
+ def cannotCreateParquetConverterForTypeError(
+ t: DecimalType, parquetType: String): SparkRuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "_LEGACY_ERROR_TEMP_2238",
+ messageParameters = Map(
+ "typeName" -> t.typeName,
+ "parquetType" -> parquetType))
+ }
+
+ def cannotCreateParquetConverterForDecimalTypeError(
+ t: DecimalType, parquetType: String): SparkRuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "_LEGACY_ERROR_TEMP_2239",
+ messageParameters = Map(
+ "t" -> t.json,
+ "parquetType" -> parquetType))
+ }
+
+ def cannotCreateParquetConverterForDataTypeError(
+ t: DataType, parquetType: String): SparkRuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "_LEGACY_ERROR_TEMP_2240",
+ messageParameters = Map(
+ "t" -> t.json,
+ "parquetType" -> parquetType))
+ }
+
+ def cannotAddMultiPartitionsOnNonatomicPartitionTableError(
+ tableName: String): SparkUnsupportedOperationException = {
+ new SparkUnsupportedOperationException(
+ errorClass = "_LEGACY_ERROR_TEMP_2241",
+ messageParameters = Map(
+ "tableName" -> tableName))
+ }
+
+ def userSpecifiedSchemaUnsupportedByDataSourceError(
+ provider: TableProvider): SparkUnsupportedOperationException = {
+ new SparkUnsupportedOperationException(
+ errorClass = "_LEGACY_ERROR_TEMP_2242",
+ messageParameters = Map(
+ "provider" -> provider.getClass.getSimpleName))
+ }
+
+ def cannotDropMultiPartitionsOnNonatomicPartitionTableError(
+ tableName: String): SparkUnsupportedOperationException = {
+ new SparkUnsupportedOperationException(
+ errorClass = "_LEGACY_ERROR_TEMP_2243",
+ messageParameters = Map(
+ "tableName" -> tableName))
+ }
+
+ def truncateMultiPartitionUnsupportedError(
+ tableName: String): SparkUnsupportedOperationException = {
+ new SparkUnsupportedOperationException(
+ errorClass = "_LEGACY_ERROR_TEMP_2244",
+ messageParameters = Map(
+ "tableName" -> tableName))
+ }
+
+ def overwriteTableByUnsupportedExpressionError(table: Table): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2245",
+ messageParameters = Map(
+ "table" -> table.toString()),
+ cause = null)
+ }
+
+ def dynamicPartitionOverwriteUnsupportedByTableError(table: Table): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2246",
+ messageParameters = Map(
+ "table" -> table.toString()),
+ cause = null)
+ }
+
+ def failedMergingSchemaError(
+ leftSchema: StructType,
+ rightSchema: StructType,
+ e: SparkException): Throwable = {
+ new SparkException(
+ errorClass = "CANNOT_MERGE_SCHEMAS",
+ messageParameters = Map("left" -> toSQLType(leftSchema), "right" -> toSQLType(rightSchema)),
+ cause = e)
+ }
+
+ def cannotBroadcastTableOverMaxTableRowsError(
+ maxBroadcastTableRows: Long, numRows: Long): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2248",
+ messageParameters = Map(
+ "maxBroadcastTableRows" -> maxBroadcastTableRows.toString(),
+ "numRows" -> numRows.toString()),
+ cause = null)
+ }
+
+ def cannotBroadcastTableOverMaxTableBytesError(
+ maxBroadcastTableBytes: Long, dataSize: Long): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2249",
+ messageParameters = Map(
+ "maxBroadcastTableBytes" -> Utils.bytesToString(maxBroadcastTableBytes),
+ "dataSize" -> Utils.bytesToString(dataSize)),
+ cause = null)
+ }
+
+ def notEnoughMemoryToBuildAndBroadcastTableError(
+ oe: OutOfMemoryError, tables: Seq[TableIdentifier]): Throwable = {
+ val analyzeTblMsg = if (tables.nonEmpty) {
+ " or analyze these tables through: " +
+ s"${tables.map(t => s"ANALYZE TABLE $t COMPUTE STATISTICS;").mkString(" ")}."
+ } else {
+ "."
+ }
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2250",
+ messageParameters = Map(
+ "autoBroadcastjoinThreshold" -> SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key,
+ "driverMemory" -> SparkLauncher.DRIVER_MEMORY,
+ "analyzeTblMsg" -> analyzeTblMsg),
+ cause = oe.getCause)
+ }
+
+ def executeCodePathUnsupportedError(execName: String): SparkUnsupportedOperationException = {
+ new SparkUnsupportedOperationException(
+ errorClass = "_LEGACY_ERROR_TEMP_2251",
+ messageParameters = Map(
+ "execName" -> execName))
+ }
+
+ def cannotMergeClassWithOtherClassError(
+ className: String, otherClass: String): SparkUnsupportedOperationException = {
+ new SparkUnsupportedOperationException(
+ errorClass = "_LEGACY_ERROR_TEMP_2252",
+ messageParameters = Map(
+ "className" -> className,
+ "otherClass" -> otherClass))
+ }
+
+ def continuousProcessingUnsupportedByDataSourceError(
+ sourceName: String): SparkUnsupportedOperationException = {
+ new SparkUnsupportedOperationException(
+ errorClass = "_LEGACY_ERROR_TEMP_2253",
+ messageParameters = Map(
+ "sourceName" -> sourceName))
+ }
+
+ def failedToReadDataError(failureReason: Throwable): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2254",
+ messageParameters = Map.empty,
+ cause = failureReason)
+ }
+
+ def failedToGenerateEpochMarkerError(failureReason: Throwable): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2255",
+ messageParameters = Map.empty,
+ cause = failureReason)
+ }
+
+ def foreachWriterAbortedDueToTaskFailureError(): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2256",
+ messageParameters = Map.empty,
+ cause = null)
+ }
+
+ def incorrectRampUpRate(rowsPerSecond: Long,
+ maxSeconds: Long,
+ rampUpTimeSeconds: Long): Throwable = {
+ new SparkRuntimeException(
+ errorClass = "INCORRECT_RAMP_UP_RATE",
+ messageParameters = Map(
+ "rowsPerSecond" -> rowsPerSecond.toString,
+ "maxSeconds" -> maxSeconds.toString,
+ "rampUpTimeSeconds" -> rampUpTimeSeconds.toString
+ ))
+ }
+
+ def incorrectEndOffset(rowsPerSecond: Long,
+ maxSeconds: Long,
+ endSeconds: Long): Throwable = {
+ SparkException.internalError(
+ s"Max offset with ${rowsPerSecond.toString} rowsPerSecond is ${maxSeconds.toString}, " +
+ s"but it's ${endSeconds.toString} now.")
+ }
+
+ def failedToReadDeltaFileError(fileToRead: Path, clazz: String, keySize: Int): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2258",
+ messageParameters = Map(
+ "fileToRead" -> fileToRead.toString(),
+ "clazz" -> clazz,
+ "keySize" -> keySize.toString()),
+ cause = null)
+ }
+
+ def failedToReadSnapshotFileError(fileToRead: Path, clazz: String, message: String): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2259",
+ messageParameters = Map(
+ "fileToRead" -> fileToRead.toString(),
+ "clazz" -> clazz,
+ "message" -> message),
+ cause = null)
+ }
+
+ def cannotPurgeAsBreakInternalStateError(): SparkUnsupportedOperationException = {
+ new SparkUnsupportedOperationException(
+ errorClass = "_LEGACY_ERROR_TEMP_2260",
+ messageParameters = Map.empty)
+ }
+
+ def cleanUpSourceFilesUnsupportedError(): SparkUnsupportedOperationException = {
+ new SparkUnsupportedOperationException(
+ errorClass = "_LEGACY_ERROR_TEMP_2261",
+ messageParameters = Map.empty)
+ }
+
+ def latestOffsetNotCalledError(): SparkUnsupportedOperationException = {
+ new SparkUnsupportedOperationException(
+ errorClass = "_LEGACY_ERROR_TEMP_2262",
+ messageParameters = Map.empty)
+ }
+
+ def legacyCheckpointDirectoryExistsError(
+ checkpointPath: Path, legacyCheckpointDir: String): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2263",
+ messageParameters = Map(
+ "checkpointPath" -> checkpointPath.toString(),
+ "legacyCheckpointDir" -> legacyCheckpointDir,
+ "StreamingCheckpointEscapedPathCheckEnabled"
+ -> SQLConf.STREAMING_CHECKPOINT_ESCAPED_PATH_CHECK_ENABLED.key),
+ cause = null)
+ }
+
+ def subprocessExitedError(
+ exitCode: Int, stderrBuffer: CircularBuffer, cause: Throwable): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2264",
+ messageParameters = Map(
+ "exitCode" -> exitCode.toString(),
+ "stderrBuffer" -> stderrBuffer.toString()),
+ cause = cause)
+ }
+
+ def outputDataTypeUnsupportedByNodeWithoutSerdeError(
+ nodeName: String, dt: DataType): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2265",
+ messageParameters = Map(
+ "nodeName" -> nodeName,
+ "dt" -> dt.getClass.getSimpleName),
+ cause = null)
+ }
+
+ def invalidStartIndexError(numRows: Int, startIndex: Int): SparkArrayIndexOutOfBoundsException = {
+ new SparkArrayIndexOutOfBoundsException(
+ errorClass = "_LEGACY_ERROR_TEMP_2266",
+ messageParameters = Map(
+ "numRows" -> numRows.toString(),
+ "startIndex" -> startIndex.toString()),
+ context = Array.empty,
+ summary = "")
+ }
+
+ def concurrentModificationOnExternalAppendOnlyUnsafeRowArrayError(
+ className: String): SparkConcurrentModificationException = {
+ new SparkConcurrentModificationException(
+ errorClass = "_LEGACY_ERROR_TEMP_2267",
+ messageParameters = Map(
+ "className" -> className))
+ }
+
+ def doExecuteBroadcastNotImplementedError(
+ nodeName: String): SparkUnsupportedOperationException = {
+ new SparkUnsupportedOperationException(
+ errorClass = "_LEGACY_ERROR_TEMP_2268",
+ messageParameters = Map(
+ "nodeName" -> nodeName))
+ }
+
+ def defaultDatabaseNotExistsError(defaultDatabase: String): Throwable = {
+ new SparkException(
+ errorClass = "DEFAULT_DATABASE_NOT_EXISTS",
+ messageParameters = Map("defaultDatabase" -> defaultDatabase),
+ cause = null
+ )
+ }
+
+ def databaseNameConflictWithSystemPreservedDatabaseError(globalTempDB: String): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2269",
+ messageParameters = Map(
+ "globalTempDB" -> globalTempDB,
+ "globalTempDatabase" -> GLOBAL_TEMP_DATABASE.key),
+ cause = null)
+ }
+
+ def commentOnTableUnsupportedError(): SparkSQLFeatureNotSupportedException = {
+ new SparkSQLFeatureNotSupportedException(
+ errorClass = "_LEGACY_ERROR_TEMP_2270",
+ messageParameters = Map.empty)
+ }
+
+ def unsupportedUpdateColumnNullabilityError(): SparkSQLFeatureNotSupportedException = {
+ new SparkSQLFeatureNotSupportedException(
+ errorClass = "_LEGACY_ERROR_TEMP_2271",
+ messageParameters = Map.empty)
+ }
+
+ def renameColumnUnsupportedForOlderMySQLError(): SparkSQLFeatureNotSupportedException = {
+ new SparkSQLFeatureNotSupportedException(
+ errorClass = "_LEGACY_ERROR_TEMP_2272",
+ messageParameters = Map.empty)
+ }
+
+ def failedToExecuteQueryError(e: Throwable): SparkException = {
+ val message = "Hit an error when executing a query" +
+ (if (e.getMessage == null) "" else s": ${e.getMessage}")
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2273",
+ messageParameters = Map(
+ "message" -> message),
+ cause = e)
+ }
+
+ def nestedFieldUnsupportedError(colName: String): SparkUnsupportedOperationException = {
+ new SparkUnsupportedOperationException(
+ errorClass = "UNSUPPORTED_FEATURE.REPLACE_NESTED_COLUMN",
+ messageParameters = Map("colName" -> toSQLId(colName)))
+ }
+
+ def transformationsAndActionsNotInvokedByDriverError(): Throwable = {
+ new SparkException(
+ errorClass = "CANNOT_INVOKE_IN_TRANSFORMATIONS",
+ messageParameters = Map.empty,
+ cause = null)
+ }
+
+ def repeatedPivotsUnsupportedError(clause: String, operation: String): Throwable = {
+ new SparkUnsupportedOperationException(
+ errorClass = "REPEATED_CLAUSE",
+ messageParameters = Map("clause" -> clause, "operation" -> operation))
+ }
+
+ def pivotNotAfterGroupByUnsupportedError(): Throwable = {
+ new SparkUnsupportedOperationException(
+ errorClass = "UNSUPPORTED_FEATURE.PIVOT_AFTER_GROUP_BY",
+ messageParameters = Map.empty[String, String])
+ }
+
+ private val aesFuncName = toSQLId("aes_encrypt") + "/" + toSQLId("aes_decrypt")
+
+ def invalidAesKeyLengthError(actualLength: Int): RuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "INVALID_PARAMETER_VALUE.AES_KEY_LENGTH",
+ messageParameters = Map(
+ "parameter" -> toSQLId("key"),
+ "functionName" -> aesFuncName,
+ "actualLength" -> actualLength.toString()))
+ }
+
+ def aesModeUnsupportedError(mode: String, padding: String): RuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "UNSUPPORTED_FEATURE.AES_MODE",
+ messageParameters = Map(
+ "mode" -> mode,
+ "padding" -> padding,
+ "functionName" -> aesFuncName))
+ }
+
+ def aesCryptoError(detailMessage: String): RuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "INVALID_PARAMETER_VALUE.AES_CRYPTO_ERROR",
+ messageParameters = Map(
+ "parameter" -> (toSQLId("expr") + ", " + toSQLId("key")),
+ "functionName" -> aesFuncName,
+ "detailMessage" -> detailMessage))
+ }
+
+ def invalidAesIvLengthError(mode: String, actualLength: Int): RuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "INVALID_PARAMETER_VALUE.AES_IV_LENGTH",
+ messageParameters = Map(
+ "mode" -> mode,
+ "parameter" -> toSQLId("iv"),
+ "functionName" -> aesFuncName,
+ "actualLength" -> actualLength.toString()))
+ }
+
+ def aesUnsupportedIv(mode: String): RuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "UNSUPPORTED_FEATURE.AES_MODE_IV",
+ messageParameters = Map(
+ "mode" -> mode,
+ "functionName" -> toSQLId("aes_encrypt")))
+ }
+
+ def aesUnsupportedAad(mode: String): RuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "UNSUPPORTED_FEATURE.AES_MODE_AAD",
+ messageParameters = Map(
+ "mode" -> mode,
+ "functionName" -> toSQLId("aes_encrypt")))
+ }
+
+ def hiveTableWithAnsiIntervalsError(
+ table: TableIdentifier): SparkUnsupportedOperationException = {
+ new SparkUnsupportedOperationException(
+ errorClass = "UNSUPPORTED_FEATURE.HIVE_WITH_ANSI_INTERVALS",
+ messageParameters = Map("tableName" -> toSQLId(table.nameParts)))
+ }
+
+ def cannotConvertOrcTimestampToTimestampNTZError(): Throwable = {
+ new SparkUnsupportedOperationException(
+ errorClass = "UNSUPPORTED_FEATURE.ORC_TYPE_CAST",
+ messageParameters = Map(
+ "orcType" -> toSQLType(TimestampType),
+ "toType" -> toSQLType(TimestampNTZType)))
+ }
+
+ def cannotConvertOrcTimestampNTZToTimestampLTZError(): Throwable = {
+ new SparkUnsupportedOperationException(
+ errorClass = "UNSUPPORTED_FEATURE.ORC_TYPE_CAST",
+ messageParameters = Map(
+ "orcType" -> toSQLType(TimestampNTZType),
+ "toType" -> toSQLType(TimestampType)))
+ }
+
+ def writePartitionExceedConfigSizeWhenDynamicPartitionError(
+ numWrittenParts: Int,
+ maxDynamicPartitions: Int,
+ maxDynamicPartitionsKey: String): Throwable = {
+ new SparkException(
+ errorClass = "_LEGACY_ERROR_TEMP_2277",
+ messageParameters = Map(
+ "numWrittenParts" -> numWrittenParts.toString(),
+ "maxDynamicPartitionsKey" -> maxDynamicPartitionsKey,
+ "maxDynamicPartitions" -> maxDynamicPartitions.toString(),
+ "numWrittenParts" -> numWrittenParts.toString()),
+ cause = null)
+ }
+
+ def invalidNumberFormatError(
+ dataType: DataType, input: String, format: String): SparkIllegalArgumentException = {
+ new SparkIllegalArgumentException(
+ errorClass = "INVALID_FORMAT.MISMATCH_INPUT",
+ messageParameters = Map(
+ "inputType" -> toSQLType(dataType),
+ "input" -> input,
+ "format" -> format))
+ }
+
+ def unsupportedMultipleBucketTransformsError(): SparkUnsupportedOperationException = {
+ new SparkUnsupportedOperationException(
+ errorClass = "UNSUPPORTED_FEATURE.MULTIPLE_BUCKET_TRANSFORMS",
+ messageParameters = Map.empty)
+ }
+
+ def unsupportedCommentNamespaceError(
+ namespace: String): SparkSQLFeatureNotSupportedException = {
+ new SparkSQLFeatureNotSupportedException(
+ errorClass = "UNSUPPORTED_FEATURE.COMMENT_NAMESPACE",
+ messageParameters = Map("namespace" -> toSQLId(namespace)))
+ }
+
+ def unsupportedRemoveNamespaceCommentError(
+ namespace: String): SparkSQLFeatureNotSupportedException = {
+ new SparkSQLFeatureNotSupportedException(
+ errorClass = "UNSUPPORTED_FEATURE.REMOVE_NAMESPACE_COMMENT",
+ messageParameters = Map("namespace" -> toSQLId(namespace)))
+ }
+
+ def unsupportedDropNamespaceError(
+ namespace: String): SparkSQLFeatureNotSupportedException = {
+ new SparkSQLFeatureNotSupportedException(
+ errorClass = "UNSUPPORTED_FEATURE.DROP_NAMESPACE",
+ messageParameters = Map("namespace" -> toSQLId(namespace)))
+ }
+
+ def exceedMaxLimit(limit: Int): SparkRuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "EXCEED_LIMIT_LENGTH",
+ messageParameters = Map("limit" -> limit.toString)
+ )
+ }
+
+ def timestampAddOverflowError(micros: Long, amount: Int, unit: String): ArithmeticException = {
+ new SparkArithmeticException(
+ errorClass = "DATETIME_OVERFLOW",
+ messageParameters = Map(
+ "operation" -> (s"add ${toSQLValue(amount, IntegerType)} $unit to " +
+ s"${toSQLValue(DateTimeUtils.microsToInstant(micros), TimestampType)}")),
+ context = Array.empty,
+ summary = "")
+ }
+
+ def invalidBucketFile(path: String): Throwable = {
+ new SparkException(
+ errorClass = "INVALID_BUCKET_FILE",
+ messageParameters = Map("path" -> path),
+ cause = null)
+ }
+
+ def multipleRowSubqueryError(context: SQLQueryContext): Throwable = {
+ new SparkException(
+ errorClass = "SCALAR_SUBQUERY_TOO_MANY_ROWS",
+ messageParameters = Map.empty,
+ cause = null,
+ context = getQueryContext(context),
+ summary = getSummary(context))
+ }
+
+ def comparatorReturnsNull(firstValue: String, secondValue: String): Throwable = {
+ new SparkException(
+ errorClass = "COMPARATOR_RETURNS_NULL",
+ messageParameters = Map("firstValue" -> firstValue, "secondValue" -> secondValue),
+ cause = null)
+ }
+
+ def invalidPatternError(
+ funcName: String,
+ pattern: String,
+ cause: Throwable): RuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "INVALID_PARAMETER_VALUE.PATTERN",
+ messageParameters = Map(
+ "parameter" -> toSQLId("regexp"),
+ "functionName" -> toSQLId(funcName),
+ "value" -> toSQLValue(pattern, StringType)),
+ cause = cause)
+ }
+
+ def tooManyArrayElementsError(
+ numElements: Int,
+ elementSize: Int): SparkIllegalArgumentException = {
+ new SparkIllegalArgumentException(
+ errorClass = "TOO_MANY_ARRAY_ELEMENTS",
+ messageParameters = Map(
+ "numElements" -> numElements.toString,
+ "size" -> elementSize.toString))
+ }
+
+ def invalidEmptyLocationError(location: String): SparkIllegalArgumentException = {
+ new SparkIllegalArgumentException(
+ errorClass = "INVALID_EMPTY_LOCATION",
+ messageParameters = Map("location" -> location))
+ }
+
+ def malformedProtobufMessageDetectedInMessageParsingError(e: Throwable): Throwable = {
+ new SparkException(
+ errorClass = "MALFORMED_PROTOBUF_MESSAGE",
+ messageParameters = Map(
+ "failFastMode" -> FailFastMode.name),
+ cause = e)
+ }
+
+ def locationAlreadyExists(tableId: TableIdentifier, location: Path): Throwable = {
+ new SparkRuntimeException(
+ errorClass = "LOCATION_ALREADY_EXISTS",
+ messageParameters = Map(
+ "location" -> toSQLValue(location.toString, StringType),
+ "identifier" -> toSQLId(tableId.nameParts)))
+ }
+
+ def cannotConvertCatalystValueToProtobufEnumTypeError(
+ sqlColumn: Seq[String],
+ protobufColumn: String,
+ data: String,
+ enumString: String): Throwable = {
+ new AnalysisException(
+ errorClass = "CANNOT_CONVERT_SQL_VALUE_TO_PROTOBUF_ENUM_TYPE",
+ messageParameters = Map(
+ "sqlColumn" -> toSQLId(sqlColumn),
+ "protobufColumn" -> protobufColumn,
+ "data" -> data,
+ "enumString" -> enumString))
+ }
+
+ def hllInvalidLgK(function: String, min: Int, max: Int, value: String): Throwable = {
+ new SparkRuntimeException(
+ errorClass = "HLL_INVALID_LG_K",
+ messageParameters = Map(
+ "function" -> toSQLId(function),
+ "min" -> toSQLValue(min, IntegerType),
+ "max" -> toSQLValue(max, IntegerType),
+ "value" -> value))
+ }
+
+ def hllInvalidInputSketchBuffer(function: String): Throwable = {
+ new SparkRuntimeException(
+ errorClass = "HLL_INVALID_INPUT_SKETCH_BUFFER",
+ messageParameters = Map(
+ "function" -> toSQLId(function)))
+ }
+
+ def hllUnionDifferentLgK(left: Int, right: Int, function: String): Throwable = {
+ new SparkRuntimeException(
+ errorClass = "HLL_UNION_DIFFERENT_LG_K",
+ messageParameters = Map(
+ "left" -> toSQLValue(left, IntegerType),
+ "right" -> toSQLValue(right, IntegerType),
+ "function" -> toSQLId(function)))
+ }
+
+ def mergeCardinalityViolationError(): SparkRuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "MERGE_CARDINALITY_VIOLATION",
+ messageParameters = Map.empty)
+ }
+
+ def unsupportedPurgePartitionError(): SparkUnsupportedOperationException = {
+ new SparkUnsupportedOperationException(
+ errorClass = "UNSUPPORTED_FEATURE.PURGE_PARTITION",
+ messageParameters = Map.empty)
+ }
+
+ def unsupportedPurgeTableError(): SparkUnsupportedOperationException = {
+ new SparkUnsupportedOperationException(
+ errorClass = "UNSUPPORTED_FEATURE.PURGE_TABLE",
+ messageParameters = Map.empty)
+ }
+}
diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto
index 944505ba1c..32cbc0ce13 100644
--- a/native/proto/src/proto/expr.proto
+++ b/native/proto/src/proto/expr.proto
@@ -89,6 +89,37 @@ message Expr {
FromJson from_json = 66;
ToCsv to_csv = 67;
}
+
+ // Optional QueryContext for error reporting (contains SQL text and position)
+ optional QueryContext query_context = 90;
+
+ // Unique expression ID for context lookup during error creation
+ optional uint64 expr_id = 91;
+}
+
+// QueryContext provides SQL query context for error messages.
+// Mirrors Spark's SQLQueryContext for rich error reporting.
+message QueryContext {
+ // Full SQL query text
+ string sql_text = 1;
+
+ // Character offset where expression starts (0-based)
+ int32 start_index = 2;
+
+ // Character offset where expression ends (0-based, inclusive)
+ int32 stop_index = 3;
+
+ // Type of SQL object (e.g., "VIEW", "Project", "Filter")
+ optional string object_type = 4;
+
+ // Name of object (e.g., view name, column name)
+ optional string object_name = 5;
+
+ // Line number in SQL query (1-based)
+ int32 line = 6;
+
+ // Column position within the line (0-based)
+ int32 start_position = 7;
}
message AggExpr {
@@ -109,6 +140,12 @@ message AggExpr {
Correlation correlation = 15;
BloomFilterAgg bloomFilterAgg = 16;
}
+
+ // Optional QueryContext for error reporting (contains SQL text and position)
+ optional QueryContext query_context = 90;
+
+ // Unique expression ID for context lookup during error creation
+ optional uint64 expr_id = 91;
}
enum StatisticsType {
diff --git a/native/spark-expr/Cargo.toml b/native/spark-expr/Cargo.toml
index e7c238f7eb..41f70c5027 100644
--- a/native/spark-expr/Cargo.toml
+++ b/native/spark-expr/Cargo.toml
@@ -33,6 +33,7 @@ datafusion = { workspace = true }
chrono-tz = { workspace = true }
num = { workspace = true }
regex = { workspace = true }
+serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
thiserror = { workspace = true }
futures = { workspace = true }
diff --git a/native/spark-expr/benches/aggregate.rs b/native/spark-expr/benches/aggregate.rs
index 47e2cf61c3..7944d392f6 100644
--- a/native/spark-expr/benches/aggregate.rs
+++ b/native/spark-expr/benches/aggregate.rs
@@ -70,6 +70,9 @@ fn criterion_benchmark(c: &mut Criterion) {
let comet_avg_decimal = Arc::new(AggregateUDF::new_from_impl(AvgDecimal::new(
DataType::Decimal128(38, 10),
DataType::Decimal128(38, 10),
+ datafusion_comet_spark_expr::EvalMode::Legacy,
+ None,
+ datafusion_comet_spark_expr::create_query_context_map(),
)));
b.to_async(&rt).iter(|| {
black_box(agg_test(
@@ -97,7 +100,13 @@ fn criterion_benchmark(c: &mut Criterion) {
group.bench_function("sum_decimal_comet", |b| {
let comet_sum_decimal = Arc::new(AggregateUDF::new_from_impl(
- SumDecimal::try_new(DataType::Decimal128(38, 10), EvalMode::Legacy).unwrap(),
+ SumDecimal::try_new(
+ DataType::Decimal128(38, 10),
+ EvalMode::Legacy,
+ None,
+ datafusion_comet_spark_expr::create_query_context_map(),
+ )
+ .unwrap(),
));
b.to_async(&rt).iter(|| {
black_box(agg_test(
diff --git a/native/spark-expr/benches/cast_from_boolean.rs b/native/spark-expr/benches/cast_from_boolean.rs
index dbb986df91..04bd72dc01 100644
--- a/native/spark-expr/benches/cast_from_boolean.rs
+++ b/native/spark-expr/benches/cast_from_boolean.rs
@@ -27,14 +27,62 @@ fn criterion_benchmark(c: &mut Criterion) {
let expr = Arc::new(Column::new("a", 0));
let boolean_batch = create_boolean_batch();
let spark_cast_options = SparkCastOptions::new(EvalMode::Legacy, "UTC", false);
- let cast_to_i8 = Cast::new(expr.clone(), DataType::Int8, spark_cast_options.clone());
- let cast_to_i16 = Cast::new(expr.clone(), DataType::Int16, spark_cast_options.clone());
- let cast_to_i32 = Cast::new(expr.clone(), DataType::Int32, spark_cast_options.clone());
- let cast_to_i64 = Cast::new(expr.clone(), DataType::Int64, spark_cast_options.clone());
- let cast_to_f32 = Cast::new(expr.clone(), DataType::Float32, spark_cast_options.clone());
- let cast_to_f64 = Cast::new(expr.clone(), DataType::Float64, spark_cast_options.clone());
- let cast_to_str = Cast::new(expr.clone(), DataType::Utf8, spark_cast_options.clone());
- let cast_to_decimal = Cast::new(expr, DataType::Decimal128(10, 4), spark_cast_options);
+ let cast_to_i8 = Cast::new(
+ expr.clone(),
+ DataType::Int8,
+ spark_cast_options.clone(),
+ None,
+ None,
+ );
+ let cast_to_i16 = Cast::new(
+ expr.clone(),
+ DataType::Int16,
+ spark_cast_options.clone(),
+ None,
+ None,
+ );
+ let cast_to_i32 = Cast::new(
+ expr.clone(),
+ DataType::Int32,
+ spark_cast_options.clone(),
+ None,
+ None,
+ );
+ let cast_to_i64 = Cast::new(
+ expr.clone(),
+ DataType::Int64,
+ spark_cast_options.clone(),
+ None,
+ None,
+ );
+ let cast_to_f32 = Cast::new(
+ expr.clone(),
+ DataType::Float32,
+ spark_cast_options.clone(),
+ None,
+ None,
+ );
+ let cast_to_f64 = Cast::new(
+ expr.clone(),
+ DataType::Float64,
+ spark_cast_options.clone(),
+ None,
+ None,
+ );
+ let cast_to_str = Cast::new(
+ expr.clone(),
+ DataType::Utf8,
+ spark_cast_options.clone(),
+ None,
+ None,
+ );
+ let cast_to_decimal = Cast::new(
+ expr,
+ DataType::Decimal128(10, 4),
+ spark_cast_options,
+ None,
+ None,
+ );
let mut group = c.benchmark_group("cast_bool".to_string());
group.bench_function("i8", |b| {
diff --git a/native/spark-expr/benches/cast_from_string.rs b/native/spark-expr/benches/cast_from_string.rs
index 9b2cb73fb4..1fbb7c535b 100644
--- a/native/spark-expr/benches/cast_from_string.rs
+++ b/native/spark-expr/benches/cast_from_string.rs
@@ -34,10 +34,34 @@ fn criterion_benchmark(c: &mut Criterion) {
(EvalMode::Try, "try"),
] {
let spark_cast_options = SparkCastOptions::new(mode, "", false);
- let cast_to_i8 = Cast::new(expr.clone(), DataType::Int8, spark_cast_options.clone());
- let cast_to_i16 = Cast::new(expr.clone(), DataType::Int16, spark_cast_options.clone());
- let cast_to_i32 = Cast::new(expr.clone(), DataType::Int32, spark_cast_options.clone());
- let cast_to_i64 = Cast::new(expr.clone(), DataType::Int64, spark_cast_options);
+ let cast_to_i8 = Cast::new(
+ expr.clone(),
+ DataType::Int8,
+ spark_cast_options.clone(),
+ None,
+ None,
+ );
+ let cast_to_i16 = Cast::new(
+ expr.clone(),
+ DataType::Int16,
+ spark_cast_options.clone(),
+ None,
+ None,
+ );
+ let cast_to_i32 = Cast::new(
+ expr.clone(),
+ DataType::Int32,
+ spark_cast_options.clone(),
+ None,
+ None,
+ );
+ let cast_to_i64 = Cast::new(
+ expr.clone(),
+ DataType::Int64,
+ spark_cast_options,
+ None,
+ None,
+ );
let mut group = c.benchmark_group(format!("cast_string_to_int/{}", mode_name));
group.bench_function("i8", |b| {
@@ -57,8 +81,20 @@ fn criterion_benchmark(c: &mut Criterion) {
// Benchmark decimal truncation (Legacy mode only)
let spark_cast_options = SparkCastOptions::new(EvalMode::Legacy, "", false);
- let cast_to_i32 = Cast::new(expr.clone(), DataType::Int32, spark_cast_options.clone());
- let cast_to_i64 = Cast::new(expr.clone(), DataType::Int64, spark_cast_options);
+ let cast_to_i32 = Cast::new(
+ expr.clone(),
+ DataType::Int32,
+ spark_cast_options.clone(),
+ None,
+ None,
+ );
+ let cast_to_i64 = Cast::new(
+ expr.clone(),
+ DataType::Int64,
+ spark_cast_options,
+ None,
+ None,
+ );
let mut group = c.benchmark_group("cast_string_to_int/legacy_decimals");
group.bench_function("i32", |b| {
@@ -81,6 +117,8 @@ fn criterion_benchmark(c: &mut Criterion) {
expr.clone(),
DataType::Decimal128(38, 10),
spark_cast_options,
+ None,
+ None,
);
let mut group = c.benchmark_group(format!("cast_string_to_decimal/{}", mode_name));
diff --git a/native/spark-expr/benches/cast_int_to_timestamp.rs b/native/spark-expr/benches/cast_int_to_timestamp.rs
index 20143d2b0e..4479627ae6 100644
--- a/native/spark-expr/benches/cast_int_to_timestamp.rs
+++ b/native/spark-expr/benches/cast_int_to_timestamp.rs
@@ -35,7 +35,13 @@ fn criterion_benchmark(c: &mut Criterion) {
// Int8 -> Timestamp
let batch_i8 = create_int8_batch();
let expr_i8 = Arc::new(Column::new("a", 0));
- let cast_i8_to_ts = Cast::new(expr_i8, timestamp_type.clone(), spark_cast_options.clone());
+ let cast_i8_to_ts = Cast::new(
+ expr_i8,
+ timestamp_type.clone(),
+ spark_cast_options.clone(),
+ None,
+ None,
+ );
group.bench_function("cast_i8_to_timestamp", |b| {
b.iter(|| cast_i8_to_ts.evaluate(&batch_i8).unwrap());
});
@@ -43,7 +49,13 @@ fn criterion_benchmark(c: &mut Criterion) {
// Int16 -> Timestamp
let batch_i16 = create_int16_batch();
let expr_i16 = Arc::new(Column::new("a", 0));
- let cast_i16_to_ts = Cast::new(expr_i16, timestamp_type.clone(), spark_cast_options.clone());
+ let cast_i16_to_ts = Cast::new(
+ expr_i16,
+ timestamp_type.clone(),
+ spark_cast_options.clone(),
+ None,
+ None,
+ );
group.bench_function("cast_i16_to_timestamp", |b| {
b.iter(|| cast_i16_to_ts.evaluate(&batch_i16).unwrap());
});
@@ -51,7 +63,13 @@ fn criterion_benchmark(c: &mut Criterion) {
// Int32 -> Timestamp
let batch_i32 = create_int32_batch();
let expr_i32 = Arc::new(Column::new("a", 0));
- let cast_i32_to_ts = Cast::new(expr_i32, timestamp_type.clone(), spark_cast_options.clone());
+ let cast_i32_to_ts = Cast::new(
+ expr_i32,
+ timestamp_type.clone(),
+ spark_cast_options.clone(),
+ None,
+ None,
+ );
group.bench_function("cast_i32_to_timestamp", |b| {
b.iter(|| cast_i32_to_ts.evaluate(&batch_i32).unwrap());
});
@@ -59,7 +77,13 @@ fn criterion_benchmark(c: &mut Criterion) {
// Int64 -> Timestamp
let batch_i64 = create_int64_batch();
let expr_i64 = Arc::new(Column::new("a", 0));
- let cast_i64_to_ts = Cast::new(expr_i64, timestamp_type.clone(), spark_cast_options.clone());
+ let cast_i64_to_ts = Cast::new(
+ expr_i64,
+ timestamp_type.clone(),
+ spark_cast_options.clone(),
+ None,
+ None,
+ );
group.bench_function("cast_i64_to_timestamp", |b| {
b.iter(|| cast_i64_to_ts.evaluate(&batch_i64).unwrap());
});
diff --git a/native/spark-expr/benches/cast_numeric.rs b/native/spark-expr/benches/cast_numeric.rs
index bd14e4cb24..989cbf4d2c 100644
--- a/native/spark-expr/benches/cast_numeric.rs
+++ b/native/spark-expr/benches/cast_numeric.rs
@@ -26,9 +26,21 @@ fn criterion_benchmark(c: &mut Criterion) {
let batch = create_int32_batch();
let expr = Arc::new(Column::new("a", 0));
let spark_cast_options = SparkCastOptions::new_without_timezone(EvalMode::Legacy, false);
- let cast_i32_to_i8 = Cast::new(expr.clone(), DataType::Int8, spark_cast_options.clone());
- let cast_i32_to_i16 = Cast::new(expr.clone(), DataType::Int16, spark_cast_options.clone());
- let cast_i32_to_i64 = Cast::new(expr, DataType::Int64, spark_cast_options);
+ let cast_i32_to_i8 = Cast::new(
+ expr.clone(),
+ DataType::Int8,
+ spark_cast_options.clone(),
+ None,
+ None,
+ );
+ let cast_i32_to_i16 = Cast::new(
+ expr.clone(),
+ DataType::Int16,
+ spark_cast_options.clone(),
+ None,
+ None,
+ );
+ let cast_i32_to_i64 = Cast::new(expr, DataType::Int64, spark_cast_options, None, None);
let mut group = c.benchmark_group("cast_int_to_int");
group.bench_function("cast_i32_to_i8", |b| {
diff --git a/native/spark-expr/src/agg_funcs/avg_decimal.rs b/native/spark-expr/src/agg_funcs/avg_decimal.rs
index b3b2731d9d..773ddea050 100644
--- a/native/spark-expr/src/agg_funcs/avg_decimal.rs
+++ b/native/spark-expr/src/agg_funcs/avg_decimal.rs
@@ -31,6 +31,7 @@ use datafusion::physical_expr::expressions::format_state_name;
use std::{any::Any, sync::Arc};
use crate::utils::{build_bool_state, is_valid_decimal_precision, unlikely};
+use crate::{decimal_sum_overflow_error, EvalMode, SparkErrorWithContext};
use arrow::array::ArrowNativeTypeOp;
use arrow::datatypes::{
DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, MAX_DECIMAL128_FOR_EACH_PRECISION,
@@ -55,20 +56,53 @@ fn avg_return_type(_name: &str, data_type: &DataType) -> Result {
}
/// AVG aggregate expression
-#[derive(Debug, Clone, PartialEq, Eq, Hash)]
+#[derive(Debug, Clone)]
pub struct AvgDecimal {
signature: Signature,
sum_data_type: DataType,
result_data_type: DataType,
+ eval_mode: EvalMode,
+ expr_id: Option,
+ registry: Arc,
+}
+
+// Manually implement PartialEq, Eq, and Hash excluding the registry field
+impl PartialEq for AvgDecimal {
+ fn eq(&self, other: &Self) -> bool {
+ self.sum_data_type == other.sum_data_type
+ && self.result_data_type == other.result_data_type
+ && self.eval_mode == other.eval_mode
+ && self.expr_id == other.expr_id
+ }
+}
+
+impl Eq for AvgDecimal {}
+
+impl std::hash::Hash for AvgDecimal {
+ fn hash(&self, state: &mut H) {
+ self.sum_data_type.hash(state);
+ self.result_data_type.hash(state);
+ self.eval_mode.hash(state);
+ self.expr_id.hash(state);
+ }
}
impl AvgDecimal {
/// Create a new AVG aggregate function
- pub fn new(result_type: DataType, sum_type: DataType) -> Self {
+ pub fn new(
+ result_type: DataType,
+ sum_type: DataType,
+ eval_mode: EvalMode,
+ expr_id: Option,
+ registry: Arc,
+ ) -> Self {
Self {
signature: Signature::user_defined(Immutable),
result_data_type: result_type,
sum_data_type: sum_type,
+ eval_mode,
+ expr_id,
+ registry,
}
}
}
@@ -87,6 +121,9 @@ impl AggregateUDFImpl for AvgDecimal {
*sum_precision,
*target_precision,
*target_scale,
+ self.eval_mode,
+ self.expr_id,
+ Arc::clone(&self.registry),
)))
}
_ => not_impl_err!(
@@ -138,6 +175,9 @@ impl AggregateUDFImpl for AvgDecimal {
*target_scale,
*sum_precision,
*sum_scale,
+ self.eval_mode,
+ self.expr_id,
+ Arc::clone(&self.registry),
)))
}
_ => not_impl_err!(
@@ -180,10 +220,21 @@ struct AvgDecimalAccumulator {
sum_precision: u8,
target_precision: u8,
target_scale: i8,
+ eval_mode: EvalMode,
+ expr_id: Option,
+ registry: Arc,
}
impl AvgDecimalAccumulator {
- pub fn new(sum_scale: i8, sum_precision: u8, target_precision: u8, target_scale: i8) -> Self {
+ pub fn new(
+ sum_scale: i8,
+ sum_precision: u8,
+ target_precision: u8,
+ target_scale: i8,
+ eval_mode: EvalMode,
+ expr_id: Option,
+ registry: Arc,
+ ) -> Self {
Self {
sum: None,
count: 0,
@@ -193,10 +244,27 @@ impl AvgDecimalAccumulator {
sum_precision,
target_precision,
target_scale,
+ eval_mode,
+ expr_id,
+ registry,
}
}
- fn update_single(&mut self, values: &Decimal128Array, idx: usize) {
+ /// Wrap a SparkError with QueryContext if expr_id is available
+ fn wrap_error_with_context(
+ &self,
+ error: crate::SparkError,
+ ) -> datafusion::common::DataFusionError {
+ if let Some(expr_id) = self.expr_id {
+ if let Some(query_ctx) = self.registry.get(expr_id) {
+ let wrapped = SparkErrorWithContext::with_context(error, query_ctx);
+ return datafusion::common::DataFusionError::External(Box::new(wrapped));
+ }
+ }
+ datafusion::common::DataFusionError::from(error)
+ }
+
+ fn update_single(&mut self, values: &Decimal128Array, idx: usize) -> Result<()> {
let v = unsafe { values.value_unchecked(idx) };
let (new_sum, is_overflow) = match self.sum {
Some(sum) => sum.overflowing_add(v),
@@ -204,9 +272,10 @@ impl AvgDecimalAccumulator {
};
if is_overflow || !is_valid_decimal_precision(new_sum, self.sum_precision) {
- // Overflow: set buffer accumulator to null
+ // Overflow: set to null. Error will be thrown during evaluate in ANSI mode.
+ // This matches Spark's DecimalAddNoOverflowCheck behavior.
self.is_not_null = false;
- return;
+ return Ok(());
}
self.sum = Some(new_sum);
@@ -214,11 +283,13 @@ impl AvgDecimalAccumulator {
if let Some(new_count) = self.count.checked_add(1) {
self.count = new_count;
} else {
+ // Count overflow: set to null. Error will be thrown during evaluate in ANSI mode.
self.is_not_null = false;
- return;
+ return Ok(());
}
self.is_not_null = true;
+ Ok(())
}
}
@@ -240,37 +311,44 @@ impl Accumulator for AvgDecimalAccumulator {
// of the computation
return Ok(());
}
-
let values = &values[0];
let data = values.as_primitive::();
-
self.is_empty = self.is_empty && values.len() == values.null_count();
-
if values.null_count() == 0 {
for i in 0..data.len() {
- self.update_single(data, i);
+ self.update_single(data, i)?;
}
} else {
for i in 0..data.len() {
if data.is_null(i) {
continue;
}
- self.update_single(data, i);
+ self.update_single(data, i)?;
}
}
Ok(())
}
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
+ let partial_sums = states[0].as_primitive::();
+ let partial_counts = states[1].as_primitive::();
+
+ // Update is_empty: if any partial state has data, we're not empty
+ if self.is_empty {
+ self.is_empty = partial_counts.len() == partial_counts.null_count();
+ }
+
// counts are summed
- self.count += sum(states[1].as_primitive::()).unwrap_or_default();
+ self.count += sum(partial_counts).unwrap_or_default();
// sums are summed
- if let Some(x) = sum(states[0].as_primitive::()) {
+ if let Some(x) = sum(partial_sums) {
let v = self.sum.get_or_insert(0);
let (result, overflowed) = v.overflowing_add(x);
- if overflowed {
- // Set to None if overflow happens
+
+ if overflowed || !is_valid_decimal_precision(result, self.sum_precision) {
+ // Overflow during merge: set to null, error will be thrown during evaluate in ANSI mode
+ self.is_not_null = false;
self.sum = None;
} else {
*v = result;
@@ -280,6 +358,19 @@ impl Accumulator for AvgDecimalAccumulator {
}
fn evaluate(&mut self) -> Result {
+ // Check for overflow during sum accumulation in ANSI mode.
+ // This matches Spark's DecimalDivideWithOverflowCheck behavior.
+ if self.sum.is_none() && !self.is_empty && self.eval_mode == EvalMode::Ansi {
+ let error = decimal_sum_overflow_error();
+ return Err(self.wrap_error_with_context(error));
+ }
+
+ // Also check if is_not_null is false (indicates overflow)
+ if !self.is_not_null && self.count > 0 && self.eval_mode == EvalMode::Ansi {
+ let error = decimal_sum_overflow_error();
+ return Err(self.wrap_error_with_context(error));
+ }
+
let scaler = 10_i128.pow(self.target_scale.saturating_sub(self.sum_scale) as u32);
let target_min = MIN_DECIMAL128_FOR_EACH_PRECISION[self.target_precision as usize];
let target_max = MAX_DECIMAL128_FOR_EACH_PRECISION[self.target_precision as usize];
@@ -328,9 +419,17 @@ struct AvgDecimalGroupsAccumulator {
/// This is input_precision + 10 to be consistent with Spark
sum_precision: u8,
sum_scale: i8,
+
+ /// Evaluation mode for error handling
+ eval_mode: EvalMode,
+ /// Optional expression ID for query context lookup during error creation
+ expr_id: Option,
+ /// Session-scoped query context registry for error reporting
+ registry: Arc,
}
impl AvgDecimalGroupsAccumulator {
+ #[allow(clippy::too_many_arguments)]
pub fn new(
return_data_type: &DataType,
sum_data_type: &DataType,
@@ -338,6 +437,9 @@ impl AvgDecimalGroupsAccumulator {
target_scale: i8,
sum_precision: u8,
sum_scale: i8,
+ eval_mode: EvalMode,
+ expr_id: Option,
+ registry: Arc,
) -> Self {
Self {
is_not_null: BooleanBufferBuilder::new(0),
@@ -349,19 +451,38 @@ impl AvgDecimalGroupsAccumulator {
sum_scale,
counts: vec![],
sums: vec![],
+ eval_mode,
+ expr_id,
+ registry,
}
}
+ /// Wrap a SparkError with QueryContext if expr_id is available
+ fn wrap_error_with_context(
+ &self,
+ error: crate::SparkError,
+ ) -> datafusion::common::DataFusionError {
+ if let Some(expr_id) = self.expr_id {
+ if let Some(query_ctx) = self.registry.get(expr_id) {
+ let wrapped = SparkErrorWithContext::with_context(error, query_ctx);
+ return datafusion::common::DataFusionError::External(Box::new(wrapped));
+ }
+ }
+ datafusion::common::DataFusionError::from(error)
+ }
+
#[inline]
- fn update_single(&mut self, group_index: usize, value: i128) {
+ fn update_single(&mut self, group_index: usize, value: i128) -> Result<()> {
let (new_sum, is_overflow) = self.sums[group_index].overflowing_add(value);
self.counts[group_index] += 1;
self.sums[group_index] = new_sum;
if unlikely(is_overflow || !is_valid_decimal_precision(new_sum, self.sum_precision)) {
- // Overflow: set buffer accumulator to null
+ // Overflow: set to null. Error will be thrown during evaluate in ANSI mode.
+ // This matches Spark's DecimalAddNoOverflowCheck behavior.
self.is_not_null.set_bit(group_index, false);
}
+ Ok(())
}
}
@@ -392,14 +513,14 @@ impl GroupsAccumulator for AvgDecimalGroupsAccumulator {
let iter = group_indices.iter().zip(data.iter());
if values.null_count() == 0 {
for (&group_index, &value) in iter {
- self.update_single(group_index, value);
+ self.update_single(group_index, value)?;
}
} else {
for (idx, (&group_index, &value)) in iter.enumerate() {
if values.is_null(idx) {
continue;
}
- self.update_single(group_index, value);
+ self.update_single(group_index, value)?;
}
}
Ok(())
@@ -425,13 +546,30 @@ impl GroupsAccumulator for AvgDecimalGroupsAccumulator {
// update sums
self.sums.resize(total_num_groups, 0);
+ // Ensure bit capacity BEFORE setting any bits
+ ensure_bit_capacity(&mut self.is_not_null, total_num_groups);
+
let iter2 = group_indices.iter().zip(partial_sums.values().iter());
- for (&group_index, &new_value) in iter2 {
- let sum = &mut self.sums[group_index];
- *sum = sum.add_wrapping(new_value);
- }
+ for (idx, (&group_index, &new_value)) in iter2.enumerate() {
+ // Check if partial sum is null (indicates overflow in that partition)
+ if partial_sums.is_null(idx) {
+ self.is_not_null.set_bit(group_index, false);
+ continue;
+ }
- ensure_bit_capacity(&mut self.is_not_null, total_num_groups);
+ let sum = self.sums[group_index];
+ let (new_sum, is_overflow) = sum.overflowing_add(new_value);
+
+ if is_overflow || !is_valid_decimal_precision(new_sum, self.sum_precision) {
+ if self.eval_mode == EvalMode::Ansi {
+ let error = decimal_sum_overflow_error();
+ return Err(self.wrap_error_with_context(error));
+ }
+ self.is_not_null.set_bit(group_index, false);
+ } else {
+ self.sums[group_index] = new_sum;
+ }
+ }
if partial_counts.null_count() != 0 {
for (index, &group_index) in group_indices.iter().enumerate() {
if partial_counts.is_null(index) {
@@ -457,6 +595,13 @@ impl GroupsAccumulator for AvgDecimalGroupsAccumulator {
let target_max = MAX_DECIMAL128_FOR_EACH_PRECISION[self.target_precision as usize];
for (is_not_null, (sum, count)) in nulls.into_iter().zip(iter) {
+ // Check for overflow during sum accumulation in ANSI mode.
+ // This matches Spark's DecimalDivideWithOverflowCheck behavior.
+ if !is_not_null && count > 0 && self.eval_mode == EvalMode::Ansi {
+ let error = decimal_sum_overflow_error();
+ return Err(self.wrap_error_with_context(error));
+ }
+
if !is_not_null || count == 0 {
builder.append_null();
continue;
diff --git a/native/spark-expr/src/agg_funcs/sum_decimal.rs b/native/spark-expr/src/agg_funcs/sum_decimal.rs
index 50645391fd..bf5569b00b 100644
--- a/native/spark-expr/src/agg_funcs/sum_decimal.rs
+++ b/native/spark-expr/src/agg_funcs/sum_decimal.rs
@@ -16,7 +16,7 @@
// under the License.
use crate::utils::is_valid_decimal_precision;
-use crate::{arithmetic_overflow_error, EvalMode};
+use crate::{decimal_sum_overflow_error, EvalMode, SparkErrorWithContext};
use arrow::array::{
cast::AsArray, types::Decimal128Type, Array, ArrayRef, BooleanArray, Decimal128Array,
};
@@ -29,7 +29,7 @@ use datafusion::logical_expr::{
};
use std::{any::Any, sync::Arc};
-#[derive(Debug, PartialEq, Eq, Hash)]
+#[derive(Debug)]
pub struct SumDecimal {
/// Aggregate function signature
signature: Signature,
@@ -41,10 +41,43 @@ pub struct SumDecimal {
/// Decimal scale
scale: i8,
eval_mode: EvalMode,
+ /// Optional expression ID for query context lookup during error creation
+ expr_id: Option,
+ /// Session-scoped query context registry for error reporting
+ registry: Arc,
+}
+
+// Manually implement PartialEq, Eq, and Hash excluding the registry field
+// since registry is only for error reporting and doesn't affect function behavior
+impl PartialEq for SumDecimal {
+ fn eq(&self, other: &Self) -> bool {
+ self.precision == other.precision
+ && self.scale == other.scale
+ && self.eval_mode == other.eval_mode
+ && self.expr_id == other.expr_id
+ && self.result_type == other.result_type
+ }
+}
+
+impl Eq for SumDecimal {}
+
+impl std::hash::Hash for SumDecimal {
+ fn hash(&self, state: &mut H) {
+ self.precision.hash(state);
+ self.scale.hash(state);
+ self.eval_mode.hash(state);
+ self.expr_id.hash(state);
+ self.result_type.hash(state);
+ }
}
impl SumDecimal {
- pub fn try_new(data_type: DataType, eval_mode: EvalMode) -> DFResult {
+ pub fn try_new(
+ data_type: DataType,
+ eval_mode: EvalMode,
+ expr_id: Option,
+ registry: Arc,
+ ) -> DFResult {
let (precision, scale) = match data_type {
DataType::Decimal128(p, s) => (p, s),
_ => {
@@ -59,6 +92,8 @@ impl SumDecimal {
precision,
scale,
eval_mode,
+ expr_id,
+ registry,
})
}
}
@@ -73,6 +108,8 @@ impl AggregateUDFImpl for SumDecimal {
self.precision,
self.scale,
self.eval_mode,
+ self.expr_id,
+ Arc::clone(&self.registry),
)))
}
@@ -110,6 +147,8 @@ impl AggregateUDFImpl for SumDecimal {
self.result_type.clone(),
self.precision,
self.eval_mode,
+ self.expr_id,
+ Arc::clone(&self.registry),
)))
}
@@ -137,10 +176,18 @@ struct SumDecimalAccumulator {
precision: u8,
scale: i8,
eval_mode: EvalMode,
+ expr_id: Option,
+ registry: Arc,
}
impl SumDecimalAccumulator {
- fn new(precision: u8, scale: i8, eval_mode: EvalMode) -> Self {
+ fn new(
+ precision: u8,
+ scale: i8,
+ eval_mode: EvalMode,
+ expr_id: Option,
+ registry: Arc,
+ ) -> Self {
// For decimal sum, always track is_empty regardless of eval_mode
// This matches Spark's behavior where DecimalType always uses shouldTrackIsEmpty = true
Self {
@@ -149,6 +196,8 @@ impl SumDecimalAccumulator {
precision,
scale,
eval_mode,
+ expr_id,
+ registry,
}
}
@@ -164,7 +213,8 @@ impl SumDecimalAccumulator {
if is_overflow || !is_valid_decimal_precision(new_sum, self.precision) {
if self.eval_mode == EvalMode::Ansi {
- return Err(DataFusionError::from(arithmetic_overflow_error("decimal")));
+ let error = decimal_sum_overflow_error();
+ return Err(self.wrap_error_with_context(error));
}
self.sum = None;
self.is_empty = false;
@@ -175,6 +225,17 @@ impl SumDecimalAccumulator {
self.is_empty = false;
Ok(())
}
+
+ /// Wrap a SparkError with QueryContext if expr_id is available
+ fn wrap_error_with_context(&self, error: crate::SparkError) -> DataFusionError {
+ if let Some(expr_id) = self.expr_id {
+ if let Some(query_ctx) = self.registry.get(expr_id) {
+ let wrapped = SparkErrorWithContext::with_context(error, query_ctx);
+ return DataFusionError::External(Box::new(wrapped));
+ }
+ }
+ DataFusionError::from(error)
+ }
}
impl Accumulator for SumDecimalAccumulator {
@@ -292,7 +353,8 @@ impl Accumulator for SumDecimalAccumulator {
if is_overflow || !is_valid_decimal_precision(new_sum, self.precision) {
if self.eval_mode == EvalMode::Ansi {
- return Err(DataFusionError::from(arithmetic_overflow_error("decimal")));
+ let error = decimal_sum_overflow_error();
+ return Err(self.wrap_error_with_context(error));
} else {
self.sum = None;
self.is_empty = false;
@@ -311,16 +373,26 @@ struct SumDecimalGroupsAccumulator {
result_type: DataType,
precision: u8,
eval_mode: EvalMode,
+ expr_id: Option,
+ registry: Arc,
}
impl SumDecimalGroupsAccumulator {
- fn new(result_type: DataType, precision: u8, eval_mode: EvalMode) -> Self {
+ fn new(
+ result_type: DataType,
+ precision: u8,
+ eval_mode: EvalMode,
+ expr_id: Option,
+ registry: Arc,
+ ) -> Self {
Self {
sum: Vec::new(),
is_empty: Vec::new(),
result_type,
precision,
eval_mode,
+ expr_id,
+ registry,
}
}
@@ -330,6 +402,17 @@ impl SumDecimalGroupsAccumulator {
self.is_empty.resize(total_num_groups, true);
}
+ /// Wrap a SparkError with QueryContext if expr_id is available
+ fn wrap_error_with_context(&self, error: crate::SparkError) -> DataFusionError {
+ if let Some(expr_id) = self.expr_id {
+ if let Some(query_ctx) = self.registry.get(expr_id) {
+ let wrapped = SparkErrorWithContext::with_context(error, query_ctx);
+ return DataFusionError::External(Box::new(wrapped));
+ }
+ }
+ DataFusionError::from(error)
+ }
+
#[inline]
fn update_single(&mut self, group_index: usize, value: i128) -> DFResult<()> {
// For decimal sum, always check for overflow regardless of eval_mode
@@ -342,7 +425,8 @@ impl SumDecimalGroupsAccumulator {
if is_overflow || !is_valid_decimal_precision(new_sum, self.precision) {
if self.eval_mode == EvalMode::Ansi {
- return Err(DataFusionError::from(arithmetic_overflow_error("decimal")));
+ let error = decimal_sum_overflow_error();
+ return Err(self.wrap_error_with_context(error));
}
self.sum[group_index] = None;
} else {
@@ -503,7 +587,8 @@ impl GroupsAccumulator for SumDecimalGroupsAccumulator {
if is_overflow || !is_valid_decimal_precision(new_sum, self.precision) {
if self.eval_mode == EvalMode::Ansi {
- return Err(DataFusionError::from(arithmetic_overflow_error("decimal")));
+ let error = decimal_sum_overflow_error();
+ return Err(self.wrap_error_with_context(error));
} else {
self.sum[group_index] = None;
self.is_empty[group_index] = false;
@@ -542,7 +627,13 @@ mod tests {
#[test]
fn invalid_data_type() {
- assert!(SumDecimal::try_new(DataType::Int32, EvalMode::Legacy).is_err());
+ assert!(SumDecimal::try_new(
+ DataType::Int32,
+ EvalMode::Legacy,
+ None,
+ crate::create_query_context_map(),
+ )
+ .is_err());
}
#[tokio::test]
@@ -566,6 +657,8 @@ mod tests {
let aggregate_udf = Arc::new(AggregateUDF::new_from_impl(SumDecimal::try_new(
data_type.clone(),
EvalMode::Legacy,
+ None,
+ crate::create_query_context_map(),
)?));
let aggr_expr = AggregateExprBuilder::new(aggregate_udf, vec![c1])
diff --git a/native/spark-expr/src/array_funcs/list_extract.rs b/native/spark-expr/src/array_funcs/list_extract.rs
index b912f0c7f6..2cbd62a2c6 100644
--- a/native/spark-expr/src/array_funcs/list_extract.rs
+++ b/native/spark-expr/src/array_funcs/list_extract.rs
@@ -31,13 +31,17 @@ use std::{
sync::Arc,
};
-#[derive(Debug, Eq)]
+use crate::SparkError;
+
+#[derive(Debug, Clone)]
pub struct ListExtract {
child: Arc,
ordinal: Arc,
default_value: Option>,
one_based: bool,
fail_on_error: bool,
+ expr_id: Option,
+ registry: Arc,
}
impl Hash for ListExtract {
@@ -47,8 +51,11 @@ impl Hash for ListExtract {
self.default_value.hash(state);
self.one_based.hash(state);
self.fail_on_error.hash(state);
+ self.expr_id.hash(state);
+ // Exclude registry from hash
}
}
+
impl PartialEq for ListExtract {
fn eq(&self, other: &Self) -> bool {
self.child.eq(&other.child)
@@ -56,9 +63,13 @@ impl PartialEq for ListExtract {
&& self.default_value.eq(&other.default_value)
&& self.one_based.eq(&other.one_based)
&& self.fail_on_error.eq(&other.fail_on_error)
+ && self.expr_id.eq(&other.expr_id)
+ // Exclude registry from equality check
}
}
+impl Eq for ListExtract {}
+
impl ListExtract {
pub fn new(
child: Arc,
@@ -66,6 +77,8 @@ impl ListExtract {
default_value: Option>,
one_based: bool,
fail_on_error: bool,
+ expr_id: Option,
+ registry: Arc,
) -> Self {
Self {
child,
@@ -73,6 +86,8 @@ impl ListExtract {
default_value,
one_based,
fail_on_error,
+ expr_id,
+ registry,
}
}
@@ -84,6 +99,17 @@ impl ListExtract {
))),
}
}
+
+ /// Wrap a SparkError with QueryContext if expr_id is available
+ fn wrap_error_with_context(&self, error: SparkError) -> DataFusionError {
+ if let Some(expr_id) = self.expr_id {
+ if let Some(query_ctx) = self.registry.get(expr_id) {
+ let wrapped = crate::SparkErrorWithContext::with_context(error, query_ctx);
+ return DataFusionError::External(Box::new(wrapped));
+ }
+ }
+ DataFusionError::from(error)
+ }
}
impl PhysicalExpr for ListExtract {
@@ -127,11 +153,15 @@ impl PhysicalExpr for ListExtract {
.transpose()?
.unwrap_or(self.data_type(&batch.schema())?.try_into())?;
- let adjust_index = if self.one_based {
- one_based_index
- } else {
- zero_based_index
- };
+ // Create error wrapper closure that has access to self
+ let error_wrapper = |error: SparkError| self.wrap_error_with_context(error);
+
+ let adjust_index: Box DataFusionResult>> =
+ if self.one_based {
+ Box::new(|idx, len| one_based_index(idx, len, &error_wrapper))
+ } else {
+ Box::new(|idx, len| zero_based_index(idx, len, &error_wrapper))
+ };
match child_value.data_type() {
DataType::List(_) => {
@@ -143,7 +173,9 @@ impl PhysicalExpr for ListExtract {
index_array,
&default_value,
self.fail_on_error,
+ self.one_based,
adjust_index,
+ &error_wrapper,
)
}
DataType::LargeList(_) => {
@@ -155,7 +187,9 @@ impl PhysicalExpr for ListExtract {
index_array,
&default_value,
self.fail_on_error,
+ self.one_based,
adjust_index,
+ &error_wrapper,
)
}
data_type => Err(DataFusionError::Internal(format!(
@@ -179,17 +213,21 @@ impl PhysicalExpr for ListExtract {
self.default_value.clone(),
self.one_based,
self.fail_on_error,
+ self.expr_id,
+ Arc::clone(&self.registry),
))),
_ => internal_err!("ListExtract should have exactly two children"),
}
}
}
-fn one_based_index(index: i32, len: usize) -> DataFusionResult > {
+fn one_based_index(
+ index: i32,
+ len: usize,
+ error_wrapper: &impl Fn(SparkError) -> DataFusionError,
+) -> DataFusionResult > {
if index == 0 {
- return Err(DataFusionError::Execution(
- "Invalid index of 0 for one-based ListExtract".to_string(),
- ));
+ return Err(error_wrapper(SparkError::InvalidIndexOfZero));
}
let abs_index = index.abs().as_usize();
@@ -204,7 +242,11 @@ fn one_based_index(index: i32, len: usize) -> DataFusionResult > {
}
}
-fn zero_based_index(index: i32, len: usize) -> DataFusionResult > {
+fn zero_based_index(
+ index: i32,
+ len: usize,
+ _error_wrapper: &impl Fn(SparkError) -> DataFusionError,
+) -> DataFusionResult > {
if index < 0 {
Ok(None)
} else {
@@ -222,7 +264,9 @@ fn list_extract(
index_array: &Int32Array,
default_value: &ScalarValue,
fail_on_error: bool,
+ one_based: bool,
adjust_index: impl Fn(i32, usize) -> DataFusionResult>,
+ error_wrapper: &impl Fn(SparkError) -> DataFusionError,
) -> DataFusionResult {
let values = list_array.values();
let offsets = list_array.offsets();
@@ -242,9 +286,22 @@ fn list_extract(
} else if list_array.is_null(row) {
mutable.extend_nulls(1);
} else if fail_on_error {
- return Err(DataFusionError::Execution(
- "Index out of bounds for array".to_string(),
- ));
+ // Throw appropriate error based on whether this is element_at (one_based=true)
+ // or GetArrayItem (one_based=false)
+ let error = if one_based {
+ // element_at function
+ SparkError::InvalidElementAtIndex {
+ index_value: *index,
+ array_size: len as i32,
+ }
+ } else {
+ // GetArrayItem (arr[index])
+ SparkError::InvalidArrayIndex {
+ index_value: *index,
+ array_size: len as i32,
+ }
+ };
+ return Err(error_wrapper(error));
} else {
mutable.extend(1, 0, 1);
}
@@ -283,8 +340,18 @@ mod test {
let null_default = ScalarValue::Int32(None);
- let ColumnarValue::Array(result) =
- list_extract(&list, &indices, &null_default, false, zero_based_index)?
+ // Simple error wrapper for tests - just converts SparkError to DataFusionError
+ let error_wrapper = |error: SparkError| DataFusionError::from(error);
+
+ let ColumnarValue::Array(result) = list_extract(
+ &list,
+ &indices,
+ &null_default,
+ false,
+ false,
+ |idx, len| zero_based_index(idx, len, &error_wrapper),
+ &error_wrapper,
+ )?
else {
unreachable!()
};
@@ -296,8 +363,15 @@ mod test {
let zero_default = ScalarValue::Int32(Some(0));
- let ColumnarValue::Array(result) =
- list_extract(&list, &indices, &zero_default, false, zero_based_index)?
+ let ColumnarValue::Array(result) = list_extract(
+ &list,
+ &indices,
+ &zero_default,
+ false,
+ false,
+ |idx, len| zero_based_index(idx, len, &error_wrapper),
+ &error_wrapper,
+ )?
else {
unreachable!()
};
diff --git a/native/spark-expr/src/conversion_funcs/boolean.rs b/native/spark-expr/src/conversion_funcs/boolean.rs
index db288fa32a..3de6b47311 100644
--- a/native/spark-expr/src/conversion_funcs/boolean.rs
+++ b/native/spark-expr/src/conversion_funcs/boolean.rs
@@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.
-use crate::SparkResult;
+use crate::{SparkError, SparkResult};
use arrow::array::{ArrayRef, AsArray, Decimal128Array};
use arrow::datatypes::DataType;
use std::sync::Arc;
@@ -40,7 +40,22 @@ pub fn cast_boolean_to_decimal(
.iter()
.map(|v| v.map(|b| if b { scaled_val } else { 0 }))
.collect();
- Ok(Arc::new(result.with_precision_and_scale(precision, scale)?))
+
+ // Convert Arrow decimal overflow errors to SparkError
+ let decimal_array = result
+ .with_precision_and_scale(precision, scale)
+ .map_err(|e| {
+ if matches!(e, arrow::error::ArrowError::InvalidArgumentError(_))
+ && e.to_string().contains("too large to store in a Decimal128")
+ {
+ // Use the scaled value as it's the only non-zero value that could overflow
+ crate::error::decimal_overflow_error(scaled_val, precision, scale)
+ } else {
+ SparkError::Arrow(Arc::new(e))
+ }
+ })?;
+
+ Ok(Arc::new(decimal_array))
}
#[cfg(test)]
diff --git a/native/spark-expr/src/conversion_funcs/cast.rs b/native/spark-expr/src/conversion_funcs/cast.rs
index e83877de1b..49198e6efa 100644
--- a/native/spark-expr/src/conversion_funcs/cast.rs
+++ b/native/spark-expr/src/conversion_funcs/cast.rs
@@ -83,6 +83,8 @@ pub struct Cast {
pub child: Arc,
pub data_type: DataType,
pub cast_options: SparkCastOptions,
+ pub expr_id: Option,
+ pub query_context: Option>,
}
impl PartialEq for Cast {
@@ -549,11 +551,15 @@ impl Cast {
child: Arc,
data_type: DataType,
cast_options: SparkCastOptions,
+ expr_id: Option,
+ query_context: Option>,
) -> Self {
Self {
child,
data_type,
cast_options,
+ expr_id,
+ query_context,
}
}
}
@@ -1256,7 +1262,22 @@ where
let res = Arc::new(
cast_array
- .with_precision_and_scale(precision, scale)?
+ .with_precision_and_scale(precision, scale)
+ .map_err(|e| {
+ if matches!(e, arrow::error::ArrowError::InvalidArgumentError(_))
+ && e.to_string().contains("too large to store in a Decimal128")
+ {
+ // Extract the overflowing value from the cast_array
+ // In practice, this should be caught above, but handle as a fallback
+ SparkError::NumericValueOutOfRange {
+ value: "overflow".to_string(),
+ precision,
+ scale,
+ }
+ } else {
+ SparkError::Arrow(Arc::new(e))
+ }
+ })?
.finish(),
) as ArrayRef;
Ok(res)
@@ -1332,7 +1353,23 @@ where
}
}
Ok(Arc::new(
- builder.with_precision_and_scale(precision, scale)?.finish(),
+ builder
+ .with_precision_and_scale(precision, scale)
+ .map_err(|e| {
+ if matches!(e, arrow::error::ArrowError::InvalidArgumentError(_))
+ && e.to_string().contains("too large to store in a Decimal128")
+ {
+ // Fallback error handling - should be caught above in most cases
+ SparkError::NumericValueOutOfRange {
+ value: "overflow".to_string(),
+ precision,
+ scale,
+ }
+ } else {
+ SparkError::Arrow(Arc::new(e))
+ }
+ })?
+ .finish(),
))
}
@@ -1598,7 +1635,23 @@ impl PhysicalExpr for Cast {
fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult {
let arg = self.child.evaluate(batch)?;
- spark_cast(arg, &self.data_type, &self.cast_options)
+ let result = spark_cast(arg, &self.data_type, &self.cast_options);
+
+ // If there's an error and we have query_context, wrap it
+ match result {
+ Err(DataFusionError::External(e)) if self.query_context.is_some() => {
+ if let Some(spark_err) = e.downcast_ref::() {
+ let wrapped = crate::SparkErrorWithContext::with_context(
+ spark_err.clone(),
+ Arc::clone(self.query_context.as_ref().unwrap()),
+ );
+ Err(DataFusionError::External(Box::new(wrapped)))
+ } else {
+ Err(DataFusionError::External(e))
+ }
+ }
+ other => other,
+ }
}
fn children(&self) -> Vec<&Arc> {
@@ -1614,6 +1667,8 @@ impl PhysicalExpr for Cast {
Arc::clone(&children[0]),
self.data_type.clone(),
self.cast_options.clone(),
+ self.expr_id,
+ self.query_context.clone(),
))),
_ => internal_err!("Cast should have exactly one child"),
}
diff --git a/native/spark-expr/src/conversion_funcs/string.rs b/native/spark-expr/src/conversion_funcs/string.rs
index 531d334d15..7c193716d0 100644
--- a/native/spark-expr/src/conversion_funcs/string.rs
+++ b/native/spark-expr/src/conversion_funcs/string.rs
@@ -324,7 +324,21 @@ fn cast_string_to_decimal128_impl(
Ok(Arc::new(
decimal_builder
- .with_precision_and_scale(precision, scale)?
+ .with_precision_and_scale(precision, scale)
+ .map_err(|e| {
+ if matches!(e, arrow::error::ArrowError::InvalidArgumentError(_))
+ && e.to_string().contains("too large to store in a Decimal128")
+ {
+ // Fallback error handling
+ SparkError::NumericValueOutOfRange {
+ value: "overflow".to_string(),
+ precision,
+ scale,
+ }
+ } else {
+ SparkError::Arrow(Arc::new(e))
+ }
+ })?
.finish(),
))
}
@@ -375,7 +389,21 @@ fn cast_string_to_decimal256_impl(
Ok(Arc::new(
decimal_builder
- .with_precision_and_scale(precision, scale)?
+ .with_precision_and_scale(precision, scale)
+ .map_err(|e| {
+ if matches!(e, arrow::error::ArrowError::InvalidArgumentError(_))
+ && e.to_string().contains("too large to store in a Decimal128")
+ {
+ // Fallback error handling
+ SparkError::NumericValueOutOfRange {
+ value: "overflow".to_string(),
+ precision,
+ scale,
+ }
+ } else {
+ SparkError::Arrow(Arc::new(e))
+ }
+ })?
.finish(),
))
}
diff --git a/native/spark-expr/src/error.rs b/native/spark-expr/src/error.rs
index c39a05cd4e..ae3b5c0eda 100644
--- a/native/spark-expr/src/error.rs
+++ b/native/spark-expr/src/error.rs
@@ -17,11 +17,11 @@
use arrow::error::ArrowError;
use datafusion::common::DataFusionError;
+use std::sync::Arc;
-#[derive(thiserror::Error, Debug)]
+#[derive(thiserror::Error, Debug, Clone)]
pub enum SparkError {
- // Note that this message format is based on Spark 3.4 and is more detailed than the message
- // returned by Spark 3.3
+ // This list was generated from the Spark code. Many of the exceptions are not yet used by Comet
#[error("[CAST_INVALID_INPUT] The value '{value}' of the type \"{from_type}\" cannot be cast to \"{to_type}\" \
because it is malformed. Correct the value as per the syntax, or change its target type. \
Use `try_cast` to tolerate malformed input and return NULL instead. If necessary \
@@ -32,7 +32,7 @@ pub enum SparkError {
to_type: String,
},
- #[error("[NUMERIC_VALUE_OUT_OF_RANGE] {value} cannot be represented as Decimal({precision}, {scale}). If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error, and return NULL instead.")]
+ #[error("[NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION] {value} cannot be represented as Decimal({precision}, {scale}). If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error, and return NULL instead.")]
NumericValueOutOfRange {
value: String,
precision: u8,
@@ -51,24 +51,646 @@ pub enum SparkError {
to_type: String,
},
+ #[error("[CANNOT_PARSE_DECIMAL] Cannot parse decimal.")]
+ CannotParseDecimal,
+
#[error("[ARITHMETIC_OVERFLOW] {from_type} overflow. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")]
ArithmeticOverflow { from_type: String },
+ #[error("[ARITHMETIC_OVERFLOW] Overflow in integral divide. Use `try_divide` to tolerate overflow and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")]
+ IntegralDivideOverflow,
+
+ #[error("[ARITHMETIC_OVERFLOW] Overflow in sum of decimals. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")]
+ DecimalSumOverflow,
+
#[error("[DIVIDE_BY_ZERO] Division by zero. Use `try_divide` to tolerate divisor being 0 and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")]
DivideByZero,
+ #[error("[REMAINDER_BY_ZERO] Division by zero. Use `try_remainder` to tolerate divisor being 0 and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")]
+ RemainderByZero,
+
+ #[error("[INTERVAL_DIVIDED_BY_ZERO] Divide by zero in interval arithmetic.")]
+ IntervalDividedByZero,
+
+ #[error("[BINARY_ARITHMETIC_OVERFLOW] {value1} {symbol} {value2} caused overflow. Use `{function_name}` to tolerate overflow and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")]
+ BinaryArithmeticOverflow {
+ value1: String,
+ symbol: String,
+ value2: String,
+ function_name: String,
+ },
+
+ #[error("[INTERVAL_ARITHMETIC_OVERFLOW.WITH_SUGGESTION] Interval arithmetic overflow. Use `{function_name}` to tolerate overflow and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")]
+ IntervalArithmeticOverflowWithSuggestion { function_name: String },
+
+ #[error("[INTERVAL_ARITHMETIC_OVERFLOW.WITHOUT_SUGGESTION] Interval arithmetic overflow. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")]
+ IntervalArithmeticOverflowWithoutSuggestion,
+
+ #[error("[DATETIME_OVERFLOW] Datetime arithmetic overflow.")]
+ DatetimeOverflow,
+
+ #[error("[INVALID_ARRAY_INDEX] The index {index_value} is out of bounds. The array has {array_size} elements. Use the SQL function get() to tolerate accessing element at invalid index and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")]
+ InvalidArrayIndex { index_value: i32, array_size: i32 },
+
+ #[error("[INVALID_ARRAY_INDEX_IN_ELEMENT_AT] The index {index_value} is out of bounds. The array has {array_size} elements. Use try_element_at to tolerate accessing element at invalid index and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")]
+ InvalidElementAtIndex { index_value: i32, array_size: i32 },
+
+ #[error("[INVALID_BITMAP_POSITION] The bit position {bit_position} is out of bounds. The bitmap has {bitmap_num_bytes} bytes ({bitmap_num_bits} bits).")]
+ InvalidBitmapPosition {
+ bit_position: i64,
+ bitmap_num_bytes: i64,
+ bitmap_num_bits: i64,
+ },
+
+ #[error("[INVALID_INDEX_OF_ZERO] The index 0 is invalid. An index shall be either < 0 or > 0 (the first element has index 1).")]
+ InvalidIndexOfZero,
+
+ #[error("[DUPLICATED_MAP_KEY] Cannot create map with duplicate keys: {key}.")]
+ DuplicatedMapKey { key: String },
+
+ #[error("[NULL_MAP_KEY] Cannot use null as map key.")]
+ NullMapKey,
+
+ #[error("[MAP_KEY_VALUE_DIFF_SIZES] The key array and value array of a map must have the same length.")]
+ MapKeyValueDiffSizes,
+
+ #[error("[EXCEED_LIMIT_LENGTH] Cannot create a map with {size} elements which exceeds the limit {max_size}.")]
+ ExceedMapSizeLimit { size: i32, max_size: i32 },
+
+ #[error("[COLLECTION_SIZE_LIMIT_EXCEEDED] Cannot create array with {num_elements} elements which exceeds the limit {max_elements}.")]
+ CollectionSizeLimitExceeded {
+ num_elements: i64,
+ max_elements: i64,
+ },
+
+ #[error("[NOT_NULL_ASSERT_VIOLATION] The field `{field_name}` cannot be null.")]
+ NotNullAssertViolation { field_name: String },
+
+ #[error("[VALUE_IS_NULL] The value of field `{field_name}` at row {row_index} is null.")]
+ ValueIsNull { field_name: String, row_index: i32 },
+
+ #[error("[CANNOT_PARSE_TIMESTAMP] Cannot parse timestamp: {message}. Try using `{suggested_func}` instead.")]
+ CannotParseTimestamp {
+ message: String,
+ suggested_func: String,
+ },
+
+ #[error("[INVALID_FRACTION_OF_SECOND] The fraction of second {value} is invalid. Valid values are in the range [0, 60]. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")]
+ InvalidFractionOfSecond { value: f64 },
+
+ #[error("[INVALID_UTF8_STRING] Invalid UTF-8 string: {hex_string}.")]
+ InvalidUtf8String { hex_string: String },
+
+ #[error("[UNEXPECTED_POSITIVE_VALUE] The {parameter_name} parameter must be less than or equal to 0. The actual value is {actual_value}.")]
+ UnexpectedPositiveValue {
+ parameter_name: String,
+ actual_value: i32,
+ },
+
+ #[error("[UNEXPECTED_NEGATIVE_VALUE] The {parameter_name} parameter must be greater than or equal to 0. The actual value is {actual_value}.")]
+ UnexpectedNegativeValue {
+ parameter_name: String,
+ actual_value: i32,
+ },
+
+ #[error("[INVALID_PARAMETER_VALUE] Invalid regex group index {group_index} in function `{function_name}`. Group count is {group_count}.")]
+ InvalidRegexGroupIndex {
+ function_name: String,
+ group_count: i32,
+ group_index: i32,
+ },
+
+ #[error("[DATATYPE_CANNOT_ORDER] Cannot order by type: {data_type}.")]
+ DatatypeCannotOrder { data_type: String },
+
+ #[error("[SCALAR_SUBQUERY_TOO_MANY_ROWS] Scalar subquery returned more than one row.")]
+ ScalarSubqueryTooManyRows,
+
#[error("ArrowError: {0}.")]
- Arrow(ArrowError),
+ Arrow(Arc),
#[error("InternalError: {0}.")]
Internal(String),
}
+impl SparkError {
+ /// Serialize this error to JSON format for JNI transfer
+ pub fn to_json(&self) -> String {
+ let error_class = self.error_class().unwrap_or("");
+
+ // Create a JSON structure with errorType, errorClass, and params
+ match serde_json::to_string(&serde_json::json!({
+ "errorType": self.error_type_name(),
+ "errorClass": error_class,
+ "params": self.params_as_json(),
+ })) {
+ Ok(json) => json,
+ Err(e) => {
+ // Fallback if serialization fails
+ format!(
+ "{{\"errorType\":\"SerializationError\",\"message\":\"{}\"}}",
+ e
+ )
+ }
+ }
+ }
+
+ /// Get the error type name for JSON serialization
+ fn error_type_name(&self) -> &'static str {
+ match self {
+ SparkError::CastInvalidValue { .. } => "CastInvalidValue",
+ SparkError::NumericValueOutOfRange { .. } => "NumericValueOutOfRange",
+ SparkError::NumericOutOfRange { .. } => "NumericOutOfRange",
+ SparkError::CastOverFlow { .. } => "CastOverFlow",
+ SparkError::CannotParseDecimal => "CannotParseDecimal",
+ SparkError::ArithmeticOverflow { .. } => "ArithmeticOverflow",
+ SparkError::IntegralDivideOverflow => "IntegralDivideOverflow",
+ SparkError::DecimalSumOverflow => "DecimalSumOverflow",
+ SparkError::DivideByZero => "DivideByZero",
+ SparkError::RemainderByZero => "RemainderByZero",
+ SparkError::IntervalDividedByZero => "IntervalDividedByZero",
+ SparkError::BinaryArithmeticOverflow { .. } => "BinaryArithmeticOverflow",
+ SparkError::IntervalArithmeticOverflowWithSuggestion { .. } => {
+ "IntervalArithmeticOverflowWithSuggestion"
+ }
+ SparkError::IntervalArithmeticOverflowWithoutSuggestion => {
+ "IntervalArithmeticOverflowWithoutSuggestion"
+ }
+ SparkError::DatetimeOverflow => "DatetimeOverflow",
+ SparkError::InvalidArrayIndex { .. } => "InvalidArrayIndex",
+ SparkError::InvalidElementAtIndex { .. } => "InvalidElementAtIndex",
+ SparkError::InvalidBitmapPosition { .. } => "InvalidBitmapPosition",
+ SparkError::InvalidIndexOfZero => "InvalidIndexOfZero",
+ SparkError::DuplicatedMapKey { .. } => "DuplicatedMapKey",
+ SparkError::NullMapKey => "NullMapKey",
+ SparkError::MapKeyValueDiffSizes => "MapKeyValueDiffSizes",
+ SparkError::ExceedMapSizeLimit { .. } => "ExceedMapSizeLimit",
+ SparkError::CollectionSizeLimitExceeded { .. } => "CollectionSizeLimitExceeded",
+ SparkError::NotNullAssertViolation { .. } => "NotNullAssertViolation",
+ SparkError::ValueIsNull { .. } => "ValueIsNull",
+ SparkError::CannotParseTimestamp { .. } => "CannotParseTimestamp",
+ SparkError::InvalidFractionOfSecond { .. } => "InvalidFractionOfSecond",
+ SparkError::InvalidUtf8String { .. } => "InvalidUtf8String",
+ SparkError::UnexpectedPositiveValue { .. } => "UnexpectedPositiveValue",
+ SparkError::UnexpectedNegativeValue { .. } => "UnexpectedNegativeValue",
+ SparkError::InvalidRegexGroupIndex { .. } => "InvalidRegexGroupIndex",
+ SparkError::DatatypeCannotOrder { .. } => "DatatypeCannotOrder",
+ SparkError::ScalarSubqueryTooManyRows => "ScalarSubqueryTooManyRows",
+ SparkError::Arrow(_) => "Arrow",
+ SparkError::Internal(_) => "Internal",
+ }
+ }
+
+ /// Extract parameters as JSON value
+ fn params_as_json(&self) -> serde_json::Value {
+ match self {
+ SparkError::CastInvalidValue {
+ value,
+ from_type,
+ to_type,
+ } => {
+ serde_json::json!({
+ "value": value,
+ "fromType": from_type,
+ "toType": to_type,
+ })
+ }
+ SparkError::NumericValueOutOfRange {
+ value,
+ precision,
+ scale,
+ } => {
+ serde_json::json!({
+ "value": value,
+ "precision": precision,
+ "scale": scale,
+ })
+ }
+ SparkError::NumericOutOfRange { value } => {
+ serde_json::json!({
+ "value": value,
+ })
+ }
+ SparkError::CastOverFlow {
+ value,
+ from_type,
+ to_type,
+ } => {
+ serde_json::json!({
+ "value": value,
+ "fromType": from_type,
+ "toType": to_type,
+ })
+ }
+ SparkError::ArithmeticOverflow { from_type } => {
+ serde_json::json!({
+ "fromType": from_type,
+ })
+ }
+ SparkError::BinaryArithmeticOverflow {
+ value1,
+ symbol,
+ value2,
+ function_name,
+ } => {
+ serde_json::json!({
+ "value1": value1,
+ "symbol": symbol,
+ "value2": value2,
+ "functionName": function_name,
+ })
+ }
+ SparkError::IntervalArithmeticOverflowWithSuggestion { function_name } => {
+ serde_json::json!({
+ "functionName": function_name,
+ })
+ }
+ SparkError::InvalidArrayIndex {
+ index_value,
+ array_size,
+ } => {
+ serde_json::json!({
+ "indexValue": index_value,
+ "arraySize": array_size,
+ })
+ }
+ SparkError::InvalidElementAtIndex {
+ index_value,
+ array_size,
+ } => {
+ serde_json::json!({
+ "indexValue": index_value,
+ "arraySize": array_size,
+ })
+ }
+ SparkError::InvalidBitmapPosition {
+ bit_position,
+ bitmap_num_bytes,
+ bitmap_num_bits,
+ } => {
+ serde_json::json!({
+ "bitPosition": bit_position,
+ "bitmapNumBytes": bitmap_num_bytes,
+ "bitmapNumBits": bitmap_num_bits,
+ })
+ }
+ SparkError::DuplicatedMapKey { key } => {
+ serde_json::json!({
+ "key": key,
+ })
+ }
+ SparkError::ExceedMapSizeLimit { size, max_size } => {
+ serde_json::json!({
+ "size": size,
+ "maxSize": max_size,
+ })
+ }
+ SparkError::CollectionSizeLimitExceeded {
+ num_elements,
+ max_elements,
+ } => {
+ serde_json::json!({
+ "numElements": num_elements,
+ "maxElements": max_elements,
+ })
+ }
+ SparkError::NotNullAssertViolation { field_name } => {
+ serde_json::json!({
+ "fieldName": field_name,
+ })
+ }
+ SparkError::ValueIsNull {
+ field_name,
+ row_index,
+ } => {
+ serde_json::json!({
+ "fieldName": field_name,
+ "rowIndex": row_index,
+ })
+ }
+ SparkError::CannotParseTimestamp {
+ message,
+ suggested_func,
+ } => {
+ serde_json::json!({
+ "message": message,
+ "suggestedFunc": suggested_func,
+ })
+ }
+ SparkError::InvalidFractionOfSecond { value } => {
+ serde_json::json!({
+ "value": value,
+ })
+ }
+ SparkError::InvalidUtf8String { hex_string } => {
+ serde_json::json!({
+ "hexString": hex_string,
+ })
+ }
+ SparkError::UnexpectedPositiveValue {
+ parameter_name,
+ actual_value,
+ } => {
+ serde_json::json!({
+ "parameterName": parameter_name,
+ "actualValue": actual_value,
+ })
+ }
+ SparkError::UnexpectedNegativeValue {
+ parameter_name,
+ actual_value,
+ } => {
+ serde_json::json!({
+ "parameterName": parameter_name,
+ "actualValue": actual_value,
+ })
+ }
+ SparkError::InvalidRegexGroupIndex {
+ function_name,
+ group_count,
+ group_index,
+ } => {
+ serde_json::json!({
+ "functionName": function_name,
+ "groupCount": group_count,
+ "groupIndex": group_index,
+ })
+ }
+ SparkError::DatatypeCannotOrder { data_type } => {
+ serde_json::json!({
+ "dataType": data_type,
+ })
+ }
+ SparkError::Arrow(e) => {
+ serde_json::json!({
+ "message": e.to_string(),
+ })
+ }
+ SparkError::Internal(msg) => {
+ serde_json::json!({
+ "message": msg,
+ })
+ }
+ // Simple errors with no parameters
+ _ => serde_json::json!({}),
+ }
+ }
+
+ /// Returns the appropriate Spark exception class for this error
+ pub fn exception_class(&self) -> &'static str {
+ match self {
+ // ArithmeticException
+ SparkError::DivideByZero
+ | SparkError::RemainderByZero
+ | SparkError::IntervalDividedByZero
+ | SparkError::NumericValueOutOfRange { .. }
+ | SparkError::NumericOutOfRange { .. } // Comet-specific extension
+ | SparkError::ArithmeticOverflow { .. }
+ | SparkError::IntegralDivideOverflow
+ | SparkError::DecimalSumOverflow
+ | SparkError::BinaryArithmeticOverflow { .. }
+ | SparkError::IntervalArithmeticOverflowWithSuggestion { .. }
+ | SparkError::IntervalArithmeticOverflowWithoutSuggestion
+ | SparkError::DatetimeOverflow => "org/apache/spark/SparkArithmeticException",
+
+ // CastOverflow gets special handling with CastOverflowException
+ SparkError::CastOverFlow { .. } => "org/apache/spark/sql/comet/CastOverflowException",
+
+ // NumberFormatException (for cast invalid input errors)
+ SparkError::CastInvalidValue { .. } => "org/apache/spark/SparkNumberFormatException",
+
+ // ArrayIndexOutOfBoundsException
+ SparkError::InvalidArrayIndex { .. }
+ | SparkError::InvalidElementAtIndex { .. }
+ | SparkError::InvalidBitmapPosition { .. }
+ | SparkError::InvalidIndexOfZero => "org/apache/spark/SparkArrayIndexOutOfBoundsException",
+
+ // RuntimeException
+ SparkError::CannotParseDecimal
+ | SparkError::DuplicatedMapKey { .. }
+ | SparkError::NullMapKey
+ | SparkError::MapKeyValueDiffSizes
+ | SparkError::ExceedMapSizeLimit { .. }
+ | SparkError::CollectionSizeLimitExceeded { .. }
+ | SparkError::NotNullAssertViolation { .. }
+ | SparkError::ValueIsNull { .. } // Comet-specific extension
+ | SparkError::UnexpectedPositiveValue { .. }
+ | SparkError::UnexpectedNegativeValue { .. }
+ | SparkError::InvalidRegexGroupIndex { .. }
+ | SparkError::ScalarSubqueryTooManyRows => "org/apache/spark/SparkRuntimeException",
+
+ // DateTimeException
+ SparkError::CannotParseTimestamp { .. }
+ | SparkError::InvalidFractionOfSecond { .. } => "org/apache/spark/SparkDateTimeException",
+
+ // IllegalArgumentException
+ SparkError::DatatypeCannotOrder { .. }
+ | SparkError::InvalidUtf8String { .. } => "org/apache/spark/SparkIllegalArgumentException",
+
+ // Generic errors
+ SparkError::Arrow(_) | SparkError::Internal(_) => "org/apache/spark/SparkException",
+ }
+ }
+
+ /// Returns the Spark error class code for this error
+ pub fn error_class(&self) -> Option<&'static str> {
+ match self {
+ // Cast errors
+ SparkError::CastInvalidValue { .. } => Some("CAST_INVALID_INPUT"),
+ SparkError::CastOverFlow { .. } => Some("CAST_OVERFLOW"),
+ SparkError::NumericValueOutOfRange { .. } => {
+ Some("NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION")
+ }
+ SparkError::NumericOutOfRange { .. } => Some("NUMERIC_OUT_OF_SUPPORTED_RANGE"),
+ SparkError::CannotParseDecimal => Some("CANNOT_PARSE_DECIMAL"),
+
+ // Arithmetic errors
+ SparkError::DivideByZero => Some("DIVIDE_BY_ZERO"),
+ SparkError::RemainderByZero => Some("REMAINDER_BY_ZERO"),
+ SparkError::IntervalDividedByZero => Some("INTERVAL_DIVIDED_BY_ZERO"),
+ SparkError::ArithmeticOverflow { .. } => Some("ARITHMETIC_OVERFLOW"),
+ SparkError::IntegralDivideOverflow => Some("ARITHMETIC_OVERFLOW"),
+ SparkError::DecimalSumOverflow => Some("ARITHMETIC_OVERFLOW"),
+ SparkError::BinaryArithmeticOverflow { .. } => Some("BINARY_ARITHMETIC_OVERFLOW"),
+ SparkError::IntervalArithmeticOverflowWithSuggestion { .. } => {
+ Some("INTERVAL_ARITHMETIC_OVERFLOW")
+ }
+ SparkError::IntervalArithmeticOverflowWithoutSuggestion => {
+ Some("INTERVAL_ARITHMETIC_OVERFLOW")
+ }
+ SparkError::DatetimeOverflow => Some("DATETIME_OVERFLOW"),
+
+ // Array index errors
+ SparkError::InvalidArrayIndex { .. } => Some("INVALID_ARRAY_INDEX"),
+ SparkError::InvalidElementAtIndex { .. } => Some("INVALID_ARRAY_INDEX_IN_ELEMENT_AT"),
+ SparkError::InvalidBitmapPosition { .. } => Some("INVALID_BITMAP_POSITION"),
+ SparkError::InvalidIndexOfZero => Some("INVALID_INDEX_OF_ZERO"),
+
+ // Map/Collection errors
+ SparkError::DuplicatedMapKey { .. } => Some("DUPLICATED_MAP_KEY"),
+ SparkError::NullMapKey => Some("NULL_MAP_KEY"),
+ SparkError::MapKeyValueDiffSizes => Some("MAP_KEY_VALUE_DIFF_SIZES"),
+ SparkError::ExceedMapSizeLimit { .. } => Some("EXCEED_LIMIT_LENGTH"),
+ SparkError::CollectionSizeLimitExceeded { .. } => {
+ Some("COLLECTION_SIZE_LIMIT_EXCEEDED")
+ }
+
+ // Null validation errors
+ SparkError::NotNullAssertViolation { .. } => Some("NOT_NULL_ASSERT_VIOLATION"),
+ SparkError::ValueIsNull { .. } => Some("VALUE_IS_NULL"),
+
+ // DateTime errors
+ SparkError::CannotParseTimestamp { .. } => Some("CANNOT_PARSE_TIMESTAMP"),
+ SparkError::InvalidFractionOfSecond { .. } => Some("INVALID_FRACTION_OF_SECOND"),
+
+ // String/UTF8 errors
+ SparkError::InvalidUtf8String { .. } => Some("INVALID_UTF8_STRING"),
+
+ // Function parameter errors
+ SparkError::UnexpectedPositiveValue { .. } => Some("UNEXPECTED_POSITIVE_VALUE"),
+ SparkError::UnexpectedNegativeValue { .. } => Some("UNEXPECTED_NEGATIVE_VALUE"),
+
+ // Regex errors
+ SparkError::InvalidRegexGroupIndex { .. } => Some("INVALID_PARAMETER_VALUE"),
+
+ // Unsupported operation errors
+ SparkError::DatatypeCannotOrder { .. } => Some("DATATYPE_CANNOT_ORDER"),
+
+ // Subquery errors
+ SparkError::ScalarSubqueryTooManyRows => Some("SCALAR_SUBQUERY_TOO_MANY_ROWS"),
+
+ // Generic errors (no error class)
+ SparkError::Arrow(_) | SparkError::Internal(_) => None,
+ }
+ }
+}
+
+/// Convert decimal overflow to SparkError::NumericValueOutOfRange.
+///
+/// Creates the appropriate SparkError when a decimal value exceeds the precision limit for Decimal128 storage.
+///
+/// # Arguments
+/// * `value` - The i128 decimal value that overflowed
+/// * `precision` - The target precision
+/// * `scale` - The scale of the decimal
+///
+/// # Returns
+/// SparkError::NumericValueOutOfRange with the value, precision, and scale
+pub fn decimal_overflow_error(value: i128, precision: u8, scale: i8) -> SparkError {
+ SparkError::NumericValueOutOfRange {
+ value: value.to_string(),
+ precision,
+ scale,
+ }
+}
+
pub type SparkResult = Result;
+/// Wrapper that adds QueryContext to SparkError
+///
+/// This allows attaching SQL context information (query text, line/position, object name) to errors
+#[derive(Debug, Clone)]
+pub struct SparkErrorWithContext {
+ /// The underlying SparkError
+ pub error: SparkError,
+ /// Optional QueryContext for SQL location information
+ pub context: Option>,
+}
+
+impl SparkErrorWithContext {
+ /// Create a SparkErrorWithContext without context
+ pub fn new(error: SparkError) -> Self {
+ Self {
+ error,
+ context: None,
+ }
+ }
+
+ /// Create a SparkErrorWithContext with QueryContext
+ pub fn with_context(error: SparkError, context: Arc) -> Self {
+ Self {
+ error,
+ context: Some(context),
+ }
+ }
+
+ /// Serialize to JSON including optional context field
+ ///
+ /// JSON structure:
+ /// ```json
+ /// {
+ /// "errorType": "DivideByZero",
+ /// "errorClass": "DIVIDE_BY_ZERO",
+ /// "params": {},
+ /// "context": {
+ /// "sqlText": "SELECT a/b FROM t",
+ /// "startIndex": 7,
+ /// "stopIndex": 9,
+ /// "line": 1,
+ /// "startPosition": 7
+ /// },
+ /// "summary": "== SQL (line 1, position 8) ==\n..."
+ /// }
+ /// ```
+ pub fn to_json(&self) -> String {
+ let mut json_obj = serde_json::json!({
+ "errorType": self.error.error_type_name(),
+ "errorClass": self.error.error_class().unwrap_or(""),
+ "params": self.error.params_as_json(),
+ });
+
+ if let Some(ctx) = &self.context {
+ // Serialize context fields
+ json_obj["context"] = serde_json::json!({
+ "sqlText": ctx.sql_text.as_str(),
+ "startIndex": ctx.start_index,
+ "stopIndex": ctx.stop_index,
+ "objectType": ctx.object_type,
+ "objectName": ctx.object_name,
+ "line": ctx.line,
+ "startPosition": ctx.start_position,
+ });
+
+ // Add formatted summary
+ json_obj["summary"] = serde_json::json!(ctx.format_summary());
+ }
+
+ serde_json::to_string(&json_obj).unwrap_or_else(|e| {
+ format!(
+ "{{\"errorType\":\"SerializationError\",\"message\":\"{}\"}}",
+ e
+ )
+ })
+ }
+}
+
+impl std::fmt::Display for SparkErrorWithContext {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(f, "{}", self.error)?;
+ if let Some(ctx) = &self.context {
+ write!(f, "\n{}", ctx.format_summary())?;
+ }
+ Ok(())
+ }
+}
+
+impl std::error::Error for SparkErrorWithContext {}
+
+impl From for SparkErrorWithContext {
+ fn from(error: SparkError) -> Self {
+ SparkErrorWithContext::new(error)
+ }
+}
+
+impl From for DataFusionError {
+ fn from(value: SparkErrorWithContext) -> Self {
+ DataFusionError::External(Box::new(value))
+ }
+}
+
impl From for SparkError {
fn from(value: ArrowError) -> Self {
- SparkError::Arrow(value)
+ SparkError::Arrow(Arc::new(value))
}
}
@@ -77,3 +699,171 @@ impl From for DataFusionError {
DataFusionError::External(Box::new(value))
}
}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_divide_by_zero_json() {
+ let error = SparkError::DivideByZero;
+ let json = error.to_json();
+
+ assert!(json.contains("\"errorType\":\"DivideByZero\""));
+ assert!(json.contains("\"errorClass\":\"DIVIDE_BY_ZERO\""));
+
+ // Verify it's valid JSON
+ let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
+ assert_eq!(parsed["errorType"], "DivideByZero");
+ assert_eq!(parsed["errorClass"], "DIVIDE_BY_ZERO");
+ }
+
+ #[test]
+ fn test_remainder_by_zero_json() {
+ let error = SparkError::RemainderByZero;
+ let json = error.to_json();
+
+ assert!(json.contains("\"errorType\":\"RemainderByZero\""));
+ assert!(json.contains("\"errorClass\":\"REMAINDER_BY_ZERO\""));
+ }
+
+ #[test]
+ fn test_binary_overflow_json() {
+ let error = SparkError::BinaryArithmeticOverflow {
+ value1: "32767".to_string(),
+ symbol: "+".to_string(),
+ value2: "1".to_string(),
+ function_name: "try_add".to_string(),
+ };
+ let json = error.to_json();
+
+ // Verify it's valid JSON
+ let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
+ assert_eq!(parsed["errorType"], "BinaryArithmeticOverflow");
+ assert_eq!(parsed["errorClass"], "BINARY_ARITHMETIC_OVERFLOW");
+ assert_eq!(parsed["params"]["value1"], "32767");
+ assert_eq!(parsed["params"]["symbol"], "+");
+ assert_eq!(parsed["params"]["value2"], "1");
+ assert_eq!(parsed["params"]["functionName"], "try_add");
+ }
+
+ #[test]
+ fn test_invalid_array_index_json() {
+ let error = SparkError::InvalidArrayIndex {
+ index_value: 10,
+ array_size: 3,
+ };
+ let json = error.to_json();
+
+ let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
+ assert_eq!(parsed["errorType"], "InvalidArrayIndex");
+ assert_eq!(parsed["errorClass"], "INVALID_ARRAY_INDEX");
+ assert_eq!(parsed["params"]["indexValue"], 10);
+ assert_eq!(parsed["params"]["arraySize"], 3);
+ }
+
+ #[test]
+ fn test_numeric_value_out_of_range_json() {
+ let error = SparkError::NumericValueOutOfRange {
+ value: "999.99".to_string(),
+ precision: 5,
+ scale: 2,
+ };
+ let json = error.to_json();
+
+ let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
+ assert_eq!(parsed["errorType"], "NumericValueOutOfRange");
+ assert_eq!(
+ parsed["errorClass"],
+ "NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION"
+ );
+ assert_eq!(parsed["params"]["value"], "999.99");
+ assert_eq!(parsed["params"]["precision"], 5);
+ assert_eq!(parsed["params"]["scale"], 2);
+ }
+
+ #[test]
+ fn test_cast_invalid_value_json() {
+ let error = SparkError::CastInvalidValue {
+ value: "abc".to_string(),
+ from_type: "STRING".to_string(),
+ to_type: "INT".to_string(),
+ };
+ let json = error.to_json();
+
+ let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
+ assert_eq!(parsed["errorType"], "CastInvalidValue");
+ assert_eq!(parsed["errorClass"], "CAST_INVALID_INPUT");
+ assert_eq!(parsed["params"]["value"], "abc");
+ assert_eq!(parsed["params"]["fromType"], "STRING");
+ assert_eq!(parsed["params"]["toType"], "INT");
+ }
+
+ #[test]
+ fn test_duplicated_map_key_json() {
+ let error = SparkError::DuplicatedMapKey {
+ key: "duplicate_key".to_string(),
+ };
+ let json = error.to_json();
+
+ let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
+ assert_eq!(parsed["errorType"], "DuplicatedMapKey");
+ assert_eq!(parsed["errorClass"], "DUPLICATED_MAP_KEY");
+ assert_eq!(parsed["params"]["key"], "duplicate_key");
+ }
+
+ #[test]
+ fn test_null_map_key_json() {
+ let error = SparkError::NullMapKey;
+ let json = error.to_json();
+
+ let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
+ assert_eq!(parsed["errorType"], "NullMapKey");
+ assert_eq!(parsed["errorClass"], "NULL_MAP_KEY");
+ // Params should be an empty object
+ assert_eq!(parsed["params"], serde_json::json!({}));
+ }
+
+ #[test]
+ fn test_error_class_mapping() {
+ // Test that error_class() returns the correct error class
+ assert_eq!(
+ SparkError::DivideByZero.error_class(),
+ Some("DIVIDE_BY_ZERO")
+ );
+ assert_eq!(
+ SparkError::RemainderByZero.error_class(),
+ Some("REMAINDER_BY_ZERO")
+ );
+ assert_eq!(
+ SparkError::InvalidArrayIndex {
+ index_value: 0,
+ array_size: 0
+ }
+ .error_class(),
+ Some("INVALID_ARRAY_INDEX")
+ );
+ assert_eq!(SparkError::NullMapKey.error_class(), Some("NULL_MAP_KEY"));
+ }
+
+ #[test]
+ fn test_exception_class_mapping() {
+ // Test that exception_class() returns the correct Java exception class
+ assert_eq!(
+ SparkError::DivideByZero.exception_class(),
+ "org/apache/spark/SparkArithmeticException"
+ );
+ assert_eq!(
+ SparkError::InvalidArrayIndex {
+ index_value: 0,
+ array_size: 0
+ }
+ .exception_class(),
+ "org/apache/spark/SparkArrayIndexOutOfBoundsException"
+ );
+ assert_eq!(
+ SparkError::NullMapKey.exception_class(),
+ "org/apache/spark/SparkRuntimeException"
+ );
+ }
+}
diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs
index 40eb180ab8..368223c00f 100644
--- a/native/spark-expr/src/lib.rs
+++ b/native/spark-expr/src/lib.rs
@@ -20,6 +20,7 @@
#![deny(clippy::clone_on_ref_ptr)]
mod error;
+mod query_context;
pub mod kernels;
pub use kernels::temporal::date_trunc_dyn;
@@ -75,7 +76,7 @@ pub use datetime_funcs::{
SparkDateDiff, SparkDateTrunc, SparkHour, SparkMakeDate, SparkMinute, SparkSecond,
SparkUnixTimestamp, TimestampTruncExpr,
};
-pub use error::{SparkError, SparkResult};
+pub use error::{SparkError, SparkErrorWithContext, SparkResult};
pub use hash_funcs::*;
pub use json_funcs::{FromJson, ToJson};
pub use math_funcs::{
@@ -83,6 +84,7 @@ pub use math_funcs::{
spark_decimal_integral_div, spark_floor, spark_make_decimal, spark_round, spark_unhex,
spark_unscaled_value, CheckOverflow, NegativeExpr, NormalizeNaNAndZero,
};
+pub use query_context::{create_query_context_map, QueryContext, QueryContextMap};
pub use string_funcs::*;
/// Spark supports three evaluation modes when evaluating expressions, which affect
@@ -119,6 +121,10 @@ pub(crate) fn arithmetic_overflow_error(from_type: &str) -> SparkError {
}
}
+pub(crate) fn decimal_sum_overflow_error() -> SparkError {
+ SparkError::DecimalSumOverflow
+}
+
pub(crate) fn divide_by_zero_error() -> SparkError {
SparkError::DivideByZero
}
diff --git a/native/spark-expr/src/math_funcs/internal/checkoverflow.rs b/native/spark-expr/src/math_funcs/internal/checkoverflow.rs
index c7caab0594..a9e8f6748d 100644
--- a/native/spark-expr/src/math_funcs/internal/checkoverflow.rs
+++ b/native/spark-expr/src/math_funcs/internal/checkoverflow.rs
@@ -25,6 +25,8 @@ use datafusion::common::{DataFusionError, ScalarValue};
use datafusion::logical_expr::ColumnarValue;
use datafusion::physical_expr::PhysicalExpr;
use std::hash::Hash;
+
+use crate::SparkError;
use std::{
any::Any,
fmt::{Display, Formatter},
@@ -40,6 +42,8 @@ pub struct CheckOverflow {
pub child: Arc,
pub data_type: DataType,
pub fail_on_error: bool,
+ pub expr_id: Option,
+ pub query_context: Option>,
}
impl Hash for CheckOverflow {
@@ -59,11 +63,19 @@ impl PartialEq for CheckOverflow {
}
impl CheckOverflow {
- pub fn new(child: Arc, data_type: DataType, fail_on_error: bool) -> Self {
+ pub fn new(
+ child: Arc,
+ data_type: DataType,
+ fail_on_error: bool,
+ expr_id: Option,
+ query_context: Option>,
+ ) -> Self {
Self {
child,
data_type,
fail_on_error,
+ expr_id,
+ query_context,
}
}
}
@@ -113,8 +125,41 @@ impl PhysicalExpr for CheckOverflow {
let decimal_array = as_primitive_array::(&array);
let casted_array = if self.fail_on_error {
- // Returning error if overflow
- decimal_array.validate_decimal_precision(*precision)?;
+ // Returning error if overflow - convert decimal overflow to SparkError
+ decimal_array
+ .validate_decimal_precision(*precision)
+ .map_err(|e| {
+ if matches!(e, arrow::error::ArrowError::InvalidArgumentError(_))
+ && e.to_string().contains("too large to store in a Decimal128") {
+ // Find the first overflowing value
+ let overflow_value = decimal_array
+ .iter()
+ .find(|v| {
+ if let Some(val) = v {
+ arrow::array::types::Decimal128Type::validate_decimal_precision(
+ *val, *precision, *scale
+ ).is_err()
+ } else {
+ false
+ }
+ })
+ .and_then(|v| v)
+ .unwrap_or(0);
+
+ let spark_error = crate::error::decimal_overflow_error(overflow_value, *precision, *scale);
+
+ // Wrap with query_context if present
+ if let Some(ctx) = &self.query_context {
+ DataFusionError::External(Box::new(
+ crate::SparkErrorWithContext::with_context(spark_error, Arc::clone(ctx))
+ ))
+ } else {
+ DataFusionError::External(Box::new(spark_error))
+ }
+ } else {
+ DataFusionError::ArrowError(Box::new(e), None)
+ }
+ })?;
decimal_array
} else {
// Overflowing gets null value
@@ -123,7 +168,33 @@ impl PhysicalExpr for CheckOverflow {
let new_array = Decimal128Array::from(casted_array.into_data())
.with_precision_and_scale(*precision, *scale)
- .map(|a| Arc::new(a) as ArrayRef)?;
+ .map(|a| Arc::new(a) as ArrayRef)
+ .map_err(|e| {
+ if matches!(e, arrow::error::ArrowError::InvalidArgumentError(_))
+ && e.to_string().contains("too large to store in a Decimal128")
+ {
+ // Fallback error handling
+ let spark_error = SparkError::NumericValueOutOfRange {
+ value: "overflow".to_string(),
+ precision: *precision,
+ scale: *scale,
+ };
+
+ // Wrap with query_context if present
+ if let Some(ctx) = &self.query_context {
+ DataFusionError::External(Box::new(
+ crate::SparkErrorWithContext::with_context(
+ spark_error,
+ Arc::clone(ctx),
+ ),
+ ))
+ } else {
+ DataFusionError::External(Box::new(spark_error))
+ }
+ } else {
+ DataFusionError::ArrowError(Box::new(e), None)
+ }
+ })?;
Ok(ColumnarValue::Array(new_array))
}
@@ -163,6 +234,8 @@ impl PhysicalExpr for CheckOverflow {
Arc::clone(&children[0]),
self.data_type.clone(),
self.fail_on_error,
+ self.expr_id,
+ self.query_context.clone(),
)))
}
}
diff --git a/native/spark-expr/src/math_funcs/modulo_expr.rs b/native/spark-expr/src/math_funcs/modulo_expr.rs
index 733d653d2c..5df14548e2 100644
--- a/native/spark-expr/src/math_funcs/modulo_expr.rs
+++ b/native/spark-expr/src/math_funcs/modulo_expr.rs
@@ -93,11 +93,15 @@ pub fn create_modulo_expr(
left,
DataType::Decimal256(p1, s1),
SparkCastOptions::new_without_timezone(EvalMode::Legacy, false),
+ None,
+ None,
));
let right_256 = Arc::new(Cast::new(
right_non_ansi_safe,
DataType::Decimal256(p2, s2),
SparkCastOptions::new_without_timezone(EvalMode::Legacy, false),
+ None,
+ None,
));
// The UDF's return type must match what Arrow's rem function will actually return.
@@ -118,6 +122,8 @@ pub fn create_modulo_expr(
modulo_scalar_func,
data_type,
SparkCastOptions::new_without_timezone(EvalMode::Legacy, false),
+ None,
+ None,
)))
}
_ => create_modulo_scalar_function(
diff --git a/native/spark-expr/src/query_context.rs b/native/spark-expr/src/query_context.rs
new file mode 100644
index 0000000000..e6591135e0
--- /dev/null
+++ b/native/spark-expr/src/query_context.rs
@@ -0,0 +1,402 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Query execution context for error reporting
+//!
+//! This module provides QueryContext which mirrors Spark's SQLQueryContext
+//! for providing SQL text, line/position information, and error location
+//! pointers in exception messages.
+
+use serde::{Deserialize, Serialize};
+use std::sync::Arc;
+
+/// Based on Spark's SQLQueryContext for error reporting.
+///
+/// Contains information about where an error occurred in a SQL query,
+/// including the full SQL text, line/column positions, and object context.
+#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
+pub struct QueryContext {
+ /// Full SQL query text
+ #[serde(rename = "sqlText")]
+ pub sql_text: Arc,
+
+ /// Start offset in SQL text (0-based, character index)
+ #[serde(rename = "startIndex")]
+ pub start_index: i32,
+
+ /// Stop offset in SQL text (0-based, character index, inclusive)
+ #[serde(rename = "stopIndex")]
+ pub stop_index: i32,
+
+ /// Object type (e.g., "VIEW", "Project", "Filter")
+ #[serde(rename = "objectType", skip_serializing_if = "Option::is_none")]
+ pub object_type: Option,
+
+ /// Object name (e.g., view name, column name)
+ #[serde(rename = "objectName", skip_serializing_if = "Option::is_none")]
+ pub object_name: Option,
+
+ /// Line number in SQL query (1-based)
+ pub line: i32,
+
+ /// Column position within the line (0-based)
+ #[serde(rename = "startPosition")]
+ pub start_position: i32,
+}
+
+impl QueryContext {
+ #[allow(clippy::too_many_arguments)]
+ pub fn new(
+ sql_text: String,
+ start_index: i32,
+ stop_index: i32,
+ object_type: Option,
+ object_name: Option,
+ line: i32,
+ start_position: i32,
+ ) -> Self {
+ Self {
+ sql_text: Arc::new(sql_text),
+ start_index,
+ stop_index,
+ object_type,
+ object_name,
+ line,
+ start_position,
+ }
+ }
+
+ /// Convert a character index to a byte offset in the SQL text.
+ /// Returns None if the character index is out of range.
+ fn char_index_to_byte_offset(&self, char_index: usize) -> Option {
+ self.sql_text
+ .char_indices()
+ .nth(char_index)
+ .map(|(byte_offset, _)| byte_offset)
+ }
+
+ /// Generate a summary string showing SQL fragment with error location.
+ /// (From SQLQueryContext.summary)
+ ///
+ /// Format example:
+ /// ```text
+ /// == SQL of VIEW v1 (line 1, position 8) ==
+ /// SELECT a/b FROM t
+ /// ^^^
+ /// ```
+ pub fn format_summary(&self) -> String {
+ let start_char = self.start_index.max(0) as usize;
+ // stop_index is inclusive; fragment covers [start, stop]
+ let stop_char = (self.stop_index + 1).max(0) as usize;
+
+ let fragment = match (
+ self.char_index_to_byte_offset(start_char),
+ // stop_char may equal sql_text.chars().count() (one past the end)
+ self.char_index_to_byte_offset(stop_char).or_else(|| {
+ if stop_char == self.sql_text.chars().count() {
+ Some(self.sql_text.len())
+ } else {
+ None
+ }
+ }),
+ ) {
+ (Some(start_byte), Some(stop_byte)) => &self.sql_text[start_byte..stop_byte],
+ _ => "",
+ };
+
+ // Build the header line
+ let mut summary = String::from("== SQL");
+
+ if let Some(obj_type) = &self.object_type {
+ if !obj_type.is_empty() {
+ summary.push_str(" of ");
+ summary.push_str(obj_type);
+
+ if let Some(obj_name) = &self.object_name {
+ if !obj_name.is_empty() {
+ summary.push(' ');
+ summary.push_str(obj_name);
+ }
+ }
+ }
+ }
+
+ summary.push_str(&format!(
+ " (line {}, position {}) ==\n",
+ self.line,
+ self.start_position + 1 // Convert 0-based to 1-based for display
+ ));
+
+ // Add the SQL text with fragment highlighted
+ summary.push_str(&self.sql_text);
+ summary.push('\n');
+
+ // Add caret pointer
+ let caret_position = self.start_position.max(0) as usize;
+ summary.push_str(&" ".repeat(caret_position));
+ // fragment.chars().count() gives the correct display width for non-ASCII
+ summary.push_str(&"^".repeat(fragment.chars().count().max(1)));
+
+ summary
+ }
+
+ /// Returns the SQL fragment that caused the error.
+ pub fn fragment(&self) -> String {
+ let start_char = self.start_index.max(0) as usize;
+ let stop_char = (self.stop_index + 1).max(0) as usize;
+
+ match (
+ self.char_index_to_byte_offset(start_char),
+ self.char_index_to_byte_offset(stop_char).or_else(|| {
+ if stop_char == self.sql_text.chars().count() {
+ Some(self.sql_text.len())
+ } else {
+ None
+ }
+ }),
+ ) {
+ (Some(start_byte), Some(stop_byte)) => self.sql_text[start_byte..stop_byte].to_string(),
+ _ => String::new(),
+ }
+ }
+}
+
+use std::collections::HashMap;
+use std::sync::RwLock;
+
+/// Map that stores QueryContext information for expressions during execution.
+///
+/// This map is populated during plan deserialization and accessed
+/// during error creation to attach SQL context to exceptions.
+#[derive(Debug)]
+pub struct QueryContextMap {
+ /// Map from expression ID to QueryContext
+ contexts: RwLock>>,
+}
+
+impl QueryContextMap {
+ pub fn new() -> Self {
+ Self {
+ contexts: RwLock::new(HashMap::new()),
+ }
+ }
+
+ /// Register a QueryContext for an expression ID.
+ ///
+ /// If the expression ID already exists, it will be replaced.
+ ///
+ /// # Arguments
+ /// * `expr_id` - Unique expression identifier from protobuf
+ /// * `context` - QueryContext containing SQL text and position info
+ pub fn register(&self, expr_id: u64, context: QueryContext) {
+ let mut contexts = self.contexts.write().unwrap();
+ contexts.insert(expr_id, Arc::new(context));
+ }
+
+ /// Get the QueryContext for an expression ID.
+ ///
+ /// Returns None if no context is registered for this expression.
+ ///
+ /// # Arguments
+ /// * `expr_id` - Expression identifier to look up
+ pub fn get(&self, expr_id: u64) -> Option> {
+ let contexts = self.contexts.read().unwrap();
+ contexts.get(&expr_id).cloned()
+ }
+
+ /// Clear all registered contexts.
+ ///
+ /// This is typically called after plan execution completes to free memory.
+ pub fn clear(&self) {
+ let mut contexts = self.contexts.write().unwrap();
+ contexts.clear();
+ }
+
+ /// Return the number of registered contexts (for debugging/testing)
+ pub fn len(&self) -> usize {
+ let contexts = self.contexts.read().unwrap();
+ contexts.len()
+ }
+
+ /// Check if the map is empty
+ pub fn is_empty(&self) -> bool {
+ self.len() == 0
+ }
+}
+
+impl Default for QueryContextMap {
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
+/// Create a new session-scoped QueryContextMap.
+///
+/// This should be called once per SessionContext during plan creation
+/// and passed to expressions that need query context for error reporting.
+pub fn create_query_context_map() -> Arc {
+ Arc::new(QueryContextMap::new())
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_query_context_creation() {
+ let ctx = QueryContext::new(
+ "SELECT a/b FROM t".to_string(),
+ 7,
+ 9,
+ Some("Divide".to_string()),
+ Some("a/b".to_string()),
+ 1,
+ 7,
+ );
+
+ assert_eq!(*ctx.sql_text, "SELECT a/b FROM t");
+ assert_eq!(ctx.start_index, 7);
+ assert_eq!(ctx.stop_index, 9);
+ assert_eq!(ctx.object_type, Some("Divide".to_string()));
+ assert_eq!(ctx.object_name, Some("a/b".to_string()));
+ assert_eq!(ctx.line, 1);
+ assert_eq!(ctx.start_position, 7);
+ }
+
+ #[test]
+ fn test_query_context_serialization() {
+ let ctx = QueryContext::new(
+ "SELECT a/b FROM t".to_string(),
+ 7,
+ 9,
+ Some("Divide".to_string()),
+ Some("a/b".to_string()),
+ 1,
+ 7,
+ );
+
+ let json = serde_json::to_string(&ctx).unwrap();
+ let deserialized: QueryContext = serde_json::from_str(&json).unwrap();
+
+ assert_eq!(ctx, deserialized);
+ }
+
+ #[test]
+ fn test_format_summary() {
+ let ctx = QueryContext::new(
+ "SELECT a/b FROM t".to_string(),
+ 7,
+ 9,
+ Some("VIEW".to_string()),
+ Some("v1".to_string()),
+ 1,
+ 7,
+ );
+
+ let summary = ctx.format_summary();
+
+ assert!(summary.contains("== SQL of VIEW v1 (line 1, position 8) =="));
+ assert!(summary.contains("SELECT a/b FROM t"));
+ assert!(summary.contains("^^^")); // Three carets for "a/b"
+ }
+
+ #[test]
+ fn test_format_summary_without_object() {
+ let ctx = QueryContext::new("SELECT a/b FROM t".to_string(), 7, 9, None, None, 1, 7);
+
+ let summary = ctx.format_summary();
+
+ assert!(summary.contains("== SQL (line 1, position 8) =="));
+ assert!(summary.contains("SELECT a/b FROM t"));
+ }
+
+ #[test]
+ fn test_fragment() {
+ let ctx = QueryContext::new("SELECT a/b FROM t".to_string(), 7, 9, None, None, 1, 7);
+
+ assert_eq!(ctx.fragment(), "a/b");
+ }
+
+ #[test]
+ fn test_arc_string_sharing() {
+ let ctx1 = QueryContext::new("SELECT a/b FROM t".to_string(), 7, 9, None, None, 1, 7);
+
+ let ctx2 = ctx1.clone();
+
+ // Arc should share the same allocation
+ assert!(Arc::ptr_eq(&ctx1.sql_text, &ctx2.sql_text));
+ }
+
+ #[test]
+ fn test_json_with_optional_fields() {
+ let ctx = QueryContext::new("SELECT a/b FROM t".to_string(), 7, 9, None, None, 1, 7);
+
+ let json = serde_json::to_string(&ctx).unwrap();
+
+ // Should not serialize objectType and objectName when None
+ assert!(!json.contains("objectType"));
+ assert!(!json.contains("objectName"));
+ }
+
+ #[test]
+ fn test_map_register_and_get() {
+ let map = QueryContextMap::new();
+
+ let ctx = QueryContext::new("SELECT a/b FROM t".to_string(), 7, 9, None, None, 1, 7);
+
+ map.register(1, ctx.clone());
+
+ let retrieved = map.get(1).unwrap();
+ assert_eq!(*retrieved.sql_text, "SELECT a/b FROM t");
+ assert_eq!(retrieved.start_index, 7);
+ }
+
+ #[test]
+ fn test_map_get_nonexistent() {
+ let map = QueryContextMap::new();
+ assert!(map.get(999).is_none());
+ }
+
+ #[test]
+ fn test_map_clear() {
+ let map = QueryContextMap::new();
+
+ let ctx = QueryContext::new("SELECT a/b FROM t".to_string(), 7, 9, None, None, 1, 7);
+
+ map.register(1, ctx);
+ assert_eq!(map.len(), 1);
+
+ map.clear();
+ assert_eq!(map.len(), 0);
+ assert!(map.is_empty());
+ }
+
+ // Verify that fragment() and format_summary() correctly handle SQL text that
+ // contains multi-byte characters
+
+ #[test]
+ fn test_fragment_non_ascii_accented() {
+ // "é" is a 2-byte UTF-8 sequence (U+00E9).
+ // SQL: "SELECT café FROM t"
+ // 0123456789...
+ // char indices: c=7, a=8, f=9, é=10, ' '=11 ... FROM = 12..
+ // start_index=7, stop_index=10 should yield "café"
+ let sql = "SELECT café FROM t".to_string();
+ let ctx = QueryContext::new(sql, 7, 10, None, None, 1, 7);
+ assert_eq!(ctx.fragment(), "café");
+ }
+}
diff --git a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala
index f17d8f4f72..28c1645718 100644
--- a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala
+++ b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala
@@ -36,6 +36,7 @@ import org.apache.spark.util.SerializableConfiguration
import org.apache.comet.CometConf._
import org.apache.comet.Tracing.withTrace
+import org.apache.comet.exceptions.CometQueryExecutionException
import org.apache.comet.parquet.CometFileKeyUnwrapper
import org.apache.comet.serde.Config.ConfigMap
import org.apache.comet.vector.NativeUtil
@@ -151,6 +152,11 @@ class CometExecIterator(
})
})
} catch {
+ // Handle CometQueryExecutionException with JSON payload first
+ case e: CometQueryExecutionException =>
+ logError(s"Native execution for task $taskAttemptId failed", e)
+ throw SparkErrorConverter.convertToSparkException(e)
+
case e: CometNativeException =>
// it is generally considered bad practice to log and then rethrow an
// exception, but it really helps debugging to be able to see which task
diff --git a/spark/src/main/scala/org/apache/comet/SparkErrorConverter.scala b/spark/src/main/scala/org/apache/comet/SparkErrorConverter.scala
new file mode 100644
index 0000000000..a8dea4cf46
--- /dev/null
+++ b/spark/src/main/scala/org/apache/comet/SparkErrorConverter.scala
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet
+
+import org.json4s._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.{QueryContext, SparkException}
+import org.apache.spark.sql.catalyst.trees.SQLQueryContext
+import org.apache.spark.sql.comet.shims.ShimSparkErrorConverter
+
+import com.fasterxml.jackson.core.JsonParseException
+
+import org.apache.comet.exceptions.CometQueryExecutionException
+
+/**
+ * Converts CometQueryExecutionException from native code (with JSON payload) to appropriate Spark
+ * QueryExecutionErrors.* exceptions
+ *
+ * Parses the JSON-encoded error information from native execution and delegates to the
+ * version-specific ShimSparkErrorConverter trait for conversion to proper Spark exception types.
+ *
+ * The ShimSparkErrorConverter handles all error cases using the correct QueryExecutionErrors API
+ * for each Spark version.
+ */
+object SparkErrorConverter extends ShimSparkErrorConverter {
+
+ implicit val formats: DefaultFormats.type = DefaultFormats
+
+ case class QueryContextJson(
+ sqlText: String,
+ startIndex: Int,
+ stopIndex: Int,
+ objectType: Option[String],
+ objectName: Option[String],
+ line: Int,
+ startPosition: Int)
+
+ case class ErrorJson(
+ errorType: String,
+ errorClass: Option[String],
+ params: Option[Map[String, Any]],
+ context: Option[QueryContextJson],
+ summary: Option[String])
+
+ /**
+ * Parse JSON from exception and convert to appropriate Spark exception.
+ *
+ * @param e
+ * the CometQueryExecutionException with JSON message
+ * @return
+ * the corresponding Spark exception, or the original exception if parsing fails
+ */
+ def convertToSparkException(e: CometQueryExecutionException): Throwable = {
+ try {
+ if (!e.isJsonMessage()) {
+ // Not JSON, return original exception
+ return e
+ }
+ } catch {
+ // Only catch JSON parsing/mapping exceptions - let conversion exceptions propagate
+ case _: MappingException | _: JsonParseException =>
+ return e
+ }
+
+ val json = parse(e.getMessage)
+ val errorJson = json.extract[ErrorJson]
+ val params = errorJson.params.getOrElse(Map.empty)
+ val errorClass = errorJson.errorClass.getOrElse("UNKNOWN_ERROR_TEMP_COMET")
+
+ // Build Spark SQLQueryContext if context is present (Not all errors carry the query context)
+ val sparkContext: Array[QueryContext] = errorJson.context match {
+ case Some(ctx) =>
+ Array(
+ SQLQueryContext(
+ sqlText = Some(ctx.sqlText),
+ line = Some(ctx.line),
+ startPosition = Some(ctx.startPosition),
+ originStartIndex = Some(ctx.startIndex),
+ originStopIndex = Some(ctx.stopIndex),
+ originObjectType = ctx.objectType,
+ originObjectName = ctx.objectName))
+ case None => Array.empty[QueryContext] // No context
+ }
+
+ val summary: String = errorJson.summary.orNull
+
+ // Delegate to version-specific shim - let conversion exceptions propagate
+ val optEx = convertErrorType(errorJson.errorType, errorClass, params, sparkContext, summary)
+ optEx match {
+ case Some(exception) =>
+ // successfully converted - return the proper typed exception
+ exception
+
+ case None =>
+ // Unknown error type - fallback to generic SparkException
+ new SparkException(
+ errorClass = errorClass,
+ messageParameters = paramsToStringMap(params),
+ cause = null)
+ }
+ }
+
+ /**
+ * Convert parameter map to string-keyed map for SparkException.
+ */
+ private[comet] def paramsToStringMap(params: Map[String, Any]): Map[String, String] = {
+ params.map { case (k, v) => (k, v.toString) }
+ }
+}
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 9d13ccd9ed..e9fb200144 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -19,6 +19,8 @@
package org.apache.comet.serde
+import java.util.concurrent.atomic.AtomicLong
+
import scala.jdk.CollectionConverters._
import org.apache.spark.internal.Logging
@@ -267,6 +269,68 @@ object QueryPlanSerde extends Logging with CometExprShim {
classOf[VariancePop] -> CometVariancePop,
classOf[VarianceSamp] -> CometVarianceSamp)
+ // A unique id for each expression. ~used to look up QueryContext during error creation.
+ private val exprIdCounter = new AtomicLong(0)
+
+ private def nextExprId(): Long = exprIdCounter.incrementAndGet()
+
+ /**
+ * Extract SQL context information (query text, line/position, object name) from the
+ * expression's origin.
+ *
+ * @param expr
+ * The Spark expression to extract context from
+ * @return
+ * Some(QueryContext) if origin is present, None otherwise
+ */
+ private def extractQueryContext(expr: Expression): Option[ExprOuterClass.QueryContext] = {
+ val contexts = expr.origin.getQueryContext
+ if (contexts != null && contexts.length > 0) {
+ try {
+ val ctx = contexts(0)
+ // Check if this is a SQLQueryContext with additional fields
+ ctx match {
+ case sqlCtx: org.apache.spark.sql.catalyst.trees.SQLQueryContext =>
+ val builder = ExprOuterClass.QueryContext
+ .newBuilder()
+ .setSqlText(sqlCtx.sqlText.getOrElse(""))
+ .setStartIndex(sqlCtx.originStartIndex.getOrElse(ctx.startIndex))
+ .setStopIndex(sqlCtx.originStopIndex.getOrElse(ctx.stopIndex))
+ .setLine(sqlCtx.line.getOrElse(0))
+ .setStartPosition(sqlCtx.startPosition.getOrElse(0))
+
+ // Add optional fields if present
+ sqlCtx.originObjectType.foreach(builder.setObjectType)
+ sqlCtx.originObjectName.foreach(builder.setObjectName)
+
+ Some(builder.build())
+ case _ =>
+ // Fallback: use only QueryContext interface methods
+ val builder = ExprOuterClass.QueryContext
+ .newBuilder()
+ .setSqlText(ctx.fragment())
+ .setStartIndex(ctx.startIndex())
+ .setStopIndex(ctx.stopIndex())
+ .setLine(0)
+ .setStartPosition(0)
+
+ if (ctx.objectType() != null && !ctx.objectType().isEmpty) {
+ builder.setObjectType(ctx.objectType())
+ }
+ if (ctx.objectName() != null && !ctx.objectName().isEmpty) {
+ builder.setObjectName(ctx.objectName())
+ }
+
+ Some(builder.build())
+ }
+ } catch {
+ case _: Exception => None
+ }
+ } else {
+ None
+ }
+ }
+
def supportedDataType(dt: DataType, allowComplex: Boolean = false): Boolean = dt match {
case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType |
_: DoubleType | _: StringType | _: BinaryType | _: TimestampType | _: TimestampNTZType |
@@ -403,7 +467,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
val fn = aggExpr.aggregateFunction
val cometExpr = aggrSerdeMap.get(fn.getClass)
- cometExpr match {
+ val protoAggExprOpt = cometExpr match {
case Some(handler) =>
val aggHandler = handler.asInstanceOf[CometAggregateExpressionSerde[AggregateFunction]]
val exprConfName = aggHandler.getExprConfigName(fn)
@@ -450,6 +514,16 @@ object QueryPlanSerde extends Logging with CometExprShim {
fn.children: _*)
None
}
+
+ // Attach QueryContext and expr_id to the aggregate expression
+ protoAggExprOpt.map { protoAggExpr =>
+ val builder = protoAggExpr.toBuilder
+ builder.setExprId(nextExprId())
+ extractQueryContext(fn).foreach { ctx =>
+ builder.setQueryContext(ctx)
+ }
+ builder.build()
+ }
}
def evalModeToProto(evalMode: CometEvalMode.Value): ExprOuterClass.EvalMode = {
@@ -550,22 +624,32 @@ object QueryPlanSerde extends Logging with CometExprShim {
}
}
- versionSpecificExprToProtoInternal(expr, inputs, binding).orElse(expr match {
+ versionSpecificExprToProtoInternal(expr, inputs, binding)
+ .orElse(expr match {
- case UnaryExpression(child) if expr.prettyName == "promote_precision" =>
- // `UnaryExpression` includes `PromotePrecision` for Spark 3.3
- // `PromotePrecision` is just a wrapper, don't need to serialize it.
- exprToProtoInternal(child, inputs, binding)
+ case UnaryExpression(child) if expr.prettyName == "promote_precision" =>
+ // `UnaryExpression` includes `PromotePrecision` for Spark 3.3
+ // `PromotePrecision` is just a wrapper, don't need to serialize it.
+ exprToProtoInternal(child, inputs, binding)
- case expr =>
- QueryPlanSerde.exprSerdeMap.get(expr.getClass) match {
- case Some(handler) =>
- convert(expr, handler.asInstanceOf[CometExpressionSerde[Expression]])
- case _ =>
- withInfo(expr, s"${expr.prettyName} is not supported", expr.children: _*)
- None
+ case expr =>
+ QueryPlanSerde.exprSerdeMap.get(expr.getClass) match {
+ case Some(handler) =>
+ convert(expr, handler.asInstanceOf[CometExpressionSerde[Expression]])
+ case _ =>
+ withInfo(expr, s"${expr.prettyName} is not supported", expr.children: _*)
+ None
+ }
+ })
+ .map { protoExpr =>
+ // Attach QueryContext and expr_id to the expression
+ val builder = protoExpr.toBuilder
+ builder.setExprId(nextExprId())
+ extractQueryContext(expr).foreach { ctx =>
+ builder.setQueryContext(ctx)
}
- })
+ builder.build()
+ }
}
/**
diff --git a/spark/src/main/spark-3.4/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala b/spark/src/main/spark-3.4/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala
new file mode 100644
index 0000000000..83f3a7f12a
--- /dev/null
+++ b/spark/src/main/spark-3.4/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala
@@ -0,0 +1,271 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.comet.shims
+
+import org.apache.spark.{QueryContext, SparkException}
+import org.apache.spark.sql.catalyst.trees.SQLQueryContext
+import org.apache.spark.sql.errors.QueryExecutionErrors
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+
+/**
+ * Spark 3.4 implementation for converting error types to proper Spark exceptions.
+ *
+ * Handles all error cases using the Spark 3.4 QueryExecutionErrors API. Differs from the 3.5
+ * version in 4 places where the API changed between 3.4 and 3.5.
+ */
+trait ShimSparkErrorConverter {
+
+ private def sqlCtx(context: Array[QueryContext]): SQLQueryContext =
+ context.headOption.map(_.asInstanceOf[SQLQueryContext]).getOrElse(null)
+
+ def convertErrorType(
+ errorType: String,
+ errorClass: String,
+ params: Map[String, Any],
+ context: Array[QueryContext],
+ summary: String): Option[Throwable] = {
+ val _ = (errorClass, summary)
+
+ errorType match {
+
+ case "DivideByZero" =>
+ Some(QueryExecutionErrors.divideByZeroError(sqlCtx(context)))
+
+ case "RemainderByZero" =>
+ Some(
+ new SparkException(
+ errorClass = "REMAINDER_BY_ZERO",
+ messageParameters = params.map { case (k, v) => (k, v.toString) },
+ cause = null))
+
+ case "IntervalDividedByZero" =>
+ Some(QueryExecutionErrors.intervalDividedByZeroError(sqlCtx(context)))
+
+ case "BinaryArithmeticOverflow" =>
+ // Spark 3.x does not take functionName parameter
+ Some(
+ QueryExecutionErrors.binaryArithmeticCauseOverflowError(
+ params("value1").toString.toShort,
+ params("symbol").toString,
+ params("value2").toString.toShort))
+
+ case "ArithmeticOverflow" =>
+ val fromType = params("fromType").toString
+ Some(QueryExecutionErrors.arithmeticOverflowError(fromType + " overflow", ""))
+
+ case "IntegralDivideOverflow" =>
+ Some(QueryExecutionErrors.overflowInIntegralDivideError(sqlCtx(context)))
+
+ case "DecimalSumOverflow" =>
+ // Spark 3.x takes SQLQueryContext, not QueryContext
+ Some(QueryExecutionErrors.overflowInSumOfDecimalError(sqlCtx(context)))
+
+ case "NumericValueOutOfRange" =>
+ val decimal = Decimal(params("value").toString)
+ Some(
+ QueryExecutionErrors.cannotChangeDecimalPrecisionError(
+ decimal,
+ params("precision").toString.toInt,
+ params("scale").toString.toInt,
+ sqlCtx(context)))
+
+ case "DatetimeOverflow" =>
+ Some(
+ new SparkException(
+ errorClass = "DATETIME_OVERFLOW",
+ messageParameters = params.map { case (k, v) => (k, v.toString) },
+ cause = null))
+
+ case "InvalidArrayIndex" =>
+ Some(
+ QueryExecutionErrors.invalidArrayIndexError(
+ params("indexValue").toString.toInt,
+ params("arraySize").toString.toInt,
+ sqlCtx(context)))
+
+ case "InvalidElementAtIndex" =>
+ Some(
+ QueryExecutionErrors.invalidElementAtIndexError(
+ params("indexValue").toString.toInt,
+ params("arraySize").toString.toInt,
+ sqlCtx(context)))
+
+ case "InvalidIndexOfZero" =>
+ Some(QueryExecutionErrors.invalidIndexOfZeroError(sqlCtx(context)))
+
+ case "InvalidBitmapPosition" =>
+ // invalidBitmapPositionError does not exist in Spark 3.4
+ Some(
+ new SparkException(
+ errorClass = "INVALID_BITMAP_POSITION",
+ messageParameters = params.map { case (k, v) => (k, v.toString) },
+ cause = null))
+
+ case "DuplicatedMapKey" =>
+ Some(QueryExecutionErrors.duplicateMapKeyFoundError(params("key")))
+
+ case "NullMapKey" =>
+ Some(QueryExecutionErrors.nullAsMapKeyNotAllowedError())
+
+ case "MapKeyValueDiffSizes" =>
+ Some(QueryExecutionErrors.mapDataKeyArrayLengthDiffersFromValueArrayLengthError())
+
+ case "ExceedMapSizeLimit" =>
+ Some(QueryExecutionErrors.exceedMapSizeLimitError(params("size").toString.toInt))
+
+ case "CollectionSizeLimitExceeded" =>
+ // createArrayWithElementsExceedLimitError takes (count: Any) in Spark 3.4
+ Some(
+ QueryExecutionErrors.createArrayWithElementsExceedLimitError(
+ params("numElements").toString.toLong))
+
+ case "NotNullAssertViolation" =>
+ Some(
+ QueryExecutionErrors.foundNullValueForNotNullableFieldError(
+ params("fieldName").toString))
+
+ case "ValueIsNull" =>
+ Some(
+ QueryExecutionErrors.fieldCannotBeNullError(
+ params.getOrElse("rowIndex", 0).toString.toInt,
+ params("fieldName").toString))
+
+ case "CannotParseTimestamp" =>
+ Some(
+ QueryExecutionErrors.ansiDateTimeParseError(new Exception(params("message").toString)))
+
+ case "InvalidFractionOfSecond" =>
+ Some(QueryExecutionErrors.invalidFractionOfSecondError())
+
+ case "CastInvalidValue" =>
+ val str = UTF8String.fromString(params("value").toString)
+ val targetType = getDataType(params("toType").toString)
+ Some(
+ QueryExecutionErrors
+ .invalidInputInCastToNumberError(targetType, str, sqlCtx(context)))
+
+ case "CastOverFlow" =>
+ val fromType = getDataType(params("fromType").toString)
+ val toType = getDataType(params("toType").toString)
+ val valueStr = params("value").toString
+
+ val typedValue: Any = fromType match {
+ case _: DecimalType =>
+ val cleanStr = if (valueStr.endsWith("BD")) valueStr.dropRight(2) else valueStr
+ Decimal(cleanStr)
+ case ByteType =>
+ val cleanStr = if (valueStr.endsWith("T")) valueStr.dropRight(1) else valueStr
+ cleanStr.toByte
+ case ShortType =>
+ val cleanStr = if (valueStr.endsWith("S")) valueStr.dropRight(1) else valueStr
+ cleanStr.toShort
+ case IntegerType => valueStr.toInt
+ case LongType =>
+ val cleanStr = if (valueStr.endsWith("L")) valueStr.dropRight(1) else valueStr
+ cleanStr.toLong
+ case FloatType => valueStr.toFloat
+ case DoubleType => valueStr.toDouble
+ case StringType => UTF8String.fromString(valueStr)
+ case _ => valueStr
+ }
+
+ Some(QueryExecutionErrors.castingCauseOverflowError(typedValue, fromType, toType))
+
+ case "CannotParseDecimal" =>
+ Some(QueryExecutionErrors.cannotParseDecimalError())
+
+ case "InvalidUtf8String" =>
+ // invalidUTF8StringError does not exist in Spark 3.x; use generic fallback
+ Some(
+ new SparkException(
+ errorClass = "INVALID_UTF8_STRING",
+ messageParameters = params.map { case (k, v) => (k, v.toString) },
+ cause = null))
+
+ case "UnexpectedPositiveValue" =>
+ Some(
+ QueryExecutionErrors.unexpectedValueForStartInFunctionError(
+ params("parameterName").toString))
+
+ case "UnexpectedNegativeValue" =>
+ Some(
+ QueryExecutionErrors.unexpectedValueForLengthInFunctionError(
+ params("parameterName").toString))
+
+ case "InvalidRegexGroupIndex" =>
+ // invalidRegexGroupIndexError does not exist in Spark 3.4
+ Some(
+ new SparkException(
+ errorClass = "INVALID_REGEX_GROUP_INDEX",
+ messageParameters = params.map { case (k, v) => (k, v.toString) },
+ cause = null))
+
+ case "DatatypeCannotOrder" =>
+ // orderedOperationUnsupportedByDataTypeError takes DataType in Spark 3.4, not String
+ Some(
+ QueryExecutionErrors.orderedOperationUnsupportedByDataTypeError(
+ getDataType(params("dataType").toString)))
+
+ case "ScalarSubqueryTooManyRows" =>
+ // multipleRowScalarSubqueryError was renamed to multipleRowSubqueryError in Spark 3.x
+ Some(QueryExecutionErrors.multipleRowSubqueryError(sqlCtx(context)))
+
+ case "IntervalArithmeticOverflowWithSuggestion" =>
+ // Spark 3.x uses a single intervalArithmeticOverflowError method
+ Some(
+ QueryExecutionErrors.intervalArithmeticOverflowError(
+ "Interval arithmetic overflow",
+ params.get("functionName").map(_.toString).getOrElse(""),
+ sqlCtx(context)))
+
+ case "IntervalArithmeticOverflowWithoutSuggestion" =>
+ Some(
+ QueryExecutionErrors
+ .intervalArithmeticOverflowError("Interval arithmetic overflow", "", sqlCtx(context)))
+
+ case _ =>
+ None
+ }
+ }
+
+ private def getDataType(typeName: String): DataType = {
+ typeName.toUpperCase match {
+ case "BYTE" | "TINYINT" => ByteType
+ case "SHORT" | "SMALLINT" => ShortType
+ case "INT" | "INTEGER" => IntegerType
+ case "LONG" | "BIGINT" => LongType
+ case "FLOAT" | "REAL" => FloatType
+ case "DOUBLE" => DoubleType
+ case "DECIMAL" => DecimalType.SYSTEM_DEFAULT
+ case "STRING" | "VARCHAR" => StringType
+ case "BINARY" => BinaryType
+ case "BOOLEAN" => BooleanType
+ case "DATE" => DateType
+ case "TIMESTAMP" => TimestampType
+ case _ =>
+ try {
+ DataType.fromDDL(typeName)
+ } catch {
+ case _: Exception => StringType
+ }
+ }
+ }
+}
diff --git a/spark/src/main/spark-3.5/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala b/spark/src/main/spark-3.5/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala
new file mode 100644
index 0000000000..44c34c1185
--- /dev/null
+++ b/spark/src/main/spark-3.5/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala
@@ -0,0 +1,267 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.comet.shims
+
+import org.apache.spark.{QueryContext, SparkException}
+import org.apache.spark.sql.catalyst.trees.SQLQueryContext
+import org.apache.spark.sql.errors.QueryExecutionErrors
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+
+/**
+ * Spark 3.5 implementation for converting error types to proper Spark exceptions.
+ *
+ * Handles all error cases using the Spark 3.5 QueryExecutionErrors API. The 4 cases with API
+ * differences from Spark 4.0 are handled with Spark 3.x-specific calls.
+ */
+trait ShimSparkErrorConverter {
+
+ private def sqlCtx(context: Array[QueryContext]): SQLQueryContext =
+ context.headOption.map(_.asInstanceOf[SQLQueryContext]).getOrElse(null)
+
+ def convertErrorType(
+ errorType: String,
+ errorClass: String,
+ params: Map[String, Any],
+ context: Array[QueryContext],
+ summary: String): Option[Throwable] = {
+ val _ = (errorClass, summary)
+
+ errorType match {
+
+ case "DivideByZero" =>
+ Some(QueryExecutionErrors.divideByZeroError(sqlCtx(context)))
+
+ case "RemainderByZero" =>
+ Some(
+ new SparkException(
+ errorClass = "REMAINDER_BY_ZERO",
+ messageParameters = params.map { case (k, v) => (k, v.toString) },
+ cause = null))
+
+ case "IntervalDividedByZero" =>
+ Some(QueryExecutionErrors.intervalDividedByZeroError(sqlCtx(context)))
+
+ case "BinaryArithmeticOverflow" =>
+ // Spark 3.x does not take functionName parameter
+ Some(
+ QueryExecutionErrors.binaryArithmeticCauseOverflowError(
+ params("value1").toString.toShort,
+ params("symbol").toString,
+ params("value2").toString.toShort))
+
+ case "ArithmeticOverflow" =>
+ val fromType = params("fromType").toString
+ Some(QueryExecutionErrors.arithmeticOverflowError(fromType + " overflow", ""))
+
+ case "IntegralDivideOverflow" =>
+ Some(QueryExecutionErrors.overflowInIntegralDivideError(sqlCtx(context)))
+
+ case "DecimalSumOverflow" =>
+ // Spark 3.x takes SQLQueryContext, not QueryContext
+ Some(QueryExecutionErrors.overflowInSumOfDecimalError(sqlCtx(context)))
+
+ case "NumericValueOutOfRange" =>
+ val decimal = Decimal(params("value").toString)
+ Some(
+ QueryExecutionErrors.cannotChangeDecimalPrecisionError(
+ decimal,
+ params("precision").toString.toInt,
+ params("scale").toString.toInt,
+ sqlCtx(context)))
+
+ case "DatetimeOverflow" =>
+ Some(
+ new SparkException(
+ errorClass = "DATETIME_OVERFLOW",
+ messageParameters = params.map { case (k, v) => (k, v.toString) },
+ cause = null))
+
+ case "InvalidArrayIndex" =>
+ Some(
+ QueryExecutionErrors.invalidArrayIndexError(
+ params("indexValue").toString.toInt,
+ params("arraySize").toString.toInt,
+ sqlCtx(context)))
+
+ case "InvalidElementAtIndex" =>
+ Some(
+ QueryExecutionErrors.invalidElementAtIndexError(
+ params("indexValue").toString.toInt,
+ params("arraySize").toString.toInt,
+ sqlCtx(context)))
+
+ case "InvalidIndexOfZero" =>
+ Some(QueryExecutionErrors.invalidIndexOfZeroError(sqlCtx(context)))
+
+ case "InvalidBitmapPosition" =>
+ Some(
+ QueryExecutionErrors.invalidBitmapPositionError(
+ params("bitPosition").toString.toLong,
+ params("bitmapNumBytes").toString.toLong))
+
+ case "DuplicatedMapKey" =>
+ Some(QueryExecutionErrors.duplicateMapKeyFoundError(params("key")))
+
+ case "NullMapKey" =>
+ Some(QueryExecutionErrors.nullAsMapKeyNotAllowedError())
+
+ case "MapKeyValueDiffSizes" =>
+ Some(QueryExecutionErrors.mapDataKeyArrayLengthDiffersFromValueArrayLengthError())
+
+ case "ExceedMapSizeLimit" =>
+ Some(QueryExecutionErrors.exceedMapSizeLimitError(params("size").toString.toInt))
+
+ case "CollectionSizeLimitExceeded" =>
+ Some(
+ QueryExecutionErrors.createArrayWithElementsExceedLimitError(
+ "array",
+ params("numElements").toString.toLong))
+
+ case "NotNullAssertViolation" =>
+ Some(
+ QueryExecutionErrors.foundNullValueForNotNullableFieldError(
+ params("fieldName").toString))
+
+ case "ValueIsNull" =>
+ Some(
+ QueryExecutionErrors.fieldCannotBeNullError(
+ params.getOrElse("rowIndex", 0).toString.toInt,
+ params("fieldName").toString))
+
+ case "CannotParseTimestamp" =>
+ Some(
+ QueryExecutionErrors.ansiDateTimeParseError(new Exception(params("message").toString)))
+
+ case "InvalidFractionOfSecond" =>
+ Some(QueryExecutionErrors.invalidFractionOfSecondError())
+
+ case "CastInvalidValue" =>
+ val str = UTF8String.fromString(params("value").toString)
+ val targetType = getDataType(params("toType").toString)
+ Some(
+ QueryExecutionErrors
+ .invalidInputInCastToNumberError(targetType, str, sqlCtx(context)))
+
+ case "CastOverFlow" =>
+ val fromType = getDataType(params("fromType").toString)
+ val toType = getDataType(params("toType").toString)
+ val valueStr = params("value").toString
+
+ val typedValue: Any = fromType match {
+ case _: DecimalType =>
+ val cleanStr = if (valueStr.endsWith("BD")) valueStr.dropRight(2) else valueStr
+ Decimal(cleanStr)
+ case ByteType =>
+ val cleanStr = if (valueStr.endsWith("T")) valueStr.dropRight(1) else valueStr
+ cleanStr.toByte
+ case ShortType =>
+ val cleanStr = if (valueStr.endsWith("S")) valueStr.dropRight(1) else valueStr
+ cleanStr.toShort
+ case IntegerType => valueStr.toInt
+ case LongType =>
+ val cleanStr = if (valueStr.endsWith("L")) valueStr.dropRight(1) else valueStr
+ cleanStr.toLong
+ case FloatType => valueStr.toFloat
+ case DoubleType => valueStr.toDouble
+ case StringType => UTF8String.fromString(valueStr)
+ case _ => valueStr
+ }
+
+ Some(QueryExecutionErrors.castingCauseOverflowError(typedValue, fromType, toType))
+
+ case "CannotParseDecimal" =>
+ Some(QueryExecutionErrors.cannotParseDecimalError())
+
+ case "InvalidUtf8String" =>
+ // invalidUTF8StringError does not exist in Spark 3.x; use generic fallback
+ Some(
+ new SparkException(
+ errorClass = "INVALID_UTF8_STRING",
+ messageParameters = params.map { case (k, v) => (k, v.toString) },
+ cause = null))
+
+ case "UnexpectedPositiveValue" =>
+ Some(
+ QueryExecutionErrors.unexpectedValueForStartInFunctionError(
+ params("parameterName").toString))
+
+ case "UnexpectedNegativeValue" =>
+ Some(
+ QueryExecutionErrors.unexpectedValueForLengthInFunctionError(
+ params("parameterName").toString))
+
+ case "InvalidRegexGroupIndex" =>
+ Some(
+ QueryExecutionErrors.invalidRegexGroupIndexError(
+ params("functionName").toString,
+ params("groupCount").toString.toInt,
+ params("groupIndex").toString.toInt))
+
+ case "DatatypeCannotOrder" =>
+ Some(
+ QueryExecutionErrors.orderedOperationUnsupportedByDataTypeError(
+ params("dataType").toString))
+
+ case "ScalarSubqueryTooManyRows" =>
+ // multipleRowScalarSubqueryError was renamed to multipleRowSubqueryError in Spark 3.x
+ Some(QueryExecutionErrors.multipleRowSubqueryError(sqlCtx(context)))
+
+ case "IntervalArithmeticOverflowWithSuggestion" =>
+ // Spark 3.x uses a single intervalArithmeticOverflowError method
+ Some(
+ QueryExecutionErrors.intervalArithmeticOverflowError(
+ "Interval arithmetic overflow",
+ params.get("functionName").map(_.toString).getOrElse(""),
+ sqlCtx(context)))
+
+ case "IntervalArithmeticOverflowWithoutSuggestion" =>
+ Some(
+ QueryExecutionErrors
+ .intervalArithmeticOverflowError("Interval arithmetic overflow", "", sqlCtx(context)))
+
+ case _ =>
+ None
+ }
+ }
+
+ private def getDataType(typeName: String): DataType = {
+ typeName.toUpperCase match {
+ case "BYTE" | "TINYINT" => ByteType
+ case "SHORT" | "SMALLINT" => ShortType
+ case "INT" | "INTEGER" => IntegerType
+ case "LONG" | "BIGINT" => LongType
+ case "FLOAT" | "REAL" => FloatType
+ case "DOUBLE" => DoubleType
+ case "DECIMAL" => DecimalType.SYSTEM_DEFAULT
+ case "STRING" | "VARCHAR" => StringType
+ case "BINARY" => BinaryType
+ case "BOOLEAN" => BooleanType
+ case "DATE" => DateType
+ case "TIMESTAMP" => TimestampType
+ case _ =>
+ try {
+ DataType.fromDDL(typeName)
+ } catch {
+ case _: Exception => StringType
+ }
+ }
+ }
+}
diff --git a/spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala b/spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala
new file mode 100644
index 0000000000..b19c7688be
--- /dev/null
+++ b/spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala
@@ -0,0 +1,280 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.comet.shims
+
+import org.apache.spark.QueryContext
+import org.apache.spark.SparkException
+import org.apache.spark.sql.errors.QueryExecutionErrors
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+
+/**
+ * Spark 4.0-specific implementation for converting error types to proper Spark exceptions.
+ */
+trait ShimSparkErrorConverter {
+
+ /**
+ * Convert error type string and parameters to appropriate Spark exception. Version-specific
+ * implementations call the correct QueryExecutionErrors.* methods.
+ *
+ * @param errorType
+ * The error type from JSON (e.g., "DivideByZero")
+ * @param errorClass
+ * The Spark error class (e.g., "DIVIDE_BY_ZERO")
+ * @param params
+ * Error parameters from JSON
+ * @param context
+ * QueryContext array with SQL text and position information
+ * @param summary
+ * Formatted summary string showing error location
+ * @return
+ * Throwable (specific exception type from QueryExecutionErrors), or None if unknown
+ */
+ def convertErrorType(
+ errorType: String,
+ _errorClass: String,
+ params: Map[String, Any],
+ context: Array[QueryContext],
+ _summary: String): Option[Throwable] = {
+
+ errorType match {
+
+ case "DivideByZero" =>
+ Some(QueryExecutionErrors.divideByZeroError(context.headOption.orNull))
+
+ case "RemainderByZero" =>
+ // SPARK 4.0 REMOVED remainderByZeroError so we use generic arithmetic exception
+ Some(
+ new SparkException(
+ errorClass = "REMAINDER_BY_ZERO",
+ messageParameters = params.map { case (k, v) => (k, v.toString) },
+ cause = null))
+
+ case "IntervalDividedByZero" =>
+ Some(QueryExecutionErrors.intervalDividedByZeroError(context.headOption.orNull))
+
+ case "BinaryArithmeticOverflow" =>
+ Some(
+ QueryExecutionErrors.binaryArithmeticCauseOverflowError(
+ params("value1").toString.toShort,
+ params("symbol").toString,
+ params("value2").toString.toShort,
+ params("functionName").toString))
+
+ case "ArithmeticOverflow" =>
+ val fromType = params("fromType").toString
+ Some(QueryExecutionErrors.arithmeticOverflowError(fromType + " overflow", ""))
+
+ case "IntegralDivideOverflow" =>
+ Some(QueryExecutionErrors.overflowInIntegralDivideError(context.headOption.orNull))
+
+ case "DecimalSumOverflow" =>
+ Some(QueryExecutionErrors.overflowInSumOfDecimalError(context.headOption.orNull, ""))
+
+ case "NumericValueOutOfRange" =>
+ val decimal = Decimal(params("value").toString)
+ Some(
+ QueryExecutionErrors.cannotChangeDecimalPrecisionError(
+ decimal,
+ params("precision").toString.toInt,
+ params("scale").toString.toInt,
+ context.headOption.orNull))
+
+ case "DatetimeOverflow" =>
+ // Spark 4.0 doesn't have datetimeOverflowError
+ Some(
+ new SparkException(
+ errorClass = "DATETIME_OVERFLOW",
+ messageParameters = params.map { case (k, v) => (k, v.toString) },
+ cause = null))
+
+ case "InvalidArrayIndex" =>
+ Some(
+ QueryExecutionErrors.invalidArrayIndexError(
+ params("indexValue").toString.toInt,
+ params("arraySize").toString.toInt,
+ context.headOption.orNull))
+
+ case "InvalidElementAtIndex" =>
+ Some(
+ QueryExecutionErrors.invalidElementAtIndexError(
+ params("indexValue").toString.toInt,
+ params("arraySize").toString.toInt,
+ context.headOption.orNull))
+
+ case "InvalidIndexOfZero" =>
+ Some(QueryExecutionErrors.invalidIndexOfZeroError(context.headOption.orNull))
+
+ case "InvalidBitmapPosition" =>
+ Some(
+ QueryExecutionErrors.invalidBitmapPositionError(
+ params("bitPosition").toString.toLong,
+ params("bitmapNumBytes").toString.toLong))
+
+ case "DuplicatedMapKey" =>
+ Some(QueryExecutionErrors.duplicateMapKeyFoundError(params("key")))
+
+ case "NullMapKey" =>
+ Some(QueryExecutionErrors.nullAsMapKeyNotAllowedError())
+
+ case "MapKeyValueDiffSizes" =>
+ Some(QueryExecutionErrors.mapDataKeyArrayLengthDiffersFromValueArrayLengthError())
+
+ case "ExceedMapSizeLimit" =>
+ Some(QueryExecutionErrors.exceedMapSizeLimitError(params("size").toString.toInt))
+
+ case "CollectionSizeLimitExceeded" =>
+ Some(
+ QueryExecutionErrors.createArrayWithElementsExceedLimitError(
+ "array",
+ params("numElements").toString.toLong))
+
+ case "NotNullAssertViolation" =>
+ Some(
+ QueryExecutionErrors.foundNullValueForNotNullableFieldError(
+ params("fieldName").toString))
+
+ case "ValueIsNull" =>
+ Some(
+ QueryExecutionErrors.fieldCannotBeNullError(
+ params.getOrElse("rowIndex", 0).toString.toInt,
+ params("fieldName").toString))
+
+ case "CannotParseTimestamp" =>
+ Some(
+ QueryExecutionErrors.ansiDateTimeParseError(
+ new Exception(params("message").toString),
+ params("suggestedFunc").toString))
+
+ case "InvalidFractionOfSecond" =>
+ Some(QueryExecutionErrors.invalidFractionOfSecondError(params("value").toString.toDouble))
+
+ case "CastInvalidValue" =>
+ val str = UTF8String.fromString(params("value").toString)
+ val targetType = getDataType(params("toType").toString)
+ Some(
+ QueryExecutionErrors
+ .invalidInputInCastToNumberError(targetType, str, context.headOption.orNull))
+
+ case "CastOverFlow" =>
+ val fromType = getDataType(params("fromType").toString)
+ val toType = getDataType(params("toType").toString)
+ val valueStr = params("value").toString
+
+ // Convert string value to appropriate type for toSQLValue
+ val typedValue: Any = fromType match {
+ case _: DecimalType =>
+ // Parse decimal string (may have "BD" suffix from BigDecimal.toString)
+ val cleanStr = if (valueStr.endsWith("BD")) valueStr.dropRight(2) else valueStr
+ Decimal(cleanStr)
+ case ByteType =>
+ // Strip "T" suffix for TINYINT literals
+ val cleanStr = if (valueStr.endsWith("T")) valueStr.dropRight(1) else valueStr
+ cleanStr.toByte
+ case ShortType =>
+ // Strip "S" suffix for SMALLINT literals
+ val cleanStr = if (valueStr.endsWith("S")) valueStr.dropRight(1) else valueStr
+ cleanStr.toShort
+ case IntegerType => valueStr.toInt
+ case LongType =>
+ // Strip "L" suffix for BIGINT literals
+ val cleanStr = if (valueStr.endsWith("L")) valueStr.dropRight(1) else valueStr
+ cleanStr.toLong
+ case FloatType => valueStr.toFloat
+ case DoubleType => valueStr.toDouble
+ case StringType => UTF8String.fromString(valueStr)
+ case _ => valueStr // Fallback to string
+ }
+
+ Some(QueryExecutionErrors.castingCauseOverflowError(typedValue, fromType, toType))
+
+ case "CannotParseDecimal" =>
+ Some(QueryExecutionErrors.cannotParseDecimalError())
+
+ case "InvalidUtf8String" =>
+ val hexStr = UTF8String.fromString(params("hexString").toString)
+ Some(QueryExecutionErrors.invalidUTF8StringError(hexStr))
+
+ case "UnexpectedPositiveValue" =>
+ Some(
+ QueryExecutionErrors.unexpectedValueForStartInFunctionError(
+ params("parameterName").toString))
+
+ case "UnexpectedNegativeValue" =>
+ Some(
+ QueryExecutionErrors.unexpectedValueForLengthInFunctionError(
+ params("parameterName").toString,
+ params("actualValue").toString.toInt))
+
+ case "InvalidRegexGroupIndex" =>
+ Some(
+ QueryExecutionErrors.invalidRegexGroupIndexError(
+ params("functionName").toString,
+ params("groupCount").toString.toInt,
+ params("groupIndex").toString.toInt))
+
+ case "DatatypeCannotOrder" =>
+ Some(
+ QueryExecutionErrors.orderedOperationUnsupportedByDataTypeError(
+ params("dataType").toString))
+
+ case "ScalarSubqueryTooManyRows" =>
+ Some(QueryExecutionErrors.multipleRowScalarSubqueryError(context.headOption.orNull))
+
+ case "IntervalArithmeticOverflowWithSuggestion" =>
+ Some(
+ QueryExecutionErrors.withSuggestionIntervalArithmeticOverflowError(
+ params.get("functionName").map(_.toString).getOrElse(""),
+ context.headOption.orNull))
+
+ case "IntervalArithmeticOverflowWithoutSuggestion" =>
+ Some(
+ QueryExecutionErrors.withoutSuggestionIntervalArithmeticOverflowError(
+ context.headOption.orNull))
+
+ case _ =>
+ // Unknown error type - return None to trigger fallback
+ None
+ }
+ }
+
+ private def getDataType(typeName: String): DataType = {
+ typeName.toUpperCase match {
+ case "BYTE" | "TINYINT" => ByteType
+ case "SHORT" | "SMALLINT" => ShortType
+ case "INT" | "INTEGER" => IntegerType
+ case "LONG" | "BIGINT" => LongType
+ case "FLOAT" | "REAL" => FloatType
+ case "DOUBLE" => DoubleType
+ case "DECIMAL" => DecimalType.SYSTEM_DEFAULT
+ case "STRING" | "VARCHAR" => StringType
+ case "BINARY" => BinaryType
+ case "BOOLEAN" => BooleanType
+ case "DATE" => DateType
+ case "TIMESTAMP" => TimestampType
+ case _ =>
+ try {
+ DataType.fromDDL(typeName)
+ } catch {
+ case _: Exception => StringType
+ }
+ }
+ }
+}
diff --git a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala
index b22d0f72db..21b4276a64 100644
--- a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala
@@ -922,4 +922,92 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp
}
}
}
+
+ // https://github.com/apache/datafusion-comet/issues/3375
+ test("(ansi) array access out of bounds - GetArrayItem") {
+ withSQLConf(
+ SQLConf.ANSI_ENABLED.key -> "true",
+ CometConf.COMET_ENABLED.key -> "true",
+ CometConf.COMET_EXEC_ENABLED.key -> "true") {
+ withTable("test_array_get_item") {
+ sql("CREATE TABLE test_array_get_item(arr ARRAY) USING parquet")
+ sql("INSERT INTO test_array_get_item VALUES (array(1, 2, 3))")
+ // Try to access array with out-of-bounds index
+ val exception = intercept[Exception] {
+ sql("select arr[5] from test_array_get_item").collect()
+ }
+ val errorMessage = exception.getMessage
+ // Verify error message contains the expected error code
+ assert(
+ errorMessage.contains("INVALID_ARRAY_INDEX"),
+ s"Error message should contain array index error: $errorMessage")
+
+ assert(errorMessage.contains("The index 5 is out of bounds. The array has 3 elements." +
+ " Use the SQL function `get()` to tolerate accessing element at invalid index and return NULL instead."))
+
+ assert(
+ errorMessage.contains("select arr[5] from test_array_get_item"),
+ s"Error message should contain SQL query text but got: $errorMessage")
+ }
+ }
+ }
+
+ // https://github.com/apache/datafusion-comet/issues/3375
+ test("(ansi) array access out of bounds - element_at with invalid index") {
+ withSQLConf(
+ SQLConf.ANSI_ENABLED.key -> "true",
+ CometConf.COMET_ENABLED.key -> "true",
+ CometConf.COMET_EXEC_ENABLED.key -> "true") {
+ withTable("test_element_at_invalid") {
+ sql("CREATE TABLE test_element_at_invalid(arr ARRAY) USING parquet")
+ sql("INSERT INTO test_element_at_invalid VALUES (array(1, 2, 3))")
+ // Try to access array with out-of-bounds index using element_at
+ val exception = intercept[Exception] {
+ sql("select element_at(arr, 10) from test_element_at_invalid").collect()
+ }
+ val errorMessage = exception.getMessage
+ // Verify error message contains the expected error code
+ assert(
+ errorMessage.contains("INVALID_ARRAY_INDEX_IN_ELEMENT_AT"),
+ s"Error message should contain array index error: $errorMessage")
+
+ assert(errorMessage.contains("The index 10 is out of bounds. The array has 3 elements." +
+ " Use `try_element_at` to tolerate accessing element at invalid index and return NULL instead"))
+
+ assert(
+ errorMessage.contains("select element_at(arr, 10) from test_element_at_invalid"),
+ s"Error message should contain SQL query text but got: $errorMessage")
+ }
+ }
+ }
+
+ // https://github.com/apache/datafusion-comet/issues/3375
+ test("(ansi) array access with zero index - element_at") {
+ withSQLConf(
+ SQLConf.ANSI_ENABLED.key -> "true",
+ CometConf.COMET_ENABLED.key -> "true",
+ CometConf.COMET_EXEC_ENABLED.key -> "true") {
+ withTable("test_element_at_zero") {
+ sql("CREATE TABLE test_element_at_zero(arr ARRAY) USING parquet")
+ sql("INSERT INTO test_element_at_zero VALUES (array(1, 2, 3))")
+ // Try to access array with zero index (invalid in Spark)
+ val exception = intercept[Exception] {
+ sql("select element_at(arr, 0) from test_element_at_zero").collect()
+ }
+ val errorMessage = exception.getMessage
+ // Verify error message contains the expected error code
+ assert(
+ errorMessage.contains("INVALID_INDEX_OF_ZERO"),
+ s"Error message should contain zero index error: $errorMessage")
+
+ assert(
+ errorMessage.contains("The index 0 is invalid. An index shall be either < 0 or > 0" +
+ " (the first element has index 1)"))
+
+ assert(
+ errorMessage.contains("select element_at(arr, 0) from test_element_at_zero"),
+ s"Error message should contain SQL query text but got: $errorMessage")
+ }
+ }
+ }
}
diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
index 72c2390d71..8bfae8a7d3 100644
--- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
@@ -913,6 +913,33 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
castTest((validDates ++ invalidDates ++ fuzzDates).toDF("a"), DataTypes.DateType)
}
+ // https://github.com/apache/datafusion-comet/issues/2215
+ test("(ansi) cast error message should include SQL query") {
+ withSQLConf(
+ SQLConf.ANSI_ENABLED.key -> "true",
+ CometConf.COMET_ENABLED.key -> "true",
+ CometConf.COMET_EXEC_ENABLED.key -> "true") {
+ withTable("cast_error_msg") {
+ // Create a table with string data using DataFrame API
+ Seq("a").toDF("s").write.format("parquet").saveAsTable("cast_error_msg")
+ // Try to cast invalid string to date - should throw exception with SQL context
+ val exception = intercept[Exception] {
+ sql("select cast(s as date) from cast_error_msg").collect()
+ }
+ val errorMessage = exception.getMessage
+ // Verify error message contains the cast invalid input error
+ assert(
+ errorMessage.contains("CAST_INVALID_INPUT") ||
+ errorMessage.contains("cannot be cast to"),
+ s"Error message should contain cast error: $errorMessage")
+
+ assert(
+ errorMessage.contains("select cast(s as date) from cast_error_msg"),
+ s"Error message should contain SQL query text but got: $errorMessage")
+ }
+ }
+ }
+
test("cast StringType to TimestampType disabled by default") {
withSQLConf((SQLConf.SESSION_LOCAL_TIMEZONE.key, "UTC")) {
val values = Seq("2020-01-01T12:34:56.123456", "T2").toDF("a")
@@ -1526,20 +1553,11 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
if (CometSparkSessionExtensions.isSpark40Plus) {
// for Spark 4 we expect to sparkException carries the message
assert(sparkMessage.contains("SQLSTATE"))
- if (sparkMessage.startsWith("[NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION]")) {
- assert(
- sparkMessage.replace(".WITH_SUGGESTION] ", "]").startsWith(cometMessage))
- } else if (cometMessage.startsWith("[CAST_INVALID_INPUT]") || cometMessage
- .startsWith("[CAST_OVERFLOW]")) {
- assert(
- sparkMessage.startsWith(
- cometMessage
- .replace(
- "If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.",
- "")))
- } else {
- assert(sparkMessage.startsWith(cometMessage))
- }
+ // we compare a subset of the error message. Comet grabs the query
+ // context eagerly so it displays the call site at the
+ // line of code where the cast method was called, whereas spark grabs the context
+ // lazily and displays the call site at the line of code where the error is checked.
+ assert(sparkMessage.startsWith(cometMessage.substring(0, 40)))
} else {
// for Spark 3.4 we expect to reproduce the error message exactly
assert(cometMessage == sparkMessage)
diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
index f0f022868f..339061f5be 100644
--- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
@@ -54,7 +54,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}
val ARITHMETIC_OVERFLOW_EXCEPTION_MSG =
- """org.apache.comet.CometNativeException: [ARITHMETIC_OVERFLOW] integer overflow. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error"""
+ """[ARITHMETIC_OVERFLOW] integer overflow. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error"""
val DIVIDE_BY_ZERO_EXCEPTION_MSG =
"""Division by zero. Use `try_divide` to tolerate divisor being 0 and return NULL instead"""