diff --git a/common/src/main/java/org/apache/comet/exceptions/CometQueryExecutionException.java b/common/src/main/java/org/apache/comet/exceptions/CometQueryExecutionException.java new file mode 100644 index 0000000000..5ff19ea398 --- /dev/null +++ b/common/src/main/java/org/apache/comet/exceptions/CometQueryExecutionException.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.exceptions; + +import org.apache.comet.CometNativeException; + +/** + * Exception thrown from Comet native execution containing JSON-encoded error information. The + * message contains a JSON object with the following structure: + * + *
+ * {
+ *   "errorType": "DivideByZero",
+ *   "errorClass": "DIVIDE_BY_ZERO",
+ *   "params": { ... },
+ *   "context": { "sqlText": "...", "startOffset": 0, "stopOffset": 10 },
+ *   "hint": "Use `try_divide` to tolerate divisor being 0"
+ * }
+ * 
+ * + * CometExecIterator parses this JSON and converts it to the appropriate Spark exception by calling + * the corresponding QueryExecutionErrors.* method. + */ +public final class CometQueryExecutionException extends CometNativeException { + + /** + * Creates a new CometQueryExecutionException with a JSON-encoded error message. + * + * @param jsonMessage JSON string containing error information + */ + public CometQueryExecutionException(String jsonMessage) { + super(jsonMessage); + } + + /** + * Returns true if the message appears to be JSON-formatted. This is used to distinguish between + * JSON-encoded errors and legacy error messages. + * + * @return true if message starts with '{' and ends with '}' + */ + public boolean isJsonMessage() { + String msg = getMessage(); + return msg != null && msg.trim().startsWith("{") && msg.trim().endsWith("}"); + } +} diff --git a/native/Cargo.lock b/native/Cargo.lock index 3cd6ea29b7..450089b04f 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -1948,6 +1948,7 @@ dependencies = [ "num", "rand 0.10.0", "regex", + "serde", "serde_json", "thiserror 2.0.18", "tokio", diff --git a/native/core/src/errors.rs b/native/core/src/errors.rs index ecac7af94e..1329614cd7 100644 --- a/native/core/src/errors.rs +++ b/native/core/src/errors.rs @@ -39,7 +39,7 @@ use jni::sys::{jboolean, jbyte, jchar, jdouble, jfloat, jint, jlong, jobject, js use crate::execution::operators::ExecutionError; use datafusion_comet_spark_expr::SparkError; -use jni::objects::{GlobalRef, JThrowable, JValue}; +use jni::objects::{GlobalRef, JThrowable}; use jni::JNIEnv; use lazy_static::lazy_static; use parquet::errors::ParquetError; @@ -223,9 +223,9 @@ impl jni::errors::ToException for CometError { class: "java/lang/NullPointerException".to_string(), msg: self.to_string(), }, - CometError::Spark { .. } => Exception { - class: "org/apache/spark/SparkException".to_string(), - msg: self.to_string(), + CometError::Spark(spark_err) => Exception { + class: spark_err.exception_class().to_string(), + msg: spark_err.to_string(), }, CometError::NumberIntFormat { source: s } => Exception { class: "java/lang/NumberFormatException".to_string(), @@ -392,33 +392,37 @@ fn throw_exception(env: &mut JNIEnv, error: &CometError, backtrace: Option env.throw(<&JThrowable>::from(throwable.as_obj())), + // Handle DataFusion errors containing SparkError - serialize to JSON CometError::DataFusion { msg: _, source: DataFusionError::External(e), - } if matches!(e.downcast_ref(), Some(SparkError::CastOverFlow { .. })) => { - match e.downcast_ref() { - Some(SparkError::CastOverFlow { - value, - from_type, - to_type, - }) => { - let throwable: JThrowable = env - .new_object( - "org/apache/spark/sql/comet/CastOverflowException", - "(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;)V", - &[ - JValue::Object(&env.new_string(value).unwrap()), - JValue::Object(&env.new_string(from_type).unwrap()), - JValue::Object(&env.new_string(to_type).unwrap()), - ], - ) - .unwrap() - .into(); - env.throw(throwable) + } => { + // Try SparkErrorWithContext first (includes context) + if let Some(spark_error_with_ctx) = + e.downcast_ref::() + { + let json_message = spark_error_with_ctx.to_json(); + env.throw_new( + "org/apache/comet/exceptions/CometQueryExecutionException", + json_message, + ) + } else if let Some(spark_error) = e.downcast_ref::() { + // Fall back to plain SparkError (no context) + throw_spark_error_as_json(env, spark_error) + } else { + // Not a SparkError, use generic exception + let exception = error.to_exception(); + match backtrace { + Some(backtrace_string) => env.throw_new( + exception.class, + to_stacktrace_string(exception.msg, backtrace_string).unwrap(), + ), + _ => env.throw_new(exception.class, exception.msg), } - _ => unreachable!(), } } + // Handle direct SparkError - serialize to JSON + CometError::Spark(spark_error) => throw_spark_error_as_json(env, spark_error), _ => { let exception = error.to_exception(); match backtrace { @@ -434,6 +438,21 @@ fn throw_exception(env: &mut JNIEnv, error: &CometError, backtrace: Option jni::errors::Result<()> { + // Serialize error to JSON + let json_message = spark_error.to_json(); + + // Throw CometQueryExecutionException with JSON message + env.throw_new( + "org/apache/comet/exceptions/CometQueryExecutionException", + json_message, + ) +} + #[derive(Debug, Error)] enum StacktraceError { #[error("Unable to initialize message: {0}")] diff --git a/native/core/src/execution/expressions/arithmetic.rs b/native/core/src/execution/expressions/arithmetic.rs index a9749678db..320532d773 100644 --- a/native/core/src/execution/expressions/arithmetic.rs +++ b/native/core/src/execution/expressions/arithmetic.rs @@ -17,6 +17,122 @@ //! Arithmetic expression builders +use std::any::Any; +use std::fmt::{Display, Formatter}; +use std::hash::{Hash, Hasher}; + +use arrow::datatypes::{DataType, Schema}; +use arrow::record_batch::RecordBatch; +use datafusion::common::DataFusionError; +use datafusion::logical_expr::ColumnarValue; +use datafusion::physical_expr::PhysicalExpr; +use datafusion_comet_spark_expr::{QueryContext, SparkError, SparkErrorWithContext}; + +/// Wrapper expression that catches and wraps SparkError with QueryContext +/// for binary arithmetic operations. +#[derive(Debug)] +pub struct CheckedBinaryExpr { + /// The underlying physical expression (typically a ScalarFunctionExpr) + child: Arc, + /// Optional query context to attach to errors + query_context: Option>, +} + +impl CheckedBinaryExpr { + pub fn new(child: Arc, query_context: Option>) -> Self { + Self { + child, + query_context, + } + } +} + +impl Display for CheckedBinaryExpr { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "CheckedBinaryExpr({})", self.child) + } +} + +impl PartialEq for CheckedBinaryExpr { + fn eq(&self, other: &Self) -> bool { + self.child.eq(&other.child) + } +} + +impl Eq for CheckedBinaryExpr {} + +impl PartialEq for CheckedBinaryExpr { + fn eq(&self, other: &dyn Any) -> bool { + other + .downcast_ref::() + .map(|x| self.eq(x)) + .unwrap_or(false) + } +} + +impl Hash for CheckedBinaryExpr { + fn hash(&self, state: &mut H) { + self.child.hash(state); + } +} + +impl PhysicalExpr for CheckedBinaryExpr { + fn as_any(&self) -> &dyn Any { + self + } + + fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + self.child.fmt_sql(f) + } + + fn data_type(&self, input_schema: &Schema) -> datafusion::common::Result { + self.child.data_type(input_schema) + } + + fn nullable(&self, input_schema: &Schema) -> datafusion::common::Result { + self.child.nullable(input_schema) + } + + fn evaluate(&self, batch: &RecordBatch) -> datafusion::common::Result { + let result = self.child.evaluate(batch); + + // If there's an error and we have query_context, wrap it + match result { + Err(DataFusionError::External(e)) if self.query_context.is_some() => { + if let Some(spark_err) = e.downcast_ref::() { + let wrapped = SparkErrorWithContext::with_context( + spark_err.clone(), + Arc::clone(self.query_context.as_ref().unwrap()), + ); + Err(DataFusionError::External(Box::new(wrapped))) + } else { + Err(DataFusionError::External(e)) + } + } + other => other, + } + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.child] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> datafusion::common::Result> { + match children.len() { + 1 => Ok(Arc::new(CheckedBinaryExpr::new( + Arc::clone(&children[0]), + self.query_context.clone(), + ))), + _ => Err(DataFusionError::Internal( + "CheckedBinaryExpr should have exactly one child".to_string(), + )), + } + } +} + /// Macro to generate arithmetic expression builders that need eval_mode handling #[macro_export] macro_rules! arithmetic_expr_builder { @@ -37,6 +153,7 @@ macro_rules! arithmetic_expr_builder { let eval_mode = $crate::execution::planner::from_protobuf_eval_mode(expr.eval_mode)?; planner.create_binary_expr( + spark_expr, // Pass the full spark_expr for query_context lookup expr.left.as_ref().unwrap(), expr.right.as_ref().unwrap(), expr.return_type.as_ref(), @@ -53,7 +170,6 @@ use std::sync::Arc; use arrow::datatypes::SchemaRef; use datafusion::logical_expr::Operator as DataFusionOperator; -use datafusion::physical_expr::PhysicalExpr; use datafusion_comet_proto::spark_expression::Expr; use datafusion_comet_spark_expr::{create_modulo_expr, create_negate_expr, EvalMode}; @@ -95,6 +211,7 @@ impl ExpressionBuilder for IntegralDivideBuilder { let expr = extract_expr!(spark_expr, IntegralDivide); let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; planner.create_binary_expr_with_options( + spark_expr, // Pass the full spark_expr for query_context lookup expr.left.as_ref().unwrap(), expr.right.as_ref().unwrap(), expr.return_type.as_ref(), diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index fb3d35c512..a9d0d4370c 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -163,6 +163,7 @@ pub struct PhysicalPlanner { exec_context_id: i64, partition: i32, session_ctx: Arc, + query_context_registry: Arc, } impl Default for PhysicalPlanner { @@ -177,6 +178,7 @@ impl PhysicalPlanner { exec_context_id: TEST_EXEC_CONTEXT_ID, session_ctx, partition, + query_context_registry: datafusion_comet_spark_expr::create_query_context_map(), } } @@ -185,6 +187,7 @@ impl PhysicalPlanner { exec_context_id, partition: self.partition, session_ctx: Arc::clone(&self.session_ctx), + query_context_registry: Arc::clone(&self.query_context_registry), } } @@ -251,6 +254,26 @@ impl PhysicalPlanner { spark_expr: &Expr, input_schema: SchemaRef, ) -> Result, ExecutionError> { + // Register QueryContext if present + if let (Some(expr_id), Some(ctx_proto)) = + (spark_expr.expr_id, spark_expr.query_context.as_ref()) + { + // Deserialize QueryContext from protobuf + let query_ctx = datafusion_comet_spark_expr::QueryContext::new( + ctx_proto.sql_text.clone(), + ctx_proto.start_index, + ctx_proto.stop_index, + ctx_proto.object_type.clone(), + ctx_proto.object_name.clone(), + ctx_proto.line, + ctx_proto.start_position, + ); + + // Register query context for error reporting + let registry = &self.query_context_registry; + registry.register(expr_id, query_ctx); + } + // Try to use the modular registry first - this automatically handles any registered expression types if ExpressionRegistry::global().can_handle(spark_expr) { return ExpressionRegistry::global().create_expr(spark_expr, input_schema, self); @@ -369,10 +392,19 @@ impl PhysicalPlanner { let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?; let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; + + // Look up query context from registry if expr_id is present + let query_context = spark_expr.expr_id.and_then(|expr_id| { + let registry = &self.query_context_registry; + registry.get(expr_id) + }); + Ok(Arc::new(Cast::new( child, datatype, SparkCastOptions::new(eval_mode, &expr.timezone, expr.allow_incompat), + spark_expr.expr_id, + query_context, ))) } ExprStruct::CheckOverflow(expr) => { @@ -380,10 +412,18 @@ impl PhysicalPlanner { let data_type = to_arrow_datatype(expr.datatype.as_ref().unwrap()); let fail_on_error = expr.fail_on_error; + // Look up query context from registry if expr_id is present + let query_context = spark_expr.expr_id.and_then(|expr_id| { + let registry = &self.query_context_registry; + registry.get(expr_id) + }); + Ok(Arc::new(CheckOverflow::new( child, data_type, fail_on_error, + spark_expr.expr_id, + query_context, ))) } ExprStruct::ScalarFunc(expr) => { @@ -397,6 +437,8 @@ impl PhysicalPlanner { None, true, false, + None, // No expr_id for internal map_extract wrapper + Arc::clone(&self.query_context_registry), ))), // DataFusion 49 hardcodes return type for MD5 built in function as UTF8View // which is not yet supported in Comet @@ -405,6 +447,8 @@ impl PhysicalPlanner { func?, DataType::Utf8, SparkCastOptions::new_without_timezone(EvalMode::Try, true), + None, + None, ))), _ => func, } @@ -519,6 +563,8 @@ impl PhysicalPlanner { Arc::clone(&child), DataType::Utf8, spark_cast_options, + None, + None, )); Ok(Arc::new(IfExpr::new( Arc::new(IsNullExpr::new(child)), @@ -544,6 +590,8 @@ impl PhysicalPlanner { default_value, expr.one_based, expr.fail_on_error, + spark_expr.expr_id, + Arc::clone(&self.query_context_registry), ))) } ExprStruct::GetArrayStructFields(expr) => { @@ -634,8 +682,10 @@ impl PhysicalPlanner { } } + #[allow(clippy::too_many_arguments)] pub fn create_binary_expr( &self, + spark_expr: &Expr, left: &Expr, right: &Expr, return_type: Option<&spark_expression::DataType>, @@ -644,6 +694,7 @@ impl PhysicalPlanner { eval_mode: EvalMode, ) -> Result, ExecutionError> { self.create_binary_expr_with_options( + spark_expr, left, right, return_type, @@ -657,6 +708,7 @@ impl PhysicalPlanner { #[allow(clippy::too_many_arguments)] pub fn create_binary_expr_with_options( &self, + spark_expr: &Expr, left: &Expr, right: &Expr, return_type: Option<&spark_expression::DataType>, @@ -665,6 +717,12 @@ impl PhysicalPlanner { options: BinaryExprOptions, eval_mode: EvalMode, ) -> Result, ExecutionError> { + // Look up query context from registry if expr_id is present + let query_context = spark_expr.expr_id.and_then(|expr_id| { + let registry = &self.query_context_registry; + registry.get(expr_id) + }); + let left = self.create_expr(left, Arc::clone(&input_schema))?; let right = self.create_expr(right, Arc::clone(&input_schema))?; match ( @@ -688,17 +746,23 @@ impl PhysicalPlanner { left, DataType::Decimal256(p1, s1), SparkCastOptions::new_without_timezone(EvalMode::Legacy, false), + None, + None, )); let right = Arc::new(Cast::new( right, DataType::Decimal256(p2, s2), SparkCastOptions::new_without_timezone(EvalMode::Legacy, false), + None, + None, )); let child = Arc::new(BinaryExpr::new(left, op, right)); Ok(Arc::new(Cast::new( child, data_type, SparkCastOptions::new_without_timezone(EvalMode::Legacy, false), + None, + None, ))) } ( @@ -787,13 +851,17 @@ impl PhysicalPlanner { None, eval_mode, )?; - Ok(Arc::new(ScalarFunctionExpr::new( + let scalar_expr = Arc::new(ScalarFunctionExpr::new( op_str, fun_expr, vec![left, right], Arc::new(Field::new(op_str, data_type, true)), Arc::new(ConfigOptions::default()), - ))) + )); + + // Wrap with CheckedBinaryExpr to add query_context to errors + use crate::execution::expressions::arithmetic::CheckedBinaryExpr; + Ok(Arc::new(CheckedBinaryExpr::new(scalar_expr, query_context))) } else { Ok(Arc::new(BinaryExpr::new(left, op, right))) } @@ -1804,6 +1872,26 @@ impl PhysicalPlanner { spark_expr: &AggExpr, schema: SchemaRef, ) -> Result { + // Register QueryContext if present + if let (Some(expr_id), Some(ctx_proto)) = + (spark_expr.expr_id, spark_expr.query_context.as_ref()) + { + // Deserialize QueryContext from protobuf + let query_ctx = datafusion_comet_spark_expr::QueryContext::new( + ctx_proto.sql_text.clone(), + ctx_proto.start_index, + ctx_proto.stop_index, + ctx_proto.object_type.clone(), + ctx_proto.object_name.clone(), + ctx_proto.line, + ctx_proto.start_position, + ); + + // Register query context for error reporting + let registry = &self.query_context_registry; + registry.register(expr_id, query_ctx); + } + match spark_expr.expr_struct.as_ref().unwrap() { AggExprStruct::Count(expr) => { assert!(!expr.children.is_empty()); @@ -1854,8 +1942,12 @@ impl PhysicalPlanner { let builder = match datatype { DataType::Decimal128(_, _) => { let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; - let func = - AggregateUDF::new_from_impl(SumDecimal::try_new(datatype, eval_mode)?); + let func = AggregateUDF::new_from_impl(SumDecimal::try_new( + datatype, + eval_mode, + spark_expr.expr_id, + Arc::clone(&self.query_context_registry), + )?); AggregateExprBuilder::new(Arc::new(func), vec![child]) } DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { @@ -1887,8 +1979,14 @@ impl PhysicalPlanner { let builder = match datatype { DataType::Decimal128(_, _) => { - let func = - AggregateUDF::new_from_impl(AvgDecimal::new(datatype, input_datatype)); + let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; + let func = AggregateUDF::new_from_impl(AvgDecimal::new( + datatype, + input_datatype, + eval_mode, + spark_expr.expr_id, + Arc::clone(&self.query_context_registry), + )); AggregateExprBuilder::new(Arc::new(func), vec![child]) } _ => { @@ -3023,14 +3121,21 @@ fn create_case_expr( Arc::clone(&x.1), coerce_type.clone(), cast_options.clone(), + None, + None, )); (Arc::clone(&x.0), t) }) .collect::, Arc)>>(); let else_phy_expr: Option> = else_expr.clone().map(|x| { - Arc::new(Cast::new(x, coerce_type.clone(), cast_options.clone())) - as Arc + Arc::new(Cast::new( + x, + coerce_type.clone(), + cast_options.clone(), + None, + None, + )) as Arc }); Ok(Arc::new(CaseExpr::try_new( None, @@ -3717,9 +3822,13 @@ mod tests { type_info: None, }), })), + query_context: None, + expr_id: None, }; let right = spark_expression::Expr { expr_struct: Some(Literal(lit)), + query_context: None, + expr_id: None, }; let expr = spark_expression::Expr { @@ -3727,6 +3836,8 @@ mod tests { left: Some(Box::new(left)), right: Some(Box::new(right)), }))), + query_context: None, + expr_id: None, }; Operator { @@ -3782,6 +3893,8 @@ mod tests { index, datatype: Some(create_proto_datatype()), })), + query_context: None, + expr_id: None, } } @@ -3847,6 +3960,8 @@ mod tests { type_info: None, }), })), + query_context: None, + expr_id: None, }; let array_col_1 = spark_expression::Expr { @@ -3857,6 +3972,8 @@ mod tests { type_info: None, }), })), + query_context: None, + expr_id: None, }; let projection = Operator { @@ -3870,6 +3987,8 @@ mod tests { return_type: None, fail_on_error: false, })), + query_context: None, + expr_id: None, }], })), }; @@ -3925,6 +4044,140 @@ mod tests { }); } + #[test] + fn test_array_repeat() { + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); + let planner = PhysicalPlanner::new(Arc::from(session_ctx), 0); + + // Mock scan operator with 3 INT32 columns + let op_scan = Operator { + plan_id: 0, + children: vec![], + op_struct: Some(OpStruct::Scan(spark_operator::Scan { + fields: vec![ + spark_expression::DataType { + type_id: 3, // Int32 + type_info: None, + }, + spark_expression::DataType { + type_id: 3, // Int32 + type_info: None, + }, + spark_expression::DataType { + type_id: 3, // Int32 + type_info: None, + }, + ], + source: "".to_string(), + arrow_ffi_safe: false, + })), + }; + + // Mock expression to read a INT32 column with position 0 + let array_col = spark_expression::Expr { + expr_struct: Some(Bound(spark_expression::BoundReference { + index: 0, + datatype: Some(spark_expression::DataType { + type_id: 3, + type_info: None, + }), + })), + query_context: None, + expr_id: None, + }; + + // Mock expression to read a INT32 column with position 1 + let array_col_1 = spark_expression::Expr { + expr_struct: Some(Bound(spark_expression::BoundReference { + index: 1, + datatype: Some(spark_expression::DataType { + type_id: 3, + type_info: None, + }), + })), + query_context: None, + expr_id: None, + }; + + // Make a projection operator with array_repeat(array_col, array_col_1) + let projection = Operator { + children: vec![op_scan], + plan_id: 0, + op_struct: Some(OpStruct::Projection(spark_operator::Projection { + project_list: vec![spark_expression::Expr { + expr_struct: Some(ExprStruct::ScalarFunc(spark_expression::ScalarFunc { + func: "array_repeat".to_string(), + args: vec![array_col, array_col_1], + return_type: None, + fail_on_error: false, + })), + query_context: None, + expr_id: None, + }], + })), + }; + + // Create a physical plan + let (mut scans, datafusion_plan) = + planner.create_plan(&projection, &mut vec![], 1).unwrap(); + + // Start executing the plan in a separate thread + // The plan waits for incoming batches and emitting result as input comes + let mut stream = datafusion_plan.native_plan.execute(0, task_ctx).unwrap(); + + let runtime = tokio::runtime::Runtime::new().unwrap(); + // create async channel + let (tx, mut rx) = mpsc::channel(1); + + // Send data as input to the plan being executed in a separate thread + runtime.spawn(async move { + // create data batch + // 0, 1, 2 + // 3, 4, 5 + // 6, null, null + let a = Int32Array::from(vec![Some(0), Some(3), Some(6)]); + let b = Int32Array::from(vec![Some(1), Some(4), None]); + let c = Int32Array::from(vec![Some(2), Some(5), None]); + let input_batch1 = InputBatch::Batch(vec![Arc::new(a), Arc::new(b), Arc::new(c)], 3); + let input_batch2 = InputBatch::EOF; + + let batches = vec![input_batch1, input_batch2]; + + for batch in batches.into_iter() { + tx.send(batch).await.unwrap(); + } + }); + + // Wait for the plan to finish executing and assert the result + runtime.block_on(async move { + loop { + let batch = rx.recv().await.unwrap(); + scans[0].set_input_batch(batch); + match poll!(stream.next()) { + Poll::Ready(Some(batch)) => { + assert!(batch.is_ok(), "got error {}", batch.unwrap_err()); + let batch = batch.unwrap(); + let expected = [ + "+--------------+", + "| col_0 |", + "+--------------+", + "| [0] |", + "| [3, 3, 3, 3] |", + "| [] |", + "+--------------+", + ]; + assert_batches_eq!(expected, &[batch]); + } + Poll::Ready(None) => { + break; + } + _ => {} + } + } + }); + } + /// Executes a `test_data_query` SQL query /// and saves the result into a temp folder using parquet format /// Read the file back to the memory using a custom schema @@ -4313,6 +4566,8 @@ mod tests { type_info: None, }), })), + expr_id: None, + query_context: None, }; // Create bound reference for the INT8 column (index 1) @@ -4324,6 +4579,8 @@ mod tests { type_info: None, }), })), + expr_id: None, + query_context: None, }; // Create a Subtract expression: date_col - int8_col @@ -4339,6 +4596,8 @@ mod tests { }), eval_mode: 0, // Legacy mode }))), + expr_id: None, + query_context: None, }; // Create a projection operator with the subtract expression diff --git a/native/core/src/parquet/schema_adapter.rs b/native/core/src/parquet/schema_adapter.rs index 42f0e7fc61..0ad61df426 100644 --- a/native/core/src/parquet/schema_adapter.rs +++ b/native/core/src/parquet/schema_adapter.rs @@ -349,7 +349,13 @@ impl SparkPhysicalExprAdapter { cast_options.allow_cast_unsigned_ints = self.parquet_options.allow_cast_unsigned_ints; cast_options.is_adapting_schema = true; - let spark_cast = Arc::new(Cast::new(child, target_type.clone(), cast_options)); + let spark_cast = Arc::new(Cast::new( + child, + target_type.clone(), + cast_options, + None, + None, + )); return Ok(Transformed::yes(spark_cast as Arc)); } diff --git a/native/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/native/org/apache/spark/sql/errors/QueryExecutionErrors.scala new file mode 100644 index 0000000000..7981468394 --- /dev/null +++ b/native/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -0,0 +1,2718 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.errors + +import java.io.{FileNotFoundException, IOException} +import java.lang.reflect.InvocationTargetException +import java.net.{URISyntaxException, URL} +import java.time.DateTimeException +import java.util.concurrent.TimeoutException + +import com.fasterxml.jackson.core.{JsonParser, JsonToken} +import org.apache.hadoop.fs.{FileAlreadyExistsException, FileStatus, Path} +import org.apache.hadoop.fs.permission.FsPermission +import org.codehaus.commons.compiler.{CompileException, InternalCompilerException} + +import org.apache.spark._ +import org.apache.spark.launcher.SparkLauncher +import org.apache.spark.memory.SparkOutOfMemoryError +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.ScalaReflection.Schema +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.UnresolvedGenerator +import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogTable} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} +import org.apache.spark.sql.catalyst.parser.ParseException +import org.apache.spark.sql.catalyst.plans.JoinType +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.ValueInterval +import org.apache.spark.sql.catalyst.trees.{Origin, SQLQueryContext, TreeNode} +import org.apache.spark.sql.catalyst.util.{sideBySide, BadRecordException, DateTimeUtils, FailFastMode} +import org.apache.spark.sql.connector.catalog.{CatalogNotFoundException, Table, TableProvider} +import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ +import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.StaticSQLConf.GLOBAL_TEMP_DATABASE +import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.array.ByteArrayMethods +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.{CircularBuffer, Utils} + +/** + * Object for grouping error messages from (most) exceptions thrown during query execution. + * This does not include exceptions thrown during the eager execution of commands, which are + * grouped into [[QueryCompilationErrors]]. + */ +private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionErrors { + + def cannotEvaluateExpressionError(expression: Expression): Throwable = { + SparkException.internalError(s"Cannot evaluate expression: $expression") + } + + def cannotGenerateCodeForExpressionError(expression: Expression): Throwable = { + SparkException.internalError(s"Cannot generate code for expression: $expression") + } + + def cannotTerminateGeneratorError(generator: UnresolvedGenerator): Throwable = { + SparkException.internalError(s"Cannot terminate expression: $generator") + } + + def castingCauseOverflowError(t: Any, from: DataType, to: DataType): ArithmeticException = { + new SparkArithmeticException( + errorClass = "CAST_OVERFLOW", + messageParameters = Map( + "value" -> toSQLValue(t, from), + "sourceType" -> toSQLType(from), + "targetType" -> toSQLType(to), + "ansiConfig" -> toSQLConf(SQLConf.ANSI_ENABLED.key)), + context = Array.empty, + summary = "") + } + + def castingCauseOverflowErrorInTableInsert( + from: DataType, + to: DataType, + columnName: String): ArithmeticException = { + new SparkArithmeticException( + errorClass = "CAST_OVERFLOW_IN_TABLE_INSERT", + messageParameters = Map( + "sourceType" -> toSQLType(from), + "targetType" -> toSQLType(to), + "columnName" -> toSQLId(columnName)), + context = Array.empty, + summary = "" + ) + } + + def cannotChangeDecimalPrecisionError( + value: Decimal, + decimalPrecision: Int, + decimalScale: Int, + context: SQLQueryContext = null): ArithmeticException = { + new SparkArithmeticException( + errorClass = "NUMERIC_VALUE_OUT_OF_RANGE", + messageParameters = Map( + "value" -> value.toPlainString, + "precision" -> decimalPrecision.toString, + "scale" -> decimalScale.toString, + "config" -> toSQLConf(SQLConf.ANSI_ENABLED.key)), + context = getQueryContext(context), + summary = getSummary(context)) + } + + def invalidInputSyntaxForBooleanError( + s: UTF8String, + context: SQLQueryContext): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "CAST_INVALID_INPUT", + messageParameters = Map( + "expression" -> toSQLValue(s, StringType), + "sourceType" -> toSQLType(StringType), + "targetType" -> toSQLType(BooleanType), + "ansiConfig" -> toSQLConf(SQLConf.ANSI_ENABLED.key)), + context = getQueryContext(context), + summary = getSummary(context)) + } + + def invalidInputInCastToNumberError( + to: DataType, + s: UTF8String, + context: SQLQueryContext): SparkNumberFormatException = { + new SparkNumberFormatException( + errorClass = "CAST_INVALID_INPUT", + messageParameters = Map( + "expression" -> toSQLValue(s, StringType), + "sourceType" -> toSQLType(StringType), + "targetType" -> toSQLType(to), + "ansiConfig" -> toSQLConf(SQLConf.ANSI_ENABLED.key)), + context = getQueryContext(context), + summary = getSummary(context)) + } + + def invalidInputInConversionError( + to: DataType, + s: UTF8String, + fmt: UTF8String, + hint: String): SparkIllegalArgumentException = { + new SparkIllegalArgumentException( + errorClass = "CONVERSION_INVALID_INPUT", + messageParameters = Map( + "str" -> toSQLValue(s, StringType), + "fmt" -> toSQLValue(fmt, StringType), + "targetType" -> toSQLType(to), + "suggestion" -> toSQLId(hint))) + } + + def cannotCastFromNullTypeError(to: DataType): Throwable = { + new SparkException( + errorClass = "CANNOT_CAST_DATATYPE", + messageParameters = Map( + "sourceType" -> NullType.typeName, + "targetType" -> to.typeName), + cause = null) + } + + def cannotCastError(from: DataType, to: DataType): Throwable = { + new SparkException( + errorClass = "CANNOT_CAST_DATATYPE", + messageParameters = Map( + "sourceType" -> from.typeName, + "targetType" -> to.typeName), + cause = null) + } + + def cannotParseDecimalError(): Throwable = { + new SparkRuntimeException( + errorClass = "CANNOT_PARSE_DECIMAL", + messageParameters = Map.empty) + } + + def dataTypeUnsupportedError(dataType: String, failure: String): Throwable = { + DataTypeErrors.dataTypeUnsupportedError(dataType, failure) + } + + def failedExecuteUserDefinedFunctionError(functionName: String, inputTypes: String, + outputType: String, e: Throwable): Throwable = { + new SparkException( + errorClass = "FAILED_EXECUTE_UDF", + messageParameters = Map( + "functionName" -> toSQLId(functionName), + "signature" -> inputTypes, + "result" -> outputType), + cause = e) + } + + def divideByZeroError(context: SQLQueryContext): ArithmeticException = { + new SparkArithmeticException( + errorClass = "DIVIDE_BY_ZERO", + messageParameters = Map("config" -> toSQLConf(SQLConf.ANSI_ENABLED.key)), + context = getQueryContext(context), + summary = getSummary(context)) + } + + def intervalDividedByZeroError(context: SQLQueryContext): ArithmeticException = { + new SparkArithmeticException( + errorClass = "INTERVAL_DIVIDED_BY_ZERO", + messageParameters = Map.empty, + context = getQueryContext(context), + summary = getSummary(context)) + } + + def invalidArrayIndexError( + index: Int, + numElements: Int, + context: SQLQueryContext): ArrayIndexOutOfBoundsException = { + new SparkArrayIndexOutOfBoundsException( + errorClass = "INVALID_ARRAY_INDEX", + messageParameters = Map( + "indexValue" -> toSQLValue(index, IntegerType), + "arraySize" -> toSQLValue(numElements, IntegerType), + "ansiConfig" -> toSQLConf(SQLConf.ANSI_ENABLED.key)), + context = getQueryContext(context), + summary = getSummary(context)) + } + + def invalidElementAtIndexError( + index: Int, + numElements: Int, + context: SQLQueryContext): ArrayIndexOutOfBoundsException = { + new SparkArrayIndexOutOfBoundsException( + errorClass = "INVALID_ARRAY_INDEX_IN_ELEMENT_AT", + messageParameters = Map( + "indexValue" -> toSQLValue(index, IntegerType), + "arraySize" -> toSQLValue(numElements, IntegerType), + "ansiConfig" -> toSQLConf(SQLConf.ANSI_ENABLED.key)), + context = getQueryContext(context), + summary = getSummary(context)) + } + + def invalidBitmapPositionError(bitPosition: Long, + bitmapNumBytes: Long): ArrayIndexOutOfBoundsException = { + new SparkArrayIndexOutOfBoundsException( + errorClass = "INVALID_BITMAP_POSITION", + messageParameters = Map( + "bitPosition" -> s"$bitPosition", + "bitmapNumBytes" -> s"$bitmapNumBytes", + "bitmapNumBits" -> s"${bitmapNumBytes * 8}"), + context = Array.empty, + summary = "") + } + + def invalidFractionOfSecondError(): DateTimeException = { + new SparkDateTimeException( + errorClass = "INVALID_FRACTION_OF_SECOND", + messageParameters = Map( + "ansiConfig" -> toSQLConf(SQLConf.ANSI_ENABLED.key) + ), + context = Array.empty, + summary = "") + } + + def ansiDateTimeParseError(e: Exception): SparkDateTimeException = { + new SparkDateTimeException( + errorClass = "CANNOT_PARSE_TIMESTAMP", + messageParameters = Map( + "message" -> e.getMessage, + "ansiConfig" -> toSQLConf(SQLConf.ANSI_ENABLED.key)), + context = Array.empty, + summary = "") + } + + def ansiDateTimeError(e: Exception): SparkDateTimeException = { + new SparkDateTimeException( + errorClass = "_LEGACY_ERROR_TEMP_2000", + messageParameters = Map( + "message" -> e.getMessage, + "ansiConfig" -> toSQLConf(SQLConf.ANSI_ENABLED.key)), + context = Array.empty, + summary = "") + } + + def ansiIllegalArgumentError(message: String): SparkIllegalArgumentException = { + new SparkIllegalArgumentException( + errorClass = "_LEGACY_ERROR_TEMP_2000", + messageParameters = Map( + "message" -> message, + "ansiConfig" -> toSQLConf(SQLConf.ANSI_ENABLED.key))) + } + + def ansiIllegalArgumentError(e: IllegalArgumentException): IllegalArgumentException = { + ansiIllegalArgumentError(e.getMessage) + } + + def overflowInSumOfDecimalError(context: SQLQueryContext): ArithmeticException = { + arithmeticOverflowError("Overflow in sum of decimals", context = context) + } + + def overflowInIntegralDivideError(context: SQLQueryContext): ArithmeticException = { + arithmeticOverflowError("Overflow in integral divide", "try_divide", context) + } + + def overflowInConvError(context: SQLQueryContext): ArithmeticException = { + arithmeticOverflowError("Overflow in function conv()", context = context) + } + + def mapSizeExceedArraySizeWhenZipMapError(size: Int): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2003", + messageParameters = Map( + "size" -> size.toString(), + "maxRoundedArrayLength" -> ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString())) + } + + def literalTypeUnsupportedError(v: Any): RuntimeException = { + new SparkRuntimeException( + errorClass = "UNSUPPORTED_FEATURE.LITERAL_TYPE", + messageParameters = Map( + "value" -> v.toString, + "type" -> v.getClass.toString)) + } + + def pivotColumnUnsupportedError(v: Any, dataType: DataType): RuntimeException = { + new SparkRuntimeException( + errorClass = "UNSUPPORTED_FEATURE.PIVOT_TYPE", + messageParameters = Map( + "value" -> v.toString, + "type" -> toSQLType(dataType))) + } + + def noDefaultForDataTypeError(dataType: DataType): SparkException = { + SparkException.internalError(s"No default value for type: ${toSQLType(dataType)}.") + } + + def orderedOperationUnsupportedByDataTypeError( + dataType: DataType): SparkIllegalArgumentException = { + new SparkIllegalArgumentException( + errorClass = "_LEGACY_ERROR_TEMP_2005", + messageParameters = Map("dataType" -> dataType.toString())) + } + + def orderedOperationUnsupportedByDataTypeError( + dataType: String): SparkIllegalArgumentException = { + new SparkIllegalArgumentException( + errorClass = "_LEGACY_ERROR_TEMP_2005", + messageParameters = Map("dataType" -> dataType)) + } + + def invalidRegexGroupIndexError( + funcName: String, + groupCount: Int, + groupIndex: Int): RuntimeException = { + new SparkRuntimeException( + errorClass = "INVALID_PARAMETER_VALUE.REGEX_GROUP_INDEX", + messageParameters = Map( + "parameter" -> toSQLId("idx"), + "functionName" -> toSQLId(funcName), + "groupCount" -> groupCount.toString(), + "groupIndex" -> groupIndex.toString())) + } + + def invalidUrlError(url: UTF8String, e: URISyntaxException): SparkIllegalArgumentException = { + new SparkIllegalArgumentException( + errorClass = "INVALID_URL", + messageParameters = Map( + "url" -> url.toString, + "ansiConfig" -> toSQLConf(SQLConf.ANSI_ENABLED.key)), + cause = e) + } + + def illegalUrlError(url: UTF8String, e: IllegalArgumentException): Throwable = { + new SparkIllegalArgumentException( + errorClass = "CANNOT_DECODE_URL", + messageParameters = Map("url" -> url.toString), + cause = e + ) + } + + def mergeUnsupportedByWindowFunctionError(funcName: String): Throwable = { + SparkException.internalError( + s"The aggregate window function ${toSQLId(funcName)} does not support merging.") + } + + def dataTypeUnexpectedError(dataType: DataType): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2011", + messageParameters = Map("dataType" -> dataType.catalogString)) + } + + def typeUnsupportedError(dataType: DataType): SparkIllegalArgumentException = { + new SparkIllegalArgumentException( + errorClass = "_LEGACY_ERROR_TEMP_2011", + messageParameters = Map("dataType" -> dataType.toString())) + } + + def negativeValueUnexpectedError( + frequencyExpression : Expression): SparkIllegalArgumentException = { + new SparkIllegalArgumentException( + errorClass = "_LEGACY_ERROR_TEMP_2013", + messageParameters = Map("frequencyExpression" -> frequencyExpression.sql)) + } + + def addNewFunctionMismatchedWithFunctionError(funcName: String): Throwable = { + SparkException.internalError( + "Cannot add new function to generated class, " + + s"failed to match ${toSQLId(funcName)} at `addNewFunction`.") + } + + def cannotGenerateCodeForIncomparableTypeError( + codeType: String, dataType: DataType): SparkIllegalArgumentException = { + new SparkIllegalArgumentException( + errorClass = "_LEGACY_ERROR_TEMP_2015", + messageParameters = Map( + "codeType" -> codeType, + "dataType" -> dataType.catalogString)) + } + + def cannotInterpolateClassIntoCodeBlockError(arg: Any): SparkIllegalArgumentException = { + new SparkIllegalArgumentException( + errorClass = "_LEGACY_ERROR_TEMP_2016", + messageParameters = Map("arg" -> arg.getClass.getName)) + } + + def customCollectionClsNotResolvedError(): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2017", + messageParameters = Map.empty) + } + + def classUnsupportedByMapObjectsError(cls: Class[_]): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2018", + messageParameters = Map("cls" -> cls.getName)) + } + + def nullAsMapKeyNotAllowedError(): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "NULL_MAP_KEY", + messageParameters = Map.empty) + } + + def methodNotDeclaredError(name: String): Throwable = { + SparkException.internalError( + s"""A method named "$name" is not declared in any enclosing class nor any supertype""") + } + + def methodNotFoundError( + cls: Class[_], + functionName: String, + argClasses: Seq[Class[_]]): Throwable = { + SparkException.internalError( + s"Couldn't find method $functionName with arguments " + + s"${argClasses.mkString("(", ", ", ")")} on $cls.") + } + + def constructorNotFoundError(cls: String): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2020", + messageParameters = Map("cls" -> cls)) + } + + def unsupportedNaturalJoinTypeError(joinType: JoinType): SparkException = { + SparkException.internalError( + s"Unsupported natural join type ${joinType.toString}") + } + + def notExpectedUnresolvedEncoderError(attr: AttributeReference): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2023", + messageParameters = Map("attr" -> attr.toString())) + } + + def unsupportedEncoderError(): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2024", + messageParameters = Map.empty) + } + + def notOverrideExpectedMethodsError( + className: String, m1: String, m2: String): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2025", + messageParameters = Map("className" -> className, "m1" -> m1, "m2" -> m2)) + } + + def failToConvertValueToJsonError( + value: AnyRef, cls: Class[_], dataType: DataType): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2026", + messageParameters = Map( + "value" -> value.toString(), + "cls" -> cls.toString(), + "dataType" -> dataType.toString())) + } + + def unexpectedOperatorInCorrelatedSubquery( + op: LogicalPlan, pos: String = ""): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2027", + messageParameters = Map("op" -> op.toString(), "pos" -> pos)) + } + + def unsupportedRoundingMode(roundMode: BigDecimal.RoundingMode.Value): SparkException = { + DataTypeErrors.unsupportedRoundingMode(roundMode) + } + + def resolveCannotHandleNestedSchema(plan: LogicalPlan): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2030", + messageParameters = Map("plan" -> plan.toString())) + } + + def inputExternalRowCannotBeNullError(): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2031", + messageParameters = Map.empty) + } + + def fieldCannotBeNullMsg(index: Int, fieldName: String): String = { + s"The ${index}th field '$fieldName' of input row cannot be null." + } + + def fieldCannotBeNullError(index: Int, fieldName: String): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2032", + messageParameters = Map("fieldCannotBeNullMsg" -> fieldCannotBeNullMsg(index, fieldName))) + } + + def unableToCreateDatabaseAsFailedToCreateDirectoryError( + dbDefinition: CatalogDatabase, e: IOException): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2033", + messageParameters = Map( + "name" -> dbDefinition.name, + "locationUri" -> dbDefinition.locationUri.toString()), + cause = e) + } + + def unableToDropDatabaseAsFailedToDeleteDirectoryError( + dbDefinition: CatalogDatabase, e: IOException): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2034", + messageParameters = Map( + "name" -> dbDefinition.name, + "locationUri" -> dbDefinition.locationUri.toString()), + cause = e) + } + + def unableToCreateTableAsFailedToCreateDirectoryError( + table: String, defaultTableLocation: Path, e: IOException): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2035", + messageParameters = Map( + "table" -> table, + "defaultTableLocation" -> defaultTableLocation.toString()), + cause = e) + } + + def unableToDeletePartitionPathError(partitionPath: Path, e: IOException): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2036", + messageParameters = Map("partitionPath" -> partitionPath.toString()), + cause = e) + } + + def unableToDropTableAsFailedToDeleteDirectoryError( + table: String, dir: Path, e: IOException): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2037", + messageParameters = Map("table" -> table, "dir" -> dir.toString()), + cause = e) + } + + def unableToRenameTableAsFailedToRenameDirectoryError( + oldName: String, newName: String, oldDir: Path, e: IOException): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2038", + messageParameters = Map( + "oldName" -> oldName, + "newName" -> newName, + "oldDir" -> oldDir.toString()), + cause = e) + } + + def unableToCreatePartitionPathError(partitionPath: Path, e: IOException): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2039", + messageParameters = Map("partitionPath" -> partitionPath.toString()), + cause = e) + } + + def unableToRenamePartitionPathError(oldPartPath: Path, e: IOException): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2040", + messageParameters = Map("oldPartPath" -> oldPartPath.toString()), + cause = e) + } + + def methodNotImplementedError(methodName: String): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2041", + messageParameters = Map("methodName" -> methodName)) + } + + def arithmeticOverflowError(e: ArithmeticException): SparkArithmeticException = { + new SparkArithmeticException( + errorClass = "_LEGACY_ERROR_TEMP_2042", + messageParameters = Map( + "message" -> e.getMessage, + "ansiConfig" -> toSQLConf(SQLConf.ANSI_ENABLED.key)), + context = Array.empty, + summary = "") + } + + def unaryMinusCauseOverflowError(originValue: Int): SparkArithmeticException = { + new SparkArithmeticException( + errorClass = "_LEGACY_ERROR_TEMP_2043", + messageParameters = Map("sqlValue" -> toSQLValue(originValue, IntegerType)), + context = Array.empty, + summary = "") + } + + def binaryArithmeticCauseOverflowError( + eval1: Short, symbol: String, eval2: Short): SparkArithmeticException = { + new SparkArithmeticException( + errorClass = "BINARY_ARITHMETIC_OVERFLOW", + messageParameters = Map( + "value1" -> toSQLValue(eval1, ShortType), + "symbol" -> symbol, + "value2" -> toSQLValue(eval2, ShortType)), + context = Array.empty, + summary = "") + } + + def intervalArithmeticOverflowError( + message: String, + hint: String = "", + context: SQLQueryContext): ArithmeticException = { + val alternative = if (hint.nonEmpty) { + s" Use '$hint' to tolerate overflow and return NULL instead." + } else "" + new SparkArithmeticException( + errorClass = "INTERVAL_ARITHMETIC_OVERFLOW", + messageParameters = Map( + "message" -> message, + "alternative" -> alternative), + context = getQueryContext(context), + summary = getSummary(context)) + } + + def failedToCompileMsg(e: Exception): String = { + s"failed to compile: $e" + } + + def internalCompilerError(e: InternalCompilerException): Throwable = { + new InternalCompilerException(failedToCompileMsg(e), e) + } + + def compilerError(e: CompileException): Throwable = { + new CompileException(failedToCompileMsg(e), e.getLocation) + } + + def unsupportedTableChangeError(e: IllegalArgumentException): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2045", + messageParameters = Map("message" -> e.getMessage), + cause = e) + } + + def notADatasourceRDDPartitionError(split: Partition): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2046", + messageParameters = Map("split" -> split.toString()), + cause = null) + } + + def dataPathNotSpecifiedError(): SparkIllegalArgumentException = { + new SparkIllegalArgumentException( + errorClass = "_LEGACY_ERROR_TEMP_2047", + messageParameters = Map.empty) + } + + def createStreamingSourceNotSpecifySchemaError(): SparkIllegalArgumentException = { + new SparkIllegalArgumentException( + errorClass = "_LEGACY_ERROR_TEMP_2048", + messageParameters = Map.empty) + } + + def streamedOperatorUnsupportedByDataSourceError( + className: String, operator: String): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2049", + messageParameters = Map("className" -> className, "operator" -> operator)) + } + + def nonTimeWindowNotSupportedInStreamingError( + windowFuncList: Seq[String], + columnNameList: Seq[String], + windowSpecList: Seq[String], + origin: Origin): AnalysisException = { + new AnalysisException( + errorClass = "NON_TIME_WINDOW_NOT_SUPPORTED_IN_STREAMING", + messageParameters = Map( + "windowFunc" -> windowFuncList.map(toSQLStmt(_)).mkString(","), + "columnName" -> columnNameList.map(toSQLId(_)).mkString(","), + "windowSpec" -> windowSpecList.map(toSQLStmt(_)).mkString(",")), + origin = origin) + } + + def multiplePathsSpecifiedError(allPaths: Seq[String]): SparkIllegalArgumentException = { + new SparkIllegalArgumentException( + errorClass = "_LEGACY_ERROR_TEMP_2050", + messageParameters = Map("paths" -> allPaths.mkString(", "))) + } + + def dataSourceNotFoundError( + provider: String, error: Throwable): SparkClassNotFoundException = { + new SparkClassNotFoundException( + errorClass = "DATA_SOURCE_NOT_FOUND", + messageParameters = Map("provider" -> provider), + cause = error) + } + + def removedClassInSpark2Error(className: String, e: Throwable): SparkClassNotFoundException = { + new SparkClassNotFoundException( + errorClass = "_LEGACY_ERROR_TEMP_2052", + messageParameters = Map("className" -> className), + cause = e) + } + + def incompatibleDataSourceRegisterError(e: Throwable): Throwable = { + new SparkClassNotFoundException( + errorClass = "INCOMPATIBLE_DATASOURCE_REGISTER", + messageParameters = Map("message" -> e.getMessage), + cause = e) + } + + def sparkUpgradeInReadingDatesError( + format: String, config: String, option: String): SparkUpgradeException = { + new SparkUpgradeException( + errorClass = "INCONSISTENT_BEHAVIOR_CROSS_VERSION.READ_ANCIENT_DATETIME", + messageParameters = Map( + "format" -> format, + "config" -> toSQLConf(config), + "option" -> toDSOption(option)), + cause = null + ) + } + + def sparkUpgradeInWritingDatesError(format: String, config: String): SparkUpgradeException = { + new SparkUpgradeException( + errorClass = "INCONSISTENT_BEHAVIOR_CROSS_VERSION.WRITE_ANCIENT_DATETIME", + messageParameters = Map( + "format" -> format, + "config" -> toSQLConf(config)), + cause = null + ) + } + + def buildReaderUnsupportedForFileFormatError( + format: String): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2053", + messageParameters = Map("format" -> format)) + } + + def taskFailedWhileWritingRowsError(path: String, cause: Throwable): Throwable = { + new SparkException( + errorClass = "TASK_WRITE_FAILED", + messageParameters = Map("path" -> path), + cause = cause) + } + + def readCurrentFileNotFoundError(e: FileNotFoundException): SparkFileNotFoundException = { + new SparkFileNotFoundException( + errorClass = "_LEGACY_ERROR_TEMP_2055", + messageParameters = Map("message" -> e.getMessage)) + } + + def saveModeUnsupportedError(saveMode: Any, pathExists: Boolean): Throwable = { + val errorSubClass = if (pathExists) "EXISTENT_PATH" else "NON_EXISTENT_PATH" + new SparkIllegalArgumentException( + errorClass = s"UNSUPPORTED_SAVE_MODE.$errorSubClass", + messageParameters = Map("saveMode" -> toSQLValue(saveMode, StringType))) + } + + def cannotClearOutputDirectoryError(staticPrefixPath: Path): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2056", + messageParameters = Map("staticPrefixPath" -> staticPrefixPath.toString()), + cause = null) + } + + def cannotClearPartitionDirectoryError(path: Path): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2057", + messageParameters = Map("path" -> path.toString()), + cause = null) + } + + def failedToCastValueToDataTypeForPartitionColumnError( + value: String, dataType: DataType, columnName: String): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2058", + messageParameters = Map( + "value" -> value, + "dataType" -> dataType.toString(), + "columnName" -> columnName)) + } + + def endOfStreamError(): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2059", + messageParameters = Map.empty, + cause = null) + } + + def fallbackV1RelationReportsInconsistentSchemaError( + v2Schema: StructType, v1Schema: StructType): SparkIllegalArgumentException = { + new SparkIllegalArgumentException( + errorClass = "_LEGACY_ERROR_TEMP_2060", + messageParameters = Map("v2Schema" -> v2Schema.toString(), "v1Schema" -> v1Schema.toString())) + } + + def noRecordsFromEmptyDataReaderError(): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2061", + messageParameters = Map.empty, + cause = null) + } + + def fileNotFoundError(e: FileNotFoundException): SparkFileNotFoundException = { + new SparkFileNotFoundException( + errorClass = "_LEGACY_ERROR_TEMP_2062", + messageParameters = Map("message" -> e.getMessage)) + } + + def unsupportedSchemaColumnConvertError( + filePath: String, + column: String, + logicalType: String, + physicalType: String, + e: Exception): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2063", + messageParameters = Map( + "filePath" -> filePath, + "column" -> column, + "logicalType" -> logicalType, + "physicalType" -> physicalType), + cause = e) + } + + def cannotReadFilesError( + e: Throwable, + path: String): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2064", + messageParameters = Map("path" -> path), + cause = e) + } + + def cannotCreateColumnarReaderError(): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2065", + messageParameters = Map.empty, + cause = null) + } + + def invalidNamespaceNameError(namespace: Array[String]): SparkIllegalArgumentException = { + new SparkIllegalArgumentException( + errorClass = "_LEGACY_ERROR_TEMP_2066", + messageParameters = Map("namespace" -> namespace.quoted)) + } + + def unsupportedPartitionTransformError( + transform: Transform): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2067", + messageParameters = Map("transform" -> transform.toString())) + } + + def missingDatabaseLocationError(): SparkIllegalArgumentException = { + new SparkIllegalArgumentException( + errorClass = "_LEGACY_ERROR_TEMP_2068", + messageParameters = Map.empty) + } + + def cannotRemoveReservedPropertyError(property: String): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2069", + messageParameters = Map("property" -> property)) + } + + def writingJobFailedError(cause: Throwable): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2070", + messageParameters = Map.empty, + cause = cause) + } + + def commitDeniedError( + partId: Int, taskId: Long, attemptId: Int, stageId: Int, stageAttempt: Int): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2071", + messageParameters = Map( + "partId" -> partId.toString(), + "taskId" -> taskId.toString(), + "attemptId" -> attemptId.toString(), + "stageId" -> stageId.toString(), + "stageAttempt" -> stageAttempt.toString()), + cause = null) + } + + def cannotCreateJDBCTableWithPartitionsError(): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2073", + messageParameters = Map.empty) + } + + def unsupportedUserSpecifiedSchemaError(): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2074", + messageParameters = Map.empty) + } + + def writeUnsupportedForBinaryFileDataSourceError(): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2075", + messageParameters = Map.empty) + } + + def fileLengthExceedsMaxLengthError(status: FileStatus, maxLength: Int): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2076", + messageParameters = Map( + "path" -> status.getPath.toString(), + "len" -> status.getLen.toString(), + "maxLength" -> maxLength.toString()), + cause = null) + } + + def unsupportedFieldNameError(fieldName: String): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2077", + messageParameters = Map("fieldName" -> fieldName)) + } + + def cannotSpecifyBothJdbcTableNameAndQueryError( + jdbcTableName: String, jdbcQueryString: String): SparkIllegalArgumentException = { + new SparkIllegalArgumentException( + errorClass = "_LEGACY_ERROR_TEMP_2078", + messageParameters = Map( + "jdbcTableName" -> jdbcTableName, + "jdbcQueryString" -> jdbcQueryString)) + } + + def missingJdbcTableNameAndQueryError( + jdbcTableName: String, jdbcQueryString: String): SparkIllegalArgumentException = { + new SparkIllegalArgumentException( + errorClass = "_LEGACY_ERROR_TEMP_2079", + messageParameters = Map( + "jdbcTableName" -> jdbcTableName, + "jdbcQueryString" -> jdbcQueryString)) + } + + def emptyOptionError(optionName: String): SparkIllegalArgumentException = { + new SparkIllegalArgumentException( + errorClass = "_LEGACY_ERROR_TEMP_2080", + messageParameters = Map("optionName" -> optionName)) + } + + def invalidJdbcTxnIsolationLevelError( + jdbcTxnIsolationLevel: String, value: String): SparkIllegalArgumentException = { + new SparkIllegalArgumentException( + errorClass = "_LEGACY_ERROR_TEMP_2081", + messageParameters = Map("value" -> value, "jdbcTxnIsolationLevel" -> jdbcTxnIsolationLevel)) + } + + def cannotGetJdbcTypeError(dt: DataType): SparkIllegalArgumentException = { + new SparkIllegalArgumentException( + errorClass = "_LEGACY_ERROR_TEMP_2082", + messageParameters = Map("catalogString" -> dt.catalogString)) + } + + def unrecognizedSqlTypeError(jdbcTypeId: String, typeName: String): Throwable = { + new SparkSQLException( + errorClass = "UNRECOGNIZED_SQL_TYPE", + messageParameters = Map("typeName" -> typeName, "jdbcType" -> jdbcTypeId)) + } + + def unsupportedJdbcTypeError(content: String): SparkSQLException = { + new SparkSQLException( + errorClass = "_LEGACY_ERROR_TEMP_2083", + messageParameters = Map("content" -> content)) + } + + def unsupportedArrayElementTypeBasedOnBinaryError(dt: DataType): SparkIllegalArgumentException = { + new SparkIllegalArgumentException( + errorClass = "_LEGACY_ERROR_TEMP_2084", + messageParameters = Map("catalogString" -> dt.catalogString)) + } + + def nestedArraysUnsupportedError(): SparkIllegalArgumentException = { + new SparkIllegalArgumentException( + errorClass = "_LEGACY_ERROR_TEMP_2085", + messageParameters = Map.empty) + } + + def cannotTranslateNonNullValueForFieldError(pos: Int): SparkIllegalArgumentException = { + new SparkIllegalArgumentException( + errorClass = "_LEGACY_ERROR_TEMP_2086", + messageParameters = Map("pos" -> pos.toString())) + } + + def invalidJdbcNumPartitionsError( + n: Int, jdbcNumPartitions: String): SparkIllegalArgumentException = { + new SparkIllegalArgumentException( + errorClass = "_LEGACY_ERROR_TEMP_2087", + messageParameters = Map("n" -> n.toString(), "jdbcNumPartitions" -> jdbcNumPartitions)) + } + + def multiActionAlterError(tableName: String): Throwable = { + new SparkSQLFeatureNotSupportedException( + errorClass = "UNSUPPORTED_FEATURE.MULTI_ACTION_ALTER", + messageParameters = Map("tableName" -> tableName)) + } + + def dataTypeUnsupportedYetError(dataType: DataType): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2088", + messageParameters = Map("dataType" -> dataType.toString())) + } + + def unsupportedOperationForDataTypeError( + dataType: DataType): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2089", + messageParameters = Map("catalogString" -> dataType.catalogString)) + } + + def inputFilterNotFullyConvertibleError(owner: String): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2090", + messageParameters = Map("owner" -> owner), + cause = null) + } + + def cannotReadFooterForFileError(file: Path, e: Exception): Throwable = { + new SparkException( + errorClass = "CANNOT_READ_FILE_FOOTER", + messageParameters = Map("file" -> file.toString()), + cause = e) + } + + def foundDuplicateFieldInCaseInsensitiveModeError( + requiredFieldName: String, matchedOrcFields: String): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2093", + messageParameters = Map( + "requiredFieldName" -> requiredFieldName, + "matchedOrcFields" -> matchedOrcFields)) + } + + def foundDuplicateFieldInFieldIdLookupModeError( + requiredId: Int, matchedFields: String): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2094", + messageParameters = Map( + "requiredId" -> requiredId.toString(), + "matchedFields" -> matchedFields)) + } + + def failedToMergeIncompatibleSchemasError( + left: StructType, right: StructType, e: Throwable): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2095", + messageParameters = Map("left" -> left.toString(), "right" -> right.toString()), + cause = e) + } + + def ddlUnsupportedTemporarilyError(ddl: String): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2096", + messageParameters = Map("ddl" -> ddl)) + } + + def executeBroadcastTimeoutError(timeout: Long, ex: Option[TimeoutException]): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2097", + messageParameters = Map( + "timeout" -> timeout.toString(), + "broadcastTimeout" -> toSQLConf(SQLConf.BROADCAST_TIMEOUT.key), + "autoBroadcastJoinThreshold" -> toSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key)), + cause = ex.orNull) + } + + def cannotCompareCostWithTargetCostError(cost: String): SparkIllegalArgumentException = { + new SparkIllegalArgumentException( + errorClass = "_LEGACY_ERROR_TEMP_2098", + messageParameters = Map("cost" -> cost)) + } + + def notSupportTypeError(dataType: DataType): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2100", + messageParameters = Map("dataType" -> dataType.toString()), + cause = null) + } + + def notSupportNonPrimitiveTypeError(): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2101", + messageParameters = Map.empty) + } + + def unsupportedTypeError(dataType: DataType): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2102", + messageParameters = Map("catalogString" -> dataType.catalogString), + cause = null) + } + + def useDictionaryEncodingWhenDictionaryOverflowError(): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2103", + messageParameters = Map.empty, + cause = null) + } + + def endOfIteratorError(): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2104", + messageParameters = Map.empty, + cause = null) + } + + def cannotAllocateMemoryToGrowBytesToBytesMapError(): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2105", + messageParameters = Map.empty, + cause = null) + } + + def cannotAcquireMemoryToBuildLongHashedRelationError(size: Long, got: Long): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2106", + messageParameters = Map("size" -> size.toString(), "got" -> got.toString()), + cause = null) + } + + def cannotAcquireMemoryToBuildUnsafeHashedRelationError(): Throwable = { + new SparkOutOfMemoryError( + "_LEGACY_ERROR_TEMP_2107") + } + + def rowLargerThan256MUnsupportedError(): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2108", + messageParameters = Map.empty) + } + + def cannotBuildHashedRelationWithUniqueKeysExceededError(): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2109", + messageParameters = Map.empty) + } + + def cannotBuildHashedRelationLargerThan8GError(): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2110", + messageParameters = Map.empty) + } + + def failedToPushRowIntoRowQueueError(rowQueue: String): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2111", + messageParameters = Map("rowQueue" -> rowQueue), + cause = null) + } + + def unexpectedWindowFunctionFrameError(frame: String): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2112", + messageParameters = Map("frame" -> frame)) + } + + def cannotParseStatisticAsPercentileError( + stats: String, e: NumberFormatException): SparkIllegalArgumentException = { + new SparkIllegalArgumentException( + errorClass = "_LEGACY_ERROR_TEMP_2113", + messageParameters = Map("stats" -> stats), + cause = e) + } + + def statisticNotRecognizedError(stats: String): SparkIllegalArgumentException = { + new SparkIllegalArgumentException( + errorClass = "_LEGACY_ERROR_TEMP_2114", + messageParameters = Map("stats" -> stats)) + } + + def unknownColumnError(unknownColumn: String): SparkIllegalArgumentException = { + new SparkIllegalArgumentException( + errorClass = "_LEGACY_ERROR_TEMP_2115", + messageParameters = Map("unknownColumn" -> unknownColumn)) + } + + def unexpectedAccumulableUpdateValueError(o: Any): SparkIllegalArgumentException = { + new SparkIllegalArgumentException( + errorClass = "_LEGACY_ERROR_TEMP_2116", + messageParameters = Map("o" -> o.toString())) + } + + def unscaledValueTooLargeForPrecisionError( + value: Decimal, + decimalPrecision: Int, + decimalScale: Int, + context: SQLQueryContext = null): ArithmeticException = { + DataTypeErrors.unscaledValueTooLargeForPrecisionError( + value, decimalPrecision, decimalScale, context) + } + + def decimalPrecisionExceedsMaxPrecisionError( + precision: Int, maxPrecision: Int): SparkArithmeticException = { + DataTypeErrors.decimalPrecisionExceedsMaxPrecisionError(precision, maxPrecision) + } + + def outOfDecimalTypeRangeError(str: UTF8String): SparkArithmeticException = { + new SparkArithmeticException( + errorClass = "NUMERIC_OUT_OF_SUPPORTED_RANGE", + messageParameters = Map( + "value" -> str.toString), + context = Array.empty, + summary = "") + } + + def unsupportedArrayTypeError(clazz: Class[_]): SparkRuntimeException = { + DataTypeErrors.unsupportedJavaTypeError(clazz) + } + + def unsupportedJavaTypeError(clazz: Class[_]): SparkRuntimeException = { + DataTypeErrors.unsupportedJavaTypeError(clazz) + } + + def failedParsingStructTypeError(raw: String): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "FAILED_PARSE_STRUCT_TYPE", + messageParameters = Map("raw" -> toSQLValue(raw, StringType))) + } + + def cannotMergeDecimalTypesWithIncompatibleScaleError( + leftScale: Int, rightScale: Int): Throwable = { + DataTypeErrors.cannotMergeDecimalTypesWithIncompatibleScaleError(leftScale, rightScale) + } + + def cannotMergeIncompatibleDataTypesError(left: DataType, right: DataType): Throwable = { + DataTypeErrors.cannotMergeIncompatibleDataTypesError(left, right) + } + + def exceedMapSizeLimitError(size: Int): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2126", + messageParameters = Map( + "size" -> size.toString(), + "maxRoundedArrayLength" -> ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString())) + } + + def duplicateMapKeyFoundError(key: Any): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "DUPLICATED_MAP_KEY", + messageParameters = Map( + "key" -> key.toString(), + "mapKeyDedupPolicy" -> toSQLConf(SQLConf.MAP_KEY_DEDUP_POLICY.key))) + } + + def mapDataKeyArrayLengthDiffersFromValueArrayLengthError(): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2128", + messageParameters = Map.empty) + } + + def registeringStreamingQueryListenerError(e: Exception): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2131", + messageParameters = Map.empty, + cause = e) + } + + def concurrentQueryInstanceError(): Throwable = { + new SparkConcurrentModificationException( + errorClass = "CONCURRENT_QUERY", + messageParameters = Map.empty[String, String]) + } + + def concurrentStreamLogUpdate(batchId: Long): Throwable = { + new SparkException( + errorClass = "CONCURRENT_STREAM_LOG_UPDATE", + messageParameters = Map("batchId" -> batchId.toString), + cause = null) + } + + def cannotParseJsonArraysAsStructsError(recordStr: String): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "MALFORMED_RECORD_IN_PARSING.CANNOT_PARSE_JSON_ARRAYS_AS_STRUCTS", + messageParameters = Map( + "badRecord" -> recordStr, + "failFastMode" -> FailFastMode.name)) + } + + def cannotParseStringAsDataTypeError(parser: JsonParser, token: JsonToken, dataType: DataType) + : SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2133", + messageParameters = Map( + "fieldName" -> parser.getCurrentName, + "fieldValue" -> parser.getText, + "token" -> token.toString(), + "dataType" -> dataType.toString())) + } + + def emptyJsonFieldValueError(dataType: DataType): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "EMPTY_JSON_FIELD_VALUE", + messageParameters = Map("dataType" -> toSQLType(dataType))) + } + + def cannotParseJSONFieldError(parser: JsonParser, jsonType: JsonToken, dataType: DataType) + : SparkRuntimeException = { + cannotParseJSONFieldError(parser.getCurrentName, parser.getText, jsonType, dataType) + } + + def cannotParseJSONFieldError( + fieldName: String, + fieldValue: String, + jsonType: JsonToken, + dataType: DataType): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "CANNOT_PARSE_JSON_FIELD", + messageParameters = Map( + "fieldName" -> toSQLValue(fieldName, StringType), + "fieldValue" -> fieldValue, + "jsonType" -> jsonType.toString(), + "dataType" -> toSQLType(dataType))) + } + + def rootConverterReturnNullError(): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "INVALID_JSON_ROOT_FIELD", + messageParameters = Map.empty) + } + + def attributesForTypeUnsupportedError(schema: Schema): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2142", + messageParameters = Map( + "schema" -> schema.toString())) + } + + def paramExceedOneCharError(paramName: String): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2145", + messageParameters = Map( + "paramName" -> paramName)) + } + + def paramIsNotIntegerError(paramName: String, value: String): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2146", + messageParameters = Map( + "paramName" -> paramName, + "value" -> value)) + } + + def paramIsNotBooleanValueError(paramName: String): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2147", + messageParameters = Map( + "paramName" -> paramName), + cause = null) + } + + def foundNullValueForNotNullableFieldError(name: String): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2148", + messageParameters = Map( + "name" -> name)) + } + + def malformedCSVRecordError(badRecord: String): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "MALFORMED_CSV_RECORD", + messageParameters = Map("badRecord" -> badRecord)) + } + + def elementsOfTupleExceedLimitError(): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2150", + messageParameters = Map.empty) + } + + def expressionDecodingError(e: Exception, expressions: Seq[Expression]): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2151", + messageParameters = Map( + "e" -> e.toString(), + "expressions" -> expressions.map( + _.simpleString(SQLConf.get.maxToStringFields)).mkString("\n")), + cause = e) + } + + def expressionEncodingError(e: Exception, expressions: Seq[Expression]): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2152", + messageParameters = Map( + "e" -> e.toString(), + "expressions" -> expressions.map( + _.simpleString(SQLConf.get.maxToStringFields)).mkString("\n")), + cause = e) + } + + def classHasUnexpectedSerializerError( + clsName: String, objSerializer: Expression): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2153", + messageParameters = Map( + "clsName" -> clsName, + "objSerializer" -> objSerializer.toString())) + } + + def unsupportedOperandTypeForSizeFunctionError( + dataType: DataType): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2156", + messageParameters = Map( + "dataType" -> dataType.getClass.getCanonicalName)) + } + + def unexpectedValueForStartInFunctionError(prettyName: String): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2157", + messageParameters = Map( + "prettyName" -> prettyName)) + } + + def unexpectedValueForLengthInFunctionError(prettyName: String): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2158", + messageParameters = Map( + "prettyName" -> prettyName)) + } + + def invalidIndexOfZeroError(context: SQLQueryContext): RuntimeException = { + new SparkRuntimeException( + errorClass = "INVALID_INDEX_OF_ZERO", + cause = null, + messageParameters = Map.empty, + context = getQueryContext(context), + summary = getSummary(context)) + } + + def concatArraysWithElementsExceedLimitError(numberOfElements: Long): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2159", + messageParameters = Map( + "numberOfElements" -> numberOfElements.toString(), + "maxRoundedArrayLength" -> ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString())) + } + + def flattenArraysWithElementsExceedLimitError(numberOfElements: Long): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2160", + messageParameters = Map( + "numberOfElements" -> numberOfElements.toString(), + "maxRoundedArrayLength" -> ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString())) + } + + def createArrayWithElementsExceedLimitError(count: Any): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2161", + messageParameters = Map( + "count" -> count.toString(), + "maxRoundedArrayLength" -> ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString())) + } + + def unionArrayWithElementsExceedLimitError(length: Int): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2162", + messageParameters = Map( + "length" -> length.toString(), + "maxRoundedArrayLength" -> ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString())) + } + + def initialTypeNotTargetDataTypeError( + dataType: DataType, target: String): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2163", + messageParameters = Map( + "dataType" -> dataType.catalogString, + "target" -> target)) + } + + def initialTypeNotTargetDataTypesError(dataType: DataType): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2164", + messageParameters = Map( + "dataType" -> dataType.catalogString, + "arrayType" -> ArrayType.simpleString, + "structType" -> StructType.simpleString, + "mapType" -> MapType.simpleString)) + } + + def malformedRecordsDetectedInSchemaInferenceError(e: Throwable): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2165", + messageParameters = Map( + "failFastMode" -> FailFastMode.name), + cause = e) + } + + def malformedJSONError(): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2166", + messageParameters = Map.empty, + cause = null) + } + + def malformedRecordsDetectedInSchemaInferenceError(dataType: DataType): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2167", + messageParameters = Map( + "failFastMode" -> FailFastMode.name, + "dataType" -> dataType.catalogString), + cause = null) + } + + def decorrelateInnerQueryThroughPlanUnsupportedError( + plan: LogicalPlan): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2168", + messageParameters = Map( + "plan" -> plan.nodeName)) + } + + def methodCalledInAnalyzerNotAllowedError(): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2169", + messageParameters = Map.empty) + } + + def cannotSafelyMergeSerdePropertiesError( + props1: Map[String, String], + props2: Map[String, String], + conflictKeys: Set[String]): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2170", + messageParameters = Map( + "props1" -> props1.map { case (k, v) => s"$k=$v" }.mkString("{", ",", "}"), + "props2" -> props2.map { case (k, v) => s"$k=$v" }.mkString("{", ",", "}"), + "conflictKeys" -> conflictKeys.mkString(", "))) + } + + def pairUnsupportedAtFunctionError( + r1: ValueInterval, + r2: ValueInterval, + function: String): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2171", + messageParameters = Map( + "r1" -> r1.toString(), + "r2" -> r2.toString(), + "function" -> function)) + } + + def onceStrategyIdempotenceIsBrokenForBatchError[TreeType <: TreeNode[_]]( + batchName: String, plan: TreeType, reOptimized: TreeType): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2172", + messageParameters = Map( + "batchName" -> batchName, + "plan" -> sideBySide(plan.treeString, reOptimized.treeString).mkString("\n"))) + } + + def ruleIdNotFoundForRuleError(ruleName: String): Throwable = { + new SparkException( + errorClass = "RULE_ID_NOT_FOUND", + messageParameters = Map("ruleName" -> ruleName), + cause = null) + } + + def cannotCreateArrayWithElementsExceedLimitError( + numElements: Long, additionalErrorMessage: String): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2176", + messageParameters = Map( + "numElements" -> numElements.toString(), + "maxRoundedArrayLength"-> ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString(), + "additionalErrorMessage" -> additionalErrorMessage)) + } + + def malformedRecordsDetectedInRecordParsingError( + badRecord: String, e: BadRecordException): Throwable = { + new SparkException( + errorClass = "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION", + messageParameters = Map( + "badRecord" -> badRecord, + "failFastMode" -> FailFastMode.name), + cause = e) + } + + def remoteOperationsUnsupportedError(): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2178", + messageParameters = Map.empty) + } + + def invalidKerberosConfigForHiveServer2Error(): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2179", + messageParameters = Map.empty, + cause = null) + } + + def parentSparkUIToAttachTabNotFoundError(): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2180", + messageParameters = Map.empty, + cause = null) + } + + def inferSchemaUnsupportedForHiveError(): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2181", + messageParameters = Map.empty) + } + + def requestedPartitionsMismatchTablePartitionsError( + table: CatalogTable, partition: Map[String, Option[String]]): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2182", + messageParameters = Map( + "tableIdentifier" -> table.identifier.table, + "partitionKeys" -> partition.keys.mkString(","), + "partitionColumnNames" -> table.partitionColumnNames.mkString(",")), + cause = null) + } + + def dynamicPartitionKeyNotAmongWrittenPartitionPathsError(key: String): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2183", + messageParameters = Map( + "key" -> toSQLValue(key, StringType)), + cause = null) + } + + def cannotRemovePartitionDirError(partitionPath: Path): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2184", + messageParameters = Map( + "partitionPath" -> partitionPath.toString())) + } + + def cannotCreateStagingDirError(message: String, e: IOException): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2185", + messageParameters = Map( + "message" -> message), + cause = e) + } + + def serDeInterfaceNotFoundError(e: NoClassDefFoundError): SparkClassNotFoundException = { + new SparkClassNotFoundException( + errorClass = "_LEGACY_ERROR_TEMP_2186", + messageParameters = Map.empty, + cause = e) + } + + def convertHiveTableToCatalogTableError( + e: SparkException, dbName: String, tableName: String): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2187", + messageParameters = Map( + "message" -> e.getMessage, + "dbName" -> dbName, + "tableName" -> tableName), + cause = e) + } + + def cannotRecognizeHiveTypeError( + e: ParseException, fieldType: String, fieldName: String): Throwable = { + new SparkException( + errorClass = "CANNOT_RECOGNIZE_HIVE_TYPE", + messageParameters = Map( + "fieldType" -> toSQLType(fieldType), + "fieldName" -> toSQLId(fieldName)), + cause = e) + } + + def getTablesByTypeUnsupportedByHiveVersionError(): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2189", + messageParameters = Map.empty) + } + + def dropTableWithPurgeUnsupportedError(): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2190", + messageParameters = Map.empty) + } + + def alterTableWithDropPartitionAndPurgeUnsupportedError(): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2191", + messageParameters = Map.empty) + } + + def invalidPartitionFilterError(): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2192", + messageParameters = Map.empty) + } + + def getPartitionMetadataByFilterError(e: InvocationTargetException): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2193", + messageParameters = Map( + "hiveMetastorePartitionPruningFallbackOnException" -> + SQLConf.HIVE_METASTORE_PARTITION_PRUNING_FALLBACK_ON_EXCEPTION.key), + cause = e) + } + + def unsupportedHiveMetastoreVersionError( + version: String, key: String): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2194", + messageParameters = Map( + "version" -> version, + "key" -> key)) + } + + def loadHiveClientCausesNoClassDefFoundError( + cnf: NoClassDefFoundError, + execJars: Seq[URL], + key: String, + e: InvocationTargetException): SparkClassNotFoundException = { + new SparkClassNotFoundException( + errorClass = "_LEGACY_ERROR_TEMP_2195", + messageParameters = Map( + "cnf" -> cnf.toString(), + "execJars" -> execJars.mkString(", "), + "key" -> key), + cause = e) + } + + def cannotFetchTablesOfDatabaseError(dbName: String, e: Exception): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2196", + messageParameters = Map( + "dbName" -> dbName), + cause = e) + } + + def illegalLocationClauseForViewPartitionError(): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2197", + messageParameters = Map.empty, + cause = null) + } + + def renamePathAsExistsPathError(srcPath: Path, dstPath: Path): Throwable = { + new SparkFileAlreadyExistsException( + errorClass = "FAILED_RENAME_PATH", + messageParameters = Map( + "sourcePath" -> srcPath.toString, + "targetPath" -> dstPath.toString)) + } + + def renameAsExistsPathError(dstPath: Path): SparkFileAlreadyExistsException = { + new SparkFileAlreadyExistsException( + errorClass = "_LEGACY_ERROR_TEMP_2198", + messageParameters = Map( + "dstPath" -> dstPath.toString())) + } + + def renameSrcPathNotFoundError(srcPath: Path): Throwable = { + new SparkFileNotFoundException( + errorClass = "RENAME_SRC_PATH_NOT_FOUND", + messageParameters = Map("sourcePath" -> srcPath.toString)) + } + + def failedRenameTempFileError(srcPath: Path, dstPath: Path): Throwable = { + new SparkException( + errorClass = "FAILED_RENAME_TEMP_FILE", + messageParameters = Map( + "srcPath" -> srcPath.toString(), + "dstPath" -> dstPath.toString()), + cause = null) + } + + def legacyMetadataPathExistsError(metadataPath: Path, legacyMetadataPath: Path): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2200", + messageParameters = Map( + "metadataPath" -> metadataPath.toString(), + "legacyMetadataPath" -> legacyMetadataPath.toString(), + "StreamingCheckpointEscaptedPathCheckEnabled" -> + SQLConf.STREAMING_CHECKPOINT_ESCAPED_PATH_CHECK_ENABLED.key), + cause = null) + } + + def partitionColumnNotFoundInSchemaError( + col: String, schema: StructType): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2201", + messageParameters = Map( + "col" -> col, + "schema" -> schema.toString())) + } + + def stateNotDefinedOrAlreadyRemovedError(): Throwable = { + new NoSuchElementException("State is either not defined or has already been removed") + } + + def cannotSetTimeoutDurationError(): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2203", + messageParameters = Map.empty) + } + + def cannotGetEventTimeWatermarkError(): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2204", + messageParameters = Map.empty) + } + + def cannotSetTimeoutTimestampError(): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2205", + messageParameters = Map.empty) + } + + def batchMetadataFileNotFoundError(batchMetadataFile: Path): SparkFileNotFoundException = { + new SparkFileNotFoundException( + errorClass = "BATCH_METADATA_NOT_FOUND", + messageParameters = Map( + "batchMetadataFile" -> batchMetadataFile.toString())) + } + + def multiStreamingQueriesUsingPathConcurrentlyError( + path: String, e: FileAlreadyExistsException): SparkConcurrentModificationException = { + new SparkConcurrentModificationException( + errorClass = "_LEGACY_ERROR_TEMP_2207", + messageParameters = Map( + "path" -> path), + cause = e) + } + + def addFilesWithAbsolutePathUnsupportedError( + commitProtocol: String): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2208", + messageParameters = Map( + "commitProtocol" -> commitProtocol)) + } + + def microBatchUnsupportedByDataSourceError( + srcName: String, + disabledSources: String, + table: Table): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2209", + messageParameters = Map( + "srcName" -> srcName, + "disabledSources" -> disabledSources, + "table" -> table.toString())) + } + + def cannotExecuteStreamingRelationExecError(): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2210", + messageParameters = Map.empty) + } + + def invalidStreamingOutputModeError( + outputMode: Option[OutputMode]): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2211", + messageParameters = Map( + "outputMode" -> outputMode.toString())) + } + + def invalidCatalogNameError(name: String): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2212", + messageParameters = Map( + "name" -> name), + cause = null) + } + + def catalogPluginClassNotFoundError(name: String): Throwable = { + new CatalogNotFoundException( + s"Catalog '$name' plugin class not found: spark.sql.catalog.$name is not defined") + } + + def catalogPluginClassNotImplementedError(name: String, pluginClassName: String): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2214", + messageParameters = Map( + "name" -> name, + "pluginClassName" -> pluginClassName), + cause = null) + } + + def catalogPluginClassNotFoundForCatalogError( + name: String, + pluginClassName: String, + e: Exception): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2215", + messageParameters = Map( + "name" -> name, + "pluginClassName" -> pluginClassName), + cause = e) + } + + def catalogFailToFindPublicNoArgConstructorError( + name: String, + pluginClassName: String, + e: Exception): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2216", + messageParameters = Map( + "name" -> name, + "pluginClassName" -> pluginClassName), + cause = e) + } + + def catalogFailToCallPublicNoArgConstructorError( + name: String, + pluginClassName: String, + e: Exception): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2217", + messageParameters = Map( + "name" -> name, + "pluginClassName" -> pluginClassName), + cause = e) + } + + def cannotInstantiateAbstractCatalogPluginClassError( + name: String, + pluginClassName: String, + e: Exception): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2218", + messageParameters = Map( + "name" -> name, + "pluginClassName" -> pluginClassName), + cause = e.getCause) + } + + def failedToInstantiateConstructorForCatalogError( + name: String, + pluginClassName: String, + e: Exception): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2219", + messageParameters = Map( + "name" -> name, + "pluginClassName" -> pluginClassName), + cause = e.getCause) + } + + def noSuchElementExceptionError(): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2220", + messageParameters = Map.empty, + cause = null) + } + + def sqlConfigNotFoundError(key: String): SparkNoSuchElementException = { + new SparkNoSuchElementException( + errorClass = "SQL_CONF_NOT_FOUND", + messageParameters = Map("sqlConf" -> toSQLConf(key))) + } + + def cannotMutateReadOnlySQLConfError(): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2222", + messageParameters = Map.empty) + } + + def cannotCloneOrCopyReadOnlySQLConfError(): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2223", + messageParameters = Map.empty) + } + + def cannotGetSQLConfInSchedulerEventLoopThreadError(): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2224", + messageParameters = Map.empty, + cause = null) + } + + def unsupportedOperationExceptionError(): SparkUnsupportedOperationException = { + DataTypeErrors.unsupportedOperationExceptionError() + } + + def nullLiteralsCannotBeCastedError(name: String): SparkUnsupportedOperationException = { + DataTypeErrors.nullLiteralsCannotBeCastedError(name) + } + + def notUserDefinedTypeError(name: String, userClass: String): Throwable = { + DataTypeErrors.notUserDefinedTypeError(name, userClass) + } + + def cannotLoadUserDefinedTypeError(name: String, userClass: String): Throwable = { + DataTypeErrors.cannotLoadUserDefinedTypeError(name, userClass) + } + + def notPublicClassError(name: String): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2229", + messageParameters = Map( + "name" -> name)) + } + + def primitiveTypesNotSupportedError(): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2230", + messageParameters = Map.empty) + } + + def fieldIndexOnRowWithoutSchemaError(): SparkUnsupportedOperationException = { + DataTypeErrors.fieldIndexOnRowWithoutSchemaError() + } + + def valueIsNullError(index: Int): Throwable = { + DataTypeErrors.valueIsNullError(index) + } + + def onlySupportDataSourcesProvidingFileFormatError(providingClass: String): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2233", + messageParameters = Map( + "providingClass" -> providingClass), + cause = null) + } + + def cannotRestorePermissionsForPathError(permission: FsPermission, path: Path): Throwable = { + new SparkSecurityException( + errorClass = "CANNOT_RESTORE_PERMISSIONS_FOR_PATH", + messageParameters = Map( + "permission" -> permission.toString, + "path" -> path.toString)) + } + + def failToSetOriginalACLBackError( + aclEntries: String, path: Path, e: Throwable): SparkSecurityException = { + new SparkSecurityException( + errorClass = "_LEGACY_ERROR_TEMP_2234", + messageParameters = Map( + "aclEntries" -> aclEntries, + "path" -> path.toString(), + "message" -> e.getMessage)) + } + + def multiFailuresInStageMaterializationError(error: Throwable): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2235", + messageParameters = Map.empty, + cause = error) + } + + def unrecognizedCompressionSchemaTypeIDError(typeId: Int): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2236", + messageParameters = Map( + "typeId" -> typeId.toString())) + } + + def getParentLoggerNotImplementedError( + className: String): SparkSQLFeatureNotSupportedException = { + new SparkSQLFeatureNotSupportedException( + errorClass = "_LEGACY_ERROR_TEMP_2237", + messageParameters = Map( + "className" -> className)) + } + + def cannotCreateParquetConverterForTypeError( + t: DecimalType, parquetType: String): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2238", + messageParameters = Map( + "typeName" -> t.typeName, + "parquetType" -> parquetType)) + } + + def cannotCreateParquetConverterForDecimalTypeError( + t: DecimalType, parquetType: String): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2239", + messageParameters = Map( + "t" -> t.json, + "parquetType" -> parquetType)) + } + + def cannotCreateParquetConverterForDataTypeError( + t: DataType, parquetType: String): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2240", + messageParameters = Map( + "t" -> t.json, + "parquetType" -> parquetType)) + } + + def cannotAddMultiPartitionsOnNonatomicPartitionTableError( + tableName: String): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2241", + messageParameters = Map( + "tableName" -> tableName)) + } + + def userSpecifiedSchemaUnsupportedByDataSourceError( + provider: TableProvider): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2242", + messageParameters = Map( + "provider" -> provider.getClass.getSimpleName)) + } + + def cannotDropMultiPartitionsOnNonatomicPartitionTableError( + tableName: String): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2243", + messageParameters = Map( + "tableName" -> tableName)) + } + + def truncateMultiPartitionUnsupportedError( + tableName: String): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2244", + messageParameters = Map( + "tableName" -> tableName)) + } + + def overwriteTableByUnsupportedExpressionError(table: Table): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2245", + messageParameters = Map( + "table" -> table.toString()), + cause = null) + } + + def dynamicPartitionOverwriteUnsupportedByTableError(table: Table): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2246", + messageParameters = Map( + "table" -> table.toString()), + cause = null) + } + + def failedMergingSchemaError( + leftSchema: StructType, + rightSchema: StructType, + e: SparkException): Throwable = { + new SparkException( + errorClass = "CANNOT_MERGE_SCHEMAS", + messageParameters = Map("left" -> toSQLType(leftSchema), "right" -> toSQLType(rightSchema)), + cause = e) + } + + def cannotBroadcastTableOverMaxTableRowsError( + maxBroadcastTableRows: Long, numRows: Long): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2248", + messageParameters = Map( + "maxBroadcastTableRows" -> maxBroadcastTableRows.toString(), + "numRows" -> numRows.toString()), + cause = null) + } + + def cannotBroadcastTableOverMaxTableBytesError( + maxBroadcastTableBytes: Long, dataSize: Long): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2249", + messageParameters = Map( + "maxBroadcastTableBytes" -> Utils.bytesToString(maxBroadcastTableBytes), + "dataSize" -> Utils.bytesToString(dataSize)), + cause = null) + } + + def notEnoughMemoryToBuildAndBroadcastTableError( + oe: OutOfMemoryError, tables: Seq[TableIdentifier]): Throwable = { + val analyzeTblMsg = if (tables.nonEmpty) { + " or analyze these tables through: " + + s"${tables.map(t => s"ANALYZE TABLE $t COMPUTE STATISTICS;").mkString(" ")}." + } else { + "." + } + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2250", + messageParameters = Map( + "autoBroadcastjoinThreshold" -> SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key, + "driverMemory" -> SparkLauncher.DRIVER_MEMORY, + "analyzeTblMsg" -> analyzeTblMsg), + cause = oe.getCause) + } + + def executeCodePathUnsupportedError(execName: String): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2251", + messageParameters = Map( + "execName" -> execName)) + } + + def cannotMergeClassWithOtherClassError( + className: String, otherClass: String): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2252", + messageParameters = Map( + "className" -> className, + "otherClass" -> otherClass)) + } + + def continuousProcessingUnsupportedByDataSourceError( + sourceName: String): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2253", + messageParameters = Map( + "sourceName" -> sourceName)) + } + + def failedToReadDataError(failureReason: Throwable): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2254", + messageParameters = Map.empty, + cause = failureReason) + } + + def failedToGenerateEpochMarkerError(failureReason: Throwable): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2255", + messageParameters = Map.empty, + cause = failureReason) + } + + def foreachWriterAbortedDueToTaskFailureError(): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2256", + messageParameters = Map.empty, + cause = null) + } + + def incorrectRampUpRate(rowsPerSecond: Long, + maxSeconds: Long, + rampUpTimeSeconds: Long): Throwable = { + new SparkRuntimeException( + errorClass = "INCORRECT_RAMP_UP_RATE", + messageParameters = Map( + "rowsPerSecond" -> rowsPerSecond.toString, + "maxSeconds" -> maxSeconds.toString, + "rampUpTimeSeconds" -> rampUpTimeSeconds.toString + )) + } + + def incorrectEndOffset(rowsPerSecond: Long, + maxSeconds: Long, + endSeconds: Long): Throwable = { + SparkException.internalError( + s"Max offset with ${rowsPerSecond.toString} rowsPerSecond is ${maxSeconds.toString}, " + + s"but it's ${endSeconds.toString} now.") + } + + def failedToReadDeltaFileError(fileToRead: Path, clazz: String, keySize: Int): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2258", + messageParameters = Map( + "fileToRead" -> fileToRead.toString(), + "clazz" -> clazz, + "keySize" -> keySize.toString()), + cause = null) + } + + def failedToReadSnapshotFileError(fileToRead: Path, clazz: String, message: String): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2259", + messageParameters = Map( + "fileToRead" -> fileToRead.toString(), + "clazz" -> clazz, + "message" -> message), + cause = null) + } + + def cannotPurgeAsBreakInternalStateError(): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2260", + messageParameters = Map.empty) + } + + def cleanUpSourceFilesUnsupportedError(): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2261", + messageParameters = Map.empty) + } + + def latestOffsetNotCalledError(): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2262", + messageParameters = Map.empty) + } + + def legacyCheckpointDirectoryExistsError( + checkpointPath: Path, legacyCheckpointDir: String): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2263", + messageParameters = Map( + "checkpointPath" -> checkpointPath.toString(), + "legacyCheckpointDir" -> legacyCheckpointDir, + "StreamingCheckpointEscapedPathCheckEnabled" + -> SQLConf.STREAMING_CHECKPOINT_ESCAPED_PATH_CHECK_ENABLED.key), + cause = null) + } + + def subprocessExitedError( + exitCode: Int, stderrBuffer: CircularBuffer, cause: Throwable): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2264", + messageParameters = Map( + "exitCode" -> exitCode.toString(), + "stderrBuffer" -> stderrBuffer.toString()), + cause = cause) + } + + def outputDataTypeUnsupportedByNodeWithoutSerdeError( + nodeName: String, dt: DataType): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2265", + messageParameters = Map( + "nodeName" -> nodeName, + "dt" -> dt.getClass.getSimpleName), + cause = null) + } + + def invalidStartIndexError(numRows: Int, startIndex: Int): SparkArrayIndexOutOfBoundsException = { + new SparkArrayIndexOutOfBoundsException( + errorClass = "_LEGACY_ERROR_TEMP_2266", + messageParameters = Map( + "numRows" -> numRows.toString(), + "startIndex" -> startIndex.toString()), + context = Array.empty, + summary = "") + } + + def concurrentModificationOnExternalAppendOnlyUnsafeRowArrayError( + className: String): SparkConcurrentModificationException = { + new SparkConcurrentModificationException( + errorClass = "_LEGACY_ERROR_TEMP_2267", + messageParameters = Map( + "className" -> className)) + } + + def doExecuteBroadcastNotImplementedError( + nodeName: String): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2268", + messageParameters = Map( + "nodeName" -> nodeName)) + } + + def defaultDatabaseNotExistsError(defaultDatabase: String): Throwable = { + new SparkException( + errorClass = "DEFAULT_DATABASE_NOT_EXISTS", + messageParameters = Map("defaultDatabase" -> defaultDatabase), + cause = null + ) + } + + def databaseNameConflictWithSystemPreservedDatabaseError(globalTempDB: String): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2269", + messageParameters = Map( + "globalTempDB" -> globalTempDB, + "globalTempDatabase" -> GLOBAL_TEMP_DATABASE.key), + cause = null) + } + + def commentOnTableUnsupportedError(): SparkSQLFeatureNotSupportedException = { + new SparkSQLFeatureNotSupportedException( + errorClass = "_LEGACY_ERROR_TEMP_2270", + messageParameters = Map.empty) + } + + def unsupportedUpdateColumnNullabilityError(): SparkSQLFeatureNotSupportedException = { + new SparkSQLFeatureNotSupportedException( + errorClass = "_LEGACY_ERROR_TEMP_2271", + messageParameters = Map.empty) + } + + def renameColumnUnsupportedForOlderMySQLError(): SparkSQLFeatureNotSupportedException = { + new SparkSQLFeatureNotSupportedException( + errorClass = "_LEGACY_ERROR_TEMP_2272", + messageParameters = Map.empty) + } + + def failedToExecuteQueryError(e: Throwable): SparkException = { + val message = "Hit an error when executing a query" + + (if (e.getMessage == null) "" else s": ${e.getMessage}") + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2273", + messageParameters = Map( + "message" -> message), + cause = e) + } + + def nestedFieldUnsupportedError(colName: String): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "UNSUPPORTED_FEATURE.REPLACE_NESTED_COLUMN", + messageParameters = Map("colName" -> toSQLId(colName))) + } + + def transformationsAndActionsNotInvokedByDriverError(): Throwable = { + new SparkException( + errorClass = "CANNOT_INVOKE_IN_TRANSFORMATIONS", + messageParameters = Map.empty, + cause = null) + } + + def repeatedPivotsUnsupportedError(clause: String, operation: String): Throwable = { + new SparkUnsupportedOperationException( + errorClass = "REPEATED_CLAUSE", + messageParameters = Map("clause" -> clause, "operation" -> operation)) + } + + def pivotNotAfterGroupByUnsupportedError(): Throwable = { + new SparkUnsupportedOperationException( + errorClass = "UNSUPPORTED_FEATURE.PIVOT_AFTER_GROUP_BY", + messageParameters = Map.empty[String, String]) + } + + private val aesFuncName = toSQLId("aes_encrypt") + "/" + toSQLId("aes_decrypt") + + def invalidAesKeyLengthError(actualLength: Int): RuntimeException = { + new SparkRuntimeException( + errorClass = "INVALID_PARAMETER_VALUE.AES_KEY_LENGTH", + messageParameters = Map( + "parameter" -> toSQLId("key"), + "functionName" -> aesFuncName, + "actualLength" -> actualLength.toString())) + } + + def aesModeUnsupportedError(mode: String, padding: String): RuntimeException = { + new SparkRuntimeException( + errorClass = "UNSUPPORTED_FEATURE.AES_MODE", + messageParameters = Map( + "mode" -> mode, + "padding" -> padding, + "functionName" -> aesFuncName)) + } + + def aesCryptoError(detailMessage: String): RuntimeException = { + new SparkRuntimeException( + errorClass = "INVALID_PARAMETER_VALUE.AES_CRYPTO_ERROR", + messageParameters = Map( + "parameter" -> (toSQLId("expr") + ", " + toSQLId("key")), + "functionName" -> aesFuncName, + "detailMessage" -> detailMessage)) + } + + def invalidAesIvLengthError(mode: String, actualLength: Int): RuntimeException = { + new SparkRuntimeException( + errorClass = "INVALID_PARAMETER_VALUE.AES_IV_LENGTH", + messageParameters = Map( + "mode" -> mode, + "parameter" -> toSQLId("iv"), + "functionName" -> aesFuncName, + "actualLength" -> actualLength.toString())) + } + + def aesUnsupportedIv(mode: String): RuntimeException = { + new SparkRuntimeException( + errorClass = "UNSUPPORTED_FEATURE.AES_MODE_IV", + messageParameters = Map( + "mode" -> mode, + "functionName" -> toSQLId("aes_encrypt"))) + } + + def aesUnsupportedAad(mode: String): RuntimeException = { + new SparkRuntimeException( + errorClass = "UNSUPPORTED_FEATURE.AES_MODE_AAD", + messageParameters = Map( + "mode" -> mode, + "functionName" -> toSQLId("aes_encrypt"))) + } + + def hiveTableWithAnsiIntervalsError( + table: TableIdentifier): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "UNSUPPORTED_FEATURE.HIVE_WITH_ANSI_INTERVALS", + messageParameters = Map("tableName" -> toSQLId(table.nameParts))) + } + + def cannotConvertOrcTimestampToTimestampNTZError(): Throwable = { + new SparkUnsupportedOperationException( + errorClass = "UNSUPPORTED_FEATURE.ORC_TYPE_CAST", + messageParameters = Map( + "orcType" -> toSQLType(TimestampType), + "toType" -> toSQLType(TimestampNTZType))) + } + + def cannotConvertOrcTimestampNTZToTimestampLTZError(): Throwable = { + new SparkUnsupportedOperationException( + errorClass = "UNSUPPORTED_FEATURE.ORC_TYPE_CAST", + messageParameters = Map( + "orcType" -> toSQLType(TimestampNTZType), + "toType" -> toSQLType(TimestampType))) + } + + def writePartitionExceedConfigSizeWhenDynamicPartitionError( + numWrittenParts: Int, + maxDynamicPartitions: Int, + maxDynamicPartitionsKey: String): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2277", + messageParameters = Map( + "numWrittenParts" -> numWrittenParts.toString(), + "maxDynamicPartitionsKey" -> maxDynamicPartitionsKey, + "maxDynamicPartitions" -> maxDynamicPartitions.toString(), + "numWrittenParts" -> numWrittenParts.toString()), + cause = null) + } + + def invalidNumberFormatError( + dataType: DataType, input: String, format: String): SparkIllegalArgumentException = { + new SparkIllegalArgumentException( + errorClass = "INVALID_FORMAT.MISMATCH_INPUT", + messageParameters = Map( + "inputType" -> toSQLType(dataType), + "input" -> input, + "format" -> format)) + } + + def unsupportedMultipleBucketTransformsError(): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "UNSUPPORTED_FEATURE.MULTIPLE_BUCKET_TRANSFORMS", + messageParameters = Map.empty) + } + + def unsupportedCommentNamespaceError( + namespace: String): SparkSQLFeatureNotSupportedException = { + new SparkSQLFeatureNotSupportedException( + errorClass = "UNSUPPORTED_FEATURE.COMMENT_NAMESPACE", + messageParameters = Map("namespace" -> toSQLId(namespace))) + } + + def unsupportedRemoveNamespaceCommentError( + namespace: String): SparkSQLFeatureNotSupportedException = { + new SparkSQLFeatureNotSupportedException( + errorClass = "UNSUPPORTED_FEATURE.REMOVE_NAMESPACE_COMMENT", + messageParameters = Map("namespace" -> toSQLId(namespace))) + } + + def unsupportedDropNamespaceError( + namespace: String): SparkSQLFeatureNotSupportedException = { + new SparkSQLFeatureNotSupportedException( + errorClass = "UNSUPPORTED_FEATURE.DROP_NAMESPACE", + messageParameters = Map("namespace" -> toSQLId(namespace))) + } + + def exceedMaxLimit(limit: Int): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "EXCEED_LIMIT_LENGTH", + messageParameters = Map("limit" -> limit.toString) + ) + } + + def timestampAddOverflowError(micros: Long, amount: Int, unit: String): ArithmeticException = { + new SparkArithmeticException( + errorClass = "DATETIME_OVERFLOW", + messageParameters = Map( + "operation" -> (s"add ${toSQLValue(amount, IntegerType)} $unit to " + + s"${toSQLValue(DateTimeUtils.microsToInstant(micros), TimestampType)}")), + context = Array.empty, + summary = "") + } + + def invalidBucketFile(path: String): Throwable = { + new SparkException( + errorClass = "INVALID_BUCKET_FILE", + messageParameters = Map("path" -> path), + cause = null) + } + + def multipleRowSubqueryError(context: SQLQueryContext): Throwable = { + new SparkException( + errorClass = "SCALAR_SUBQUERY_TOO_MANY_ROWS", + messageParameters = Map.empty, + cause = null, + context = getQueryContext(context), + summary = getSummary(context)) + } + + def comparatorReturnsNull(firstValue: String, secondValue: String): Throwable = { + new SparkException( + errorClass = "COMPARATOR_RETURNS_NULL", + messageParameters = Map("firstValue" -> firstValue, "secondValue" -> secondValue), + cause = null) + } + + def invalidPatternError( + funcName: String, + pattern: String, + cause: Throwable): RuntimeException = { + new SparkRuntimeException( + errorClass = "INVALID_PARAMETER_VALUE.PATTERN", + messageParameters = Map( + "parameter" -> toSQLId("regexp"), + "functionName" -> toSQLId(funcName), + "value" -> toSQLValue(pattern, StringType)), + cause = cause) + } + + def tooManyArrayElementsError( + numElements: Int, + elementSize: Int): SparkIllegalArgumentException = { + new SparkIllegalArgumentException( + errorClass = "TOO_MANY_ARRAY_ELEMENTS", + messageParameters = Map( + "numElements" -> numElements.toString, + "size" -> elementSize.toString)) + } + + def invalidEmptyLocationError(location: String): SparkIllegalArgumentException = { + new SparkIllegalArgumentException( + errorClass = "INVALID_EMPTY_LOCATION", + messageParameters = Map("location" -> location)) + } + + def malformedProtobufMessageDetectedInMessageParsingError(e: Throwable): Throwable = { + new SparkException( + errorClass = "MALFORMED_PROTOBUF_MESSAGE", + messageParameters = Map( + "failFastMode" -> FailFastMode.name), + cause = e) + } + + def locationAlreadyExists(tableId: TableIdentifier, location: Path): Throwable = { + new SparkRuntimeException( + errorClass = "LOCATION_ALREADY_EXISTS", + messageParameters = Map( + "location" -> toSQLValue(location.toString, StringType), + "identifier" -> toSQLId(tableId.nameParts))) + } + + def cannotConvertCatalystValueToProtobufEnumTypeError( + sqlColumn: Seq[String], + protobufColumn: String, + data: String, + enumString: String): Throwable = { + new AnalysisException( + errorClass = "CANNOT_CONVERT_SQL_VALUE_TO_PROTOBUF_ENUM_TYPE", + messageParameters = Map( + "sqlColumn" -> toSQLId(sqlColumn), + "protobufColumn" -> protobufColumn, + "data" -> data, + "enumString" -> enumString)) + } + + def hllInvalidLgK(function: String, min: Int, max: Int, value: String): Throwable = { + new SparkRuntimeException( + errorClass = "HLL_INVALID_LG_K", + messageParameters = Map( + "function" -> toSQLId(function), + "min" -> toSQLValue(min, IntegerType), + "max" -> toSQLValue(max, IntegerType), + "value" -> value)) + } + + def hllInvalidInputSketchBuffer(function: String): Throwable = { + new SparkRuntimeException( + errorClass = "HLL_INVALID_INPUT_SKETCH_BUFFER", + messageParameters = Map( + "function" -> toSQLId(function))) + } + + def hllUnionDifferentLgK(left: Int, right: Int, function: String): Throwable = { + new SparkRuntimeException( + errorClass = "HLL_UNION_DIFFERENT_LG_K", + messageParameters = Map( + "left" -> toSQLValue(left, IntegerType), + "right" -> toSQLValue(right, IntegerType), + "function" -> toSQLId(function))) + } + + def mergeCardinalityViolationError(): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "MERGE_CARDINALITY_VIOLATION", + messageParameters = Map.empty) + } + + def unsupportedPurgePartitionError(): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "UNSUPPORTED_FEATURE.PURGE_PARTITION", + messageParameters = Map.empty) + } + + def unsupportedPurgeTableError(): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "UNSUPPORTED_FEATURE.PURGE_TABLE", + messageParameters = Map.empty) + } +} diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index 944505ba1c..32cbc0ce13 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -89,6 +89,37 @@ message Expr { FromJson from_json = 66; ToCsv to_csv = 67; } + + // Optional QueryContext for error reporting (contains SQL text and position) + optional QueryContext query_context = 90; + + // Unique expression ID for context lookup during error creation + optional uint64 expr_id = 91; +} + +// QueryContext provides SQL query context for error messages. +// Mirrors Spark's SQLQueryContext for rich error reporting. +message QueryContext { + // Full SQL query text + string sql_text = 1; + + // Character offset where expression starts (0-based) + int32 start_index = 2; + + // Character offset where expression ends (0-based, inclusive) + int32 stop_index = 3; + + // Type of SQL object (e.g., "VIEW", "Project", "Filter") + optional string object_type = 4; + + // Name of object (e.g., view name, column name) + optional string object_name = 5; + + // Line number in SQL query (1-based) + int32 line = 6; + + // Column position within the line (0-based) + int32 start_position = 7; } message AggExpr { @@ -109,6 +140,12 @@ message AggExpr { Correlation correlation = 15; BloomFilterAgg bloomFilterAgg = 16; } + + // Optional QueryContext for error reporting (contains SQL text and position) + optional QueryContext query_context = 90; + + // Unique expression ID for context lookup during error creation + optional uint64 expr_id = 91; } enum StatisticsType { diff --git a/native/spark-expr/Cargo.toml b/native/spark-expr/Cargo.toml index e7c238f7eb..41f70c5027 100644 --- a/native/spark-expr/Cargo.toml +++ b/native/spark-expr/Cargo.toml @@ -33,6 +33,7 @@ datafusion = { workspace = true } chrono-tz = { workspace = true } num = { workspace = true } regex = { workspace = true } +serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" thiserror = { workspace = true } futures = { workspace = true } diff --git a/native/spark-expr/benches/aggregate.rs b/native/spark-expr/benches/aggregate.rs index 47e2cf61c3..7944d392f6 100644 --- a/native/spark-expr/benches/aggregate.rs +++ b/native/spark-expr/benches/aggregate.rs @@ -70,6 +70,9 @@ fn criterion_benchmark(c: &mut Criterion) { let comet_avg_decimal = Arc::new(AggregateUDF::new_from_impl(AvgDecimal::new( DataType::Decimal128(38, 10), DataType::Decimal128(38, 10), + datafusion_comet_spark_expr::EvalMode::Legacy, + None, + datafusion_comet_spark_expr::create_query_context_map(), ))); b.to_async(&rt).iter(|| { black_box(agg_test( @@ -97,7 +100,13 @@ fn criterion_benchmark(c: &mut Criterion) { group.bench_function("sum_decimal_comet", |b| { let comet_sum_decimal = Arc::new(AggregateUDF::new_from_impl( - SumDecimal::try_new(DataType::Decimal128(38, 10), EvalMode::Legacy).unwrap(), + SumDecimal::try_new( + DataType::Decimal128(38, 10), + EvalMode::Legacy, + None, + datafusion_comet_spark_expr::create_query_context_map(), + ) + .unwrap(), )); b.to_async(&rt).iter(|| { black_box(agg_test( diff --git a/native/spark-expr/benches/cast_from_boolean.rs b/native/spark-expr/benches/cast_from_boolean.rs index dbb986df91..04bd72dc01 100644 --- a/native/spark-expr/benches/cast_from_boolean.rs +++ b/native/spark-expr/benches/cast_from_boolean.rs @@ -27,14 +27,62 @@ fn criterion_benchmark(c: &mut Criterion) { let expr = Arc::new(Column::new("a", 0)); let boolean_batch = create_boolean_batch(); let spark_cast_options = SparkCastOptions::new(EvalMode::Legacy, "UTC", false); - let cast_to_i8 = Cast::new(expr.clone(), DataType::Int8, spark_cast_options.clone()); - let cast_to_i16 = Cast::new(expr.clone(), DataType::Int16, spark_cast_options.clone()); - let cast_to_i32 = Cast::new(expr.clone(), DataType::Int32, spark_cast_options.clone()); - let cast_to_i64 = Cast::new(expr.clone(), DataType::Int64, spark_cast_options.clone()); - let cast_to_f32 = Cast::new(expr.clone(), DataType::Float32, spark_cast_options.clone()); - let cast_to_f64 = Cast::new(expr.clone(), DataType::Float64, spark_cast_options.clone()); - let cast_to_str = Cast::new(expr.clone(), DataType::Utf8, spark_cast_options.clone()); - let cast_to_decimal = Cast::new(expr, DataType::Decimal128(10, 4), spark_cast_options); + let cast_to_i8 = Cast::new( + expr.clone(), + DataType::Int8, + spark_cast_options.clone(), + None, + None, + ); + let cast_to_i16 = Cast::new( + expr.clone(), + DataType::Int16, + spark_cast_options.clone(), + None, + None, + ); + let cast_to_i32 = Cast::new( + expr.clone(), + DataType::Int32, + spark_cast_options.clone(), + None, + None, + ); + let cast_to_i64 = Cast::new( + expr.clone(), + DataType::Int64, + spark_cast_options.clone(), + None, + None, + ); + let cast_to_f32 = Cast::new( + expr.clone(), + DataType::Float32, + spark_cast_options.clone(), + None, + None, + ); + let cast_to_f64 = Cast::new( + expr.clone(), + DataType::Float64, + spark_cast_options.clone(), + None, + None, + ); + let cast_to_str = Cast::new( + expr.clone(), + DataType::Utf8, + spark_cast_options.clone(), + None, + None, + ); + let cast_to_decimal = Cast::new( + expr, + DataType::Decimal128(10, 4), + spark_cast_options, + None, + None, + ); let mut group = c.benchmark_group("cast_bool".to_string()); group.bench_function("i8", |b| { diff --git a/native/spark-expr/benches/cast_from_string.rs b/native/spark-expr/benches/cast_from_string.rs index 9b2cb73fb4..1fbb7c535b 100644 --- a/native/spark-expr/benches/cast_from_string.rs +++ b/native/spark-expr/benches/cast_from_string.rs @@ -34,10 +34,34 @@ fn criterion_benchmark(c: &mut Criterion) { (EvalMode::Try, "try"), ] { let spark_cast_options = SparkCastOptions::new(mode, "", false); - let cast_to_i8 = Cast::new(expr.clone(), DataType::Int8, spark_cast_options.clone()); - let cast_to_i16 = Cast::new(expr.clone(), DataType::Int16, spark_cast_options.clone()); - let cast_to_i32 = Cast::new(expr.clone(), DataType::Int32, spark_cast_options.clone()); - let cast_to_i64 = Cast::new(expr.clone(), DataType::Int64, spark_cast_options); + let cast_to_i8 = Cast::new( + expr.clone(), + DataType::Int8, + spark_cast_options.clone(), + None, + None, + ); + let cast_to_i16 = Cast::new( + expr.clone(), + DataType::Int16, + spark_cast_options.clone(), + None, + None, + ); + let cast_to_i32 = Cast::new( + expr.clone(), + DataType::Int32, + spark_cast_options.clone(), + None, + None, + ); + let cast_to_i64 = Cast::new( + expr.clone(), + DataType::Int64, + spark_cast_options, + None, + None, + ); let mut group = c.benchmark_group(format!("cast_string_to_int/{}", mode_name)); group.bench_function("i8", |b| { @@ -57,8 +81,20 @@ fn criterion_benchmark(c: &mut Criterion) { // Benchmark decimal truncation (Legacy mode only) let spark_cast_options = SparkCastOptions::new(EvalMode::Legacy, "", false); - let cast_to_i32 = Cast::new(expr.clone(), DataType::Int32, spark_cast_options.clone()); - let cast_to_i64 = Cast::new(expr.clone(), DataType::Int64, spark_cast_options); + let cast_to_i32 = Cast::new( + expr.clone(), + DataType::Int32, + spark_cast_options.clone(), + None, + None, + ); + let cast_to_i64 = Cast::new( + expr.clone(), + DataType::Int64, + spark_cast_options, + None, + None, + ); let mut group = c.benchmark_group("cast_string_to_int/legacy_decimals"); group.bench_function("i32", |b| { @@ -81,6 +117,8 @@ fn criterion_benchmark(c: &mut Criterion) { expr.clone(), DataType::Decimal128(38, 10), spark_cast_options, + None, + None, ); let mut group = c.benchmark_group(format!("cast_string_to_decimal/{}", mode_name)); diff --git a/native/spark-expr/benches/cast_int_to_timestamp.rs b/native/spark-expr/benches/cast_int_to_timestamp.rs index 20143d2b0e..4479627ae6 100644 --- a/native/spark-expr/benches/cast_int_to_timestamp.rs +++ b/native/spark-expr/benches/cast_int_to_timestamp.rs @@ -35,7 +35,13 @@ fn criterion_benchmark(c: &mut Criterion) { // Int8 -> Timestamp let batch_i8 = create_int8_batch(); let expr_i8 = Arc::new(Column::new("a", 0)); - let cast_i8_to_ts = Cast::new(expr_i8, timestamp_type.clone(), spark_cast_options.clone()); + let cast_i8_to_ts = Cast::new( + expr_i8, + timestamp_type.clone(), + spark_cast_options.clone(), + None, + None, + ); group.bench_function("cast_i8_to_timestamp", |b| { b.iter(|| cast_i8_to_ts.evaluate(&batch_i8).unwrap()); }); @@ -43,7 +49,13 @@ fn criterion_benchmark(c: &mut Criterion) { // Int16 -> Timestamp let batch_i16 = create_int16_batch(); let expr_i16 = Arc::new(Column::new("a", 0)); - let cast_i16_to_ts = Cast::new(expr_i16, timestamp_type.clone(), spark_cast_options.clone()); + let cast_i16_to_ts = Cast::new( + expr_i16, + timestamp_type.clone(), + spark_cast_options.clone(), + None, + None, + ); group.bench_function("cast_i16_to_timestamp", |b| { b.iter(|| cast_i16_to_ts.evaluate(&batch_i16).unwrap()); }); @@ -51,7 +63,13 @@ fn criterion_benchmark(c: &mut Criterion) { // Int32 -> Timestamp let batch_i32 = create_int32_batch(); let expr_i32 = Arc::new(Column::new("a", 0)); - let cast_i32_to_ts = Cast::new(expr_i32, timestamp_type.clone(), spark_cast_options.clone()); + let cast_i32_to_ts = Cast::new( + expr_i32, + timestamp_type.clone(), + spark_cast_options.clone(), + None, + None, + ); group.bench_function("cast_i32_to_timestamp", |b| { b.iter(|| cast_i32_to_ts.evaluate(&batch_i32).unwrap()); }); @@ -59,7 +77,13 @@ fn criterion_benchmark(c: &mut Criterion) { // Int64 -> Timestamp let batch_i64 = create_int64_batch(); let expr_i64 = Arc::new(Column::new("a", 0)); - let cast_i64_to_ts = Cast::new(expr_i64, timestamp_type.clone(), spark_cast_options.clone()); + let cast_i64_to_ts = Cast::new( + expr_i64, + timestamp_type.clone(), + spark_cast_options.clone(), + None, + None, + ); group.bench_function("cast_i64_to_timestamp", |b| { b.iter(|| cast_i64_to_ts.evaluate(&batch_i64).unwrap()); }); diff --git a/native/spark-expr/benches/cast_numeric.rs b/native/spark-expr/benches/cast_numeric.rs index bd14e4cb24..989cbf4d2c 100644 --- a/native/spark-expr/benches/cast_numeric.rs +++ b/native/spark-expr/benches/cast_numeric.rs @@ -26,9 +26,21 @@ fn criterion_benchmark(c: &mut Criterion) { let batch = create_int32_batch(); let expr = Arc::new(Column::new("a", 0)); let spark_cast_options = SparkCastOptions::new_without_timezone(EvalMode::Legacy, false); - let cast_i32_to_i8 = Cast::new(expr.clone(), DataType::Int8, spark_cast_options.clone()); - let cast_i32_to_i16 = Cast::new(expr.clone(), DataType::Int16, spark_cast_options.clone()); - let cast_i32_to_i64 = Cast::new(expr, DataType::Int64, spark_cast_options); + let cast_i32_to_i8 = Cast::new( + expr.clone(), + DataType::Int8, + spark_cast_options.clone(), + None, + None, + ); + let cast_i32_to_i16 = Cast::new( + expr.clone(), + DataType::Int16, + spark_cast_options.clone(), + None, + None, + ); + let cast_i32_to_i64 = Cast::new(expr, DataType::Int64, spark_cast_options, None, None); let mut group = c.benchmark_group("cast_int_to_int"); group.bench_function("cast_i32_to_i8", |b| { diff --git a/native/spark-expr/src/agg_funcs/avg_decimal.rs b/native/spark-expr/src/agg_funcs/avg_decimal.rs index b3b2731d9d..773ddea050 100644 --- a/native/spark-expr/src/agg_funcs/avg_decimal.rs +++ b/native/spark-expr/src/agg_funcs/avg_decimal.rs @@ -31,6 +31,7 @@ use datafusion::physical_expr::expressions::format_state_name; use std::{any::Any, sync::Arc}; use crate::utils::{build_bool_state, is_valid_decimal_precision, unlikely}; +use crate::{decimal_sum_overflow_error, EvalMode, SparkErrorWithContext}; use arrow::array::ArrowNativeTypeOp; use arrow::datatypes::{ DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, MAX_DECIMAL128_FOR_EACH_PRECISION, @@ -55,20 +56,53 @@ fn avg_return_type(_name: &str, data_type: &DataType) -> Result { } /// AVG aggregate expression -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone)] pub struct AvgDecimal { signature: Signature, sum_data_type: DataType, result_data_type: DataType, + eval_mode: EvalMode, + expr_id: Option, + registry: Arc, +} + +// Manually implement PartialEq, Eq, and Hash excluding the registry field +impl PartialEq for AvgDecimal { + fn eq(&self, other: &Self) -> bool { + self.sum_data_type == other.sum_data_type + && self.result_data_type == other.result_data_type + && self.eval_mode == other.eval_mode + && self.expr_id == other.expr_id + } +} + +impl Eq for AvgDecimal {} + +impl std::hash::Hash for AvgDecimal { + fn hash(&self, state: &mut H) { + self.sum_data_type.hash(state); + self.result_data_type.hash(state); + self.eval_mode.hash(state); + self.expr_id.hash(state); + } } impl AvgDecimal { /// Create a new AVG aggregate function - pub fn new(result_type: DataType, sum_type: DataType) -> Self { + pub fn new( + result_type: DataType, + sum_type: DataType, + eval_mode: EvalMode, + expr_id: Option, + registry: Arc, + ) -> Self { Self { signature: Signature::user_defined(Immutable), result_data_type: result_type, sum_data_type: sum_type, + eval_mode, + expr_id, + registry, } } } @@ -87,6 +121,9 @@ impl AggregateUDFImpl for AvgDecimal { *sum_precision, *target_precision, *target_scale, + self.eval_mode, + self.expr_id, + Arc::clone(&self.registry), ))) } _ => not_impl_err!( @@ -138,6 +175,9 @@ impl AggregateUDFImpl for AvgDecimal { *target_scale, *sum_precision, *sum_scale, + self.eval_mode, + self.expr_id, + Arc::clone(&self.registry), ))) } _ => not_impl_err!( @@ -180,10 +220,21 @@ struct AvgDecimalAccumulator { sum_precision: u8, target_precision: u8, target_scale: i8, + eval_mode: EvalMode, + expr_id: Option, + registry: Arc, } impl AvgDecimalAccumulator { - pub fn new(sum_scale: i8, sum_precision: u8, target_precision: u8, target_scale: i8) -> Self { + pub fn new( + sum_scale: i8, + sum_precision: u8, + target_precision: u8, + target_scale: i8, + eval_mode: EvalMode, + expr_id: Option, + registry: Arc, + ) -> Self { Self { sum: None, count: 0, @@ -193,10 +244,27 @@ impl AvgDecimalAccumulator { sum_precision, target_precision, target_scale, + eval_mode, + expr_id, + registry, } } - fn update_single(&mut self, values: &Decimal128Array, idx: usize) { + /// Wrap a SparkError with QueryContext if expr_id is available + fn wrap_error_with_context( + &self, + error: crate::SparkError, + ) -> datafusion::common::DataFusionError { + if let Some(expr_id) = self.expr_id { + if let Some(query_ctx) = self.registry.get(expr_id) { + let wrapped = SparkErrorWithContext::with_context(error, query_ctx); + return datafusion::common::DataFusionError::External(Box::new(wrapped)); + } + } + datafusion::common::DataFusionError::from(error) + } + + fn update_single(&mut self, values: &Decimal128Array, idx: usize) -> Result<()> { let v = unsafe { values.value_unchecked(idx) }; let (new_sum, is_overflow) = match self.sum { Some(sum) => sum.overflowing_add(v), @@ -204,9 +272,10 @@ impl AvgDecimalAccumulator { }; if is_overflow || !is_valid_decimal_precision(new_sum, self.sum_precision) { - // Overflow: set buffer accumulator to null + // Overflow: set to null. Error will be thrown during evaluate in ANSI mode. + // This matches Spark's DecimalAddNoOverflowCheck behavior. self.is_not_null = false; - return; + return Ok(()); } self.sum = Some(new_sum); @@ -214,11 +283,13 @@ impl AvgDecimalAccumulator { if let Some(new_count) = self.count.checked_add(1) { self.count = new_count; } else { + // Count overflow: set to null. Error will be thrown during evaluate in ANSI mode. self.is_not_null = false; - return; + return Ok(()); } self.is_not_null = true; + Ok(()) } } @@ -240,37 +311,44 @@ impl Accumulator for AvgDecimalAccumulator { // of the computation return Ok(()); } - let values = &values[0]; let data = values.as_primitive::(); - self.is_empty = self.is_empty && values.len() == values.null_count(); - if values.null_count() == 0 { for i in 0..data.len() { - self.update_single(data, i); + self.update_single(data, i)?; } } else { for i in 0..data.len() { if data.is_null(i) { continue; } - self.update_single(data, i); + self.update_single(data, i)?; } } Ok(()) } fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + let partial_sums = states[0].as_primitive::(); + let partial_counts = states[1].as_primitive::(); + + // Update is_empty: if any partial state has data, we're not empty + if self.is_empty { + self.is_empty = partial_counts.len() == partial_counts.null_count(); + } + // counts are summed - self.count += sum(states[1].as_primitive::()).unwrap_or_default(); + self.count += sum(partial_counts).unwrap_or_default(); // sums are summed - if let Some(x) = sum(states[0].as_primitive::()) { + if let Some(x) = sum(partial_sums) { let v = self.sum.get_or_insert(0); let (result, overflowed) = v.overflowing_add(x); - if overflowed { - // Set to None if overflow happens + + if overflowed || !is_valid_decimal_precision(result, self.sum_precision) { + // Overflow during merge: set to null, error will be thrown during evaluate in ANSI mode + self.is_not_null = false; self.sum = None; } else { *v = result; @@ -280,6 +358,19 @@ impl Accumulator for AvgDecimalAccumulator { } fn evaluate(&mut self) -> Result { + // Check for overflow during sum accumulation in ANSI mode. + // This matches Spark's DecimalDivideWithOverflowCheck behavior. + if self.sum.is_none() && !self.is_empty && self.eval_mode == EvalMode::Ansi { + let error = decimal_sum_overflow_error(); + return Err(self.wrap_error_with_context(error)); + } + + // Also check if is_not_null is false (indicates overflow) + if !self.is_not_null && self.count > 0 && self.eval_mode == EvalMode::Ansi { + let error = decimal_sum_overflow_error(); + return Err(self.wrap_error_with_context(error)); + } + let scaler = 10_i128.pow(self.target_scale.saturating_sub(self.sum_scale) as u32); let target_min = MIN_DECIMAL128_FOR_EACH_PRECISION[self.target_precision as usize]; let target_max = MAX_DECIMAL128_FOR_EACH_PRECISION[self.target_precision as usize]; @@ -328,9 +419,17 @@ struct AvgDecimalGroupsAccumulator { /// This is input_precision + 10 to be consistent with Spark sum_precision: u8, sum_scale: i8, + + /// Evaluation mode for error handling + eval_mode: EvalMode, + /// Optional expression ID for query context lookup during error creation + expr_id: Option, + /// Session-scoped query context registry for error reporting + registry: Arc, } impl AvgDecimalGroupsAccumulator { + #[allow(clippy::too_many_arguments)] pub fn new( return_data_type: &DataType, sum_data_type: &DataType, @@ -338,6 +437,9 @@ impl AvgDecimalGroupsAccumulator { target_scale: i8, sum_precision: u8, sum_scale: i8, + eval_mode: EvalMode, + expr_id: Option, + registry: Arc, ) -> Self { Self { is_not_null: BooleanBufferBuilder::new(0), @@ -349,19 +451,38 @@ impl AvgDecimalGroupsAccumulator { sum_scale, counts: vec![], sums: vec![], + eval_mode, + expr_id, + registry, } } + /// Wrap a SparkError with QueryContext if expr_id is available + fn wrap_error_with_context( + &self, + error: crate::SparkError, + ) -> datafusion::common::DataFusionError { + if let Some(expr_id) = self.expr_id { + if let Some(query_ctx) = self.registry.get(expr_id) { + let wrapped = SparkErrorWithContext::with_context(error, query_ctx); + return datafusion::common::DataFusionError::External(Box::new(wrapped)); + } + } + datafusion::common::DataFusionError::from(error) + } + #[inline] - fn update_single(&mut self, group_index: usize, value: i128) { + fn update_single(&mut self, group_index: usize, value: i128) -> Result<()> { let (new_sum, is_overflow) = self.sums[group_index].overflowing_add(value); self.counts[group_index] += 1; self.sums[group_index] = new_sum; if unlikely(is_overflow || !is_valid_decimal_precision(new_sum, self.sum_precision)) { - // Overflow: set buffer accumulator to null + // Overflow: set to null. Error will be thrown during evaluate in ANSI mode. + // This matches Spark's DecimalAddNoOverflowCheck behavior. self.is_not_null.set_bit(group_index, false); } + Ok(()) } } @@ -392,14 +513,14 @@ impl GroupsAccumulator for AvgDecimalGroupsAccumulator { let iter = group_indices.iter().zip(data.iter()); if values.null_count() == 0 { for (&group_index, &value) in iter { - self.update_single(group_index, value); + self.update_single(group_index, value)?; } } else { for (idx, (&group_index, &value)) in iter.enumerate() { if values.is_null(idx) { continue; } - self.update_single(group_index, value); + self.update_single(group_index, value)?; } } Ok(()) @@ -425,13 +546,30 @@ impl GroupsAccumulator for AvgDecimalGroupsAccumulator { // update sums self.sums.resize(total_num_groups, 0); + // Ensure bit capacity BEFORE setting any bits + ensure_bit_capacity(&mut self.is_not_null, total_num_groups); + let iter2 = group_indices.iter().zip(partial_sums.values().iter()); - for (&group_index, &new_value) in iter2 { - let sum = &mut self.sums[group_index]; - *sum = sum.add_wrapping(new_value); - } + for (idx, (&group_index, &new_value)) in iter2.enumerate() { + // Check if partial sum is null (indicates overflow in that partition) + if partial_sums.is_null(idx) { + self.is_not_null.set_bit(group_index, false); + continue; + } - ensure_bit_capacity(&mut self.is_not_null, total_num_groups); + let sum = self.sums[group_index]; + let (new_sum, is_overflow) = sum.overflowing_add(new_value); + + if is_overflow || !is_valid_decimal_precision(new_sum, self.sum_precision) { + if self.eval_mode == EvalMode::Ansi { + let error = decimal_sum_overflow_error(); + return Err(self.wrap_error_with_context(error)); + } + self.is_not_null.set_bit(group_index, false); + } else { + self.sums[group_index] = new_sum; + } + } if partial_counts.null_count() != 0 { for (index, &group_index) in group_indices.iter().enumerate() { if partial_counts.is_null(index) { @@ -457,6 +595,13 @@ impl GroupsAccumulator for AvgDecimalGroupsAccumulator { let target_max = MAX_DECIMAL128_FOR_EACH_PRECISION[self.target_precision as usize]; for (is_not_null, (sum, count)) in nulls.into_iter().zip(iter) { + // Check for overflow during sum accumulation in ANSI mode. + // This matches Spark's DecimalDivideWithOverflowCheck behavior. + if !is_not_null && count > 0 && self.eval_mode == EvalMode::Ansi { + let error = decimal_sum_overflow_error(); + return Err(self.wrap_error_with_context(error)); + } + if !is_not_null || count == 0 { builder.append_null(); continue; diff --git a/native/spark-expr/src/agg_funcs/sum_decimal.rs b/native/spark-expr/src/agg_funcs/sum_decimal.rs index 50645391fd..bf5569b00b 100644 --- a/native/spark-expr/src/agg_funcs/sum_decimal.rs +++ b/native/spark-expr/src/agg_funcs/sum_decimal.rs @@ -16,7 +16,7 @@ // under the License. use crate::utils::is_valid_decimal_precision; -use crate::{arithmetic_overflow_error, EvalMode}; +use crate::{decimal_sum_overflow_error, EvalMode, SparkErrorWithContext}; use arrow::array::{ cast::AsArray, types::Decimal128Type, Array, ArrayRef, BooleanArray, Decimal128Array, }; @@ -29,7 +29,7 @@ use datafusion::logical_expr::{ }; use std::{any::Any, sync::Arc}; -#[derive(Debug, PartialEq, Eq, Hash)] +#[derive(Debug)] pub struct SumDecimal { /// Aggregate function signature signature: Signature, @@ -41,10 +41,43 @@ pub struct SumDecimal { /// Decimal scale scale: i8, eval_mode: EvalMode, + /// Optional expression ID for query context lookup during error creation + expr_id: Option, + /// Session-scoped query context registry for error reporting + registry: Arc, +} + +// Manually implement PartialEq, Eq, and Hash excluding the registry field +// since registry is only for error reporting and doesn't affect function behavior +impl PartialEq for SumDecimal { + fn eq(&self, other: &Self) -> bool { + self.precision == other.precision + && self.scale == other.scale + && self.eval_mode == other.eval_mode + && self.expr_id == other.expr_id + && self.result_type == other.result_type + } +} + +impl Eq for SumDecimal {} + +impl std::hash::Hash for SumDecimal { + fn hash(&self, state: &mut H) { + self.precision.hash(state); + self.scale.hash(state); + self.eval_mode.hash(state); + self.expr_id.hash(state); + self.result_type.hash(state); + } } impl SumDecimal { - pub fn try_new(data_type: DataType, eval_mode: EvalMode) -> DFResult { + pub fn try_new( + data_type: DataType, + eval_mode: EvalMode, + expr_id: Option, + registry: Arc, + ) -> DFResult { let (precision, scale) = match data_type { DataType::Decimal128(p, s) => (p, s), _ => { @@ -59,6 +92,8 @@ impl SumDecimal { precision, scale, eval_mode, + expr_id, + registry, }) } } @@ -73,6 +108,8 @@ impl AggregateUDFImpl for SumDecimal { self.precision, self.scale, self.eval_mode, + self.expr_id, + Arc::clone(&self.registry), ))) } @@ -110,6 +147,8 @@ impl AggregateUDFImpl for SumDecimal { self.result_type.clone(), self.precision, self.eval_mode, + self.expr_id, + Arc::clone(&self.registry), ))) } @@ -137,10 +176,18 @@ struct SumDecimalAccumulator { precision: u8, scale: i8, eval_mode: EvalMode, + expr_id: Option, + registry: Arc, } impl SumDecimalAccumulator { - fn new(precision: u8, scale: i8, eval_mode: EvalMode) -> Self { + fn new( + precision: u8, + scale: i8, + eval_mode: EvalMode, + expr_id: Option, + registry: Arc, + ) -> Self { // For decimal sum, always track is_empty regardless of eval_mode // This matches Spark's behavior where DecimalType always uses shouldTrackIsEmpty = true Self { @@ -149,6 +196,8 @@ impl SumDecimalAccumulator { precision, scale, eval_mode, + expr_id, + registry, } } @@ -164,7 +213,8 @@ impl SumDecimalAccumulator { if is_overflow || !is_valid_decimal_precision(new_sum, self.precision) { if self.eval_mode == EvalMode::Ansi { - return Err(DataFusionError::from(arithmetic_overflow_error("decimal"))); + let error = decimal_sum_overflow_error(); + return Err(self.wrap_error_with_context(error)); } self.sum = None; self.is_empty = false; @@ -175,6 +225,17 @@ impl SumDecimalAccumulator { self.is_empty = false; Ok(()) } + + /// Wrap a SparkError with QueryContext if expr_id is available + fn wrap_error_with_context(&self, error: crate::SparkError) -> DataFusionError { + if let Some(expr_id) = self.expr_id { + if let Some(query_ctx) = self.registry.get(expr_id) { + let wrapped = SparkErrorWithContext::with_context(error, query_ctx); + return DataFusionError::External(Box::new(wrapped)); + } + } + DataFusionError::from(error) + } } impl Accumulator for SumDecimalAccumulator { @@ -292,7 +353,8 @@ impl Accumulator for SumDecimalAccumulator { if is_overflow || !is_valid_decimal_precision(new_sum, self.precision) { if self.eval_mode == EvalMode::Ansi { - return Err(DataFusionError::from(arithmetic_overflow_error("decimal"))); + let error = decimal_sum_overflow_error(); + return Err(self.wrap_error_with_context(error)); } else { self.sum = None; self.is_empty = false; @@ -311,16 +373,26 @@ struct SumDecimalGroupsAccumulator { result_type: DataType, precision: u8, eval_mode: EvalMode, + expr_id: Option, + registry: Arc, } impl SumDecimalGroupsAccumulator { - fn new(result_type: DataType, precision: u8, eval_mode: EvalMode) -> Self { + fn new( + result_type: DataType, + precision: u8, + eval_mode: EvalMode, + expr_id: Option, + registry: Arc, + ) -> Self { Self { sum: Vec::new(), is_empty: Vec::new(), result_type, precision, eval_mode, + expr_id, + registry, } } @@ -330,6 +402,17 @@ impl SumDecimalGroupsAccumulator { self.is_empty.resize(total_num_groups, true); } + /// Wrap a SparkError with QueryContext if expr_id is available + fn wrap_error_with_context(&self, error: crate::SparkError) -> DataFusionError { + if let Some(expr_id) = self.expr_id { + if let Some(query_ctx) = self.registry.get(expr_id) { + let wrapped = SparkErrorWithContext::with_context(error, query_ctx); + return DataFusionError::External(Box::new(wrapped)); + } + } + DataFusionError::from(error) + } + #[inline] fn update_single(&mut self, group_index: usize, value: i128) -> DFResult<()> { // For decimal sum, always check for overflow regardless of eval_mode @@ -342,7 +425,8 @@ impl SumDecimalGroupsAccumulator { if is_overflow || !is_valid_decimal_precision(new_sum, self.precision) { if self.eval_mode == EvalMode::Ansi { - return Err(DataFusionError::from(arithmetic_overflow_error("decimal"))); + let error = decimal_sum_overflow_error(); + return Err(self.wrap_error_with_context(error)); } self.sum[group_index] = None; } else { @@ -503,7 +587,8 @@ impl GroupsAccumulator for SumDecimalGroupsAccumulator { if is_overflow || !is_valid_decimal_precision(new_sum, self.precision) { if self.eval_mode == EvalMode::Ansi { - return Err(DataFusionError::from(arithmetic_overflow_error("decimal"))); + let error = decimal_sum_overflow_error(); + return Err(self.wrap_error_with_context(error)); } else { self.sum[group_index] = None; self.is_empty[group_index] = false; @@ -542,7 +627,13 @@ mod tests { #[test] fn invalid_data_type() { - assert!(SumDecimal::try_new(DataType::Int32, EvalMode::Legacy).is_err()); + assert!(SumDecimal::try_new( + DataType::Int32, + EvalMode::Legacy, + None, + crate::create_query_context_map(), + ) + .is_err()); } #[tokio::test] @@ -566,6 +657,8 @@ mod tests { let aggregate_udf = Arc::new(AggregateUDF::new_from_impl(SumDecimal::try_new( data_type.clone(), EvalMode::Legacy, + None, + crate::create_query_context_map(), )?)); let aggr_expr = AggregateExprBuilder::new(aggregate_udf, vec![c1]) diff --git a/native/spark-expr/src/array_funcs/list_extract.rs b/native/spark-expr/src/array_funcs/list_extract.rs index b912f0c7f6..2cbd62a2c6 100644 --- a/native/spark-expr/src/array_funcs/list_extract.rs +++ b/native/spark-expr/src/array_funcs/list_extract.rs @@ -31,13 +31,17 @@ use std::{ sync::Arc, }; -#[derive(Debug, Eq)] +use crate::SparkError; + +#[derive(Debug, Clone)] pub struct ListExtract { child: Arc, ordinal: Arc, default_value: Option>, one_based: bool, fail_on_error: bool, + expr_id: Option, + registry: Arc, } impl Hash for ListExtract { @@ -47,8 +51,11 @@ impl Hash for ListExtract { self.default_value.hash(state); self.one_based.hash(state); self.fail_on_error.hash(state); + self.expr_id.hash(state); + // Exclude registry from hash } } + impl PartialEq for ListExtract { fn eq(&self, other: &Self) -> bool { self.child.eq(&other.child) @@ -56,9 +63,13 @@ impl PartialEq for ListExtract { && self.default_value.eq(&other.default_value) && self.one_based.eq(&other.one_based) && self.fail_on_error.eq(&other.fail_on_error) + && self.expr_id.eq(&other.expr_id) + // Exclude registry from equality check } } +impl Eq for ListExtract {} + impl ListExtract { pub fn new( child: Arc, @@ -66,6 +77,8 @@ impl ListExtract { default_value: Option>, one_based: bool, fail_on_error: bool, + expr_id: Option, + registry: Arc, ) -> Self { Self { child, @@ -73,6 +86,8 @@ impl ListExtract { default_value, one_based, fail_on_error, + expr_id, + registry, } } @@ -84,6 +99,17 @@ impl ListExtract { ))), } } + + /// Wrap a SparkError with QueryContext if expr_id is available + fn wrap_error_with_context(&self, error: SparkError) -> DataFusionError { + if let Some(expr_id) = self.expr_id { + if let Some(query_ctx) = self.registry.get(expr_id) { + let wrapped = crate::SparkErrorWithContext::with_context(error, query_ctx); + return DataFusionError::External(Box::new(wrapped)); + } + } + DataFusionError::from(error) + } } impl PhysicalExpr for ListExtract { @@ -127,11 +153,15 @@ impl PhysicalExpr for ListExtract { .transpose()? .unwrap_or(self.data_type(&batch.schema())?.try_into())?; - let adjust_index = if self.one_based { - one_based_index - } else { - zero_based_index - }; + // Create error wrapper closure that has access to self + let error_wrapper = |error: SparkError| self.wrap_error_with_context(error); + + let adjust_index: Box DataFusionResult>> = + if self.one_based { + Box::new(|idx, len| one_based_index(idx, len, &error_wrapper)) + } else { + Box::new(|idx, len| zero_based_index(idx, len, &error_wrapper)) + }; match child_value.data_type() { DataType::List(_) => { @@ -143,7 +173,9 @@ impl PhysicalExpr for ListExtract { index_array, &default_value, self.fail_on_error, + self.one_based, adjust_index, + &error_wrapper, ) } DataType::LargeList(_) => { @@ -155,7 +187,9 @@ impl PhysicalExpr for ListExtract { index_array, &default_value, self.fail_on_error, + self.one_based, adjust_index, + &error_wrapper, ) } data_type => Err(DataFusionError::Internal(format!( @@ -179,17 +213,21 @@ impl PhysicalExpr for ListExtract { self.default_value.clone(), self.one_based, self.fail_on_error, + self.expr_id, + Arc::clone(&self.registry), ))), _ => internal_err!("ListExtract should have exactly two children"), } } } -fn one_based_index(index: i32, len: usize) -> DataFusionResult> { +fn one_based_index( + index: i32, + len: usize, + error_wrapper: &impl Fn(SparkError) -> DataFusionError, +) -> DataFusionResult> { if index == 0 { - return Err(DataFusionError::Execution( - "Invalid index of 0 for one-based ListExtract".to_string(), - )); + return Err(error_wrapper(SparkError::InvalidIndexOfZero)); } let abs_index = index.abs().as_usize(); @@ -204,7 +242,11 @@ fn one_based_index(index: i32, len: usize) -> DataFusionResult> { } } -fn zero_based_index(index: i32, len: usize) -> DataFusionResult> { +fn zero_based_index( + index: i32, + len: usize, + _error_wrapper: &impl Fn(SparkError) -> DataFusionError, +) -> DataFusionResult> { if index < 0 { Ok(None) } else { @@ -222,7 +264,9 @@ fn list_extract( index_array: &Int32Array, default_value: &ScalarValue, fail_on_error: bool, + one_based: bool, adjust_index: impl Fn(i32, usize) -> DataFusionResult>, + error_wrapper: &impl Fn(SparkError) -> DataFusionError, ) -> DataFusionResult { let values = list_array.values(); let offsets = list_array.offsets(); @@ -242,9 +286,22 @@ fn list_extract( } else if list_array.is_null(row) { mutable.extend_nulls(1); } else if fail_on_error { - return Err(DataFusionError::Execution( - "Index out of bounds for array".to_string(), - )); + // Throw appropriate error based on whether this is element_at (one_based=true) + // or GetArrayItem (one_based=false) + let error = if one_based { + // element_at function + SparkError::InvalidElementAtIndex { + index_value: *index, + array_size: len as i32, + } + } else { + // GetArrayItem (arr[index]) + SparkError::InvalidArrayIndex { + index_value: *index, + array_size: len as i32, + } + }; + return Err(error_wrapper(error)); } else { mutable.extend(1, 0, 1); } @@ -283,8 +340,18 @@ mod test { let null_default = ScalarValue::Int32(None); - let ColumnarValue::Array(result) = - list_extract(&list, &indices, &null_default, false, zero_based_index)? + // Simple error wrapper for tests - just converts SparkError to DataFusionError + let error_wrapper = |error: SparkError| DataFusionError::from(error); + + let ColumnarValue::Array(result) = list_extract( + &list, + &indices, + &null_default, + false, + false, + |idx, len| zero_based_index(idx, len, &error_wrapper), + &error_wrapper, + )? else { unreachable!() }; @@ -296,8 +363,15 @@ mod test { let zero_default = ScalarValue::Int32(Some(0)); - let ColumnarValue::Array(result) = - list_extract(&list, &indices, &zero_default, false, zero_based_index)? + let ColumnarValue::Array(result) = list_extract( + &list, + &indices, + &zero_default, + false, + false, + |idx, len| zero_based_index(idx, len, &error_wrapper), + &error_wrapper, + )? else { unreachable!() }; diff --git a/native/spark-expr/src/conversion_funcs/boolean.rs b/native/spark-expr/src/conversion_funcs/boolean.rs index db288fa32a..3de6b47311 100644 --- a/native/spark-expr/src/conversion_funcs/boolean.rs +++ b/native/spark-expr/src/conversion_funcs/boolean.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::SparkResult; +use crate::{SparkError, SparkResult}; use arrow::array::{ArrayRef, AsArray, Decimal128Array}; use arrow::datatypes::DataType; use std::sync::Arc; @@ -40,7 +40,22 @@ pub fn cast_boolean_to_decimal( .iter() .map(|v| v.map(|b| if b { scaled_val } else { 0 })) .collect(); - Ok(Arc::new(result.with_precision_and_scale(precision, scale)?)) + + // Convert Arrow decimal overflow errors to SparkError + let decimal_array = result + .with_precision_and_scale(precision, scale) + .map_err(|e| { + if matches!(e, arrow::error::ArrowError::InvalidArgumentError(_)) + && e.to_string().contains("too large to store in a Decimal128") + { + // Use the scaled value as it's the only non-zero value that could overflow + crate::error::decimal_overflow_error(scaled_val, precision, scale) + } else { + SparkError::Arrow(Arc::new(e)) + } + })?; + + Ok(Arc::new(decimal_array)) } #[cfg(test)] diff --git a/native/spark-expr/src/conversion_funcs/cast.rs b/native/spark-expr/src/conversion_funcs/cast.rs index e83877de1b..49198e6efa 100644 --- a/native/spark-expr/src/conversion_funcs/cast.rs +++ b/native/spark-expr/src/conversion_funcs/cast.rs @@ -83,6 +83,8 @@ pub struct Cast { pub child: Arc, pub data_type: DataType, pub cast_options: SparkCastOptions, + pub expr_id: Option, + pub query_context: Option>, } impl PartialEq for Cast { @@ -549,11 +551,15 @@ impl Cast { child: Arc, data_type: DataType, cast_options: SparkCastOptions, + expr_id: Option, + query_context: Option>, ) -> Self { Self { child, data_type, cast_options, + expr_id, + query_context, } } } @@ -1256,7 +1262,22 @@ where let res = Arc::new( cast_array - .with_precision_and_scale(precision, scale)? + .with_precision_and_scale(precision, scale) + .map_err(|e| { + if matches!(e, arrow::error::ArrowError::InvalidArgumentError(_)) + && e.to_string().contains("too large to store in a Decimal128") + { + // Extract the overflowing value from the cast_array + // In practice, this should be caught above, but handle as a fallback + SparkError::NumericValueOutOfRange { + value: "overflow".to_string(), + precision, + scale, + } + } else { + SparkError::Arrow(Arc::new(e)) + } + })? .finish(), ) as ArrayRef; Ok(res) @@ -1332,7 +1353,23 @@ where } } Ok(Arc::new( - builder.with_precision_and_scale(precision, scale)?.finish(), + builder + .with_precision_and_scale(precision, scale) + .map_err(|e| { + if matches!(e, arrow::error::ArrowError::InvalidArgumentError(_)) + && e.to_string().contains("too large to store in a Decimal128") + { + // Fallback error handling - should be caught above in most cases + SparkError::NumericValueOutOfRange { + value: "overflow".to_string(), + precision, + scale, + } + } else { + SparkError::Arrow(Arc::new(e)) + } + })? + .finish(), )) } @@ -1598,7 +1635,23 @@ impl PhysicalExpr for Cast { fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult { let arg = self.child.evaluate(batch)?; - spark_cast(arg, &self.data_type, &self.cast_options) + let result = spark_cast(arg, &self.data_type, &self.cast_options); + + // If there's an error and we have query_context, wrap it + match result { + Err(DataFusionError::External(e)) if self.query_context.is_some() => { + if let Some(spark_err) = e.downcast_ref::() { + let wrapped = crate::SparkErrorWithContext::with_context( + spark_err.clone(), + Arc::clone(self.query_context.as_ref().unwrap()), + ); + Err(DataFusionError::External(Box::new(wrapped))) + } else { + Err(DataFusionError::External(e)) + } + } + other => other, + } } fn children(&self) -> Vec<&Arc> { @@ -1614,6 +1667,8 @@ impl PhysicalExpr for Cast { Arc::clone(&children[0]), self.data_type.clone(), self.cast_options.clone(), + self.expr_id, + self.query_context.clone(), ))), _ => internal_err!("Cast should have exactly one child"), } diff --git a/native/spark-expr/src/conversion_funcs/string.rs b/native/spark-expr/src/conversion_funcs/string.rs index 531d334d15..7c193716d0 100644 --- a/native/spark-expr/src/conversion_funcs/string.rs +++ b/native/spark-expr/src/conversion_funcs/string.rs @@ -324,7 +324,21 @@ fn cast_string_to_decimal128_impl( Ok(Arc::new( decimal_builder - .with_precision_and_scale(precision, scale)? + .with_precision_and_scale(precision, scale) + .map_err(|e| { + if matches!(e, arrow::error::ArrowError::InvalidArgumentError(_)) + && e.to_string().contains("too large to store in a Decimal128") + { + // Fallback error handling + SparkError::NumericValueOutOfRange { + value: "overflow".to_string(), + precision, + scale, + } + } else { + SparkError::Arrow(Arc::new(e)) + } + })? .finish(), )) } @@ -375,7 +389,21 @@ fn cast_string_to_decimal256_impl( Ok(Arc::new( decimal_builder - .with_precision_and_scale(precision, scale)? + .with_precision_and_scale(precision, scale) + .map_err(|e| { + if matches!(e, arrow::error::ArrowError::InvalidArgumentError(_)) + && e.to_string().contains("too large to store in a Decimal128") + { + // Fallback error handling + SparkError::NumericValueOutOfRange { + value: "overflow".to_string(), + precision, + scale, + } + } else { + SparkError::Arrow(Arc::new(e)) + } + })? .finish(), )) } diff --git a/native/spark-expr/src/error.rs b/native/spark-expr/src/error.rs index c39a05cd4e..ae3b5c0eda 100644 --- a/native/spark-expr/src/error.rs +++ b/native/spark-expr/src/error.rs @@ -17,11 +17,11 @@ use arrow::error::ArrowError; use datafusion::common::DataFusionError; +use std::sync::Arc; -#[derive(thiserror::Error, Debug)] +#[derive(thiserror::Error, Debug, Clone)] pub enum SparkError { - // Note that this message format is based on Spark 3.4 and is more detailed than the message - // returned by Spark 3.3 + // This list was generated from the Spark code. Many of the exceptions are not yet used by Comet #[error("[CAST_INVALID_INPUT] The value '{value}' of the type \"{from_type}\" cannot be cast to \"{to_type}\" \ because it is malformed. Correct the value as per the syntax, or change its target type. \ Use `try_cast` to tolerate malformed input and return NULL instead. If necessary \ @@ -32,7 +32,7 @@ pub enum SparkError { to_type: String, }, - #[error("[NUMERIC_VALUE_OUT_OF_RANGE] {value} cannot be represented as Decimal({precision}, {scale}). If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error, and return NULL instead.")] + #[error("[NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION] {value} cannot be represented as Decimal({precision}, {scale}). If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error, and return NULL instead.")] NumericValueOutOfRange { value: String, precision: u8, @@ -51,24 +51,646 @@ pub enum SparkError { to_type: String, }, + #[error("[CANNOT_PARSE_DECIMAL] Cannot parse decimal.")] + CannotParseDecimal, + #[error("[ARITHMETIC_OVERFLOW] {from_type} overflow. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] ArithmeticOverflow { from_type: String }, + #[error("[ARITHMETIC_OVERFLOW] Overflow in integral divide. Use `try_divide` to tolerate overflow and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + IntegralDivideOverflow, + + #[error("[ARITHMETIC_OVERFLOW] Overflow in sum of decimals. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + DecimalSumOverflow, + #[error("[DIVIDE_BY_ZERO] Division by zero. Use `try_divide` to tolerate divisor being 0 and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] DivideByZero, + #[error("[REMAINDER_BY_ZERO] Division by zero. Use `try_remainder` to tolerate divisor being 0 and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + RemainderByZero, + + #[error("[INTERVAL_DIVIDED_BY_ZERO] Divide by zero in interval arithmetic.")] + IntervalDividedByZero, + + #[error("[BINARY_ARITHMETIC_OVERFLOW] {value1} {symbol} {value2} caused overflow. Use `{function_name}` to tolerate overflow and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + BinaryArithmeticOverflow { + value1: String, + symbol: String, + value2: String, + function_name: String, + }, + + #[error("[INTERVAL_ARITHMETIC_OVERFLOW.WITH_SUGGESTION] Interval arithmetic overflow. Use `{function_name}` to tolerate overflow and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + IntervalArithmeticOverflowWithSuggestion { function_name: String }, + + #[error("[INTERVAL_ARITHMETIC_OVERFLOW.WITHOUT_SUGGESTION] Interval arithmetic overflow. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + IntervalArithmeticOverflowWithoutSuggestion, + + #[error("[DATETIME_OVERFLOW] Datetime arithmetic overflow.")] + DatetimeOverflow, + + #[error("[INVALID_ARRAY_INDEX] The index {index_value} is out of bounds. The array has {array_size} elements. Use the SQL function get() to tolerate accessing element at invalid index and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + InvalidArrayIndex { index_value: i32, array_size: i32 }, + + #[error("[INVALID_ARRAY_INDEX_IN_ELEMENT_AT] The index {index_value} is out of bounds. The array has {array_size} elements. Use try_element_at to tolerate accessing element at invalid index and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + InvalidElementAtIndex { index_value: i32, array_size: i32 }, + + #[error("[INVALID_BITMAP_POSITION] The bit position {bit_position} is out of bounds. The bitmap has {bitmap_num_bytes} bytes ({bitmap_num_bits} bits).")] + InvalidBitmapPosition { + bit_position: i64, + bitmap_num_bytes: i64, + bitmap_num_bits: i64, + }, + + #[error("[INVALID_INDEX_OF_ZERO] The index 0 is invalid. An index shall be either < 0 or > 0 (the first element has index 1).")] + InvalidIndexOfZero, + + #[error("[DUPLICATED_MAP_KEY] Cannot create map with duplicate keys: {key}.")] + DuplicatedMapKey { key: String }, + + #[error("[NULL_MAP_KEY] Cannot use null as map key.")] + NullMapKey, + + #[error("[MAP_KEY_VALUE_DIFF_SIZES] The key array and value array of a map must have the same length.")] + MapKeyValueDiffSizes, + + #[error("[EXCEED_LIMIT_LENGTH] Cannot create a map with {size} elements which exceeds the limit {max_size}.")] + ExceedMapSizeLimit { size: i32, max_size: i32 }, + + #[error("[COLLECTION_SIZE_LIMIT_EXCEEDED] Cannot create array with {num_elements} elements which exceeds the limit {max_elements}.")] + CollectionSizeLimitExceeded { + num_elements: i64, + max_elements: i64, + }, + + #[error("[NOT_NULL_ASSERT_VIOLATION] The field `{field_name}` cannot be null.")] + NotNullAssertViolation { field_name: String }, + + #[error("[VALUE_IS_NULL] The value of field `{field_name}` at row {row_index} is null.")] + ValueIsNull { field_name: String, row_index: i32 }, + + #[error("[CANNOT_PARSE_TIMESTAMP] Cannot parse timestamp: {message}. Try using `{suggested_func}` instead.")] + CannotParseTimestamp { + message: String, + suggested_func: String, + }, + + #[error("[INVALID_FRACTION_OF_SECOND] The fraction of second {value} is invalid. Valid values are in the range [0, 60]. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + InvalidFractionOfSecond { value: f64 }, + + #[error("[INVALID_UTF8_STRING] Invalid UTF-8 string: {hex_string}.")] + InvalidUtf8String { hex_string: String }, + + #[error("[UNEXPECTED_POSITIVE_VALUE] The {parameter_name} parameter must be less than or equal to 0. The actual value is {actual_value}.")] + UnexpectedPositiveValue { + parameter_name: String, + actual_value: i32, + }, + + #[error("[UNEXPECTED_NEGATIVE_VALUE] The {parameter_name} parameter must be greater than or equal to 0. The actual value is {actual_value}.")] + UnexpectedNegativeValue { + parameter_name: String, + actual_value: i32, + }, + + #[error("[INVALID_PARAMETER_VALUE] Invalid regex group index {group_index} in function `{function_name}`. Group count is {group_count}.")] + InvalidRegexGroupIndex { + function_name: String, + group_count: i32, + group_index: i32, + }, + + #[error("[DATATYPE_CANNOT_ORDER] Cannot order by type: {data_type}.")] + DatatypeCannotOrder { data_type: String }, + + #[error("[SCALAR_SUBQUERY_TOO_MANY_ROWS] Scalar subquery returned more than one row.")] + ScalarSubqueryTooManyRows, + #[error("ArrowError: {0}.")] - Arrow(ArrowError), + Arrow(Arc), #[error("InternalError: {0}.")] Internal(String), } +impl SparkError { + /// Serialize this error to JSON format for JNI transfer + pub fn to_json(&self) -> String { + let error_class = self.error_class().unwrap_or(""); + + // Create a JSON structure with errorType, errorClass, and params + match serde_json::to_string(&serde_json::json!({ + "errorType": self.error_type_name(), + "errorClass": error_class, + "params": self.params_as_json(), + })) { + Ok(json) => json, + Err(e) => { + // Fallback if serialization fails + format!( + "{{\"errorType\":\"SerializationError\",\"message\":\"{}\"}}", + e + ) + } + } + } + + /// Get the error type name for JSON serialization + fn error_type_name(&self) -> &'static str { + match self { + SparkError::CastInvalidValue { .. } => "CastInvalidValue", + SparkError::NumericValueOutOfRange { .. } => "NumericValueOutOfRange", + SparkError::NumericOutOfRange { .. } => "NumericOutOfRange", + SparkError::CastOverFlow { .. } => "CastOverFlow", + SparkError::CannotParseDecimal => "CannotParseDecimal", + SparkError::ArithmeticOverflow { .. } => "ArithmeticOverflow", + SparkError::IntegralDivideOverflow => "IntegralDivideOverflow", + SparkError::DecimalSumOverflow => "DecimalSumOverflow", + SparkError::DivideByZero => "DivideByZero", + SparkError::RemainderByZero => "RemainderByZero", + SparkError::IntervalDividedByZero => "IntervalDividedByZero", + SparkError::BinaryArithmeticOverflow { .. } => "BinaryArithmeticOverflow", + SparkError::IntervalArithmeticOverflowWithSuggestion { .. } => { + "IntervalArithmeticOverflowWithSuggestion" + } + SparkError::IntervalArithmeticOverflowWithoutSuggestion => { + "IntervalArithmeticOverflowWithoutSuggestion" + } + SparkError::DatetimeOverflow => "DatetimeOverflow", + SparkError::InvalidArrayIndex { .. } => "InvalidArrayIndex", + SparkError::InvalidElementAtIndex { .. } => "InvalidElementAtIndex", + SparkError::InvalidBitmapPosition { .. } => "InvalidBitmapPosition", + SparkError::InvalidIndexOfZero => "InvalidIndexOfZero", + SparkError::DuplicatedMapKey { .. } => "DuplicatedMapKey", + SparkError::NullMapKey => "NullMapKey", + SparkError::MapKeyValueDiffSizes => "MapKeyValueDiffSizes", + SparkError::ExceedMapSizeLimit { .. } => "ExceedMapSizeLimit", + SparkError::CollectionSizeLimitExceeded { .. } => "CollectionSizeLimitExceeded", + SparkError::NotNullAssertViolation { .. } => "NotNullAssertViolation", + SparkError::ValueIsNull { .. } => "ValueIsNull", + SparkError::CannotParseTimestamp { .. } => "CannotParseTimestamp", + SparkError::InvalidFractionOfSecond { .. } => "InvalidFractionOfSecond", + SparkError::InvalidUtf8String { .. } => "InvalidUtf8String", + SparkError::UnexpectedPositiveValue { .. } => "UnexpectedPositiveValue", + SparkError::UnexpectedNegativeValue { .. } => "UnexpectedNegativeValue", + SparkError::InvalidRegexGroupIndex { .. } => "InvalidRegexGroupIndex", + SparkError::DatatypeCannotOrder { .. } => "DatatypeCannotOrder", + SparkError::ScalarSubqueryTooManyRows => "ScalarSubqueryTooManyRows", + SparkError::Arrow(_) => "Arrow", + SparkError::Internal(_) => "Internal", + } + } + + /// Extract parameters as JSON value + fn params_as_json(&self) -> serde_json::Value { + match self { + SparkError::CastInvalidValue { + value, + from_type, + to_type, + } => { + serde_json::json!({ + "value": value, + "fromType": from_type, + "toType": to_type, + }) + } + SparkError::NumericValueOutOfRange { + value, + precision, + scale, + } => { + serde_json::json!({ + "value": value, + "precision": precision, + "scale": scale, + }) + } + SparkError::NumericOutOfRange { value } => { + serde_json::json!({ + "value": value, + }) + } + SparkError::CastOverFlow { + value, + from_type, + to_type, + } => { + serde_json::json!({ + "value": value, + "fromType": from_type, + "toType": to_type, + }) + } + SparkError::ArithmeticOverflow { from_type } => { + serde_json::json!({ + "fromType": from_type, + }) + } + SparkError::BinaryArithmeticOverflow { + value1, + symbol, + value2, + function_name, + } => { + serde_json::json!({ + "value1": value1, + "symbol": symbol, + "value2": value2, + "functionName": function_name, + }) + } + SparkError::IntervalArithmeticOverflowWithSuggestion { function_name } => { + serde_json::json!({ + "functionName": function_name, + }) + } + SparkError::InvalidArrayIndex { + index_value, + array_size, + } => { + serde_json::json!({ + "indexValue": index_value, + "arraySize": array_size, + }) + } + SparkError::InvalidElementAtIndex { + index_value, + array_size, + } => { + serde_json::json!({ + "indexValue": index_value, + "arraySize": array_size, + }) + } + SparkError::InvalidBitmapPosition { + bit_position, + bitmap_num_bytes, + bitmap_num_bits, + } => { + serde_json::json!({ + "bitPosition": bit_position, + "bitmapNumBytes": bitmap_num_bytes, + "bitmapNumBits": bitmap_num_bits, + }) + } + SparkError::DuplicatedMapKey { key } => { + serde_json::json!({ + "key": key, + }) + } + SparkError::ExceedMapSizeLimit { size, max_size } => { + serde_json::json!({ + "size": size, + "maxSize": max_size, + }) + } + SparkError::CollectionSizeLimitExceeded { + num_elements, + max_elements, + } => { + serde_json::json!({ + "numElements": num_elements, + "maxElements": max_elements, + }) + } + SparkError::NotNullAssertViolation { field_name } => { + serde_json::json!({ + "fieldName": field_name, + }) + } + SparkError::ValueIsNull { + field_name, + row_index, + } => { + serde_json::json!({ + "fieldName": field_name, + "rowIndex": row_index, + }) + } + SparkError::CannotParseTimestamp { + message, + suggested_func, + } => { + serde_json::json!({ + "message": message, + "suggestedFunc": suggested_func, + }) + } + SparkError::InvalidFractionOfSecond { value } => { + serde_json::json!({ + "value": value, + }) + } + SparkError::InvalidUtf8String { hex_string } => { + serde_json::json!({ + "hexString": hex_string, + }) + } + SparkError::UnexpectedPositiveValue { + parameter_name, + actual_value, + } => { + serde_json::json!({ + "parameterName": parameter_name, + "actualValue": actual_value, + }) + } + SparkError::UnexpectedNegativeValue { + parameter_name, + actual_value, + } => { + serde_json::json!({ + "parameterName": parameter_name, + "actualValue": actual_value, + }) + } + SparkError::InvalidRegexGroupIndex { + function_name, + group_count, + group_index, + } => { + serde_json::json!({ + "functionName": function_name, + "groupCount": group_count, + "groupIndex": group_index, + }) + } + SparkError::DatatypeCannotOrder { data_type } => { + serde_json::json!({ + "dataType": data_type, + }) + } + SparkError::Arrow(e) => { + serde_json::json!({ + "message": e.to_string(), + }) + } + SparkError::Internal(msg) => { + serde_json::json!({ + "message": msg, + }) + } + // Simple errors with no parameters + _ => serde_json::json!({}), + } + } + + /// Returns the appropriate Spark exception class for this error + pub fn exception_class(&self) -> &'static str { + match self { + // ArithmeticException + SparkError::DivideByZero + | SparkError::RemainderByZero + | SparkError::IntervalDividedByZero + | SparkError::NumericValueOutOfRange { .. } + | SparkError::NumericOutOfRange { .. } // Comet-specific extension + | SparkError::ArithmeticOverflow { .. } + | SparkError::IntegralDivideOverflow + | SparkError::DecimalSumOverflow + | SparkError::BinaryArithmeticOverflow { .. } + | SparkError::IntervalArithmeticOverflowWithSuggestion { .. } + | SparkError::IntervalArithmeticOverflowWithoutSuggestion + | SparkError::DatetimeOverflow => "org/apache/spark/SparkArithmeticException", + + // CastOverflow gets special handling with CastOverflowException + SparkError::CastOverFlow { .. } => "org/apache/spark/sql/comet/CastOverflowException", + + // NumberFormatException (for cast invalid input errors) + SparkError::CastInvalidValue { .. } => "org/apache/spark/SparkNumberFormatException", + + // ArrayIndexOutOfBoundsException + SparkError::InvalidArrayIndex { .. } + | SparkError::InvalidElementAtIndex { .. } + | SparkError::InvalidBitmapPosition { .. } + | SparkError::InvalidIndexOfZero => "org/apache/spark/SparkArrayIndexOutOfBoundsException", + + // RuntimeException + SparkError::CannotParseDecimal + | SparkError::DuplicatedMapKey { .. } + | SparkError::NullMapKey + | SparkError::MapKeyValueDiffSizes + | SparkError::ExceedMapSizeLimit { .. } + | SparkError::CollectionSizeLimitExceeded { .. } + | SparkError::NotNullAssertViolation { .. } + | SparkError::ValueIsNull { .. } // Comet-specific extension + | SparkError::UnexpectedPositiveValue { .. } + | SparkError::UnexpectedNegativeValue { .. } + | SparkError::InvalidRegexGroupIndex { .. } + | SparkError::ScalarSubqueryTooManyRows => "org/apache/spark/SparkRuntimeException", + + // DateTimeException + SparkError::CannotParseTimestamp { .. } + | SparkError::InvalidFractionOfSecond { .. } => "org/apache/spark/SparkDateTimeException", + + // IllegalArgumentException + SparkError::DatatypeCannotOrder { .. } + | SparkError::InvalidUtf8String { .. } => "org/apache/spark/SparkIllegalArgumentException", + + // Generic errors + SparkError::Arrow(_) | SparkError::Internal(_) => "org/apache/spark/SparkException", + } + } + + /// Returns the Spark error class code for this error + pub fn error_class(&self) -> Option<&'static str> { + match self { + // Cast errors + SparkError::CastInvalidValue { .. } => Some("CAST_INVALID_INPUT"), + SparkError::CastOverFlow { .. } => Some("CAST_OVERFLOW"), + SparkError::NumericValueOutOfRange { .. } => { + Some("NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION") + } + SparkError::NumericOutOfRange { .. } => Some("NUMERIC_OUT_OF_SUPPORTED_RANGE"), + SparkError::CannotParseDecimal => Some("CANNOT_PARSE_DECIMAL"), + + // Arithmetic errors + SparkError::DivideByZero => Some("DIVIDE_BY_ZERO"), + SparkError::RemainderByZero => Some("REMAINDER_BY_ZERO"), + SparkError::IntervalDividedByZero => Some("INTERVAL_DIVIDED_BY_ZERO"), + SparkError::ArithmeticOverflow { .. } => Some("ARITHMETIC_OVERFLOW"), + SparkError::IntegralDivideOverflow => Some("ARITHMETIC_OVERFLOW"), + SparkError::DecimalSumOverflow => Some("ARITHMETIC_OVERFLOW"), + SparkError::BinaryArithmeticOverflow { .. } => Some("BINARY_ARITHMETIC_OVERFLOW"), + SparkError::IntervalArithmeticOverflowWithSuggestion { .. } => { + Some("INTERVAL_ARITHMETIC_OVERFLOW") + } + SparkError::IntervalArithmeticOverflowWithoutSuggestion => { + Some("INTERVAL_ARITHMETIC_OVERFLOW") + } + SparkError::DatetimeOverflow => Some("DATETIME_OVERFLOW"), + + // Array index errors + SparkError::InvalidArrayIndex { .. } => Some("INVALID_ARRAY_INDEX"), + SparkError::InvalidElementAtIndex { .. } => Some("INVALID_ARRAY_INDEX_IN_ELEMENT_AT"), + SparkError::InvalidBitmapPosition { .. } => Some("INVALID_BITMAP_POSITION"), + SparkError::InvalidIndexOfZero => Some("INVALID_INDEX_OF_ZERO"), + + // Map/Collection errors + SparkError::DuplicatedMapKey { .. } => Some("DUPLICATED_MAP_KEY"), + SparkError::NullMapKey => Some("NULL_MAP_KEY"), + SparkError::MapKeyValueDiffSizes => Some("MAP_KEY_VALUE_DIFF_SIZES"), + SparkError::ExceedMapSizeLimit { .. } => Some("EXCEED_LIMIT_LENGTH"), + SparkError::CollectionSizeLimitExceeded { .. } => { + Some("COLLECTION_SIZE_LIMIT_EXCEEDED") + } + + // Null validation errors + SparkError::NotNullAssertViolation { .. } => Some("NOT_NULL_ASSERT_VIOLATION"), + SparkError::ValueIsNull { .. } => Some("VALUE_IS_NULL"), + + // DateTime errors + SparkError::CannotParseTimestamp { .. } => Some("CANNOT_PARSE_TIMESTAMP"), + SparkError::InvalidFractionOfSecond { .. } => Some("INVALID_FRACTION_OF_SECOND"), + + // String/UTF8 errors + SparkError::InvalidUtf8String { .. } => Some("INVALID_UTF8_STRING"), + + // Function parameter errors + SparkError::UnexpectedPositiveValue { .. } => Some("UNEXPECTED_POSITIVE_VALUE"), + SparkError::UnexpectedNegativeValue { .. } => Some("UNEXPECTED_NEGATIVE_VALUE"), + + // Regex errors + SparkError::InvalidRegexGroupIndex { .. } => Some("INVALID_PARAMETER_VALUE"), + + // Unsupported operation errors + SparkError::DatatypeCannotOrder { .. } => Some("DATATYPE_CANNOT_ORDER"), + + // Subquery errors + SparkError::ScalarSubqueryTooManyRows => Some("SCALAR_SUBQUERY_TOO_MANY_ROWS"), + + // Generic errors (no error class) + SparkError::Arrow(_) | SparkError::Internal(_) => None, + } + } +} + +/// Convert decimal overflow to SparkError::NumericValueOutOfRange. +/// +/// Creates the appropriate SparkError when a decimal value exceeds the precision limit for Decimal128 storage. +/// +/// # Arguments +/// * `value` - The i128 decimal value that overflowed +/// * `precision` - The target precision +/// * `scale` - The scale of the decimal +/// +/// # Returns +/// SparkError::NumericValueOutOfRange with the value, precision, and scale +pub fn decimal_overflow_error(value: i128, precision: u8, scale: i8) -> SparkError { + SparkError::NumericValueOutOfRange { + value: value.to_string(), + precision, + scale, + } +} + pub type SparkResult = Result; +/// Wrapper that adds QueryContext to SparkError +/// +/// This allows attaching SQL context information (query text, line/position, object name) to errors +#[derive(Debug, Clone)] +pub struct SparkErrorWithContext { + /// The underlying SparkError + pub error: SparkError, + /// Optional QueryContext for SQL location information + pub context: Option>, +} + +impl SparkErrorWithContext { + /// Create a SparkErrorWithContext without context + pub fn new(error: SparkError) -> Self { + Self { + error, + context: None, + } + } + + /// Create a SparkErrorWithContext with QueryContext + pub fn with_context(error: SparkError, context: Arc) -> Self { + Self { + error, + context: Some(context), + } + } + + /// Serialize to JSON including optional context field + /// + /// JSON structure: + /// ```json + /// { + /// "errorType": "DivideByZero", + /// "errorClass": "DIVIDE_BY_ZERO", + /// "params": {}, + /// "context": { + /// "sqlText": "SELECT a/b FROM t", + /// "startIndex": 7, + /// "stopIndex": 9, + /// "line": 1, + /// "startPosition": 7 + /// }, + /// "summary": "== SQL (line 1, position 8) ==\n..." + /// } + /// ``` + pub fn to_json(&self) -> String { + let mut json_obj = serde_json::json!({ + "errorType": self.error.error_type_name(), + "errorClass": self.error.error_class().unwrap_or(""), + "params": self.error.params_as_json(), + }); + + if let Some(ctx) = &self.context { + // Serialize context fields + json_obj["context"] = serde_json::json!({ + "sqlText": ctx.sql_text.as_str(), + "startIndex": ctx.start_index, + "stopIndex": ctx.stop_index, + "objectType": ctx.object_type, + "objectName": ctx.object_name, + "line": ctx.line, + "startPosition": ctx.start_position, + }); + + // Add formatted summary + json_obj["summary"] = serde_json::json!(ctx.format_summary()); + } + + serde_json::to_string(&json_obj).unwrap_or_else(|e| { + format!( + "{{\"errorType\":\"SerializationError\",\"message\":\"{}\"}}", + e + ) + }) + } +} + +impl std::fmt::Display for SparkErrorWithContext { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.error)?; + if let Some(ctx) = &self.context { + write!(f, "\n{}", ctx.format_summary())?; + } + Ok(()) + } +} + +impl std::error::Error for SparkErrorWithContext {} + +impl From for SparkErrorWithContext { + fn from(error: SparkError) -> Self { + SparkErrorWithContext::new(error) + } +} + +impl From for DataFusionError { + fn from(value: SparkErrorWithContext) -> Self { + DataFusionError::External(Box::new(value)) + } +} + impl From for SparkError { fn from(value: ArrowError) -> Self { - SparkError::Arrow(value) + SparkError::Arrow(Arc::new(value)) } } @@ -77,3 +699,171 @@ impl From for DataFusionError { DataFusionError::External(Box::new(value)) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_divide_by_zero_json() { + let error = SparkError::DivideByZero; + let json = error.to_json(); + + assert!(json.contains("\"errorType\":\"DivideByZero\"")); + assert!(json.contains("\"errorClass\":\"DIVIDE_BY_ZERO\"")); + + // Verify it's valid JSON + let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed["errorType"], "DivideByZero"); + assert_eq!(parsed["errorClass"], "DIVIDE_BY_ZERO"); + } + + #[test] + fn test_remainder_by_zero_json() { + let error = SparkError::RemainderByZero; + let json = error.to_json(); + + assert!(json.contains("\"errorType\":\"RemainderByZero\"")); + assert!(json.contains("\"errorClass\":\"REMAINDER_BY_ZERO\"")); + } + + #[test] + fn test_binary_overflow_json() { + let error = SparkError::BinaryArithmeticOverflow { + value1: "32767".to_string(), + symbol: "+".to_string(), + value2: "1".to_string(), + function_name: "try_add".to_string(), + }; + let json = error.to_json(); + + // Verify it's valid JSON + let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed["errorType"], "BinaryArithmeticOverflow"); + assert_eq!(parsed["errorClass"], "BINARY_ARITHMETIC_OVERFLOW"); + assert_eq!(parsed["params"]["value1"], "32767"); + assert_eq!(parsed["params"]["symbol"], "+"); + assert_eq!(parsed["params"]["value2"], "1"); + assert_eq!(parsed["params"]["functionName"], "try_add"); + } + + #[test] + fn test_invalid_array_index_json() { + let error = SparkError::InvalidArrayIndex { + index_value: 10, + array_size: 3, + }; + let json = error.to_json(); + + let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed["errorType"], "InvalidArrayIndex"); + assert_eq!(parsed["errorClass"], "INVALID_ARRAY_INDEX"); + assert_eq!(parsed["params"]["indexValue"], 10); + assert_eq!(parsed["params"]["arraySize"], 3); + } + + #[test] + fn test_numeric_value_out_of_range_json() { + let error = SparkError::NumericValueOutOfRange { + value: "999.99".to_string(), + precision: 5, + scale: 2, + }; + let json = error.to_json(); + + let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed["errorType"], "NumericValueOutOfRange"); + assert_eq!( + parsed["errorClass"], + "NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION" + ); + assert_eq!(parsed["params"]["value"], "999.99"); + assert_eq!(parsed["params"]["precision"], 5); + assert_eq!(parsed["params"]["scale"], 2); + } + + #[test] + fn test_cast_invalid_value_json() { + let error = SparkError::CastInvalidValue { + value: "abc".to_string(), + from_type: "STRING".to_string(), + to_type: "INT".to_string(), + }; + let json = error.to_json(); + + let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed["errorType"], "CastInvalidValue"); + assert_eq!(parsed["errorClass"], "CAST_INVALID_INPUT"); + assert_eq!(parsed["params"]["value"], "abc"); + assert_eq!(parsed["params"]["fromType"], "STRING"); + assert_eq!(parsed["params"]["toType"], "INT"); + } + + #[test] + fn test_duplicated_map_key_json() { + let error = SparkError::DuplicatedMapKey { + key: "duplicate_key".to_string(), + }; + let json = error.to_json(); + + let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed["errorType"], "DuplicatedMapKey"); + assert_eq!(parsed["errorClass"], "DUPLICATED_MAP_KEY"); + assert_eq!(parsed["params"]["key"], "duplicate_key"); + } + + #[test] + fn test_null_map_key_json() { + let error = SparkError::NullMapKey; + let json = error.to_json(); + + let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed["errorType"], "NullMapKey"); + assert_eq!(parsed["errorClass"], "NULL_MAP_KEY"); + // Params should be an empty object + assert_eq!(parsed["params"], serde_json::json!({})); + } + + #[test] + fn test_error_class_mapping() { + // Test that error_class() returns the correct error class + assert_eq!( + SparkError::DivideByZero.error_class(), + Some("DIVIDE_BY_ZERO") + ); + assert_eq!( + SparkError::RemainderByZero.error_class(), + Some("REMAINDER_BY_ZERO") + ); + assert_eq!( + SparkError::InvalidArrayIndex { + index_value: 0, + array_size: 0 + } + .error_class(), + Some("INVALID_ARRAY_INDEX") + ); + assert_eq!(SparkError::NullMapKey.error_class(), Some("NULL_MAP_KEY")); + } + + #[test] + fn test_exception_class_mapping() { + // Test that exception_class() returns the correct Java exception class + assert_eq!( + SparkError::DivideByZero.exception_class(), + "org/apache/spark/SparkArithmeticException" + ); + assert_eq!( + SparkError::InvalidArrayIndex { + index_value: 0, + array_size: 0 + } + .exception_class(), + "org/apache/spark/SparkArrayIndexOutOfBoundsException" + ); + assert_eq!( + SparkError::NullMapKey.exception_class(), + "org/apache/spark/SparkRuntimeException" + ); + } +} diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs index 40eb180ab8..368223c00f 100644 --- a/native/spark-expr/src/lib.rs +++ b/native/spark-expr/src/lib.rs @@ -20,6 +20,7 @@ #![deny(clippy::clone_on_ref_ptr)] mod error; +mod query_context; pub mod kernels; pub use kernels::temporal::date_trunc_dyn; @@ -75,7 +76,7 @@ pub use datetime_funcs::{ SparkDateDiff, SparkDateTrunc, SparkHour, SparkMakeDate, SparkMinute, SparkSecond, SparkUnixTimestamp, TimestampTruncExpr, }; -pub use error::{SparkError, SparkResult}; +pub use error::{SparkError, SparkErrorWithContext, SparkResult}; pub use hash_funcs::*; pub use json_funcs::{FromJson, ToJson}; pub use math_funcs::{ @@ -83,6 +84,7 @@ pub use math_funcs::{ spark_decimal_integral_div, spark_floor, spark_make_decimal, spark_round, spark_unhex, spark_unscaled_value, CheckOverflow, NegativeExpr, NormalizeNaNAndZero, }; +pub use query_context::{create_query_context_map, QueryContext, QueryContextMap}; pub use string_funcs::*; /// Spark supports three evaluation modes when evaluating expressions, which affect @@ -119,6 +121,10 @@ pub(crate) fn arithmetic_overflow_error(from_type: &str) -> SparkError { } } +pub(crate) fn decimal_sum_overflow_error() -> SparkError { + SparkError::DecimalSumOverflow +} + pub(crate) fn divide_by_zero_error() -> SparkError { SparkError::DivideByZero } diff --git a/native/spark-expr/src/math_funcs/internal/checkoverflow.rs b/native/spark-expr/src/math_funcs/internal/checkoverflow.rs index c7caab0594..a9e8f6748d 100644 --- a/native/spark-expr/src/math_funcs/internal/checkoverflow.rs +++ b/native/spark-expr/src/math_funcs/internal/checkoverflow.rs @@ -25,6 +25,8 @@ use datafusion::common::{DataFusionError, ScalarValue}; use datafusion::logical_expr::ColumnarValue; use datafusion::physical_expr::PhysicalExpr; use std::hash::Hash; + +use crate::SparkError; use std::{ any::Any, fmt::{Display, Formatter}, @@ -40,6 +42,8 @@ pub struct CheckOverflow { pub child: Arc, pub data_type: DataType, pub fail_on_error: bool, + pub expr_id: Option, + pub query_context: Option>, } impl Hash for CheckOverflow { @@ -59,11 +63,19 @@ impl PartialEq for CheckOverflow { } impl CheckOverflow { - pub fn new(child: Arc, data_type: DataType, fail_on_error: bool) -> Self { + pub fn new( + child: Arc, + data_type: DataType, + fail_on_error: bool, + expr_id: Option, + query_context: Option>, + ) -> Self { Self { child, data_type, fail_on_error, + expr_id, + query_context, } } } @@ -113,8 +125,41 @@ impl PhysicalExpr for CheckOverflow { let decimal_array = as_primitive_array::(&array); let casted_array = if self.fail_on_error { - // Returning error if overflow - decimal_array.validate_decimal_precision(*precision)?; + // Returning error if overflow - convert decimal overflow to SparkError + decimal_array + .validate_decimal_precision(*precision) + .map_err(|e| { + if matches!(e, arrow::error::ArrowError::InvalidArgumentError(_)) + && e.to_string().contains("too large to store in a Decimal128") { + // Find the first overflowing value + let overflow_value = decimal_array + .iter() + .find(|v| { + if let Some(val) = v { + arrow::array::types::Decimal128Type::validate_decimal_precision( + *val, *precision, *scale + ).is_err() + } else { + false + } + }) + .and_then(|v| v) + .unwrap_or(0); + + let spark_error = crate::error::decimal_overflow_error(overflow_value, *precision, *scale); + + // Wrap with query_context if present + if let Some(ctx) = &self.query_context { + DataFusionError::External(Box::new( + crate::SparkErrorWithContext::with_context(spark_error, Arc::clone(ctx)) + )) + } else { + DataFusionError::External(Box::new(spark_error)) + } + } else { + DataFusionError::ArrowError(Box::new(e), None) + } + })?; decimal_array } else { // Overflowing gets null value @@ -123,7 +168,33 @@ impl PhysicalExpr for CheckOverflow { let new_array = Decimal128Array::from(casted_array.into_data()) .with_precision_and_scale(*precision, *scale) - .map(|a| Arc::new(a) as ArrayRef)?; + .map(|a| Arc::new(a) as ArrayRef) + .map_err(|e| { + if matches!(e, arrow::error::ArrowError::InvalidArgumentError(_)) + && e.to_string().contains("too large to store in a Decimal128") + { + // Fallback error handling + let spark_error = SparkError::NumericValueOutOfRange { + value: "overflow".to_string(), + precision: *precision, + scale: *scale, + }; + + // Wrap with query_context if present + if let Some(ctx) = &self.query_context { + DataFusionError::External(Box::new( + crate::SparkErrorWithContext::with_context( + spark_error, + Arc::clone(ctx), + ), + )) + } else { + DataFusionError::External(Box::new(spark_error)) + } + } else { + DataFusionError::ArrowError(Box::new(e), None) + } + })?; Ok(ColumnarValue::Array(new_array)) } @@ -163,6 +234,8 @@ impl PhysicalExpr for CheckOverflow { Arc::clone(&children[0]), self.data_type.clone(), self.fail_on_error, + self.expr_id, + self.query_context.clone(), ))) } } diff --git a/native/spark-expr/src/math_funcs/modulo_expr.rs b/native/spark-expr/src/math_funcs/modulo_expr.rs index 733d653d2c..5df14548e2 100644 --- a/native/spark-expr/src/math_funcs/modulo_expr.rs +++ b/native/spark-expr/src/math_funcs/modulo_expr.rs @@ -93,11 +93,15 @@ pub fn create_modulo_expr( left, DataType::Decimal256(p1, s1), SparkCastOptions::new_without_timezone(EvalMode::Legacy, false), + None, + None, )); let right_256 = Arc::new(Cast::new( right_non_ansi_safe, DataType::Decimal256(p2, s2), SparkCastOptions::new_without_timezone(EvalMode::Legacy, false), + None, + None, )); // The UDF's return type must match what Arrow's rem function will actually return. @@ -118,6 +122,8 @@ pub fn create_modulo_expr( modulo_scalar_func, data_type, SparkCastOptions::new_without_timezone(EvalMode::Legacy, false), + None, + None, ))) } _ => create_modulo_scalar_function( diff --git a/native/spark-expr/src/query_context.rs b/native/spark-expr/src/query_context.rs new file mode 100644 index 0000000000..e6591135e0 --- /dev/null +++ b/native/spark-expr/src/query_context.rs @@ -0,0 +1,402 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Query execution context for error reporting +//! +//! This module provides QueryContext which mirrors Spark's SQLQueryContext +//! for providing SQL text, line/position information, and error location +//! pointers in exception messages. + +use serde::{Deserialize, Serialize}; +use std::sync::Arc; + +/// Based on Spark's SQLQueryContext for error reporting. +/// +/// Contains information about where an error occurred in a SQL query, +/// including the full SQL text, line/column positions, and object context. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct QueryContext { + /// Full SQL query text + #[serde(rename = "sqlText")] + pub sql_text: Arc, + + /// Start offset in SQL text (0-based, character index) + #[serde(rename = "startIndex")] + pub start_index: i32, + + /// Stop offset in SQL text (0-based, character index, inclusive) + #[serde(rename = "stopIndex")] + pub stop_index: i32, + + /// Object type (e.g., "VIEW", "Project", "Filter") + #[serde(rename = "objectType", skip_serializing_if = "Option::is_none")] + pub object_type: Option, + + /// Object name (e.g., view name, column name) + #[serde(rename = "objectName", skip_serializing_if = "Option::is_none")] + pub object_name: Option, + + /// Line number in SQL query (1-based) + pub line: i32, + + /// Column position within the line (0-based) + #[serde(rename = "startPosition")] + pub start_position: i32, +} + +impl QueryContext { + #[allow(clippy::too_many_arguments)] + pub fn new( + sql_text: String, + start_index: i32, + stop_index: i32, + object_type: Option, + object_name: Option, + line: i32, + start_position: i32, + ) -> Self { + Self { + sql_text: Arc::new(sql_text), + start_index, + stop_index, + object_type, + object_name, + line, + start_position, + } + } + + /// Convert a character index to a byte offset in the SQL text. + /// Returns None if the character index is out of range. + fn char_index_to_byte_offset(&self, char_index: usize) -> Option { + self.sql_text + .char_indices() + .nth(char_index) + .map(|(byte_offset, _)| byte_offset) + } + + /// Generate a summary string showing SQL fragment with error location. + /// (From SQLQueryContext.summary) + /// + /// Format example: + /// ```text + /// == SQL of VIEW v1 (line 1, position 8) == + /// SELECT a/b FROM t + /// ^^^ + /// ``` + pub fn format_summary(&self) -> String { + let start_char = self.start_index.max(0) as usize; + // stop_index is inclusive; fragment covers [start, stop] + let stop_char = (self.stop_index + 1).max(0) as usize; + + let fragment = match ( + self.char_index_to_byte_offset(start_char), + // stop_char may equal sql_text.chars().count() (one past the end) + self.char_index_to_byte_offset(stop_char).or_else(|| { + if stop_char == self.sql_text.chars().count() { + Some(self.sql_text.len()) + } else { + None + } + }), + ) { + (Some(start_byte), Some(stop_byte)) => &self.sql_text[start_byte..stop_byte], + _ => "", + }; + + // Build the header line + let mut summary = String::from("== SQL"); + + if let Some(obj_type) = &self.object_type { + if !obj_type.is_empty() { + summary.push_str(" of "); + summary.push_str(obj_type); + + if let Some(obj_name) = &self.object_name { + if !obj_name.is_empty() { + summary.push(' '); + summary.push_str(obj_name); + } + } + } + } + + summary.push_str(&format!( + " (line {}, position {}) ==\n", + self.line, + self.start_position + 1 // Convert 0-based to 1-based for display + )); + + // Add the SQL text with fragment highlighted + summary.push_str(&self.sql_text); + summary.push('\n'); + + // Add caret pointer + let caret_position = self.start_position.max(0) as usize; + summary.push_str(&" ".repeat(caret_position)); + // fragment.chars().count() gives the correct display width for non-ASCII + summary.push_str(&"^".repeat(fragment.chars().count().max(1))); + + summary + } + + /// Returns the SQL fragment that caused the error. + pub fn fragment(&self) -> String { + let start_char = self.start_index.max(0) as usize; + let stop_char = (self.stop_index + 1).max(0) as usize; + + match ( + self.char_index_to_byte_offset(start_char), + self.char_index_to_byte_offset(stop_char).or_else(|| { + if stop_char == self.sql_text.chars().count() { + Some(self.sql_text.len()) + } else { + None + } + }), + ) { + (Some(start_byte), Some(stop_byte)) => self.sql_text[start_byte..stop_byte].to_string(), + _ => String::new(), + } + } +} + +use std::collections::HashMap; +use std::sync::RwLock; + +/// Map that stores QueryContext information for expressions during execution. +/// +/// This map is populated during plan deserialization and accessed +/// during error creation to attach SQL context to exceptions. +#[derive(Debug)] +pub struct QueryContextMap { + /// Map from expression ID to QueryContext + contexts: RwLock>>, +} + +impl QueryContextMap { + pub fn new() -> Self { + Self { + contexts: RwLock::new(HashMap::new()), + } + } + + /// Register a QueryContext for an expression ID. + /// + /// If the expression ID already exists, it will be replaced. + /// + /// # Arguments + /// * `expr_id` - Unique expression identifier from protobuf + /// * `context` - QueryContext containing SQL text and position info + pub fn register(&self, expr_id: u64, context: QueryContext) { + let mut contexts = self.contexts.write().unwrap(); + contexts.insert(expr_id, Arc::new(context)); + } + + /// Get the QueryContext for an expression ID. + /// + /// Returns None if no context is registered for this expression. + /// + /// # Arguments + /// * `expr_id` - Expression identifier to look up + pub fn get(&self, expr_id: u64) -> Option> { + let contexts = self.contexts.read().unwrap(); + contexts.get(&expr_id).cloned() + } + + /// Clear all registered contexts. + /// + /// This is typically called after plan execution completes to free memory. + pub fn clear(&self) { + let mut contexts = self.contexts.write().unwrap(); + contexts.clear(); + } + + /// Return the number of registered contexts (for debugging/testing) + pub fn len(&self) -> usize { + let contexts = self.contexts.read().unwrap(); + contexts.len() + } + + /// Check if the map is empty + pub fn is_empty(&self) -> bool { + self.len() == 0 + } +} + +impl Default for QueryContextMap { + fn default() -> Self { + Self::new() + } +} + +/// Create a new session-scoped QueryContextMap. +/// +/// This should be called once per SessionContext during plan creation +/// and passed to expressions that need query context for error reporting. +pub fn create_query_context_map() -> Arc { + Arc::new(QueryContextMap::new()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_query_context_creation() { + let ctx = QueryContext::new( + "SELECT a/b FROM t".to_string(), + 7, + 9, + Some("Divide".to_string()), + Some("a/b".to_string()), + 1, + 7, + ); + + assert_eq!(*ctx.sql_text, "SELECT a/b FROM t"); + assert_eq!(ctx.start_index, 7); + assert_eq!(ctx.stop_index, 9); + assert_eq!(ctx.object_type, Some("Divide".to_string())); + assert_eq!(ctx.object_name, Some("a/b".to_string())); + assert_eq!(ctx.line, 1); + assert_eq!(ctx.start_position, 7); + } + + #[test] + fn test_query_context_serialization() { + let ctx = QueryContext::new( + "SELECT a/b FROM t".to_string(), + 7, + 9, + Some("Divide".to_string()), + Some("a/b".to_string()), + 1, + 7, + ); + + let json = serde_json::to_string(&ctx).unwrap(); + let deserialized: QueryContext = serde_json::from_str(&json).unwrap(); + + assert_eq!(ctx, deserialized); + } + + #[test] + fn test_format_summary() { + let ctx = QueryContext::new( + "SELECT a/b FROM t".to_string(), + 7, + 9, + Some("VIEW".to_string()), + Some("v1".to_string()), + 1, + 7, + ); + + let summary = ctx.format_summary(); + + assert!(summary.contains("== SQL of VIEW v1 (line 1, position 8) ==")); + assert!(summary.contains("SELECT a/b FROM t")); + assert!(summary.contains("^^^")); // Three carets for "a/b" + } + + #[test] + fn test_format_summary_without_object() { + let ctx = QueryContext::new("SELECT a/b FROM t".to_string(), 7, 9, None, None, 1, 7); + + let summary = ctx.format_summary(); + + assert!(summary.contains("== SQL (line 1, position 8) ==")); + assert!(summary.contains("SELECT a/b FROM t")); + } + + #[test] + fn test_fragment() { + let ctx = QueryContext::new("SELECT a/b FROM t".to_string(), 7, 9, None, None, 1, 7); + + assert_eq!(ctx.fragment(), "a/b"); + } + + #[test] + fn test_arc_string_sharing() { + let ctx1 = QueryContext::new("SELECT a/b FROM t".to_string(), 7, 9, None, None, 1, 7); + + let ctx2 = ctx1.clone(); + + // Arc should share the same allocation + assert!(Arc::ptr_eq(&ctx1.sql_text, &ctx2.sql_text)); + } + + #[test] + fn test_json_with_optional_fields() { + let ctx = QueryContext::new("SELECT a/b FROM t".to_string(), 7, 9, None, None, 1, 7); + + let json = serde_json::to_string(&ctx).unwrap(); + + // Should not serialize objectType and objectName when None + assert!(!json.contains("objectType")); + assert!(!json.contains("objectName")); + } + + #[test] + fn test_map_register_and_get() { + let map = QueryContextMap::new(); + + let ctx = QueryContext::new("SELECT a/b FROM t".to_string(), 7, 9, None, None, 1, 7); + + map.register(1, ctx.clone()); + + let retrieved = map.get(1).unwrap(); + assert_eq!(*retrieved.sql_text, "SELECT a/b FROM t"); + assert_eq!(retrieved.start_index, 7); + } + + #[test] + fn test_map_get_nonexistent() { + let map = QueryContextMap::new(); + assert!(map.get(999).is_none()); + } + + #[test] + fn test_map_clear() { + let map = QueryContextMap::new(); + + let ctx = QueryContext::new("SELECT a/b FROM t".to_string(), 7, 9, None, None, 1, 7); + + map.register(1, ctx); + assert_eq!(map.len(), 1); + + map.clear(); + assert_eq!(map.len(), 0); + assert!(map.is_empty()); + } + + // Verify that fragment() and format_summary() correctly handle SQL text that + // contains multi-byte characters + + #[test] + fn test_fragment_non_ascii_accented() { + // "é" is a 2-byte UTF-8 sequence (U+00E9). + // SQL: "SELECT café FROM t" + // 0123456789... + // char indices: c=7, a=8, f=9, é=10, ' '=11 ... FROM = 12.. + // start_index=7, stop_index=10 should yield "café" + let sql = "SELECT café FROM t".to_string(); + let ctx = QueryContext::new(sql, 7, 10, None, None, 1, 7); + assert_eq!(ctx.fragment(), "café"); + } +} diff --git a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala index f17d8f4f72..28c1645718 100644 --- a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala +++ b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala @@ -36,6 +36,7 @@ import org.apache.spark.util.SerializableConfiguration import org.apache.comet.CometConf._ import org.apache.comet.Tracing.withTrace +import org.apache.comet.exceptions.CometQueryExecutionException import org.apache.comet.parquet.CometFileKeyUnwrapper import org.apache.comet.serde.Config.ConfigMap import org.apache.comet.vector.NativeUtil @@ -151,6 +152,11 @@ class CometExecIterator( }) }) } catch { + // Handle CometQueryExecutionException with JSON payload first + case e: CometQueryExecutionException => + logError(s"Native execution for task $taskAttemptId failed", e) + throw SparkErrorConverter.convertToSparkException(e) + case e: CometNativeException => // it is generally considered bad practice to log and then rethrow an // exception, but it really helps debugging to be able to see which task diff --git a/spark/src/main/scala/org/apache/comet/SparkErrorConverter.scala b/spark/src/main/scala/org/apache/comet/SparkErrorConverter.scala new file mode 100644 index 0000000000..a8dea4cf46 --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/SparkErrorConverter.scala @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet + +import org.json4s._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.{QueryContext, SparkException} +import org.apache.spark.sql.catalyst.trees.SQLQueryContext +import org.apache.spark.sql.comet.shims.ShimSparkErrorConverter + +import com.fasterxml.jackson.core.JsonParseException + +import org.apache.comet.exceptions.CometQueryExecutionException + +/** + * Converts CometQueryExecutionException from native code (with JSON payload) to appropriate Spark + * QueryExecutionErrors.* exceptions + * + * Parses the JSON-encoded error information from native execution and delegates to the + * version-specific ShimSparkErrorConverter trait for conversion to proper Spark exception types. + * + * The ShimSparkErrorConverter handles all error cases using the correct QueryExecutionErrors API + * for each Spark version. + */ +object SparkErrorConverter extends ShimSparkErrorConverter { + + implicit val formats: DefaultFormats.type = DefaultFormats + + case class QueryContextJson( + sqlText: String, + startIndex: Int, + stopIndex: Int, + objectType: Option[String], + objectName: Option[String], + line: Int, + startPosition: Int) + + case class ErrorJson( + errorType: String, + errorClass: Option[String], + params: Option[Map[String, Any]], + context: Option[QueryContextJson], + summary: Option[String]) + + /** + * Parse JSON from exception and convert to appropriate Spark exception. + * + * @param e + * the CometQueryExecutionException with JSON message + * @return + * the corresponding Spark exception, or the original exception if parsing fails + */ + def convertToSparkException(e: CometQueryExecutionException): Throwable = { + try { + if (!e.isJsonMessage()) { + // Not JSON, return original exception + return e + } + } catch { + // Only catch JSON parsing/mapping exceptions - let conversion exceptions propagate + case _: MappingException | _: JsonParseException => + return e + } + + val json = parse(e.getMessage) + val errorJson = json.extract[ErrorJson] + val params = errorJson.params.getOrElse(Map.empty) + val errorClass = errorJson.errorClass.getOrElse("UNKNOWN_ERROR_TEMP_COMET") + + // Build Spark SQLQueryContext if context is present (Not all errors carry the query context) + val sparkContext: Array[QueryContext] = errorJson.context match { + case Some(ctx) => + Array( + SQLQueryContext( + sqlText = Some(ctx.sqlText), + line = Some(ctx.line), + startPosition = Some(ctx.startPosition), + originStartIndex = Some(ctx.startIndex), + originStopIndex = Some(ctx.stopIndex), + originObjectType = ctx.objectType, + originObjectName = ctx.objectName)) + case None => Array.empty[QueryContext] // No context + } + + val summary: String = errorJson.summary.orNull + + // Delegate to version-specific shim - let conversion exceptions propagate + val optEx = convertErrorType(errorJson.errorType, errorClass, params, sparkContext, summary) + optEx match { + case Some(exception) => + // successfully converted - return the proper typed exception + exception + + case None => + // Unknown error type - fallback to generic SparkException + new SparkException( + errorClass = errorClass, + messageParameters = paramsToStringMap(params), + cause = null) + } + } + + /** + * Convert parameter map to string-keyed map for SparkException. + */ + private[comet] def paramsToStringMap(params: Map[String, Any]): Map[String, String] = { + params.map { case (k, v) => (k, v.toString) } + } +} diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 9d13ccd9ed..e9fb200144 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -19,6 +19,8 @@ package org.apache.comet.serde +import java.util.concurrent.atomic.AtomicLong + import scala.jdk.CollectionConverters._ import org.apache.spark.internal.Logging @@ -267,6 +269,68 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[VariancePop] -> CometVariancePop, classOf[VarianceSamp] -> CometVarianceSamp) + // A unique id for each expression. ~used to look up QueryContext during error creation. + private val exprIdCounter = new AtomicLong(0) + + private def nextExprId(): Long = exprIdCounter.incrementAndGet() + + /** + * Extract SQL context information (query text, line/position, object name) from the + * expression's origin. + * + * @param expr + * The Spark expression to extract context from + * @return + * Some(QueryContext) if origin is present, None otherwise + */ + private def extractQueryContext(expr: Expression): Option[ExprOuterClass.QueryContext] = { + val contexts = expr.origin.getQueryContext + if (contexts != null && contexts.length > 0) { + try { + val ctx = contexts(0) + // Check if this is a SQLQueryContext with additional fields + ctx match { + case sqlCtx: org.apache.spark.sql.catalyst.trees.SQLQueryContext => + val builder = ExprOuterClass.QueryContext + .newBuilder() + .setSqlText(sqlCtx.sqlText.getOrElse("")) + .setStartIndex(sqlCtx.originStartIndex.getOrElse(ctx.startIndex)) + .setStopIndex(sqlCtx.originStopIndex.getOrElse(ctx.stopIndex)) + .setLine(sqlCtx.line.getOrElse(0)) + .setStartPosition(sqlCtx.startPosition.getOrElse(0)) + + // Add optional fields if present + sqlCtx.originObjectType.foreach(builder.setObjectType) + sqlCtx.originObjectName.foreach(builder.setObjectName) + + Some(builder.build()) + case _ => + // Fallback: use only QueryContext interface methods + val builder = ExprOuterClass.QueryContext + .newBuilder() + .setSqlText(ctx.fragment()) + .setStartIndex(ctx.startIndex()) + .setStopIndex(ctx.stopIndex()) + .setLine(0) + .setStartPosition(0) + + if (ctx.objectType() != null && !ctx.objectType().isEmpty) { + builder.setObjectType(ctx.objectType()) + } + if (ctx.objectName() != null && !ctx.objectName().isEmpty) { + builder.setObjectName(ctx.objectName()) + } + + Some(builder.build()) + } + } catch { + case _: Exception => None + } + } else { + None + } + } + def supportedDataType(dt: DataType, allowComplex: Boolean = false): Boolean = dt match { case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType | _: DoubleType | _: StringType | _: BinaryType | _: TimestampType | _: TimestampNTZType | @@ -403,7 +467,7 @@ object QueryPlanSerde extends Logging with CometExprShim { val fn = aggExpr.aggregateFunction val cometExpr = aggrSerdeMap.get(fn.getClass) - cometExpr match { + val protoAggExprOpt = cometExpr match { case Some(handler) => val aggHandler = handler.asInstanceOf[CometAggregateExpressionSerde[AggregateFunction]] val exprConfName = aggHandler.getExprConfigName(fn) @@ -450,6 +514,16 @@ object QueryPlanSerde extends Logging with CometExprShim { fn.children: _*) None } + + // Attach QueryContext and expr_id to the aggregate expression + protoAggExprOpt.map { protoAggExpr => + val builder = protoAggExpr.toBuilder + builder.setExprId(nextExprId()) + extractQueryContext(fn).foreach { ctx => + builder.setQueryContext(ctx) + } + builder.build() + } } def evalModeToProto(evalMode: CometEvalMode.Value): ExprOuterClass.EvalMode = { @@ -550,22 +624,32 @@ object QueryPlanSerde extends Logging with CometExprShim { } } - versionSpecificExprToProtoInternal(expr, inputs, binding).orElse(expr match { + versionSpecificExprToProtoInternal(expr, inputs, binding) + .orElse(expr match { - case UnaryExpression(child) if expr.prettyName == "promote_precision" => - // `UnaryExpression` includes `PromotePrecision` for Spark 3.3 - // `PromotePrecision` is just a wrapper, don't need to serialize it. - exprToProtoInternal(child, inputs, binding) + case UnaryExpression(child) if expr.prettyName == "promote_precision" => + // `UnaryExpression` includes `PromotePrecision` for Spark 3.3 + // `PromotePrecision` is just a wrapper, don't need to serialize it. + exprToProtoInternal(child, inputs, binding) - case expr => - QueryPlanSerde.exprSerdeMap.get(expr.getClass) match { - case Some(handler) => - convert(expr, handler.asInstanceOf[CometExpressionSerde[Expression]]) - case _ => - withInfo(expr, s"${expr.prettyName} is not supported", expr.children: _*) - None + case expr => + QueryPlanSerde.exprSerdeMap.get(expr.getClass) match { + case Some(handler) => + convert(expr, handler.asInstanceOf[CometExpressionSerde[Expression]]) + case _ => + withInfo(expr, s"${expr.prettyName} is not supported", expr.children: _*) + None + } + }) + .map { protoExpr => + // Attach QueryContext and expr_id to the expression + val builder = protoExpr.toBuilder + builder.setExprId(nextExprId()) + extractQueryContext(expr).foreach { ctx => + builder.setQueryContext(ctx) } - }) + builder.build() + } } /** diff --git a/spark/src/main/spark-3.4/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala b/spark/src/main/spark-3.4/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala new file mode 100644 index 0000000000..83f3a7f12a --- /dev/null +++ b/spark/src/main/spark-3.4/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala @@ -0,0 +1,271 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.shims + +import org.apache.spark.{QueryContext, SparkException} +import org.apache.spark.sql.catalyst.trees.SQLQueryContext +import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +/** + * Spark 3.4 implementation for converting error types to proper Spark exceptions. + * + * Handles all error cases using the Spark 3.4 QueryExecutionErrors API. Differs from the 3.5 + * version in 4 places where the API changed between 3.4 and 3.5. + */ +trait ShimSparkErrorConverter { + + private def sqlCtx(context: Array[QueryContext]): SQLQueryContext = + context.headOption.map(_.asInstanceOf[SQLQueryContext]).getOrElse(null) + + def convertErrorType( + errorType: String, + errorClass: String, + params: Map[String, Any], + context: Array[QueryContext], + summary: String): Option[Throwable] = { + val _ = (errorClass, summary) + + errorType match { + + case "DivideByZero" => + Some(QueryExecutionErrors.divideByZeroError(sqlCtx(context))) + + case "RemainderByZero" => + Some( + new SparkException( + errorClass = "REMAINDER_BY_ZERO", + messageParameters = params.map { case (k, v) => (k, v.toString) }, + cause = null)) + + case "IntervalDividedByZero" => + Some(QueryExecutionErrors.intervalDividedByZeroError(sqlCtx(context))) + + case "BinaryArithmeticOverflow" => + // Spark 3.x does not take functionName parameter + Some( + QueryExecutionErrors.binaryArithmeticCauseOverflowError( + params("value1").toString.toShort, + params("symbol").toString, + params("value2").toString.toShort)) + + case "ArithmeticOverflow" => + val fromType = params("fromType").toString + Some(QueryExecutionErrors.arithmeticOverflowError(fromType + " overflow", "")) + + case "IntegralDivideOverflow" => + Some(QueryExecutionErrors.overflowInIntegralDivideError(sqlCtx(context))) + + case "DecimalSumOverflow" => + // Spark 3.x takes SQLQueryContext, not QueryContext + Some(QueryExecutionErrors.overflowInSumOfDecimalError(sqlCtx(context))) + + case "NumericValueOutOfRange" => + val decimal = Decimal(params("value").toString) + Some( + QueryExecutionErrors.cannotChangeDecimalPrecisionError( + decimal, + params("precision").toString.toInt, + params("scale").toString.toInt, + sqlCtx(context))) + + case "DatetimeOverflow" => + Some( + new SparkException( + errorClass = "DATETIME_OVERFLOW", + messageParameters = params.map { case (k, v) => (k, v.toString) }, + cause = null)) + + case "InvalidArrayIndex" => + Some( + QueryExecutionErrors.invalidArrayIndexError( + params("indexValue").toString.toInt, + params("arraySize").toString.toInt, + sqlCtx(context))) + + case "InvalidElementAtIndex" => + Some( + QueryExecutionErrors.invalidElementAtIndexError( + params("indexValue").toString.toInt, + params("arraySize").toString.toInt, + sqlCtx(context))) + + case "InvalidIndexOfZero" => + Some(QueryExecutionErrors.invalidIndexOfZeroError(sqlCtx(context))) + + case "InvalidBitmapPosition" => + // invalidBitmapPositionError does not exist in Spark 3.4 + Some( + new SparkException( + errorClass = "INVALID_BITMAP_POSITION", + messageParameters = params.map { case (k, v) => (k, v.toString) }, + cause = null)) + + case "DuplicatedMapKey" => + Some(QueryExecutionErrors.duplicateMapKeyFoundError(params("key"))) + + case "NullMapKey" => + Some(QueryExecutionErrors.nullAsMapKeyNotAllowedError()) + + case "MapKeyValueDiffSizes" => + Some(QueryExecutionErrors.mapDataKeyArrayLengthDiffersFromValueArrayLengthError()) + + case "ExceedMapSizeLimit" => + Some(QueryExecutionErrors.exceedMapSizeLimitError(params("size").toString.toInt)) + + case "CollectionSizeLimitExceeded" => + // createArrayWithElementsExceedLimitError takes (count: Any) in Spark 3.4 + Some( + QueryExecutionErrors.createArrayWithElementsExceedLimitError( + params("numElements").toString.toLong)) + + case "NotNullAssertViolation" => + Some( + QueryExecutionErrors.foundNullValueForNotNullableFieldError( + params("fieldName").toString)) + + case "ValueIsNull" => + Some( + QueryExecutionErrors.fieldCannotBeNullError( + params.getOrElse("rowIndex", 0).toString.toInt, + params("fieldName").toString)) + + case "CannotParseTimestamp" => + Some( + QueryExecutionErrors.ansiDateTimeParseError(new Exception(params("message").toString))) + + case "InvalidFractionOfSecond" => + Some(QueryExecutionErrors.invalidFractionOfSecondError()) + + case "CastInvalidValue" => + val str = UTF8String.fromString(params("value").toString) + val targetType = getDataType(params("toType").toString) + Some( + QueryExecutionErrors + .invalidInputInCastToNumberError(targetType, str, sqlCtx(context))) + + case "CastOverFlow" => + val fromType = getDataType(params("fromType").toString) + val toType = getDataType(params("toType").toString) + val valueStr = params("value").toString + + val typedValue: Any = fromType match { + case _: DecimalType => + val cleanStr = if (valueStr.endsWith("BD")) valueStr.dropRight(2) else valueStr + Decimal(cleanStr) + case ByteType => + val cleanStr = if (valueStr.endsWith("T")) valueStr.dropRight(1) else valueStr + cleanStr.toByte + case ShortType => + val cleanStr = if (valueStr.endsWith("S")) valueStr.dropRight(1) else valueStr + cleanStr.toShort + case IntegerType => valueStr.toInt + case LongType => + val cleanStr = if (valueStr.endsWith("L")) valueStr.dropRight(1) else valueStr + cleanStr.toLong + case FloatType => valueStr.toFloat + case DoubleType => valueStr.toDouble + case StringType => UTF8String.fromString(valueStr) + case _ => valueStr + } + + Some(QueryExecutionErrors.castingCauseOverflowError(typedValue, fromType, toType)) + + case "CannotParseDecimal" => + Some(QueryExecutionErrors.cannotParseDecimalError()) + + case "InvalidUtf8String" => + // invalidUTF8StringError does not exist in Spark 3.x; use generic fallback + Some( + new SparkException( + errorClass = "INVALID_UTF8_STRING", + messageParameters = params.map { case (k, v) => (k, v.toString) }, + cause = null)) + + case "UnexpectedPositiveValue" => + Some( + QueryExecutionErrors.unexpectedValueForStartInFunctionError( + params("parameterName").toString)) + + case "UnexpectedNegativeValue" => + Some( + QueryExecutionErrors.unexpectedValueForLengthInFunctionError( + params("parameterName").toString)) + + case "InvalidRegexGroupIndex" => + // invalidRegexGroupIndexError does not exist in Spark 3.4 + Some( + new SparkException( + errorClass = "INVALID_REGEX_GROUP_INDEX", + messageParameters = params.map { case (k, v) => (k, v.toString) }, + cause = null)) + + case "DatatypeCannotOrder" => + // orderedOperationUnsupportedByDataTypeError takes DataType in Spark 3.4, not String + Some( + QueryExecutionErrors.orderedOperationUnsupportedByDataTypeError( + getDataType(params("dataType").toString))) + + case "ScalarSubqueryTooManyRows" => + // multipleRowScalarSubqueryError was renamed to multipleRowSubqueryError in Spark 3.x + Some(QueryExecutionErrors.multipleRowSubqueryError(sqlCtx(context))) + + case "IntervalArithmeticOverflowWithSuggestion" => + // Spark 3.x uses a single intervalArithmeticOverflowError method + Some( + QueryExecutionErrors.intervalArithmeticOverflowError( + "Interval arithmetic overflow", + params.get("functionName").map(_.toString).getOrElse(""), + sqlCtx(context))) + + case "IntervalArithmeticOverflowWithoutSuggestion" => + Some( + QueryExecutionErrors + .intervalArithmeticOverflowError("Interval arithmetic overflow", "", sqlCtx(context))) + + case _ => + None + } + } + + private def getDataType(typeName: String): DataType = { + typeName.toUpperCase match { + case "BYTE" | "TINYINT" => ByteType + case "SHORT" | "SMALLINT" => ShortType + case "INT" | "INTEGER" => IntegerType + case "LONG" | "BIGINT" => LongType + case "FLOAT" | "REAL" => FloatType + case "DOUBLE" => DoubleType + case "DECIMAL" => DecimalType.SYSTEM_DEFAULT + case "STRING" | "VARCHAR" => StringType + case "BINARY" => BinaryType + case "BOOLEAN" => BooleanType + case "DATE" => DateType + case "TIMESTAMP" => TimestampType + case _ => + try { + DataType.fromDDL(typeName) + } catch { + case _: Exception => StringType + } + } + } +} diff --git a/spark/src/main/spark-3.5/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala b/spark/src/main/spark-3.5/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala new file mode 100644 index 0000000000..44c34c1185 --- /dev/null +++ b/spark/src/main/spark-3.5/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala @@ -0,0 +1,267 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.shims + +import org.apache.spark.{QueryContext, SparkException} +import org.apache.spark.sql.catalyst.trees.SQLQueryContext +import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +/** + * Spark 3.5 implementation for converting error types to proper Spark exceptions. + * + * Handles all error cases using the Spark 3.5 QueryExecutionErrors API. The 4 cases with API + * differences from Spark 4.0 are handled with Spark 3.x-specific calls. + */ +trait ShimSparkErrorConverter { + + private def sqlCtx(context: Array[QueryContext]): SQLQueryContext = + context.headOption.map(_.asInstanceOf[SQLQueryContext]).getOrElse(null) + + def convertErrorType( + errorType: String, + errorClass: String, + params: Map[String, Any], + context: Array[QueryContext], + summary: String): Option[Throwable] = { + val _ = (errorClass, summary) + + errorType match { + + case "DivideByZero" => + Some(QueryExecutionErrors.divideByZeroError(sqlCtx(context))) + + case "RemainderByZero" => + Some( + new SparkException( + errorClass = "REMAINDER_BY_ZERO", + messageParameters = params.map { case (k, v) => (k, v.toString) }, + cause = null)) + + case "IntervalDividedByZero" => + Some(QueryExecutionErrors.intervalDividedByZeroError(sqlCtx(context))) + + case "BinaryArithmeticOverflow" => + // Spark 3.x does not take functionName parameter + Some( + QueryExecutionErrors.binaryArithmeticCauseOverflowError( + params("value1").toString.toShort, + params("symbol").toString, + params("value2").toString.toShort)) + + case "ArithmeticOverflow" => + val fromType = params("fromType").toString + Some(QueryExecutionErrors.arithmeticOverflowError(fromType + " overflow", "")) + + case "IntegralDivideOverflow" => + Some(QueryExecutionErrors.overflowInIntegralDivideError(sqlCtx(context))) + + case "DecimalSumOverflow" => + // Spark 3.x takes SQLQueryContext, not QueryContext + Some(QueryExecutionErrors.overflowInSumOfDecimalError(sqlCtx(context))) + + case "NumericValueOutOfRange" => + val decimal = Decimal(params("value").toString) + Some( + QueryExecutionErrors.cannotChangeDecimalPrecisionError( + decimal, + params("precision").toString.toInt, + params("scale").toString.toInt, + sqlCtx(context))) + + case "DatetimeOverflow" => + Some( + new SparkException( + errorClass = "DATETIME_OVERFLOW", + messageParameters = params.map { case (k, v) => (k, v.toString) }, + cause = null)) + + case "InvalidArrayIndex" => + Some( + QueryExecutionErrors.invalidArrayIndexError( + params("indexValue").toString.toInt, + params("arraySize").toString.toInt, + sqlCtx(context))) + + case "InvalidElementAtIndex" => + Some( + QueryExecutionErrors.invalidElementAtIndexError( + params("indexValue").toString.toInt, + params("arraySize").toString.toInt, + sqlCtx(context))) + + case "InvalidIndexOfZero" => + Some(QueryExecutionErrors.invalidIndexOfZeroError(sqlCtx(context))) + + case "InvalidBitmapPosition" => + Some( + QueryExecutionErrors.invalidBitmapPositionError( + params("bitPosition").toString.toLong, + params("bitmapNumBytes").toString.toLong)) + + case "DuplicatedMapKey" => + Some(QueryExecutionErrors.duplicateMapKeyFoundError(params("key"))) + + case "NullMapKey" => + Some(QueryExecutionErrors.nullAsMapKeyNotAllowedError()) + + case "MapKeyValueDiffSizes" => + Some(QueryExecutionErrors.mapDataKeyArrayLengthDiffersFromValueArrayLengthError()) + + case "ExceedMapSizeLimit" => + Some(QueryExecutionErrors.exceedMapSizeLimitError(params("size").toString.toInt)) + + case "CollectionSizeLimitExceeded" => + Some( + QueryExecutionErrors.createArrayWithElementsExceedLimitError( + "array", + params("numElements").toString.toLong)) + + case "NotNullAssertViolation" => + Some( + QueryExecutionErrors.foundNullValueForNotNullableFieldError( + params("fieldName").toString)) + + case "ValueIsNull" => + Some( + QueryExecutionErrors.fieldCannotBeNullError( + params.getOrElse("rowIndex", 0).toString.toInt, + params("fieldName").toString)) + + case "CannotParseTimestamp" => + Some( + QueryExecutionErrors.ansiDateTimeParseError(new Exception(params("message").toString))) + + case "InvalidFractionOfSecond" => + Some(QueryExecutionErrors.invalidFractionOfSecondError()) + + case "CastInvalidValue" => + val str = UTF8String.fromString(params("value").toString) + val targetType = getDataType(params("toType").toString) + Some( + QueryExecutionErrors + .invalidInputInCastToNumberError(targetType, str, sqlCtx(context))) + + case "CastOverFlow" => + val fromType = getDataType(params("fromType").toString) + val toType = getDataType(params("toType").toString) + val valueStr = params("value").toString + + val typedValue: Any = fromType match { + case _: DecimalType => + val cleanStr = if (valueStr.endsWith("BD")) valueStr.dropRight(2) else valueStr + Decimal(cleanStr) + case ByteType => + val cleanStr = if (valueStr.endsWith("T")) valueStr.dropRight(1) else valueStr + cleanStr.toByte + case ShortType => + val cleanStr = if (valueStr.endsWith("S")) valueStr.dropRight(1) else valueStr + cleanStr.toShort + case IntegerType => valueStr.toInt + case LongType => + val cleanStr = if (valueStr.endsWith("L")) valueStr.dropRight(1) else valueStr + cleanStr.toLong + case FloatType => valueStr.toFloat + case DoubleType => valueStr.toDouble + case StringType => UTF8String.fromString(valueStr) + case _ => valueStr + } + + Some(QueryExecutionErrors.castingCauseOverflowError(typedValue, fromType, toType)) + + case "CannotParseDecimal" => + Some(QueryExecutionErrors.cannotParseDecimalError()) + + case "InvalidUtf8String" => + // invalidUTF8StringError does not exist in Spark 3.x; use generic fallback + Some( + new SparkException( + errorClass = "INVALID_UTF8_STRING", + messageParameters = params.map { case (k, v) => (k, v.toString) }, + cause = null)) + + case "UnexpectedPositiveValue" => + Some( + QueryExecutionErrors.unexpectedValueForStartInFunctionError( + params("parameterName").toString)) + + case "UnexpectedNegativeValue" => + Some( + QueryExecutionErrors.unexpectedValueForLengthInFunctionError( + params("parameterName").toString)) + + case "InvalidRegexGroupIndex" => + Some( + QueryExecutionErrors.invalidRegexGroupIndexError( + params("functionName").toString, + params("groupCount").toString.toInt, + params("groupIndex").toString.toInt)) + + case "DatatypeCannotOrder" => + Some( + QueryExecutionErrors.orderedOperationUnsupportedByDataTypeError( + params("dataType").toString)) + + case "ScalarSubqueryTooManyRows" => + // multipleRowScalarSubqueryError was renamed to multipleRowSubqueryError in Spark 3.x + Some(QueryExecutionErrors.multipleRowSubqueryError(sqlCtx(context))) + + case "IntervalArithmeticOverflowWithSuggestion" => + // Spark 3.x uses a single intervalArithmeticOverflowError method + Some( + QueryExecutionErrors.intervalArithmeticOverflowError( + "Interval arithmetic overflow", + params.get("functionName").map(_.toString).getOrElse(""), + sqlCtx(context))) + + case "IntervalArithmeticOverflowWithoutSuggestion" => + Some( + QueryExecutionErrors + .intervalArithmeticOverflowError("Interval arithmetic overflow", "", sqlCtx(context))) + + case _ => + None + } + } + + private def getDataType(typeName: String): DataType = { + typeName.toUpperCase match { + case "BYTE" | "TINYINT" => ByteType + case "SHORT" | "SMALLINT" => ShortType + case "INT" | "INTEGER" => IntegerType + case "LONG" | "BIGINT" => LongType + case "FLOAT" | "REAL" => FloatType + case "DOUBLE" => DoubleType + case "DECIMAL" => DecimalType.SYSTEM_DEFAULT + case "STRING" | "VARCHAR" => StringType + case "BINARY" => BinaryType + case "BOOLEAN" => BooleanType + case "DATE" => DateType + case "TIMESTAMP" => TimestampType + case _ => + try { + DataType.fromDDL(typeName) + } catch { + case _: Exception => StringType + } + } + } +} diff --git a/spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala b/spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala new file mode 100644 index 0000000000..b19c7688be --- /dev/null +++ b/spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala @@ -0,0 +1,280 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.shims + +import org.apache.spark.QueryContext +import org.apache.spark.SparkException +import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +/** + * Spark 4.0-specific implementation for converting error types to proper Spark exceptions. + */ +trait ShimSparkErrorConverter { + + /** + * Convert error type string and parameters to appropriate Spark exception. Version-specific + * implementations call the correct QueryExecutionErrors.* methods. + * + * @param errorType + * The error type from JSON (e.g., "DivideByZero") + * @param errorClass + * The Spark error class (e.g., "DIVIDE_BY_ZERO") + * @param params + * Error parameters from JSON + * @param context + * QueryContext array with SQL text and position information + * @param summary + * Formatted summary string showing error location + * @return + * Throwable (specific exception type from QueryExecutionErrors), or None if unknown + */ + def convertErrorType( + errorType: String, + _errorClass: String, + params: Map[String, Any], + context: Array[QueryContext], + _summary: String): Option[Throwable] = { + + errorType match { + + case "DivideByZero" => + Some(QueryExecutionErrors.divideByZeroError(context.headOption.orNull)) + + case "RemainderByZero" => + // SPARK 4.0 REMOVED remainderByZeroError so we use generic arithmetic exception + Some( + new SparkException( + errorClass = "REMAINDER_BY_ZERO", + messageParameters = params.map { case (k, v) => (k, v.toString) }, + cause = null)) + + case "IntervalDividedByZero" => + Some(QueryExecutionErrors.intervalDividedByZeroError(context.headOption.orNull)) + + case "BinaryArithmeticOverflow" => + Some( + QueryExecutionErrors.binaryArithmeticCauseOverflowError( + params("value1").toString.toShort, + params("symbol").toString, + params("value2").toString.toShort, + params("functionName").toString)) + + case "ArithmeticOverflow" => + val fromType = params("fromType").toString + Some(QueryExecutionErrors.arithmeticOverflowError(fromType + " overflow", "")) + + case "IntegralDivideOverflow" => + Some(QueryExecutionErrors.overflowInIntegralDivideError(context.headOption.orNull)) + + case "DecimalSumOverflow" => + Some(QueryExecutionErrors.overflowInSumOfDecimalError(context.headOption.orNull, "")) + + case "NumericValueOutOfRange" => + val decimal = Decimal(params("value").toString) + Some( + QueryExecutionErrors.cannotChangeDecimalPrecisionError( + decimal, + params("precision").toString.toInt, + params("scale").toString.toInt, + context.headOption.orNull)) + + case "DatetimeOverflow" => + // Spark 4.0 doesn't have datetimeOverflowError + Some( + new SparkException( + errorClass = "DATETIME_OVERFLOW", + messageParameters = params.map { case (k, v) => (k, v.toString) }, + cause = null)) + + case "InvalidArrayIndex" => + Some( + QueryExecutionErrors.invalidArrayIndexError( + params("indexValue").toString.toInt, + params("arraySize").toString.toInt, + context.headOption.orNull)) + + case "InvalidElementAtIndex" => + Some( + QueryExecutionErrors.invalidElementAtIndexError( + params("indexValue").toString.toInt, + params("arraySize").toString.toInt, + context.headOption.orNull)) + + case "InvalidIndexOfZero" => + Some(QueryExecutionErrors.invalidIndexOfZeroError(context.headOption.orNull)) + + case "InvalidBitmapPosition" => + Some( + QueryExecutionErrors.invalidBitmapPositionError( + params("bitPosition").toString.toLong, + params("bitmapNumBytes").toString.toLong)) + + case "DuplicatedMapKey" => + Some(QueryExecutionErrors.duplicateMapKeyFoundError(params("key"))) + + case "NullMapKey" => + Some(QueryExecutionErrors.nullAsMapKeyNotAllowedError()) + + case "MapKeyValueDiffSizes" => + Some(QueryExecutionErrors.mapDataKeyArrayLengthDiffersFromValueArrayLengthError()) + + case "ExceedMapSizeLimit" => + Some(QueryExecutionErrors.exceedMapSizeLimitError(params("size").toString.toInt)) + + case "CollectionSizeLimitExceeded" => + Some( + QueryExecutionErrors.createArrayWithElementsExceedLimitError( + "array", + params("numElements").toString.toLong)) + + case "NotNullAssertViolation" => + Some( + QueryExecutionErrors.foundNullValueForNotNullableFieldError( + params("fieldName").toString)) + + case "ValueIsNull" => + Some( + QueryExecutionErrors.fieldCannotBeNullError( + params.getOrElse("rowIndex", 0).toString.toInt, + params("fieldName").toString)) + + case "CannotParseTimestamp" => + Some( + QueryExecutionErrors.ansiDateTimeParseError( + new Exception(params("message").toString), + params("suggestedFunc").toString)) + + case "InvalidFractionOfSecond" => + Some(QueryExecutionErrors.invalidFractionOfSecondError(params("value").toString.toDouble)) + + case "CastInvalidValue" => + val str = UTF8String.fromString(params("value").toString) + val targetType = getDataType(params("toType").toString) + Some( + QueryExecutionErrors + .invalidInputInCastToNumberError(targetType, str, context.headOption.orNull)) + + case "CastOverFlow" => + val fromType = getDataType(params("fromType").toString) + val toType = getDataType(params("toType").toString) + val valueStr = params("value").toString + + // Convert string value to appropriate type for toSQLValue + val typedValue: Any = fromType match { + case _: DecimalType => + // Parse decimal string (may have "BD" suffix from BigDecimal.toString) + val cleanStr = if (valueStr.endsWith("BD")) valueStr.dropRight(2) else valueStr + Decimal(cleanStr) + case ByteType => + // Strip "T" suffix for TINYINT literals + val cleanStr = if (valueStr.endsWith("T")) valueStr.dropRight(1) else valueStr + cleanStr.toByte + case ShortType => + // Strip "S" suffix for SMALLINT literals + val cleanStr = if (valueStr.endsWith("S")) valueStr.dropRight(1) else valueStr + cleanStr.toShort + case IntegerType => valueStr.toInt + case LongType => + // Strip "L" suffix for BIGINT literals + val cleanStr = if (valueStr.endsWith("L")) valueStr.dropRight(1) else valueStr + cleanStr.toLong + case FloatType => valueStr.toFloat + case DoubleType => valueStr.toDouble + case StringType => UTF8String.fromString(valueStr) + case _ => valueStr // Fallback to string + } + + Some(QueryExecutionErrors.castingCauseOverflowError(typedValue, fromType, toType)) + + case "CannotParseDecimal" => + Some(QueryExecutionErrors.cannotParseDecimalError()) + + case "InvalidUtf8String" => + val hexStr = UTF8String.fromString(params("hexString").toString) + Some(QueryExecutionErrors.invalidUTF8StringError(hexStr)) + + case "UnexpectedPositiveValue" => + Some( + QueryExecutionErrors.unexpectedValueForStartInFunctionError( + params("parameterName").toString)) + + case "UnexpectedNegativeValue" => + Some( + QueryExecutionErrors.unexpectedValueForLengthInFunctionError( + params("parameterName").toString, + params("actualValue").toString.toInt)) + + case "InvalidRegexGroupIndex" => + Some( + QueryExecutionErrors.invalidRegexGroupIndexError( + params("functionName").toString, + params("groupCount").toString.toInt, + params("groupIndex").toString.toInt)) + + case "DatatypeCannotOrder" => + Some( + QueryExecutionErrors.orderedOperationUnsupportedByDataTypeError( + params("dataType").toString)) + + case "ScalarSubqueryTooManyRows" => + Some(QueryExecutionErrors.multipleRowScalarSubqueryError(context.headOption.orNull)) + + case "IntervalArithmeticOverflowWithSuggestion" => + Some( + QueryExecutionErrors.withSuggestionIntervalArithmeticOverflowError( + params.get("functionName").map(_.toString).getOrElse(""), + context.headOption.orNull)) + + case "IntervalArithmeticOverflowWithoutSuggestion" => + Some( + QueryExecutionErrors.withoutSuggestionIntervalArithmeticOverflowError( + context.headOption.orNull)) + + case _ => + // Unknown error type - return None to trigger fallback + None + } + } + + private def getDataType(typeName: String): DataType = { + typeName.toUpperCase match { + case "BYTE" | "TINYINT" => ByteType + case "SHORT" | "SMALLINT" => ShortType + case "INT" | "INTEGER" => IntegerType + case "LONG" | "BIGINT" => LongType + case "FLOAT" | "REAL" => FloatType + case "DOUBLE" => DoubleType + case "DECIMAL" => DecimalType.SYSTEM_DEFAULT + case "STRING" | "VARCHAR" => StringType + case "BINARY" => BinaryType + case "BOOLEAN" => BooleanType + case "DATE" => DateType + case "TIMESTAMP" => TimestampType + case _ => + try { + DataType.fromDDL(typeName) + } catch { + case _: Exception => StringType + } + } + } +} diff --git a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala index b22d0f72db..21b4276a64 100644 --- a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala @@ -922,4 +922,92 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp } } } + + // https://github.com/apache/datafusion-comet/issues/3375 + test("(ansi) array access out of bounds - GetArrayItem") { + withSQLConf( + SQLConf.ANSI_ENABLED.key -> "true", + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true") { + withTable("test_array_get_item") { + sql("CREATE TABLE test_array_get_item(arr ARRAY) USING parquet") + sql("INSERT INTO test_array_get_item VALUES (array(1, 2, 3))") + // Try to access array with out-of-bounds index + val exception = intercept[Exception] { + sql("select arr[5] from test_array_get_item").collect() + } + val errorMessage = exception.getMessage + // Verify error message contains the expected error code + assert( + errorMessage.contains("INVALID_ARRAY_INDEX"), + s"Error message should contain array index error: $errorMessage") + + assert(errorMessage.contains("The index 5 is out of bounds. The array has 3 elements." + + " Use the SQL function `get()` to tolerate accessing element at invalid index and return NULL instead.")) + + assert( + errorMessage.contains("select arr[5] from test_array_get_item"), + s"Error message should contain SQL query text but got: $errorMessage") + } + } + } + + // https://github.com/apache/datafusion-comet/issues/3375 + test("(ansi) array access out of bounds - element_at with invalid index") { + withSQLConf( + SQLConf.ANSI_ENABLED.key -> "true", + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true") { + withTable("test_element_at_invalid") { + sql("CREATE TABLE test_element_at_invalid(arr ARRAY) USING parquet") + sql("INSERT INTO test_element_at_invalid VALUES (array(1, 2, 3))") + // Try to access array with out-of-bounds index using element_at + val exception = intercept[Exception] { + sql("select element_at(arr, 10) from test_element_at_invalid").collect() + } + val errorMessage = exception.getMessage + // Verify error message contains the expected error code + assert( + errorMessage.contains("INVALID_ARRAY_INDEX_IN_ELEMENT_AT"), + s"Error message should contain array index error: $errorMessage") + + assert(errorMessage.contains("The index 10 is out of bounds. The array has 3 elements." + + " Use `try_element_at` to tolerate accessing element at invalid index and return NULL instead")) + + assert( + errorMessage.contains("select element_at(arr, 10) from test_element_at_invalid"), + s"Error message should contain SQL query text but got: $errorMessage") + } + } + } + + // https://github.com/apache/datafusion-comet/issues/3375 + test("(ansi) array access with zero index - element_at") { + withSQLConf( + SQLConf.ANSI_ENABLED.key -> "true", + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true") { + withTable("test_element_at_zero") { + sql("CREATE TABLE test_element_at_zero(arr ARRAY) USING parquet") + sql("INSERT INTO test_element_at_zero VALUES (array(1, 2, 3))") + // Try to access array with zero index (invalid in Spark) + val exception = intercept[Exception] { + sql("select element_at(arr, 0) from test_element_at_zero").collect() + } + val errorMessage = exception.getMessage + // Verify error message contains the expected error code + assert( + errorMessage.contains("INVALID_INDEX_OF_ZERO"), + s"Error message should contain zero index error: $errorMessage") + + assert( + errorMessage.contains("The index 0 is invalid. An index shall be either < 0 or > 0" + + " (the first element has index 1)")) + + assert( + errorMessage.contains("select element_at(arr, 0) from test_element_at_zero"), + s"Error message should contain SQL query text but got: $errorMessage") + } + } + } } diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index 72c2390d71..8bfae8a7d3 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -913,6 +913,33 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { castTest((validDates ++ invalidDates ++ fuzzDates).toDF("a"), DataTypes.DateType) } + // https://github.com/apache/datafusion-comet/issues/2215 + test("(ansi) cast error message should include SQL query") { + withSQLConf( + SQLConf.ANSI_ENABLED.key -> "true", + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true") { + withTable("cast_error_msg") { + // Create a table with string data using DataFrame API + Seq("a").toDF("s").write.format("parquet").saveAsTable("cast_error_msg") + // Try to cast invalid string to date - should throw exception with SQL context + val exception = intercept[Exception] { + sql("select cast(s as date) from cast_error_msg").collect() + } + val errorMessage = exception.getMessage + // Verify error message contains the cast invalid input error + assert( + errorMessage.contains("CAST_INVALID_INPUT") || + errorMessage.contains("cannot be cast to"), + s"Error message should contain cast error: $errorMessage") + + assert( + errorMessage.contains("select cast(s as date) from cast_error_msg"), + s"Error message should contain SQL query text but got: $errorMessage") + } + } + } + test("cast StringType to TimestampType disabled by default") { withSQLConf((SQLConf.SESSION_LOCAL_TIMEZONE.key, "UTC")) { val values = Seq("2020-01-01T12:34:56.123456", "T2").toDF("a") @@ -1526,20 +1553,11 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { if (CometSparkSessionExtensions.isSpark40Plus) { // for Spark 4 we expect to sparkException carries the message assert(sparkMessage.contains("SQLSTATE")) - if (sparkMessage.startsWith("[NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION]")) { - assert( - sparkMessage.replace(".WITH_SUGGESTION] ", "]").startsWith(cometMessage)) - } else if (cometMessage.startsWith("[CAST_INVALID_INPUT]") || cometMessage - .startsWith("[CAST_OVERFLOW]")) { - assert( - sparkMessage.startsWith( - cometMessage - .replace( - "If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.", - ""))) - } else { - assert(sparkMessage.startsWith(cometMessage)) - } + // we compare a subset of the error message. Comet grabs the query + // context eagerly so it displays the call site at the + // line of code where the cast method was called, whereas spark grabs the context + // lazily and displays the call site at the line of code where the error is checked. + assert(sparkMessage.startsWith(cometMessage.substring(0, 40))) } else { // for Spark 3.4 we expect to reproduce the error message exactly assert(cometMessage == sparkMessage) diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index f0f022868f..339061f5be 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -54,7 +54,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } val ARITHMETIC_OVERFLOW_EXCEPTION_MSG = - """org.apache.comet.CometNativeException: [ARITHMETIC_OVERFLOW] integer overflow. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error""" + """[ARITHMETIC_OVERFLOW] integer overflow. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error""" val DIVIDE_BY_ZERO_EXCEPTION_MSG = """Division by zero. Use `try_divide` to tolerate divisor being 0 and return NULL instead"""