diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index 480eafdcb7..349e5e21d5 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -294,6 +294,39 @@ object CometConf extends ShimCometConf { val COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED: ConfigEntry[Boolean] = createExecEnabledConfig("localTableScan", defaultValue = false) + val COMET_EXEC_GRACE_HASH_JOIN_ENABLED: ConfigEntry[Boolean] = + conf(s"$COMET_EXEC_CONFIG_PREFIX.graceHashJoin.enabled") + .category(CATEGORY_EXEC) + .doc( + "Whether to enable Grace Hash Join. When enabled, Comet will use a Grace Hash Join " + + "operator that partitions both sides into buckets and can spill to disk when memory " + + "is tight. Supports all join types. This is an experimental feature.") + .booleanConf + .createWithDefault(false) + + val COMET_EXEC_GRACE_HASH_JOIN_NUM_PARTITIONS: ConfigEntry[Int] = + conf(s"$COMET_EXEC_CONFIG_PREFIX.graceHashJoin.numPartitions") + .category(CATEGORY_EXEC) + .doc("The number of partitions (buckets) to use for Grace Hash Join. A higher number " + + "reduces the size of each partition but increases overhead.") + .intConf + .checkValue(v => v > 0, "The number of partitions must be positive.") + .createWithDefault(16) + + val COMET_EXEC_GRACE_HASH_JOIN_FAST_PATH_THRESHOLD: ConfigEntry[Int] = + conf(s"$COMET_EXEC_CONFIG_PREFIX.graceHashJoin.fastPathThreshold") + .category(CATEGORY_EXEC) + .doc( + "Total memory budget in bytes for Grace Hash Join fast-path hash tables across " + + "all concurrent tasks. This is divided by spark.executor.cores to get the per-task " + + "threshold. When a build side fits in memory and is smaller than the per-task " + + "threshold, the join executes as a single HashJoinExec without spilling. " + + "Set to 0 to disable the fast path. Larger values risk OOM because HashJoinExec " + + "creates non-spillable hash tables.") + .intConf + .checkValue(v => v >= 0, "The fast path threshold must be non-negative.") + .createWithDefault(10 * 1024 * 1024) // 10 MB + val COMET_NATIVE_COLUMNAR_TO_ROW_ENABLED: ConfigEntry[Boolean] = conf(s"$COMET_EXEC_CONFIG_PREFIX.columnarToRow.native.enabled") .category(CATEGORY_EXEC) diff --git a/docs/source/contributor-guide/grace-hash-join-design.md b/docs/source/contributor-guide/grace-hash-join-design.md new file mode 100644 index 0000000000..c85898adaf --- /dev/null +++ b/docs/source/contributor-guide/grace-hash-join-design.md @@ -0,0 +1,333 @@ + + +# Grace Hash Join Design Document + +## Overview + +Grace Hash Join (GHJ) is an operator for Apache DataFusion Comet that replaces Spark's `ShuffledHashJoinExec` with a spill-capable hash join. It partitions both build and probe sides into N buckets by hashing join keys, then joins each bucket independently. When memory is tight, partitions spill to disk using Arrow IPC format and are joined later using streaming reads. + +GHJ supports all join types (Inner, Left, Right, Full, LeftSemi, LeftAnti, LeftMark, RightSemi, RightAnti, RightMark) and handles skewed data through recursive repartitioning. + +## Motivation + +Spark's `ShuffledHashJoinExec` loads the entire build side into a hash table in memory. When the build side is large or executor memory is constrained, this causes OOM failures. DataFusion's built-in `HashJoinExec` has the same limitation — its `HashJoinInput` consumer is marked `can_spill: false`. + +GHJ solves this by: + +1. Partitioning both sides into smaller buckets that fit in memory individually +2. Spilling partitions to disk when memory pressure is detected +3. Joining partitions independently, reading spilled data back via streaming I/O + +## Configuration + +| Config Key | Type | Default | Description | +| ---------------------------------------------- | ------- | ------- | ----------------------------------- | +| `spark.comet.exec.graceHashJoin.enabled` | boolean | `false` | Enable Grace Hash Join | +| `spark.comet.exec.graceHashJoin.numPartitions` | int | `16` | Number of hash partitions (buckets) | + +## Architecture + +### Plan Integration + +``` +Spark ShuffledHashJoinExec + → CometExecRule identifies ShuffledHashJoinExec + → CometHashJoinExec.createExec() checks config + → If GHJ enabled: CometGraceHashJoinExec (serialized to protobuf) + → JNI → PhysicalPlanner (Rust) creates GraceHashJoinExec +``` + +The `RewriteJoin` rule additionally converts `SortMergeJoinExec` to `ShuffledHashJoinExec` so that GHJ can intercept sort-merge joins as well. + +### Key Data Structures + +``` +GraceHashJoinExec ExecutionPlan implementation +├── left/right Child input plans +├── on Join key pairs [(left_key, right_key)] +├── filter Optional post-join filter +├── join_type Inner/Left/Right/Full/Semi/Anti/Mark +├── num_partitions Number of hash buckets (default 16) +├── build_left Whether left input is the build side +└── schema Output schema + +HashPartition Per-bucket state during partitioning +├── build_batches In-memory build-side RecordBatches +├── probe_batches In-memory probe-side RecordBatches +├── build_spill_writer Optional SpillWriter for build data +├── probe_spill_writer Optional SpillWriter for probe data +├── build_mem_size Tracked memory for build side +└── probe_mem_size Tracked memory for probe side + +FinishedPartition State after spill writers are closed +├── build_batches In-memory build batches (if not spilled) +├── probe_batches In-memory probe batches (if not spilled) +├── build_spill_file Temp file for spilled build data +└── probe_spill_file Temp file for spilled probe data +``` + +## Execution Phases + +### Overview + +``` +execute() + │ + ├─ Phase 1: Partition build side + │ Hash-partition all build input into N buckets. + │ Spill the largest bucket on memory pressure. + │ + ├─ Phase 2: Partition probe side + │ Hash-partition probe input into N buckets. + │ Spill ALL non-spilled buckets on first memory pressure. + │ + └─ Phase 3: Join each partition (sequential) + For each bucket, create a per-partition HashJoinExec. + Spilled probes use streaming SpillReaderExec. + Oversized builds trigger recursive repartitioning. + Only one partition's HashJoinInput exists at a time. +``` + +### Phase 1: Build-Side Partitioning + +For each incoming batch from the build input: + +1. Evaluate join key expressions and compute hash values +2. Assign each row to a partition: `partition_id = hash % num_partitions` +3. Use the prefix-sum algorithm (from the shuffle operator) to efficiently extract contiguous row groups per partition via `arrow::compute::take()` +4. For each partition's sub-batch: + - If the partition is already spilled, append to its `SpillWriter` + - Otherwise, call `reservation.try_grow(batch_size)` + - On failure: spill the largest non-spilled partition, retry + - If still fails: spill this partition and write to disk + +**Memory tracking**: All in-memory build data is tracked in a shared `MutableReservation` registered as `can_spill: true`. This is critical — it makes GHJ a cooperative citizen in DataFusion's memory pool, allowing other operators to trigger memory reclamation. + +### Phase 2: Probe-Side Partitioning + +Same hash-partitioning algorithm as Phase 1, with key differences: + +1. **Spilled build implies spilled probe**: If a partition's build side was spilled, the probe side must also be spilled for consistency during the join phase. Both sides need to be on disk (or both in memory). + +2. **Aggressive spilling strategy**: On first memory pressure event, spill ALL non-spilled partitions (both build and probe sides). This prevents a pattern where spilling one partition frees memory, new probe data accumulates in remaining partitions, pressure returns, another partition is spilled, etc. With multiple concurrent GHJ instances sharing a memory pool, this "whack-a-mole" pattern never converges. + +3. **Probe memory tracked in same reservation**: The shared `MutableReservation` from Phase 1 continues to track probe-side memory. + +### Phase 3: Per-Partition Joins + +Partitions are joined **sequentially** — one at a time — so only one `HashJoinInput` consumer exists at any moment. This keeps peak memory at ~1/N of what a single large hash table would require. DataFusion manages parallelism externally by calling `execute(partition)` from multiple async tasks; GHJ does not spawn its own internal parallelism. + +The GHJ reservation is freed before Phase 3 begins, since the partition data has been moved into `FinishedPartition` structs and each per-partition `HashJoinExec` will track its own memory via `HashJoinInput`. + +**In-memory probe** → `join_partition_recursive()`: + +- Concatenate build and probe sub-batches +- Create `HashJoinExec` with both sides as `MemorySourceConfig` +- If build too large for hash table: recursively repartition (up to `MAX_RECURSION_DEPTH = 3` levels, yielding up to 16^3 = 4096 effective partitions) + +**Spilled probe** → `join_with_spilled_probe()`: + +- Build side loaded from memory or disk via `spawn_blocking` (to avoid blocking the async executor) +- Probe side streamed via `SpillReaderExec` (never fully loaded into memory) +- If build too large: fall back to eager probe read + recursive repartitioning + +## Spill Mechanism + +### Writing + +`SpillWriter` wraps Arrow IPC `StreamWriter` for incremental appends: + +- Uses `BufWriter` with 1 MB buffer (vs 8 KB default) for throughput +- Batches are appended one at a time — no need to rewrite the file +- `finish()` flushes the writer and returns the `RefCountedTempFile` + +Temp files are created via DataFusion's `DiskManager`, which handles allocation and cleanup. + +### Reading + +Two read paths depending on whether the full data is needed: + +**Eager read** (`read_spilled_batches`): Opens file, reads all batches into `Vec`. Used for small build-side spill files. + +**Streaming read** (`SpillReaderExec`): An `ExecutionPlan` that reads batches on-demand: + +- Spawns a `tokio::task::spawn_blocking` to read from the file on a blocking thread pool +- Uses an `mpsc` channel (capacity 4) to feed batches to the async executor +- Coalesces small sub-batches into ~8192-row chunks before sending, reducing per-batch overhead in the downstream hash join kernel +- The `RefCountedTempFile` handle is moved into the blocking closure to keep the file alive until reading completes + +### Spill I/O Optimization + +Spill files contain many tiny sub-batches because each incoming batch is partitioned into N pieces. Without coalescing, a spill file with 1M rows might contain 10,000+ batches of ~100 rows each. The coalescing step in `SpillReaderExec` merges these into ~122 batches of ~8192 rows, dramatically reducing: + +- Channel send/recv overhead +- Hash join kernel invocations +- Per-batch `RecordBatch` construction costs + +## Memory Management + +### Reservation Model + +GHJ uses a single `MemoryReservation` registered as a spillable consumer (`with_can_spill(true)`). This reservation: + +- Tracks all in-memory build and probe data across all partitions +- Grows via `try_grow()` before each batch is added to memory +- Shrinks via `shrink()` when partitions are spilled to disk +- Acts as a cooperative memory citizen — DataFusion's memory pool can account for GHJ's memory when other operators request allocations + +### Why Spillable Registration Matters + +DataFusion's memory pool (typically `FairSpillPool`) divides memory between spillable and non-spillable consumers. Non-spillable consumers (`can_spill: false`) like `HashJoinInput` from regular `HashJoinExec` get a guaranteed fraction. When non-spillable consumers exhaust their allocation, the pool returns an error. + +GHJ registers as spillable so the pool can account for its memory when computing fair shares. During Phases 1 and 2, the reservation tracks all in-memory partition data and triggers spilling when `try_grow` fails. Before Phase 3, the reservation is freed — the data is now owned by `FinishedPartition` structs and will be tracked by each per-partition `HashJoinExec`'s own `HashJoinInput` reservation. + +### Concurrent GHJ Instances + +In a typical Spark executor, multiple tasks run concurrently, each potentially executing a GHJ. All instances share the same DataFusion memory pool. This creates contention: + +- Instance A spills a partition, freeing memory +- Instance B immediately claims that memory for its probe data +- Instance A needs memory for the next batch, finds none available +- Both instances thrash between spilling and accumulating + +The "spill ALL non-spilled partitions" strategy in Phase 2 addresses this by making each instance's spill decision atomic — once triggered, the instance moves all its data to disk in one operation, preventing interleaving with other instances. + +## Hash Partitioning Algorithm + +### Prefix-Sum Approach + +Instead of N separate `take()` kernel calls (one per partition), GHJ uses a prefix-sum algorithm from the shuffle operator: + +1. **Hash**: Compute hash values for all rows +2. **Assign**: Map each row to a partition: `partition_id = hash % N` +3. **Count**: Count rows per partition +4. **Prefix-sum**: Accumulate counts into start offsets +5. **Scatter**: Place row indices into contiguous regions per partition +6. **Take**: Single `arrow::compute::take()` per partition using the precomputed indices + +This is O(rows) with excellent cache locality, compared to O(rows × partitions) for the naive approach. + +### Hash Seed Variation + +GHJ hashes on the same join keys that Spark already used for its shuffle exchange, but this is not redundant. Spark's shuffle uses Murmur3 to assign rows to exchange partitions, so all rows arriving at a given Spark partition share the same `murmur3(key) % num_spark_partitions` value — but they have diverse actual key values. GHJ then hashes those same keys with a **different hash function** (ahash via `RandomState` with fixed seeds), producing a completely different distribution: + +``` +Spark shuffle: murmur3(key) % 200 → all rows land in partition 42 +GHJ level 0: ahash(key, seed0) % 16 → rows spread across buckets 0-15 +GHJ level 1: ahash(key, seed1) % 16 → further redistribution within each bucket +``` + +The hash function uses different random seeds at each recursion level: + +```rust +fn partition_random_state(recursion_level: usize) -> RandomState { + RandomState::with_seeds( + 0x517cc1b727220a95 ^ (recursion_level as u64), + 0x3a8b7c9d1e2f4056, 0, 0, + ) +} +``` + +This ensures that rows which hash to the same partition at level 0 are distributed across different sub-partitions at level 1, breaking up hash collisions. The only case where repartitioning cannot help is true data skew — many rows with the _same_ key value. No amount of rehashing can separate identical keys, which is why there is a `MAX_RECURSION_DEPTH = 3` limit, after which GHJ returns a `ResourcesExhausted` error. + +## Recursive Repartitioning + +When a partition's build side is too large for a hash table (tested via `try_grow(build_size * 3)`), GHJ recursively repartitions: + +1. Sub-partition both build and probe into 16 new buckets using a different hash seed +2. Recursively join each sub-partition +3. Maximum depth: 3 (yielding up to 16^3 = 4096 effective partitions) +4. If still too large at max depth: return `ResourcesExhausted` error + +The 3x multiplier accounts for hash table overhead (the `JoinHashMap` typically uses 2-3x the raw data size). + +## Build Side Selection + +GHJ respects Spark's build side selection (`BuildLeft` or `BuildRight`). The `build_left` flag determines: + +- Which input is consumed in Phase 1 (build) vs Phase 2 (probe) +- How join key expressions are mapped (left keys → build keys if `build_left`) +- How `HashJoinExec` is constructed (build side is always left in `CollectLeft` mode) + +When `build_left = false`, the `HashJoinExec` is created with swapped inputs and then `swap_inputs()` is called to produce correct output column ordering. + +## Metrics + +| Metric | Description | +| --------------------- | ------------------------------------------- | +| `build_time` | Time spent partitioning the build side | +| `probe_time` | Time spent partitioning the probe side | +| `spill_count` | Number of partition spill events | +| `spilled_bytes` | Total bytes written to spill files | +| `build_input_rows` | Total rows from build input | +| `build_input_batches` | Total batches from build input | +| `input_rows` | Total rows from probe input | +| `input_batches` | Total batches from probe input | +| `output_rows` | Total output rows (from `BaselineMetrics`) | +| `elapsed_compute` | Total compute time (from `BaselineMetrics`) | + +## Lessons Learned + +### 1. Memory pool cooperation is non-negotiable + +Any optimization that removes the spillable reservation from the memory pool during Phases 1 and 2 breaks other operators. The pool's ability to handle pressure depends on having at least one spillable consumer. The reservation is freed before Phase 3 only because each per-partition `HashJoinExec` tracks its own memory. + +### 2. Spill one partition at a time doesn't work with concurrency + +With N concurrent GHJ instances sharing a pool, spilling the "largest partition" frees memory that other instances immediately claim. The effective free memory after spilling is near zero. Spilling ALL non-spilled partitions atomically prevents this race. + +### 3. Probe-side memory must be tracked + +The original implementation only tracked build-side memory in the reservation. Untracked probe-side accumulation (e.g., 170M rows at 6.5GB per executor) caused OOM before any spilling could occur. + +### 4. The join phase can be the OOM bottleneck, not the partition phase + +Even with proper spilling during partitioning, eagerly loading all spilled probe data in the join phase reintroduces the OOM. `SpillReaderExec` with streaming reads solved this. + +### 5. Small batches from spill files kill performance + +Hash-partitioning creates N sub-batches per input batch. With N=16 partitions and 1000-row input batches, spill files contain ~62-row sub-batches. Reading and joining millions of tiny batches has massive per-batch overhead. Coalescing to ~8192-row batches on read reduces overhead by 100x+. + +### 6. A fast path that skips partitioning creates non-spillable memory pressure + +An earlier design included a "fast path" that skipped Phases 2 and 3 when the build side appeared small: it concatenated all build data into a single `HashJoinExec` and streamed the probe directly through it. This was removed because: + +- **`HashJoinInput` is non-spillable.** `HashJoinExec` registers its hash table memory as `can_spill: false`. A single large `HashJoinInput` cannot be reclaimed under memory pressure. +- **`build_mem_size` severely underestimates actual memory.** The proportional estimate (`total_batch_size * sub_rows / total_rows`) used during partitioning can undercount by 5-20x because it doesn't account for per-array overhead in sub-batches created by `take()`. A build side estimated at 45 MB could actually be 460 MB, producing a 1.3 GB hash table. +- **The 3x memory check is a point-in-time snapshot.** Even with accurate sizes, the check (`try_grow(build_bytes * 3)`) passes when other operators haven't allocated yet. By the time the hash table is built, concurrent operators (broadcast hash joins, other GHJ instances) have consumed pool space, and the total exceeds the pool limit. +- **The slow path handles small builds efficiently.** With 16 partitions processed sequentially, each hash table is ~1/16 of the total. The overhead of partitioning the probe side is modest compared to the memory safety gained. + +In TPC-DS q72 (which has 2 GHJ operators and 8 broadcast hash joins sharing a pool), the fast path created a 1.3 GB non-spillable hash table in a ~954 MB pool, causing OOM. The slow path keeps peak hash table memory at ~86 MB per partition. + +### 7. DataFusion's HashJoinExec is not spill-capable + +`HashJoinInput` is registered with `can_spill: false`. There is no way to make `HashJoinExec` yield memory under pressure. This is a fundamental DataFusion limitation that GHJ works around by managing memory at the partition level — keeping each per-partition hash table small and processing them one at a time. + +### 8. Internal parallelism fights the runtime + +An earlier design spawned each partition's join as a separate `tokio::task` for parallel execution. This was removed because DataFusion already manages parallelism by calling `execute(partition)` from multiple async tasks. Internal parallelism creates concurrent `HashJoinInput` reservations that compete for pool space and is redundant with the runtime's own scheduling. + +## Future Work + +- **Bloom filter pre-filtering**: For inner joins with tiny build sides, a bloom filter could skip probe batches that have no matching keys, reducing both I/O and computation +- **Adaptive partition count**: Dynamically choose the number of partitions based on input size rather than a fixed default +- **Spill file compression**: Compress Arrow IPC data on disk to reduce I/O volume at the cost of CPU +- **Memory-mapped spill files**: Use mmap instead of sequential reads for random access patterns during repartitioning +- **Upstream DataFusion spill support**: Contribute spill capability to DataFusion's `HashJoinExec` to eliminate the need for a separate GHJ operator diff --git a/native/Cargo.lock b/native/Cargo.lock index 0977bb96dc..5d301e902f 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -1842,6 +1842,7 @@ dependencies = [ name = "datafusion-comet" version = "0.14.0" dependencies = [ + "ahash 0.8.12", "arrow", "assertables", "async-trait", diff --git a/native/Cargo.toml b/native/Cargo.toml index 3aa3cd0abf..7a960eaf18 100644 --- a/native/Cargo.toml +++ b/native/Cargo.toml @@ -34,7 +34,7 @@ edition = "2021" rust-version = "1.88" [workspace.dependencies] -arrow = { version = "57.3.0", features = ["prettyprint", "ffi", "chrono-tz"] } +arrow = { version = "57.3.0", features = ["prettyprint", "ffi", "chrono-tz", "ipc_compression"] } async-trait = { version = "0.1" } bytes = { version = "1.11.1" } parquet = { version = "57.2.0", default-features = false, features = ["experimental"] } diff --git a/native/core/Cargo.toml b/native/core/Cargo.toml index bcf2dad8c8..a72b439dce 100644 --- a/native/core/Cargo.toml +++ b/native/core/Cargo.toml @@ -35,6 +35,7 @@ include = [ publish = false [dependencies] +ahash = "0.8" arrow = { workspace = true } parquet = { workspace = true, default-features = false, features = ["experimental", "arrow"] } futures = { workspace = true } diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index 0193f3012c..f832d30cee 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -173,6 +173,8 @@ struct ExecutionContext { pub memory_pool_config: MemoryPoolConfig, /// Whether to log memory usage on each call to execute_plan pub tracing_enabled: bool, + /// Spark configuration map for comet-specific settings + pub spark_conf: HashMap, } /// Accept serialized query plan and return the address of the native query plan. @@ -320,6 +322,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( explain_native, memory_pool_config, tracing_enabled, + spark_conf: spark_config, }); Ok(Box::into_raw(exec_context) as i64) @@ -531,7 +534,8 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( let start = Instant::now(); let planner = PhysicalPlanner::new(Arc::clone(&exec_context.session_ctx), partition) - .with_exec_id(exec_context_id); + .with_exec_id(exec_context_id) + .with_spark_conf(exec_context.spark_conf.clone()); let (scans, root_op) = planner.create_plan( &exec_context.spark_plan, &mut exec_context.input_sources.clone(), diff --git a/native/core/src/execution/operators/grace_hash_join.rs b/native/core/src/execution/operators/grace_hash_join.rs new file mode 100644 index 0000000000..f749d47114 --- /dev/null +++ b/native/core/src/execution/operators/grace_hash_join.rs @@ -0,0 +1,2625 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Grace Hash Join operator for Apache DataFusion Comet. +//! +//! Partitions both build and probe sides into N buckets by hashing join keys, +//! then performs per-partition hash joins. Spills partitions to disk (Arrow IPC) +//! when memory is tight. +//! +//! Supports all join types. Recursively repartitions oversized partitions +//! up to `MAX_RECURSION_DEPTH` levels. + +use std::any::Any; +use std::fmt; +use std::fs::File; +use std::io::{BufReader, BufWriter}; +use std::sync::Arc; +use std::sync::Mutex; + +use ahash::RandomState; +use arrow::array::UInt32Array; +use arrow::compute::{concat_batches, take}; +use arrow::datatypes::SchemaRef; +use arrow::ipc::reader::StreamReader; +use arrow::ipc::writer::{IpcWriteOptions, StreamWriter}; +use arrow::ipc::CompressionType; +use arrow::record_batch::RecordBatch; +use datafusion::common::hash_utils::create_hashes; +use datafusion::common::{DataFusionError, JoinType, NullEquality, Result as DFResult}; +use datafusion::execution::context::TaskContext; +use datafusion::execution::disk_manager::RefCountedTempFile; +use datafusion::execution::memory_pool::{MemoryConsumer, MemoryReservation}; +use datafusion::physical_expr::EquivalenceProperties; +use datafusion::physical_expr::PhysicalExpr; +use datafusion::physical_plan::display::DisplayableExecutionPlan; +use datafusion::physical_plan::joins::utils::JoinFilter; +use datafusion::physical_plan::joins::{HashJoinExec, PartitionMode}; +use datafusion::physical_plan::metrics::{ + BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet, Time, +}; +use datafusion::physical_plan::stream::RecordBatchStreamAdapter; +use datafusion::physical_plan::{ + DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, + SendableRecordBatchStream, +}; +use futures::stream::{self, StreamExt, TryStreamExt}; +use futures::Stream; +use log::info; +use tokio::sync::mpsc; + +/// Global atomic counter for unique GHJ instance IDs (debug tracing). +static GHJ_INSTANCE_COUNTER: std::sync::atomic::AtomicUsize = + std::sync::atomic::AtomicUsize::new(0); + +/// Type alias for join key expression pairs. +type JoinOnRef<'a> = &'a [(Arc, Arc)]; + +/// Number of partitions (buckets) for the grace hash join. +const DEFAULT_NUM_PARTITIONS: usize = 16; + +/// Maximum recursion depth for repartitioning oversized partitions. +/// At depth 3 with 16 partitions per level, effective partitions = 16^3 = 4096. +const MAX_RECURSION_DEPTH: usize = 3; + +/// I/O buffer size for spill file reads and writes. The default BufReader/BufWriter +/// size (8 KB) is far too small for multi-GB spill files. 1 MB provides good +/// sequential throughput while keeping per-partition memory overhead modest. +const SPILL_IO_BUFFER_SIZE: usize = 1024 * 1024; + +/// Target number of rows per coalesced batch when reading spill files. +/// Spill files contain many tiny sub-batches (from partitioning). Coalescing +/// into larger batches reduces per-batch overhead in the hash join kernel +/// and channel send/recv costs. +const SPILL_READ_COALESCE_TARGET: usize = 8192; + +/// Target build-side size per merged partition. After Phase 2, adjacent +/// `FinishedPartition`s are merged so each group has roughly this much +/// build data, reducing the number of per-partition HashJoinExec calls. +const TARGET_PARTITION_BUILD_SIZE: usize = 32 * 1024 * 1024; + +/// Random state for hashing join keys into partitions. Uses fixed seeds +/// different from DataFusion's HashJoinExec to avoid correlation. +/// The `recursion_level` is XORed into the seed so that recursive +/// repartitioning uses different hash functions at each level. +fn partition_random_state(recursion_level: usize) -> RandomState { + RandomState::with_seeds( + 0x517cc1b727220a95 ^ (recursion_level as u64), + 0x3a8b7c9d1e2f4056, + 0, + 0, + ) +} + +// --------------------------------------------------------------------------- +// SpillWriter: incremental append to Arrow IPC spill files +// --------------------------------------------------------------------------- + +/// Wraps an Arrow IPC `StreamWriter` for incremental spill writes. +/// Avoids the O(n²) read-rewrite pattern by keeping the writer open. +struct SpillWriter { + writer: StreamWriter>, + temp_file: RefCountedTempFile, + bytes_written: usize, +} + +impl SpillWriter { + /// Create a new spill writer backed by a temp file. + fn new(temp_file: RefCountedTempFile, schema: &SchemaRef) -> DFResult { + let file = std::fs::OpenOptions::new() + .write(true) + .create(true) + .truncate(true) + .open(temp_file.path()) + .map_err(|e| DataFusionError::Execution(format!("Failed to open spill file: {e}")))?; + let buf_writer = BufWriter::with_capacity(SPILL_IO_BUFFER_SIZE, file); + let write_options = + IpcWriteOptions::default().try_with_compression(Some(CompressionType::LZ4_FRAME))?; + let writer = StreamWriter::try_new_with_options(buf_writer, schema, write_options)?; + Ok(Self { + writer, + temp_file, + bytes_written: 0, + }) + } + + /// Append a single batch to the spill file. + fn write_batch(&mut self, batch: &RecordBatch) -> DFResult<()> { + if batch.num_rows() > 0 { + self.bytes_written += batch.get_array_memory_size(); + self.writer.write(batch)?; + } + Ok(()) + } + + /// Append multiple batches to the spill file. + fn write_batches(&mut self, batches: &[RecordBatch]) -> DFResult<()> { + for batch in batches { + self.write_batch(batch)?; + } + Ok(()) + } + + /// Finish writing. Must be called before reading back. + fn finish(mut self) -> DFResult<(RefCountedTempFile, usize)> { + self.writer.finish()?; + Ok((self.temp_file, self.bytes_written)) + } +} + +// --------------------------------------------------------------------------- +// SpillReaderExec: streaming ExecutionPlan for reading spill files +// --------------------------------------------------------------------------- + +/// An ExecutionPlan that streams record batches from an Arrow IPC spill file. +/// Used during the join phase so that spilled probe data is read on-demand +/// instead of loaded entirely into memory. +#[derive(Debug)] +struct SpillReaderExec { + spill_file: RefCountedTempFile, + schema: SchemaRef, + cache: PlanProperties, +} + +impl SpillReaderExec { + fn new(spill_file: RefCountedTempFile, schema: SchemaRef) -> Self { + let cache = PlanProperties::new( + EquivalenceProperties::new(Arc::clone(&schema)), + Partitioning::UnknownPartitioning(1), + datafusion::physical_plan::execution_plan::EmissionType::Incremental, + datafusion::physical_plan::execution_plan::Boundedness::Bounded, + ); + Self { + spill_file, + schema, + cache, + } + } +} + +impl DisplayAs for SpillReaderExec { + fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "SpillReaderExec") + } +} + +impl ExecutionPlan for SpillReaderExec { + fn name(&self) -> &str { + "SpillReaderExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> DFResult> { + Ok(self) + } + + fn properties(&self) -> &PlanProperties { + &self.cache + } + + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> DFResult { + let schema = Arc::clone(&self.schema); + let coalesce_schema = Arc::clone(&self.schema); + let path = self.spill_file.path().to_path_buf(); + // Move the spill file handle into the blocking closure to keep + // the temp file alive until the reader is done. + let spill_file_handle = self.spill_file.clone(); + + // Use a channel so file I/O runs on a blocking thread and doesn't + // block the async executor. This lets select_all interleave multiple + // partition streams effectively. + let (tx, rx) = mpsc::channel::>(4); + + tokio::task::spawn_blocking(move || { + let _keep_alive = spill_file_handle; + let file = match File::open(&path) { + Ok(f) => f, + Err(e) => { + let _ = tx.blocking_send(Err(DataFusionError::Execution(format!( + "Failed to open spill file: {e}" + )))); + return; + } + }; + let reader = match StreamReader::try_new( + BufReader::with_capacity(SPILL_IO_BUFFER_SIZE, file), + None, + ) { + Ok(r) => r, + Err(e) => { + let _ = tx.blocking_send(Err(DataFusionError::ArrowError(Box::new(e), None))); + return; + } + }; + + // Coalesce small sub-batches into larger ones to reduce per-batch + // overhead in the downstream hash join. + let mut pending: Vec = Vec::new(); + let mut pending_rows = 0usize; + + for batch_result in reader { + let batch = match batch_result { + Ok(b) => b, + Err(e) => { + let _ = + tx.blocking_send(Err(DataFusionError::ArrowError(Box::new(e), None))); + return; + } + }; + if batch.num_rows() == 0 { + continue; + } + pending_rows += batch.num_rows(); + pending.push(batch); + + if pending_rows >= SPILL_READ_COALESCE_TARGET { + let merged = if pending.len() == 1 { + Ok(pending.pop().unwrap()) + } else { + concat_batches(&coalesce_schema, &pending) + .map_err(|e| DataFusionError::ArrowError(Box::new(e), None)) + }; + pending.clear(); + pending_rows = 0; + if tx.blocking_send(merged).is_err() { + return; + } + } + } + + // Flush remaining + if !pending.is_empty() { + let merged = if pending.len() == 1 { + Ok(pending.pop().unwrap()) + } else { + concat_batches(&coalesce_schema, &pending) + .map_err(|e| DataFusionError::ArrowError(Box::new(e), None)) + }; + let _ = tx.blocking_send(merged); + } + }); + + let batch_stream = futures::stream::unfold(rx, |mut rx| async move { + rx.recv().await.map(|batch| (batch, rx)) + }); + Ok(Box::pin(RecordBatchStreamAdapter::new( + schema, + batch_stream, + ))) + } +} + +// --------------------------------------------------------------------------- +// StreamSourceExec: wrap an existing stream as an ExecutionPlan +// --------------------------------------------------------------------------- + +/// An ExecutionPlan that yields batches from a pre-existing stream. +/// Used in the fast path to feed the probe side's live stream into +/// a `HashJoinExec` without buffering or spilling. +struct StreamSourceExec { + stream: Mutex>, + schema: SchemaRef, + cache: PlanProperties, +} + +impl StreamSourceExec { + fn new(stream: SendableRecordBatchStream, schema: SchemaRef) -> Self { + let cache = PlanProperties::new( + EquivalenceProperties::new(Arc::clone(&schema)), + Partitioning::UnknownPartitioning(1), + datafusion::physical_plan::execution_plan::EmissionType::Incremental, + datafusion::physical_plan::execution_plan::Boundedness::Bounded, + ); + Self { + stream: Mutex::new(Some(stream)), + schema, + cache, + } + } +} + +impl fmt::Debug for StreamSourceExec { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("StreamSourceExec").finish() + } +} + +impl DisplayAs for StreamSourceExec { + fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "StreamSourceExec") + } +} + +impl ExecutionPlan for StreamSourceExec { + fn name(&self) -> &str { + "StreamSourceExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> DFResult> { + Ok(self) + } + + fn properties(&self) -> &PlanProperties { + &self.cache + } + + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> DFResult { + self.stream + .lock() + .map_err(|e| DataFusionError::Internal(format!("lock poisoned: {e}")))? + .take() + .ok_or_else(|| { + DataFusionError::Internal("StreamSourceExec: stream already consumed".to_string()) + }) + } +} + +// --------------------------------------------------------------------------- +// GraceHashJoinMetrics +// --------------------------------------------------------------------------- + +/// Production metrics for the Grace Hash Join operator. +struct GraceHashJoinMetrics { + /// Baseline metrics (output rows, elapsed compute) + baseline: BaselineMetrics, + /// Time spent partitioning the build side + build_time: Time, + /// Time spent partitioning the probe side + probe_time: Time, + /// Number of spill events + spill_count: Count, + /// Total bytes spilled to disk + spilled_bytes: Count, + /// Number of build-side input rows + build_input_rows: Count, + /// Number of build-side input batches + build_input_batches: Count, + /// Number of probe-side input rows + input_rows: Count, + /// Number of probe-side input batches + input_batches: Count, +} + +impl GraceHashJoinMetrics { + fn new(metrics: &ExecutionPlanMetricsSet, partition: usize) -> Self { + Self { + baseline: BaselineMetrics::new(metrics, partition), + build_time: MetricBuilder::new(metrics).subset_time("build_time", partition), + probe_time: MetricBuilder::new(metrics).subset_time("probe_time", partition), + spill_count: MetricBuilder::new(metrics).spill_count(partition), + spilled_bytes: MetricBuilder::new(metrics).spilled_bytes(partition), + build_input_rows: MetricBuilder::new(metrics).counter("build_input_rows", partition), + build_input_batches: MetricBuilder::new(metrics) + .counter("build_input_batches", partition), + input_rows: MetricBuilder::new(metrics).counter("input_rows", partition), + input_batches: MetricBuilder::new(metrics).counter("input_batches", partition), + } + } +} + +// --------------------------------------------------------------------------- +// GraceHashJoinExec +// --------------------------------------------------------------------------- + +/// Grace Hash Join execution plan. +/// +/// Partitions both sides into N buckets, then joins each bucket independently +/// using DataFusion's HashJoinExec. Spills partitions to disk when memory +/// pressure is detected. +#[derive(Debug)] +pub struct GraceHashJoinExec { + /// Left input + left: Arc, + /// Right input + right: Arc, + /// Join key pairs: (left_key, right_key) + on: Vec<(Arc, Arc)>, + /// Optional join filter applied after key matching + filter: Option, + /// Join type + join_type: JoinType, + /// Number of hash partitions + num_partitions: usize, + /// Whether left is the build side (true) or right is (false) + build_left: bool, + /// Maximum build-side bytes for the fast path (0 = disabled) + fast_path_threshold: usize, + /// Output schema + schema: SchemaRef, + /// Plan properties cache + cache: PlanProperties, + /// Metrics + metrics: ExecutionPlanMetricsSet, +} + +impl GraceHashJoinExec { + #[allow(clippy::too_many_arguments)] + pub fn try_new( + left: Arc, + right: Arc, + on: Vec<(Arc, Arc)>, + filter: Option, + join_type: &JoinType, + num_partitions: usize, + build_left: bool, + fast_path_threshold: usize, + ) -> DFResult { + // Build the output schema using HashJoinExec's logic. + // HashJoinExec expects left=build, right=probe. When build_left=false, + // we swap inputs + keys + join type for schema derivation, then store + // original values for our own partitioning logic. + let hash_join = HashJoinExec::try_new( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + filter.clone(), + join_type, + None, + PartitionMode::CollectLeft, + NullEquality::NullEqualsNothing, + )?; + let (schema, cache) = if build_left { + (hash_join.schema(), hash_join.properties().clone()) + } else { + // Swap to get correct output schema for build-right + let swapped = hash_join.swap_inputs(PartitionMode::CollectLeft)?; + (swapped.schema(), swapped.properties().clone()) + }; + + Ok(Self { + left, + right, + on, + filter, + join_type: *join_type, + num_partitions: if num_partitions == 0 { + DEFAULT_NUM_PARTITIONS + } else { + num_partitions + }, + build_left, + fast_path_threshold, + schema, + cache, + metrics: ExecutionPlanMetricsSet::new(), + }) + } +} + +impl DisplayAs for GraceHashJoinExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { + match t { + DisplayFormatType::Default + | DisplayFormatType::Verbose + | DisplayFormatType::TreeRender => { + let on: Vec = self.on.iter().map(|(l, r)| format!("({l}, {r})")).collect(); + write!( + f, + "GraceHashJoinExec: join_type={:?}, on=[{}], num_partitions={}", + self.join_type, + on.join(", "), + self.num_partitions, + ) + } + } + } +} + +impl ExecutionPlan for GraceHashJoinExec { + fn name(&self) -> &str { + "GraceHashJoinExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.left, &self.right] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> DFResult> { + Ok(Arc::new(GraceHashJoinExec::try_new( + Arc::clone(&children[0]), + Arc::clone(&children[1]), + self.on.clone(), + self.filter.clone(), + &self.join_type, + self.num_partitions, + self.build_left, + self.fast_path_threshold, + )?)) + } + + fn properties(&self) -> &PlanProperties { + &self.cache + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> DFResult { + info!( + "GraceHashJoin: execute() called. build_left={}, join_type={:?}, \ + num_partitions={}, fast_path_threshold={}\n left: {}\n right: {}", + self.build_left, + self.join_type, + self.num_partitions, + self.fast_path_threshold, + DisplayableExecutionPlan::new(self.left.as_ref()).one_line(), + DisplayableExecutionPlan::new(self.right.as_ref()).one_line(), + ); + let left_stream = self.left.execute(partition, Arc::clone(&context))?; + let right_stream = self.right.execute(partition, Arc::clone(&context))?; + + let join_metrics = GraceHashJoinMetrics::new(&self.metrics, partition); + + // Determine build/probe streams and schemas based on build_left. + // The internal execution always treats first arg as build, second as probe. + let (build_stream, probe_stream, build_schema, probe_schema, build_on, probe_on) = + if self.build_left { + let build_keys: Vec<_> = self.on.iter().map(|(l, _)| Arc::clone(l)).collect(); + let probe_keys: Vec<_> = self.on.iter().map(|(_, r)| Arc::clone(r)).collect(); + ( + left_stream, + right_stream, + self.left.schema(), + self.right.schema(), + build_keys, + probe_keys, + ) + } else { + // Build right: right is build side, left is probe side + let build_keys: Vec<_> = self.on.iter().map(|(_, r)| Arc::clone(r)).collect(); + let probe_keys: Vec<_> = self.on.iter().map(|(l, _)| Arc::clone(l)).collect(); + ( + right_stream, + left_stream, + self.right.schema(), + self.left.schema(), + build_keys, + probe_keys, + ) + }; + + let on = self.on.clone(); + let filter = self.filter.clone(); + let join_type = self.join_type; + let num_partitions = self.num_partitions; + let build_left = self.build_left; + let fast_path_threshold = self.fast_path_threshold; + let output_schema = Arc::clone(&self.schema); + + let result_stream = futures::stream::once(async move { + execute_grace_hash_join( + build_stream, + probe_stream, + build_on, + probe_on, + on, + filter, + join_type, + num_partitions, + build_left, + fast_path_threshold, + build_schema, + probe_schema, + output_schema, + context, + join_metrics, + ) + .await + }) + .try_flatten(); + + Ok(Box::pin(RecordBatchStreamAdapter::new( + Arc::clone(&self.schema), + result_stream, + ))) + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } +} + +// --------------------------------------------------------------------------- +// Per-partition state +// --------------------------------------------------------------------------- + +/// Per-partition state tracking buffered data or spill writers. +struct HashPartition { + /// In-memory build-side batches for this partition. + build_batches: Vec, + /// In-memory probe-side batches for this partition. + probe_batches: Vec, + /// Incremental spill writer for build side (if spilling). + build_spill_writer: Option, + /// Incremental spill writer for probe side (if spilling). + probe_spill_writer: Option, + /// Approximate memory used by build-side batches in this partition. + build_mem_size: usize, + /// Approximate memory used by probe-side batches in this partition. + probe_mem_size: usize, +} + +impl HashPartition { + fn new() -> Self { + Self { + build_batches: Vec::new(), + probe_batches: Vec::new(), + build_spill_writer: None, + probe_spill_writer: None, + build_mem_size: 0, + probe_mem_size: 0, + } + } + + /// Whether the build side has been spilled to disk. + fn build_spilled(&self) -> bool { + self.build_spill_writer.is_some() + } +} + +// --------------------------------------------------------------------------- +// Main execution logic +// --------------------------------------------------------------------------- + +/// Main execution logic for the grace hash join. +/// +/// `build_stream`/`probe_stream`: already swapped based on build_left. +/// `build_keys`/`probe_keys`: key expressions for their respective sides. +/// `original_on`: original (left_key, right_key) pairs for HashJoinExec. +/// `build_left`: whether left is build side (affects HashJoinExec construction). +#[allow(clippy::too_many_arguments)] +async fn execute_grace_hash_join( + build_stream: SendableRecordBatchStream, + probe_stream: SendableRecordBatchStream, + build_keys: Vec>, + probe_keys: Vec>, + original_on: Vec<(Arc, Arc)>, + filter: Option, + join_type: JoinType, + num_partitions: usize, + build_left: bool, + fast_path_threshold: usize, + build_schema: SchemaRef, + probe_schema: SchemaRef, + _output_schema: SchemaRef, + context: Arc, + metrics: GraceHashJoinMetrics, +) -> DFResult>> { + let ghj_id = GHJ_INSTANCE_COUNTER.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + + // Set up memory reservation (shared across build and probe phases) + let mut reservation = MutableReservation( + MemoryConsumer::new("GraceHashJoinExec") + .with_can_spill(true) + .register(&context.runtime_env().memory_pool), + ); + + info!( + "GHJ#{}: started. build_left={}, join_type={:?}, pool reserved={}", + ghj_id, + build_left, + join_type, + context.runtime_env().memory_pool.reserved(), + ); + + let mut partitions: Vec = + (0..num_partitions).map(|_| HashPartition::new()).collect(); + + let mut scratch = ScratchSpace::default(); + + // Phase 1: Partition the build side + { + let _timer = metrics.build_time.timer(); + partition_build_side( + build_stream, + &build_keys, + num_partitions, + &build_schema, + &mut partitions, + &mut reservation, + &context, + &metrics, + &mut scratch, + ) + .await?; + } + + // Log build-side partition summary + { + let pool = &context.runtime_env().memory_pool; + let total_build_rows: usize = partitions + .iter() + .flat_map(|p| p.build_batches.iter()) + .map(|b| b.num_rows()) + .sum(); + let total_build_bytes: usize = partitions.iter().map(|p| p.build_mem_size).sum(); + let spilled_count = partitions.iter().filter(|p| p.build_spilled()).count(); + info!( + "GraceHashJoin: build phase complete. {} partitions ({} spilled), \ + total build: {} rows, {} bytes. Memory pool reserved={}", + num_partitions, + spilled_count, + total_build_rows, + total_build_bytes, + pool.reserved(), + ); + for (i, p) in partitions.iter().enumerate() { + if !p.build_batches.is_empty() || p.build_spilled() { + let rows: usize = p.build_batches.iter().map(|b| b.num_rows()).sum(); + info!( + "GraceHashJoin: partition[{}] build: {} batches, {} rows, {} bytes, spilled={}", + i, + p.build_batches.len(), + rows, + p.build_mem_size, + p.build_spilled(), + ); + } + } + } + + // Fast path: if no build partitions spilled and the build side is + // genuinely tiny, skip probe partitioning and stream the probe directly + // through a single HashJoinExec. This avoids spilling gigabytes of + // probe data to disk for a trivial hash table (e.g. 10-row build side). + // + // The threshold uses actual batch sizes (not the unreliable proportional + // estimate). The configured value is divided by spark.executor.cores in + // the planner so each concurrent task gets its fair share. + // Configurable via spark.comet.exec.graceHashJoin.fastPathThreshold. + + let build_spilled = partitions.iter().any(|p| p.build_spilled()); + let actual_build_bytes: usize = partitions + .iter() + .flat_map(|p| p.build_batches.iter()) + .map(|b| b.get_array_memory_size()) + .sum(); + + if !build_spilled && fast_path_threshold > 0 && actual_build_bytes <= fast_path_threshold { + let total_build_rows: usize = partitions + .iter() + .flat_map(|p| p.build_batches.iter()) + .map(|b| b.num_rows()) + .sum(); + info!( + "GHJ#{}: fast path — build side tiny ({} rows, {} bytes). \ + Streaming probe directly through HashJoinExec. pool reserved={}", + ghj_id, + total_build_rows, + actual_build_bytes, + context.runtime_env().memory_pool.reserved(), + ); + + // Release our reservation — HashJoinExec tracks its own memory. + reservation.free(); + + let build_data: Vec = partitions + .into_iter() + .flat_map(|p| p.build_batches) + .collect(); + + let build_source = memory_source_exec(build_data, &build_schema)?; + + let probe_source: Arc = Arc::new(StreamSourceExec::new( + probe_stream, + Arc::clone(&probe_schema), + )); + + let (left_source, right_source): (Arc, Arc) = + if build_left { + (build_source, probe_source) + } else { + (probe_source, build_source) + }; + + info!( + "GraceHashJoin: FAST PATH creating HashJoinExec, \ + build_left={}, actual_build_bytes={}", + build_left, actual_build_bytes, + ); + + let stream = if build_left { + let hash_join = HashJoinExec::try_new( + left_source, + right_source, + original_on, + filter, + &join_type, + None, + PartitionMode::CollectLeft, + NullEquality::NullEqualsNothing, + )?; + info!( + "GraceHashJoin: FAST PATH plan:\n{}", + DisplayableExecutionPlan::new(&hash_join).indent(true) + ); + hash_join.execute(0, context_for_join_output(&context))? + } else { + let hash_join = Arc::new(HashJoinExec::try_new( + left_source, + right_source, + original_on, + filter, + &join_type, + None, + PartitionMode::CollectLeft, + NullEquality::NullEqualsNothing, + )?); + let swapped = hash_join.swap_inputs(PartitionMode::CollectLeft)?; + info!( + "GraceHashJoin: FAST PATH (swapped) plan:\n{}", + DisplayableExecutionPlan::new(swapped.as_ref()).indent(true) + ); + swapped.execute(0, context_for_join_output(&context))? + }; + + let output_metrics = metrics.baseline.clone(); + let result_stream = stream.inspect_ok(move |batch| { + output_metrics.record_output(batch.num_rows()); + }); + + return Ok(result_stream.boxed()); + } + + let total_build_rows: usize = partitions + .iter() + .flat_map(|p| p.build_batches.iter()) + .map(|b| b.num_rows()) + .sum(); + info!( + "GHJ#{}: slow path — build spilled={}, {} rows, {} bytes (actual). \ + join_type={:?}, build_left={}. pool reserved={}. Partitioning probe side.", + ghj_id, + build_spilled, + total_build_rows, + actual_build_bytes, + join_type, + build_left, + context.runtime_env().memory_pool.reserved(), + ); + + // Phase 2: Partition the probe side + { + let _timer = metrics.probe_time.timer(); + partition_probe_side( + probe_stream, + &probe_keys, + num_partitions, + &probe_schema, + &mut partitions, + &mut reservation, + &build_schema, + &context, + &metrics, + &mut scratch, + ) + .await?; + } + + // Log probe-side partition summary + { + let total_probe_rows: usize = partitions + .iter() + .flat_map(|p| p.probe_batches.iter()) + .map(|b| b.num_rows()) + .sum(); + let total_probe_bytes: usize = partitions.iter().map(|p| p.probe_mem_size).sum(); + let probe_spilled = partitions + .iter() + .filter(|p| p.probe_spill_writer.is_some()) + .count(); + info!( + "GHJ#{}: probe phase complete. \ + total probe (in-memory): {} rows, {} bytes, {} spilled. \ + reservation={}, pool reserved={}", + ghj_id, + total_probe_rows, + total_probe_bytes, + probe_spilled, + reservation.0.size(), + context.runtime_env().memory_pool.reserved(), + ); + } + + // Finish all open spill writers before reading back + let finished_partitions = + finish_spill_writers(partitions, &build_schema, &probe_schema, &metrics)?; + + // Merge adjacent partitions to reduce the number of HashJoinExec calls. + // Compute desired partition count from total build bytes. + let total_build_bytes: usize = finished_partitions.iter().map(|p| p.build_bytes).sum(); + let desired_partitions = if total_build_bytes > 0 { + let desired = total_build_bytes.div_ceil(TARGET_PARTITION_BUILD_SIZE); + desired.max(1).min(num_partitions) + } else { + 1 + }; + let original_partition_count = finished_partitions.len(); + let finished_partitions = merge_finished_partitions(finished_partitions, desired_partitions); + if finished_partitions.len() < original_partition_count { + info!( + "GraceHashJoin: merged {} partitions into {} (total build {} bytes, \ + target {} bytes/partition)", + original_partition_count, + finished_partitions.len(), + total_build_bytes, + TARGET_PARTITION_BUILD_SIZE, + ); + } + + // Release all remaining reservation before Phase 3. The in-memory + // partition data is now owned by finished_partitions and will be moved + // into per-partition HashJoinExec instances (which track memory via + // their own HashJoinInput reservations). Keeping our reservation alive + // would double-count the memory and starve other consumers. + info!( + "GHJ#{}: freeing reservation ({} bytes) before Phase 3. pool reserved={}", + ghj_id, + reservation.0.size(), + context.runtime_env().memory_pool.reserved(), + ); + reservation.free(); + + // Phase 3: Join partitions sequentially. + // We use a concurrency limit of 1 to avoid creating multiple simultaneous + // HashJoinInput reservations per task. With multiple Spark tasks sharing + // the same memory pool, even modest build sides (e.g. 22 MB) can exhaust + // memory when many tasks run concurrent hash table builds simultaneously. + const MAX_CONCURRENT_PARTITIONS: usize = 1; + let semaphore = Arc::new(tokio::sync::Semaphore::new(MAX_CONCURRENT_PARTITIONS)); + let (tx, rx) = mpsc::channel::>(MAX_CONCURRENT_PARTITIONS * 2); + + for partition in finished_partitions { + let tx = tx.clone(); + let sem = Arc::clone(&semaphore); + let original_on = original_on.clone(); + let filter = filter.clone(); + let build_schema = Arc::clone(&build_schema); + let probe_schema = Arc::clone(&probe_schema); + let context = Arc::clone(&context); + + tokio::spawn(async move { + let _permit = match sem.acquire().await { + Ok(p) => p, + Err(_) => return, // semaphore closed + }; + match join_single_partition( + partition, + original_on, + filter, + join_type, + build_left, + build_schema, + probe_schema, + context, + ) + .await + { + Ok(streams) => { + for mut stream in streams { + while let Some(batch) = stream.next().await { + if tx.send(batch).await.is_err() { + return; + } + } + } + } + Err(e) => { + let _ = tx.send(Err(e)).await; + } + } + }); + } + drop(tx); + + let output_metrics = metrics.baseline.clone(); + let output_row_count = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0)); + let counter = Arc::clone(&output_row_count); + let jt = join_type; + let bl = build_left; + let result_stream = futures::stream::unfold(rx, |mut rx| async move { + rx.recv().await.map(|batch| (batch, rx)) + }) + .inspect_ok(move |batch| { + output_metrics.record_output(batch.num_rows()); + let prev = counter.fetch_add(batch.num_rows(), std::sync::atomic::Ordering::Relaxed); + let new_total = prev + batch.num_rows(); + // Log every ~1M rows to detect exploding joins + if new_total / 1_000_000 > prev / 1_000_000 { + info!( + "GraceHashJoin: slow path output: {} rows emitted so far \ + (join_type={:?}, build_left={})", + new_total, jt, bl, + ); + } + }); + + Ok(result_stream.boxed()) +} + +/// Wraps MemoryReservation to allow mutation through reference. +struct MutableReservation(MemoryReservation); + +impl MutableReservation { + fn try_grow(&mut self, additional: usize) -> DFResult<()> { + self.0.try_grow(additional) + } + + fn shrink(&mut self, amount: usize) { + self.0.shrink(amount); + } + + fn free(&mut self) -> usize { + self.0.free() + } +} + +// --------------------------------------------------------------------------- +// ScratchSpace: reusable buffers for efficient hash partitioning +// --------------------------------------------------------------------------- + +/// Reusable scratch buffers for partitioning batches. Uses a prefix-sum +/// algorithm (borrowed from the shuffle `multi_partition.rs`) to compute +/// contiguous row-index regions per partition in a single pass, avoiding +/// N separate `take()` kernel calls. +#[derive(Default)] +struct ScratchSpace { + /// Hash values for each row. + hashes: Vec, + /// Partition id assigned to each row. + partition_ids: Vec, + /// Row indices reordered so that each partition's rows are contiguous. + partition_row_indices: Vec, + /// `partition_starts[k]..partition_starts[k+1]` gives the slice of + /// `partition_row_indices` belonging to partition k. + partition_starts: Vec, +} + +impl ScratchSpace { + /// Compute hashes and partition ids, then build the prefix-sum index + /// structures for the given batch. + fn compute_partitions( + &mut self, + batch: &RecordBatch, + keys: &[Arc], + num_partitions: usize, + recursion_level: usize, + ) -> DFResult<()> { + let num_rows = batch.num_rows(); + + // Evaluate key columns + let key_columns: Vec<_> = keys + .iter() + .map(|expr| expr.evaluate(batch).and_then(|cv| cv.into_array(num_rows))) + .collect::>>()?; + + // Hash + self.hashes.resize(num_rows, 0); + self.hashes.truncate(num_rows); + self.hashes.fill(0); + let random_state = partition_random_state(recursion_level); + create_hashes(&key_columns, &random_state, &mut self.hashes)?; + + // Assign partition ids + self.partition_ids.resize(num_rows, 0); + for (i, hash) in self.hashes[..num_rows].iter().enumerate() { + self.partition_ids[i] = (*hash as u32) % (num_partitions as u32); + } + + // Prefix-sum to get contiguous regions + self.map_partition_ids_to_starts_and_indices(num_partitions, num_rows); + + Ok(()) + } + + /// Prefix-sum algorithm from `multi_partition.rs`. + fn map_partition_ids_to_starts_and_indices(&mut self, num_partitions: usize, num_rows: usize) { + let partition_ids = &self.partition_ids[..num_rows]; + + // Count each partition size + let partition_counters = &mut self.partition_starts; + partition_counters.resize(num_partitions + 1, 0); + partition_counters.fill(0); + partition_ids + .iter() + .for_each(|pid| partition_counters[*pid as usize] += 1); + + // Accumulate into partition ends + let mut accum = 0u32; + for v in partition_counters.iter_mut() { + *v += accum; + accum = *v; + } + + // Build partition_row_indices (iterate in reverse to turn ends into starts) + self.partition_row_indices.resize(num_rows, 0); + for (index, pid) in partition_ids.iter().enumerate().rev() { + self.partition_starts[*pid as usize] -= 1; + let pos = self.partition_starts[*pid as usize]; + self.partition_row_indices[pos as usize] = index as u32; + } + } + + /// Get the row index slice for a given partition. + fn partition_slice(&self, partition_id: usize) -> &[u32] { + let start = self.partition_starts[partition_id] as usize; + let end = self.partition_starts[partition_id + 1] as usize; + &self.partition_row_indices[start..end] + } + + /// Number of rows in a given partition. + fn partition_len(&self, partition_id: usize) -> usize { + (self.partition_starts[partition_id + 1] - self.partition_starts[partition_id]) as usize + } + + fn take_partition( + &self, + batch: &RecordBatch, + partition_id: usize, + ) -> DFResult> { + let row_indices = self.partition_slice(partition_id); + if row_indices.is_empty() { + return Ok(None); + } + let indices_array = UInt32Array::from(row_indices.to_vec()); + let columns: Vec<_> = batch + .columns() + .iter() + .map(|col| take(col.as_ref(), &indices_array, None)) + .collect::, _>>()?; + Ok(Some(RecordBatch::try_new(batch.schema(), columns)?)) + } +} + +// --------------------------------------------------------------------------- +// Spill reading +// --------------------------------------------------------------------------- + +/// Read record batches from a finished spill file. +fn read_spilled_batches( + spill_file: &RefCountedTempFile, + _schema: &SchemaRef, +) -> DFResult> { + let file = File::open(spill_file.path()) + .map_err(|e| DataFusionError::Execution(format!("Failed to open spill file: {e}")))?; + let reader = BufReader::with_capacity(SPILL_IO_BUFFER_SIZE, file); + let stream_reader = StreamReader::try_new(reader, None)?; + let batches: Vec = stream_reader.into_iter().collect::, _>>()?; + Ok(batches) +} + +// --------------------------------------------------------------------------- +// Phase 1: Build-side partitioning +// --------------------------------------------------------------------------- + +/// Phase 1: Read all build-side batches, hash-partition into N buckets. +/// Spills the largest partition when memory pressure is detected. +#[allow(clippy::too_many_arguments)] +async fn partition_build_side( + mut input: SendableRecordBatchStream, + keys: &[Arc], + num_partitions: usize, + schema: &SchemaRef, + partitions: &mut [HashPartition], + reservation: &mut MutableReservation, + context: &Arc, + metrics: &GraceHashJoinMetrics, + scratch: &mut ScratchSpace, +) -> DFResult<()> { + while let Some(batch) = input.next().await { + let batch = batch?; + if batch.num_rows() == 0 { + continue; + } + + metrics.build_input_batches.add(1); + metrics.build_input_rows.add(batch.num_rows()); + + // Track total batch size once, estimate per-partition proportionally + let total_batch_size = batch.get_array_memory_size(); + let total_rows = batch.num_rows(); + + scratch.compute_partitions(&batch, keys, num_partitions, 0)?; + + #[allow(clippy::needless_range_loop)] + for part_idx in 0..num_partitions { + if scratch.partition_len(part_idx) == 0 { + continue; + } + + let sub_rows = scratch.partition_len(part_idx); + let sub_batch = if sub_rows == total_rows { + batch.clone() + } else { + scratch.take_partition(&batch, part_idx)?.unwrap() + }; + let batch_size = if total_rows > 0 { + (total_batch_size as u64 * sub_rows as u64 / total_rows as u64) as usize + } else { + 0 + }; + + if partitions[part_idx].build_spilled() { + // This partition is already spilled; append incrementally + if let Some(ref mut writer) = partitions[part_idx].build_spill_writer { + writer.write_batch(&sub_batch)?; + } + } else { + // Try to reserve memory + if reservation.try_grow(batch_size).is_err() { + // Memory pressure: spill the largest in-memory partition + info!( + "GraceHashJoin: memory pressure during build, spilling largest partition" + ); + spill_largest_partition(partitions, schema, context, reservation, metrics)?; + + // Retry reservation after spilling + if reservation.try_grow(batch_size).is_err() { + // Still can't fit; spill this partition too + info!( + "GraceHashJoin: still under pressure, spilling partition {}", + part_idx + ); + spill_partition_build( + &mut partitions[part_idx], + schema, + context, + reservation, + metrics, + )?; + if let Some(ref mut writer) = partitions[part_idx].build_spill_writer { + writer.write_batch(&sub_batch)?; + } + continue; + } + } + + partitions[part_idx].build_mem_size += batch_size; + partitions[part_idx].build_batches.push(sub_batch); + } + } + } + + Ok(()) +} + +/// Spill the largest in-memory build partition to disk. +fn spill_largest_partition( + partitions: &mut [HashPartition], + schema: &SchemaRef, + context: &Arc, + reservation: &mut MutableReservation, + metrics: &GraceHashJoinMetrics, +) -> DFResult<()> { + // Find the largest non-spilled partition + let largest_idx = partitions + .iter() + .enumerate() + .filter(|(_, p)| !p.build_spilled() && !p.build_batches.is_empty()) + .max_by_key(|(_, p)| p.build_mem_size) + .map(|(idx, _)| idx); + + if let Some(idx) = largest_idx { + info!( + "GraceHashJoin: spilling partition {} ({} bytes, {} batches)", + idx, + partitions[idx].build_mem_size, + partitions[idx].build_batches.len() + ); + spill_partition_build(&mut partitions[idx], schema, context, reservation, metrics)?; + } + + Ok(()) +} + +/// Spill a single partition's build-side data to disk using SpillWriter. +fn spill_partition_build( + partition: &mut HashPartition, + schema: &SchemaRef, + context: &Arc, + reservation: &mut MutableReservation, + metrics: &GraceHashJoinMetrics, +) -> DFResult<()> { + let temp_file = context + .runtime_env() + .disk_manager + .create_tmp_file("grace hash join build")?; + + let mut writer = SpillWriter::new(temp_file, schema)?; + writer.write_batches(&partition.build_batches)?; + + // Free memory + let freed = partition.build_mem_size; + reservation.shrink(freed); + + metrics.spill_count.add(1); + metrics.spilled_bytes.add(freed); + + partition.build_spill_writer = Some(writer); + partition.build_batches.clear(); + partition.build_mem_size = 0; + + Ok(()) +} + +/// Spill a single partition's probe-side data to disk using SpillWriter. +fn spill_partition_probe( + partition: &mut HashPartition, + schema: &SchemaRef, + context: &Arc, + reservation: &mut MutableReservation, + metrics: &GraceHashJoinMetrics, +) -> DFResult<()> { + if partition.probe_batches.is_empty() && partition.probe_spill_writer.is_some() { + return Ok(()); + } + + let temp_file = context + .runtime_env() + .disk_manager + .create_tmp_file("grace hash join probe")?; + + let mut writer = SpillWriter::new(temp_file, schema)?; + writer.write_batches(&partition.probe_batches)?; + + let freed = partition.probe_mem_size; + reservation.shrink(freed); + + metrics.spill_count.add(1); + metrics.spilled_bytes.add(freed); + + partition.probe_spill_writer = Some(writer); + partition.probe_batches.clear(); + partition.probe_mem_size = 0; + + Ok(()) +} + +/// Spill both build and probe sides of a partition to disk. +/// When spilling during the probe phase, both sides must be spilled so the +/// join phase reads both consistently from disk. +fn spill_partition_both_sides( + partition: &mut HashPartition, + probe_schema: &SchemaRef, + build_schema: &SchemaRef, + context: &Arc, + reservation: &mut MutableReservation, + metrics: &GraceHashJoinMetrics, +) -> DFResult<()> { + if !partition.build_spilled() { + spill_partition_build(partition, build_schema, context, reservation, metrics)?; + } + if partition.probe_spill_writer.is_none() { + spill_partition_probe(partition, probe_schema, context, reservation, metrics)?; + } + Ok(()) +} + +// --------------------------------------------------------------------------- +// Phase 2: Probe-side partitioning +// --------------------------------------------------------------------------- + +/// Phase 2: Read all probe-side batches, route to in-memory buffers or spill files. +/// Tracks probe-side memory in the reservation and spills partitions when pressure +/// is detected, preventing OOM when the probe side is much larger than the build side. +#[allow(clippy::too_many_arguments)] +async fn partition_probe_side( + mut input: SendableRecordBatchStream, + keys: &[Arc], + num_partitions: usize, + schema: &SchemaRef, + partitions: &mut [HashPartition], + reservation: &mut MutableReservation, + build_schema: &SchemaRef, + context: &Arc, + metrics: &GraceHashJoinMetrics, + scratch: &mut ScratchSpace, +) -> DFResult<()> { + let mut probe_rows_accumulated: usize = 0; + while let Some(batch) = input.next().await { + let batch = batch?; + if batch.num_rows() == 0 { + continue; + } + let prev_milestone = probe_rows_accumulated / 5_000_000; + probe_rows_accumulated += batch.num_rows(); + let new_milestone = probe_rows_accumulated / 5_000_000; + if new_milestone > prev_milestone { + info!( + "GraceHashJoin: probe accumulation progress: {} rows, \ + reservation={}, pool reserved={}", + probe_rows_accumulated, + reservation.0.size(), + context.runtime_env().memory_pool.reserved(), + ); + } + + metrics.input_batches.add(1); + metrics.input_rows.add(batch.num_rows()); + + let total_rows = batch.num_rows(); + scratch.compute_partitions(&batch, keys, num_partitions, 0)?; + + #[allow(clippy::needless_range_loop)] + for part_idx in 0..num_partitions { + if scratch.partition_len(part_idx) == 0 { + continue; + } + + let sub_batch = if scratch.partition_len(part_idx) == total_rows { + batch.clone() + } else { + scratch.take_partition(&batch, part_idx)?.unwrap() + }; + + if partitions[part_idx].build_spilled() { + // Build side was spilled, so spill probe side too + if partitions[part_idx].probe_spill_writer.is_none() { + let temp_file = context + .runtime_env() + .disk_manager + .create_tmp_file("grace hash join probe")?; + let mut writer = SpillWriter::new(temp_file, schema)?; + // Write any accumulated in-memory probe batches first + if !partitions[part_idx].probe_batches.is_empty() { + let freed = partitions[part_idx].probe_mem_size; + let batches = std::mem::take(&mut partitions[part_idx].probe_batches); + writer.write_batches(&batches)?; + partitions[part_idx].probe_mem_size = 0; + reservation.shrink(freed); + } + partitions[part_idx].probe_spill_writer = Some(writer); + } + if let Some(ref mut writer) = partitions[part_idx].probe_spill_writer { + writer.write_batch(&sub_batch)?; + } + } else { + let batch_size = sub_batch.get_array_memory_size(); + if reservation.try_grow(batch_size).is_err() { + // Memory pressure: spill ALL non-spilled partitions. + // With multiple concurrent GHJ instances sharing the pool, + // partial spilling just lets data re-accumulate. Spilling + // everything ensures all subsequent probe data goes directly + // to disk, keeping in-memory footprint near zero. + let total_in_memory: usize = partitions + .iter() + .filter(|p| !p.build_spilled()) + .map(|p| p.build_mem_size + p.probe_mem_size) + .sum(); + let spillable_count = partitions.iter().filter(|p| !p.build_spilled()).count(); + + info!( + "GraceHashJoin: memory pressure during probe, \ + spilling all {} non-spilled partitions ({} bytes)", + spillable_count, total_in_memory, + ); + + for i in 0..partitions.len() { + if !partitions[i].build_spilled() { + spill_partition_both_sides( + &mut partitions[i], + schema, + build_schema, + context, + reservation, + metrics, + )?; + } + } + } + + if partitions[part_idx].build_spilled() { + // Partition was just spilled above — write to spill writer + if partitions[part_idx].probe_spill_writer.is_none() { + let temp_file = context + .runtime_env() + .disk_manager + .create_tmp_file("grace hash join probe")?; + partitions[part_idx].probe_spill_writer = + Some(SpillWriter::new(temp_file, schema)?); + } + if let Some(ref mut writer) = partitions[part_idx].probe_spill_writer { + writer.write_batch(&sub_batch)?; + } + } else { + partitions[part_idx].probe_mem_size += batch_size; + partitions[part_idx].probe_batches.push(sub_batch); + } + } + } + } + + Ok(()) +} + +// --------------------------------------------------------------------------- +// Finish spill writers +// --------------------------------------------------------------------------- + +/// State of a finished partition ready for joining. +/// After merging, a partition may hold multiple spill files from adjacent +/// original partitions. +struct FinishedPartition { + build_batches: Vec, + probe_batches: Vec, + build_spill_files: Vec, + probe_spill_files: Vec, + /// Total build-side bytes (in-memory + spilled) for merge decisions. + build_bytes: usize, +} + +/// Finish all open spill writers so files can be read back. +fn finish_spill_writers( + partitions: Vec, + _left_schema: &SchemaRef, + _right_schema: &SchemaRef, + _metrics: &GraceHashJoinMetrics, +) -> DFResult> { + let mut finished = Vec::with_capacity(partitions.len()); + + for partition in partitions { + let (build_spill_files, spilled_build_bytes) = + if let Some(writer) = partition.build_spill_writer { + let (file, bytes) = writer.finish()?; + (vec![file], bytes) + } else { + (vec![], 0) + }; + + let probe_spill_files = if let Some(writer) = partition.probe_spill_writer { + let (file, _bytes) = writer.finish()?; + vec![file] + } else { + vec![] + }; + + finished.push(FinishedPartition { + build_bytes: partition.build_mem_size + spilled_build_bytes, + build_batches: partition.build_batches, + probe_batches: partition.probe_batches, + build_spill_files, + probe_spill_files, + }); + } + + Ok(finished) +} + +/// Merge adjacent finished partitions to reduce the number of per-partition +/// HashJoinExec calls. Groups adjacent partitions so each merged group has +/// roughly `TARGET_PARTITION_BUILD_SIZE` bytes of build data. +fn merge_finished_partitions( + partitions: Vec, + target_count: usize, +) -> Vec { + let original_count = partitions.len(); + if target_count >= original_count { + return partitions; + } + + // Divide original_count partitions into target_count groups as evenly as possible + let base_group_size = original_count / target_count; + let remainder = original_count % target_count; + + let mut merged = Vec::with_capacity(target_count); + let mut iter = partitions.into_iter(); + + for group_idx in 0..target_count { + // First `remainder` groups get one extra partition + let group_size = base_group_size + if group_idx < remainder { 1 } else { 0 }; + + let mut build_batches = Vec::new(); + let mut probe_batches = Vec::new(); + let mut build_spill_files = Vec::new(); + let mut probe_spill_files = Vec::new(); + let mut build_bytes = 0usize; + + for _ in 0..group_size { + if let Some(p) = iter.next() { + build_batches.extend(p.build_batches); + probe_batches.extend(p.probe_batches); + build_spill_files.extend(p.build_spill_files); + probe_spill_files.extend(p.probe_spill_files); + build_bytes += p.build_bytes; + } + } + + merged.push(FinishedPartition { + build_batches, + probe_batches, + build_spill_files, + probe_spill_files, + build_bytes, + }); + } + + merged +} + +// --------------------------------------------------------------------------- +// Phase 3: Per-partition hash joins +// --------------------------------------------------------------------------- + +/// The output batch size for HashJoinExec within GHJ. +/// +/// With the default Comet batch size (8192), HashJoinExec produces thousands +/// of small output batches, causing significant per-batch overhead for large +/// joins (e.g., 150M output rows = 18K batches at 8192). +/// +/// 1M rows gives ~150 batches for a 150M row join — enough to avoid +/// per-batch overhead while keeping each output batch at a few hundred MB. +/// Cannot use `usize::MAX` because HashJoinExec pre-allocates Vec with +/// capacity = batch_size in `get_matched_indices_with_limit_offset`. +/// Cannot use 10M+ because output batches become multi-GB and cause OOM. +const GHJ_OUTPUT_BATCH_SIZE: usize = 1_000_000; + +/// Create a TaskContext with a larger output batch size for HashJoinExec. +/// +/// Input splitting is handled by StreamSourceExec (not batch_size). +fn context_for_join_output(context: &Arc) -> Arc { + let batch_size = GHJ_OUTPUT_BATCH_SIZE.max(context.session_config().batch_size()); + Arc::new(TaskContext::new( + context.task_id(), + context.session_id(), + context.session_config().clone().with_batch_size(batch_size), + context.scalar_functions().clone(), + context.aggregate_functions().clone(), + context.window_functions().clone(), + context.runtime_env(), + )) +} + +/// Create a `StreamSourceExec` that yields `data` batches without splitting. +/// +/// Unlike `DataSourceExec(MemorySourceConfig)`, `StreamSourceExec` does NOT +/// wrap its output in `BatchSplitStream`. This is critical for the build side +/// because Arrow's zero-copy `batch.slice()` shares underlying buffers, so +/// `get_record_batch_memory_size()` reports the full buffer size for every +/// slice — causing `collect_left_input` to vastly over-count memory and +/// trigger spurious OOM. Additionally, using `batch_size` large enough to +/// prevent splitting can cause Arrow i32 offset overflow for string columns. +fn memory_source_exec( + data: Vec, + schema: &SchemaRef, +) -> DFResult> { + let schema_clone = Arc::clone(schema); + let stream = + RecordBatchStreamAdapter::new(Arc::clone(schema), stream::iter(data.into_iter().map(Ok))); + Ok(Arc::new(StreamSourceExec::new( + Box::pin(stream), + schema_clone, + ))) +} + +/// Join a single partition: reads build-side spill (if any) via spawn_blocking, +/// then delegates to `join_with_spilled_probe` or `join_partition_recursive`. +/// Returns the resulting streams for this partition. +/// +/// Takes all owned data so it can be called inside `tokio::spawn`. +#[allow(clippy::too_many_arguments)] +async fn join_single_partition( + partition: FinishedPartition, + original_on: Vec<(Arc, Arc)>, + filter: Option, + join_type: JoinType, + build_left: bool, + build_schema: SchemaRef, + probe_schema: SchemaRef, + context: Arc, +) -> DFResult> { + // Get build-side batches (from memory or disk — build side is typically small). + // Use spawn_blocking for spill reads to avoid blocking the async executor. + let mut build_batches = partition.build_batches; + if !partition.build_spill_files.is_empty() { + let schema = Arc::clone(&build_schema); + let spill_files = partition.build_spill_files; + let spilled = tokio::task::spawn_blocking(move || { + let mut all = Vec::new(); + for spill_file in &spill_files { + all.extend(read_spilled_batches(spill_file, &schema)?); + } + Ok::<_, DataFusionError>(all) + }) + .await + .map_err(|e| { + DataFusionError::Execution(format!("GraceHashJoin: build spill read task failed: {e}")) + })??; + build_batches.extend(spilled); + } + + // Coalesce many tiny sub-batches into single batches to reduce per-batch + // overhead in HashJoinExec. Per-partition data is bounded by + // TARGET_PARTITION_BUILD_SIZE so concat won't hit i32 offset overflow. + let build_batches = if build_batches.len() > 1 { + vec![concat_batches(&build_schema, &build_batches)?] + } else { + build_batches + }; + + let mut streams = Vec::new(); + + if !partition.probe_spill_files.is_empty() { + // Probe side has spill file(s). Also include any in-memory probe + // batches (possible after merging adjacent partitions). + join_with_spilled_probe( + build_batches, + partition.probe_spill_files, + partition.probe_batches, + &original_on, + &filter, + &join_type, + build_left, + &build_schema, + &probe_schema, + &context, + &mut streams, + )?; + } else { + // Probe side is in-memory: coalesce before joining + let probe_batches = if partition.probe_batches.len() > 1 { + vec![concat_batches(&probe_schema, &partition.probe_batches)?] + } else { + partition.probe_batches + }; + join_partition_recursive( + build_batches, + probe_batches, + &original_on, + &filter, + &join_type, + build_left, + &build_schema, + &probe_schema, + &context, + 1, + &mut streams, + )?; + } + + Ok(streams) +} + +/// Join a partition where the probe side was spilled to disk. +/// Uses SpillReaderExec to stream probe data from the spill file instead of +/// loading it all into memory. The build side (typically small) is loaded +/// into a MemorySourceConfig for the hash table. +#[allow(clippy::too_many_arguments)] +fn join_with_spilled_probe( + build_batches: Vec, + probe_spill_files: Vec, + probe_in_memory: Vec, + original_on: JoinOnRef<'_>, + filter: &Option, + join_type: &JoinType, + build_left: bool, + build_schema: &SchemaRef, + probe_schema: &SchemaRef, + context: &Arc, + streams: &mut Vec, +) -> DFResult<()> { + let probe_spill_files_count = probe_spill_files.len(); + + // Skip if build side is empty and join type requires it + let build_empty = build_batches.is_empty(); + let skip = match join_type { + JoinType::Inner | JoinType::LeftSemi | JoinType::LeftAnti => { + if build_left { + build_empty + } else { + false // probe emptiness unknown without reading + } + } + JoinType::Left | JoinType::LeftMark => { + if build_left { + build_empty + } else { + false + } + } + JoinType::Right => { + if !build_left { + build_empty + } else { + false + } + } + _ => false, + }; + if skip { + return Ok(()); + } + + let build_size: usize = build_batches + .iter() + .map(|b| b.get_array_memory_size()) + .sum(); + let build_rows: usize = build_batches.iter().map(|b| b.num_rows()).sum(); + info!( + "GraceHashJoin: join_with_spilled_probe build: {} batches/{} rows/{} bytes, \ + probe: streaming from spill file", + build_batches.len(), + build_rows, + build_size, + ); + + // If build side exceeds the target partition size, fall back to eager + // read + recursive repartitioning. This prevents creating HashJoinExec + // with oversized build sides that expand into huge hash tables. + let needs_repartition = build_size > TARGET_PARTITION_BUILD_SIZE; + + if needs_repartition { + info!( + "GraceHashJoin: build too large for streaming probe ({} bytes > {} target), \ + falling back to eager read + repartition", + build_size, TARGET_PARTITION_BUILD_SIZE, + ); + let mut probe_batches = probe_in_memory; + for spill_file in &probe_spill_files { + probe_batches.extend(read_spilled_batches(spill_file, probe_schema)?); + } + return join_partition_recursive( + build_batches, + probe_batches, + original_on, + filter, + join_type, + build_left, + build_schema, + probe_schema, + context, + 1, + streams, + ); + } + + // Concatenate build side into single batch. Per-partition data is bounded + // by TARGET_PARTITION_BUILD_SIZE so this won't hit i32 offset overflow. + let build_data = if build_batches.is_empty() { + vec![RecordBatch::new_empty(Arc::clone(build_schema))] + } else if build_batches.len() == 1 { + build_batches + } else { + vec![concat_batches(build_schema, &build_batches)?] + }; + + // Build side: StreamSourceExec to avoid BatchSplitStream splitting + let build_source = memory_source_exec(build_data, build_schema)?; + + // Probe side: streaming from spill file(s). + // With a single spill file and no in-memory batches, use the streaming + // SpillReaderExec. Otherwise read eagerly since the merged group sizes + // are bounded by TARGET_PARTITION_BUILD_SIZE. + let probe_source: Arc = + if probe_spill_files.len() == 1 && probe_in_memory.is_empty() { + Arc::new(SpillReaderExec::new( + probe_spill_files.into_iter().next().unwrap(), + Arc::clone(probe_schema), + )) + } else { + let mut probe_batches = probe_in_memory; + for spill_file in &probe_spill_files { + probe_batches.extend(read_spilled_batches(spill_file, probe_schema)?); + } + let probe_data = if probe_batches.is_empty() { + vec![RecordBatch::new_empty(Arc::clone(probe_schema))] + } else { + vec![concat_batches(probe_schema, &probe_batches)?] + }; + memory_source_exec(probe_data, probe_schema)? + }; + + // HashJoinExec expects left=build in CollectLeft mode + let (left_source, right_source) = if build_left { + (build_source as Arc, probe_source) + } else { + (probe_source, build_source as Arc) + }; + + info!( + "GraceHashJoin: SPILLED PROBE PATH creating HashJoinExec, \ + build_left={}, build_size={}, probe_source={}", + build_left, + build_size, + if probe_spill_files_count == 1 { + "SpillReaderExec" + } else { + "StreamSourceExec" + }, + ); + + let stream = if build_left { + let hash_join = HashJoinExec::try_new( + left_source, + right_source, + original_on.to_vec(), + filter.clone(), + join_type, + None, + PartitionMode::CollectLeft, + NullEquality::NullEqualsNothing, + )?; + info!( + "GraceHashJoin: SPILLED PROBE PATH plan:\n{}", + DisplayableExecutionPlan::new(&hash_join).indent(true) + ); + hash_join.execute(0, context_for_join_output(context))? + } else { + let hash_join = Arc::new(HashJoinExec::try_new( + left_source, + right_source, + original_on.to_vec(), + filter.clone(), + join_type, + None, + PartitionMode::CollectLeft, + NullEquality::NullEqualsNothing, + )?); + let swapped = hash_join.swap_inputs(PartitionMode::CollectLeft)?; + info!( + "GraceHashJoin: SPILLED PROBE PATH (swapped) plan:\n{}", + DisplayableExecutionPlan::new(swapped.as_ref()).indent(true) + ); + swapped.execute(0, context_for_join_output(context))? + }; + + streams.push(stream); + Ok(()) +} + +/// Join a single partition, recursively repartitioning if the build side is too large. +/// +/// `build_keys` / `probe_keys` for repartitioning are extracted from `original_on` +/// based on `build_left`. +#[allow(clippy::too_many_arguments)] +fn join_partition_recursive( + build_batches: Vec, + probe_batches: Vec, + original_on: JoinOnRef<'_>, + filter: &Option, + join_type: &JoinType, + build_left: bool, + build_schema: &SchemaRef, + probe_schema: &SchemaRef, + context: &Arc, + recursion_level: usize, + streams: &mut Vec, +) -> DFResult<()> { + // Skip partitions that cannot produce output based on join type. + // The join type uses Spark's left/right semantics. Map build/probe + // back to left/right based on build_left. + let (left_empty, right_empty) = if build_left { + (build_batches.is_empty(), probe_batches.is_empty()) + } else { + (probe_batches.is_empty(), build_batches.is_empty()) + }; + let skip = match join_type { + JoinType::Inner | JoinType::LeftSemi | JoinType::LeftAnti => left_empty || right_empty, + JoinType::Left | JoinType::LeftMark => left_empty, + JoinType::Right => right_empty, + JoinType::Full => left_empty && right_empty, + JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => { + left_empty || right_empty + } + }; + if skip { + return Ok(()); + } + + // Check if build side is too large and needs recursive repartitioning. + let build_size: usize = build_batches + .iter() + .map(|b| b.get_array_memory_size()) + .sum(); + let build_rows: usize = build_batches.iter().map(|b| b.num_rows()).sum(); + let probe_size: usize = probe_batches + .iter() + .map(|b| b.get_array_memory_size()) + .sum(); + let probe_rows: usize = probe_batches.iter().map(|b| b.num_rows()).sum(); + let pool_reserved = context.runtime_env().memory_pool.reserved(); + info!( + "GraceHashJoin: join_partition_recursive level={}, \ + build: {} batches/{} rows/{} bytes, \ + probe: {} batches/{} rows/{} bytes, \ + pool reserved={}", + recursion_level, + build_batches.len(), + build_rows, + build_size, + probe_batches.len(), + probe_rows, + probe_size, + pool_reserved, + ); + // Repartition if the build side exceeds the target size. This prevents + // creating HashJoinExec with oversized build sides whose hash tables + // can expand well beyond the raw data size and exhaust the memory pool. + let needs_repartition = build_size > TARGET_PARTITION_BUILD_SIZE; + if needs_repartition { + info!( + "GraceHashJoin: repartition needed at level {}: \ + build_size={} > target={}, pool reserved={}", + recursion_level, + build_size, + TARGET_PARTITION_BUILD_SIZE, + context.runtime_env().memory_pool.reserved(), + ); + } + + if needs_repartition { + if recursion_level >= MAX_RECURSION_DEPTH { + let total_build_rows: usize = build_batches.iter().map(|b| b.num_rows()).sum(); + return Err(DataFusionError::ResourcesExhausted(format!( + "GraceHashJoin: build side partition is still too large after {} levels of \ + repartitioning ({} bytes, {} rows). Consider increasing \ + spark.comet.exec.graceHashJoin.numPartitions or \ + spark.executor.memory.", + MAX_RECURSION_DEPTH, build_size, total_build_rows + ))); + } + + info!( + "GraceHashJoin: repartitioning oversized partition at level {} \ + (build: {} bytes, {} batches)", + recursion_level, + build_size, + build_batches.len() + ); + + return repartition_and_join( + build_batches, + probe_batches, + original_on, + filter, + join_type, + build_left, + build_schema, + probe_schema, + context, + recursion_level, + streams, + ); + } + + // Concatenate sub-batches into single batches to reduce per-batch overhead + // in HashJoinExec. Per-partition data is bounded by TARGET_PARTITION_BUILD_SIZE + // so this won't hit i32 offset overflow. + let build_data = if build_batches.is_empty() { + vec![RecordBatch::new_empty(Arc::clone(build_schema))] + } else if build_batches.len() == 1 { + build_batches + } else { + vec![concat_batches(build_schema, &build_batches)?] + }; + let probe_data = if probe_batches.is_empty() { + vec![RecordBatch::new_empty(Arc::clone(probe_schema))] + } else if probe_batches.len() == 1 { + probe_batches + } else { + vec![concat_batches(probe_schema, &probe_batches)?] + }; + + // Create per-partition hash join. + // HashJoinExec expects left=build (CollectLeft mode). + // Both sides use StreamSourceExec to avoid DataSourceExec's BatchSplitStream. + let build_source = memory_source_exec(build_data, build_schema)?; + let probe_source = memory_source_exec(probe_data, probe_schema)?; + + let (left_source, right_source) = if build_left { + (build_source, probe_source) + } else { + (probe_source, build_source) + }; + + let pool_before_join = context.runtime_env().memory_pool.reserved(); + info!( + "GraceHashJoin: RECURSIVE PATH creating HashJoinExec at level={}, \ + build_left={}, build_size={}, probe_size={}, pool reserved={}", + recursion_level, build_left, build_size, probe_size, pool_before_join, + ); + + let stream = if build_left { + let hash_join = HashJoinExec::try_new( + left_source, + right_source, + original_on.to_vec(), + filter.clone(), + join_type, + None, + PartitionMode::CollectLeft, + NullEquality::NullEqualsNothing, + )?; + info!( + "GraceHashJoin: RECURSIVE PATH plan (level={}):\n{}", + recursion_level, + DisplayableExecutionPlan::new(&hash_join).indent(true) + ); + hash_join.execute(0, context_for_join_output(context))? + } else { + let hash_join = Arc::new(HashJoinExec::try_new( + left_source, + right_source, + original_on.to_vec(), + filter.clone(), + join_type, + None, + PartitionMode::CollectLeft, + NullEquality::NullEqualsNothing, + )?); + let swapped = hash_join.swap_inputs(PartitionMode::CollectLeft)?; + info!( + "GraceHashJoin: RECURSIVE PATH (swapped, level={}) plan:\n{}", + recursion_level, + DisplayableExecutionPlan::new(swapped.as_ref()).indent(true) + ); + swapped.execute(0, context_for_join_output(context))? + }; + + streams.push(stream); + Ok(()) +} + +/// Repartition build and probe batches into sub-partitions using a different +/// hash seed, then recursively join each sub-partition. +#[allow(clippy::too_many_arguments)] +fn repartition_and_join( + build_batches: Vec, + probe_batches: Vec, + original_on: JoinOnRef<'_>, + filter: &Option, + join_type: &JoinType, + build_left: bool, + build_schema: &SchemaRef, + probe_schema: &SchemaRef, + context: &Arc, + recursion_level: usize, + streams: &mut Vec, +) -> DFResult<()> { + let num_sub_partitions = DEFAULT_NUM_PARTITIONS; + + // Extract build/probe key expressions from original_on + let (build_keys, probe_keys): (Vec<_>, Vec<_>) = if build_left { + original_on + .iter() + .map(|(l, r)| (Arc::clone(l), Arc::clone(r))) + .unzip() + } else { + original_on + .iter() + .map(|(l, r)| (Arc::clone(r), Arc::clone(l))) + .unzip() + }; + + let mut scratch = ScratchSpace::default(); + + // Sub-partition the build side + let mut build_sub: Vec> = + (0..num_sub_partitions).map(|_| Vec::new()).collect(); + for batch in &build_batches { + let total_rows = batch.num_rows(); + scratch.compute_partitions(batch, &build_keys, num_sub_partitions, recursion_level)?; + for (i, sub_vec) in build_sub.iter_mut().enumerate() { + if scratch.partition_len(i) == 0 { + continue; + } + if scratch.partition_len(i) == total_rows { + sub_vec.push(batch.clone()); + } else if let Some(sub) = scratch.take_partition(batch, i)? { + sub_vec.push(sub); + } + } + } + + // Sub-partition the probe side + let mut probe_sub: Vec> = + (0..num_sub_partitions).map(|_| Vec::new()).collect(); + for batch in &probe_batches { + let total_rows = batch.num_rows(); + scratch.compute_partitions(batch, &probe_keys, num_sub_partitions, recursion_level)?; + for (i, sub_vec) in probe_sub.iter_mut().enumerate() { + if scratch.partition_len(i) == 0 { + continue; + } + if scratch.partition_len(i) == total_rows { + sub_vec.push(batch.clone()); + } else if let Some(sub) = scratch.take_partition(batch, i)? { + sub_vec.push(sub); + } + } + } + + // Recursively join each sub-partition + for (build_part, probe_part) in build_sub.into_iter().zip(probe_sub.into_iter()) { + join_partition_recursive( + build_part, + probe_part, + original_on, + filter, + join_type, + build_left, + build_schema, + probe_schema, + context, + recursion_level + 1, + streams, + )?; + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{Int32Array, StringArray}; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion::datasource::memory::MemorySourceConfig; + use datafusion::datasource::source::DataSourceExec; + use datafusion::execution::memory_pool::FairSpillPool; + use datafusion::execution::runtime_env::RuntimeEnvBuilder; + use datafusion::physical_expr::expressions::Column; + use datafusion::prelude::SessionConfig; + use datafusion::prelude::SessionContext; + use futures::TryStreamExt; + + fn make_batch(ids: &[i32], values: &[&str]) -> RecordBatch { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("val", DataType::Utf8, false), + ])); + RecordBatch::try_new( + schema, + vec![ + Arc::new(Int32Array::from(ids.to_vec())), + Arc::new(StringArray::from(values.to_vec())), + ], + ) + .unwrap() + } + + #[tokio::test] + async fn test_grace_hash_join_basic() -> DFResult<()> { + let ctx = SessionContext::new(); + let task_ctx = ctx.task_ctx(); + + let left_schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("val", DataType::Utf8, false), + ])); + let right_schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("val", DataType::Utf8, false), + ])); + + let left_batches = vec![ + make_batch(&[1, 2, 3, 4, 5], &["a", "b", "c", "d", "e"]), + make_batch(&[6, 7, 8], &["f", "g", "h"]), + ]; + let right_batches = vec![ + make_batch(&[2, 4, 6, 8], &["x", "y", "z", "w"]), + make_batch(&[1, 3, 5, 7], &["p", "q", "r", "s"]), + ]; + + let left_source = Arc::new(DataSourceExec::new(Arc::new(MemorySourceConfig::try_new( + &[left_batches], + Arc::clone(&left_schema), + None, + )?))); + let right_source = Arc::new(DataSourceExec::new(Arc::new(MemorySourceConfig::try_new( + &[right_batches], + Arc::clone(&right_schema), + None, + )?))); + + let on = vec![( + Arc::new(Column::new("id", 0)) as Arc, + Arc::new(Column::new("id", 0)) as Arc, + )]; + + let grace_join = GraceHashJoinExec::try_new( + left_source, + right_source, + on, + None, + &JoinType::Inner, + 4, // Use 4 partitions for testing + true, + 10 * 1024 * 1024, // 10 MB fast path threshold + )?; + + let stream = grace_join.execute(0, task_ctx)?; + let result_batches: Vec = stream.try_collect().await?; + + // Count total rows - should be 8 (each left id matches exactly one right id) + let total_rows: usize = result_batches.iter().map(|b| b.num_rows()).sum(); + assert_eq!(total_rows, 8, "Expected 8 matching rows for inner join"); + + Ok(()) + } + + #[tokio::test] + async fn test_grace_hash_join_empty_partition() -> DFResult<()> { + let ctx = SessionContext::new(); + let task_ctx = ctx.task_ctx(); + + let left_schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])); + let right_schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])); + + let left_batches = vec![RecordBatch::try_new( + Arc::clone(&left_schema), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + )?]; + let right_batches = vec![RecordBatch::try_new( + Arc::clone(&right_schema), + vec![Arc::new(Int32Array::from(vec![10, 20, 30]))], + )?]; + + let left_source = Arc::new(DataSourceExec::new(Arc::new(MemorySourceConfig::try_new( + &[left_batches], + Arc::clone(&left_schema), + None, + )?))); + let right_source = Arc::new(DataSourceExec::new(Arc::new(MemorySourceConfig::try_new( + &[right_batches], + Arc::clone(&right_schema), + None, + )?))); + + let on = vec![( + Arc::new(Column::new("id", 0)) as Arc, + Arc::new(Column::new("id", 0)) as Arc, + )]; + + let grace_join = GraceHashJoinExec::try_new( + left_source, + right_source, + on, + None, + &JoinType::Inner, + 4, + true, + 10 * 1024 * 1024, // 10 MB fast path threshold + )?; + + let stream = grace_join.execute(0, task_ctx)?; + let result_batches: Vec = stream.try_collect().await?; + + let total_rows: usize = result_batches.iter().map(|b| b.num_rows()).sum(); + assert_eq!(total_rows, 0, "Expected 0 rows for non-matching keys"); + + Ok(()) + } + + /// Helper to create a SessionContext with a bounded FairSpillPool. + fn context_with_memory_limit(pool_bytes: usize) -> SessionContext { + let pool = Arc::new(FairSpillPool::new(pool_bytes)); + let runtime = RuntimeEnvBuilder::new() + .with_memory_pool(pool) + .build_arc() + .unwrap(); + let config = SessionConfig::new(); + SessionContext::new_with_config_rt(config, runtime) + } + + /// Generate a batch of N rows with sequential IDs and a padding string + /// column to control memory size. Each row is ~100 bytes of padding. + fn make_large_batch(start_id: i32, count: usize) -> RecordBatch { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("val", DataType::Utf8, false), + ])); + let ids: Vec = (start_id..start_id + count as i32).collect(); + let padding = "x".repeat(100); + let vals: Vec<&str> = (0..count).map(|_| padding.as_str()).collect(); + RecordBatch::try_new( + schema, + vec![ + Arc::new(Int32Array::from(ids)), + Arc::new(StringArray::from(vals)), + ], + ) + .unwrap() + } + + /// Test that GHJ correctly repartitions a large build side instead of + /// creating an oversized HashJoinExec hash table that OOMs. + /// + /// Setup: 256 MB memory pool, ~80 MB build side, ~10 MB probe side. + /// Without repartitioning, the hash table would be ~240 MB and could + /// exhaust the 256 MB pool. With repartitioning (32 MB threshold), + /// the build side is split into sub-partitions of ~5 MB each. + #[tokio::test] + async fn test_grace_hash_join_repartitions_large_build() -> DFResult<()> { + // 256 MB pool — tight enough that a 80 MB build → ~240 MB hash table fails + let ctx = context_with_memory_limit(256 * 1024 * 1024); + let task_ctx = ctx.task_ctx(); + + let left_schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("val", DataType::Utf8, false), + ])); + let right_schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("val", DataType::Utf8, false), + ])); + + // Build side: ~80 MB (800K rows × ~100 bytes) + let left_batches = vec![ + make_large_batch(0, 200_000), + make_large_batch(200_000, 200_000), + make_large_batch(400_000, 200_000), + make_large_batch(600_000, 200_000), + ]; + let build_bytes: usize = left_batches.iter().map(|b| b.get_array_memory_size()).sum(); + eprintln!( + "Test build side: {} bytes ({} MB)", + build_bytes, + build_bytes / (1024 * 1024) + ); + + // Probe side: small (~1 MB, 10K rows) + let right_batches = vec![make_large_batch(0, 10_000)]; + + let left_source = Arc::new(DataSourceExec::new(Arc::new(MemorySourceConfig::try_new( + &[left_batches], + Arc::clone(&left_schema), + None, + )?))); + let right_source = Arc::new(DataSourceExec::new(Arc::new(MemorySourceConfig::try_new( + &[right_batches], + Arc::clone(&right_schema), + None, + )?))); + + let on = vec![( + Arc::new(Column::new("id", 0)) as Arc, + Arc::new(Column::new("id", 0)) as Arc, + )]; + + // Disable fast path to force slow path + let grace_join = GraceHashJoinExec::try_new( + left_source, + right_source, + on, + None, + &JoinType::Inner, + 16, + true, // build_left + 0, // fast_path_threshold = 0 (disabled) + )?; + + let stream = grace_join.execute(0, task_ctx)?; + let result_batches: Vec = stream.try_collect().await?; + + let total_rows: usize = result_batches.iter().map(|b| b.num_rows()).sum(); + // All 10K probe rows match (IDs 0..10000 exist in build) + assert_eq!(total_rows, 10_000, "Expected 10000 matching rows"); + + Ok(()) + } + + /// Same test but with build_left=false to exercise the swap_inputs path. + #[tokio::test] + async fn test_grace_hash_join_repartitions_large_build_right() -> DFResult<()> { + let ctx = context_with_memory_limit(256 * 1024 * 1024); + let task_ctx = ctx.task_ctx(); + + let left_schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("val", DataType::Utf8, false), + ])); + let right_schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("val", DataType::Utf8, false), + ])); + + // Probe side (left): small + let left_batches = vec![make_large_batch(0, 10_000)]; + + // Build side (right): ~80 MB + let right_batches = vec![ + make_large_batch(0, 200_000), + make_large_batch(200_000, 200_000), + make_large_batch(400_000, 200_000), + make_large_batch(600_000, 200_000), + ]; + + let left_source = Arc::new(DataSourceExec::new(Arc::new(MemorySourceConfig::try_new( + &[left_batches], + Arc::clone(&left_schema), + None, + )?))); + let right_source = Arc::new(DataSourceExec::new(Arc::new(MemorySourceConfig::try_new( + &[right_batches], + Arc::clone(&right_schema), + None, + )?))); + + let on = vec![( + Arc::new(Column::new("id", 0)) as Arc, + Arc::new(Column::new("id", 0)) as Arc, + )]; + + let grace_join = GraceHashJoinExec::try_new( + left_source, + right_source, + on, + None, + &JoinType::Inner, + 16, + false, // build_left=false → right is build side + 0, // fast_path_threshold = 0 (disabled) + )?; + + let stream = grace_join.execute(0, task_ctx)?; + let result_batches: Vec = stream.try_collect().await?; + + let total_rows: usize = result_batches.iter().map(|b| b.num_rows()).sum(); + assert_eq!(total_rows, 10_000, "Expected 10000 matching rows"); + + Ok(()) + } +} diff --git a/native/core/src/execution/operators/mod.rs b/native/core/src/execution/operators/mod.rs index 07ee995367..ed1dce219e 100644 --- a/native/core/src/execution/operators/mod.rs +++ b/native/core/src/execution/operators/mod.rs @@ -32,6 +32,8 @@ mod iceberg_scan; mod parquet_writer; pub use parquet_writer::ParquetWriterExec; mod csv_scan; +mod grace_hash_join; +pub use grace_hash_join::GraceHashJoinExec; pub mod projection; mod scan; pub use csv_scan::init_csv_datasource_exec; diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index f84d6cc590..b8951d4d38 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -21,6 +21,8 @@ pub mod expression_registry; pub mod macros; pub mod operator_registry; +use log::info; + use crate::execution::operators::init_csv_datasource_exec; use crate::execution::operators::IcebergScanExec; use crate::{ @@ -163,6 +165,8 @@ pub struct PhysicalPlanner { exec_context_id: i64, partition: i32, session_ctx: Arc, + /// Spark configuration map, used to read comet-specific settings. + spark_conf: HashMap, } impl Default for PhysicalPlanner { @@ -177,6 +181,7 @@ impl PhysicalPlanner { exec_context_id: TEST_EXEC_CONTEXT_ID, session_ctx, partition, + spark_conf: HashMap::new(), } } @@ -185,9 +190,14 @@ impl PhysicalPlanner { exec_context_id, partition: self.partition, session_ctx: Arc::clone(&self.session_ctx), + spark_conf: self.spark_conf, } } + pub fn with_spark_conf(self, spark_conf: HashMap) -> Self { + Self { spark_conf, ..self } + } + /// Return session context of this planner. pub fn session_ctx(&self) -> &Arc { &self.session_ctx @@ -1531,6 +1541,67 @@ impl PhysicalPlanner { let left = Arc::clone(&join_params.left.native_plan); let right = Arc::clone(&join_params.right.native_plan); + // Check if Grace Hash Join is enabled + { + use crate::execution::spark_config::{ + SparkConfig, COMET_GRACE_HASH_JOIN_ENABLED, + COMET_GRACE_HASH_JOIN_FAST_PATH_THRESHOLD, + COMET_GRACE_HASH_JOIN_NUM_PARTITIONS, SPARK_EXECUTOR_CORES, + }; + let grace_enabled = self.spark_conf.get_bool(COMET_GRACE_HASH_JOIN_ENABLED); + + if grace_enabled { + let num_partitions = self + .spark_conf + .get_usize(COMET_GRACE_HASH_JOIN_NUM_PARTITIONS, 16); + let executor_cores = + self.spark_conf.get_usize(SPARK_EXECUTOR_CORES, 1).max(1); + // The configured threshold is the total budget across all + // concurrent tasks. Divide by executor cores so each task's + // fast-path hash table stays within its fair share. + let fast_path_threshold = self + .spark_conf + .get_usize(COMET_GRACE_HASH_JOIN_FAST_PATH_THRESHOLD, 10 * 1024 * 1024) + / executor_cores; + + let build_left = join.build_side == BuildSide::BuildLeft as i32; + + let grace_join = + Arc::new(crate::execution::operators::GraceHashJoinExec::try_new( + Arc::clone(&left), + Arc::clone(&right), + join_params.join_on, + join_params.join_filter, + &join_params.join_type, + num_partitions, + build_left, + fast_path_threshold, + )?); + + return Ok(( + scans, + Arc::new(SparkPlan::new( + spark_plan.plan_id, + grace_join, + vec![join_params.left, join_params.right], + )), + )); + } + } + + { + use crate::execution::spark_config::{ + SparkConfig, COMET_GRACE_HASH_JOIN_ENABLED, + }; + info!( + "PLANNER: creating plain HashJoinExec (NOT GraceHashJoin). \ + join_type={:?}, build_side={:?}, grace_enabled={}", + join_params.join_type, + join.build_side, + self.spark_conf.get_bool(COMET_GRACE_HASH_JOIN_ENABLED), + ); + } + let hash_join = Arc::new(HashJoinExec::try_new( left, right, diff --git a/native/core/src/execution/spark_config.rs b/native/core/src/execution/spark_config.rs index 277c0eb43b..ef528c4405 100644 --- a/native/core/src/execution/spark_config.rs +++ b/native/core/src/execution/spark_config.rs @@ -23,6 +23,11 @@ pub(crate) const COMET_EXPLAIN_NATIVE_ENABLED: &str = "spark.comet.explain.nativ pub(crate) const COMET_MAX_TEMP_DIRECTORY_SIZE: &str = "spark.comet.maxTempDirectorySize"; pub(crate) const COMET_DEBUG_MEMORY: &str = "spark.comet.debug.memory"; pub(crate) const SPARK_EXECUTOR_CORES: &str = "spark.executor.cores"; +pub(crate) const COMET_GRACE_HASH_JOIN_ENABLED: &str = "spark.comet.exec.graceHashJoin.enabled"; +pub(crate) const COMET_GRACE_HASH_JOIN_NUM_PARTITIONS: &str = + "spark.comet.exec.graceHashJoin.numPartitions"; +pub(crate) const COMET_GRACE_HASH_JOIN_FAST_PATH_THRESHOLD: &str = + "spark.comet.exec.graceHashJoin.fastPathThreshold"; pub(crate) trait SparkConfig { fn get_bool(&self, name: &str) -> bool; diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala index 8c75df1d45..2d2222129c 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala @@ -225,6 +225,33 @@ object CometMetricNode { "join_time" -> SQLMetrics.createNanoTimingMetric(sc, "Total time for joining")) } + /** + * SQL Metrics for GraceHashJoin + */ + def graceHashJoinMetrics(sc: SparkContext): Map[String, SQLMetric] = { + Map( + "build_time" -> + SQLMetrics.createNanoTimingMetric(sc, "Total time for partitioning build-side"), + "probe_time" -> + SQLMetrics.createNanoTimingMetric(sc, "Total time for partitioning probe-side"), + "join_time" -> + SQLMetrics.createNanoTimingMetric(sc, "Total time for per-partition joins"), + "spill_count" -> SQLMetrics.createMetric(sc, "Count of spills"), + "spilled_bytes" -> SQLMetrics.createSizeMetric(sc, "Total spilled bytes"), + "build_input_rows" -> + SQLMetrics.createMetric(sc, "Number of rows consumed by build-side"), + "build_input_batches" -> + SQLMetrics.createMetric(sc, "Number of batches consumed by build-side"), + "input_rows" -> + SQLMetrics.createMetric(sc, "Number of rows consumed by probe-side"), + "input_batches" -> + SQLMetrics.createMetric(sc, "Number of batches consumed by probe-side"), + "output_batches" -> SQLMetrics.createMetric(sc, "Number of batches produced"), + "output_rows" -> SQLMetrics.createMetric(sc, "Number of rows produced"), + "elapsed_compute" -> + SQLMetrics.createNanoTimingMetric(sc, "Total elapsed compute time")) + } + /** * SQL Metrics for DataFusion SortMergeJoin */ diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index da2ae21a95..fe0ed016f4 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -1724,19 +1724,35 @@ object CometHashJoinExec extends CometOperatorSerde[HashJoin] with CometHashJoin doConvert(join, builder, childOp: _*) override def createExec(nativeOp: Operator, op: HashJoin): CometNativeExec = { - CometHashJoinExec( - nativeOp, - op, - op.output, - op.outputOrdering, - op.leftKeys, - op.rightKeys, - op.joinType, - op.condition, - op.buildSide, - op.left, - op.right, - SerializedPlan(None)) + if (CometConf.COMET_EXEC_GRACE_HASH_JOIN_ENABLED.get()) { + CometGraceHashJoinExec( + nativeOp, + op, + op.output, + op.outputOrdering, + op.leftKeys, + op.rightKeys, + op.joinType, + op.condition, + op.buildSide, + op.left, + op.right, + SerializedPlan(None)) + } else { + CometHashJoinExec( + nativeOp, + op, + op.output, + op.outputOrdering, + op.leftKeys, + op.rightKeys, + op.joinType, + op.condition, + op.buildSide, + op.left, + op.right, + SerializedPlan(None)) + } } } @@ -1795,6 +1811,61 @@ case class CometHashJoinExec( CometMetricNode.hashJoinMetrics(sparkContext) } +case class CometGraceHashJoinExec( + override val nativeOp: Operator, + override val originalPlan: SparkPlan, + override val output: Seq[Attribute], + override val outputOrdering: Seq[SortOrder], + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + condition: Option[Expression], + buildSide: BuildSide, + override val left: SparkPlan, + override val right: SparkPlan, + override val serializedPlanOpt: SerializedPlan) + extends CometBinaryExec { + + override def outputPartitioning: Partitioning = joinType match { + case _: InnerLike => + PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning)) + case LeftOuter => left.outputPartitioning + case RightOuter => right.outputPartitioning + case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) + case LeftExistence(_) => left.outputPartitioning + case x => + throw new IllegalArgumentException(s"GraceHashJoin should not take $x as the JoinType") + } + + override def withNewChildrenInternal(newLeft: SparkPlan, newRight: SparkPlan): SparkPlan = + this.copy(left = newLeft, right = newRight) + + override def stringArgs: Iterator[Any] = + Iterator(leftKeys, rightKeys, joinType, buildSide, condition, left, right) + + override def equals(obj: Any): Boolean = { + obj match { + case other: CometGraceHashJoinExec => + this.output == other.output && + this.leftKeys == other.leftKeys && + this.rightKeys == other.rightKeys && + this.condition == other.condition && + this.buildSide == other.buildSide && + this.left == other.left && + this.right == other.right && + this.serializedPlanOpt == other.serializedPlanOpt + case _ => + false + } + } + + override def hashCode(): Int = + Objects.hashCode(output, leftKeys, rightKeys, condition, buildSide, left, right) + + override lazy val metrics: Map[String, SQLMetric] = + CometMetricNode.graceHashJoinMetrics(sparkContext) +} + case class CometBroadcastHashJoinExec( override val nativeOp: Operator, override val originalPlan: SparkPlan, diff --git a/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala index 6111b9c0d4..79411f3a3f 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala @@ -19,17 +19,20 @@ package org.apache.comet.exec +import scala.util.Random + import org.scalactic.source.Position import org.scalatest.Tag import org.apache.spark.sql.CometTestBase import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation -import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometBroadcastHashJoinExec} +import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometBroadcastHashJoinExec, CometGraceHashJoinExec} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.Decimal +import org.apache.spark.sql.types.{DataTypes, Decimal, StructField, StructType} import org.apache.comet.CometConf +import org.apache.comet.testing.{DataGenOptions, ParquetGenerator} class CometJoinSuite extends CometTestBase { import testImplicits._ @@ -446,4 +449,254 @@ class CometJoinSuite extends CometTestBase { """.stripMargin)) } } + + // Common SQL config for Grace Hash Join tests + private val graceHashJoinConf: Seq[(String, String)] = Seq( + CometConf.COMET_EXEC_GRACE_HASH_JOIN_ENABLED.key -> "true", + CometConf.COMET_EXEC_GRACE_HASH_JOIN_NUM_PARTITIONS.key -> "4", + "spark.sql.join.forceApplyShuffledHashJoin" -> "true", + SQLConf.PREFER_SORTMERGEJOIN.key -> "false", + SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") + + test("Grace HashJoin - all join types") { + withSQLConf(graceHashJoinConf: _*) { + withParquetTable((0 until 100).map(i => (i, i % 5)), "tbl_a") { + withParquetTable((0 until 100).map(i => (i % 10, i + 2)), "tbl_b") { + // Inner join + checkSparkAnswer( + sql( + "SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a JOIN tbl_b ON tbl_a._2 = tbl_b._1")) + + // Left join + checkSparkAnswer(sql( + "SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a LEFT JOIN tbl_b ON tbl_a._2 = tbl_b._1")) + + // Right join + checkSparkAnswer(sql( + "SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a RIGHT JOIN tbl_b ON tbl_a._2 = tbl_b._1")) + + // Full outer join + checkSparkAnswer(sql( + "SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a FULL JOIN tbl_b ON tbl_a._2 = tbl_b._1")) + + // Left semi join + checkSparkAnswer(sql( + "SELECT /*+ SHUFFLE_HASH(tbl_b) */ * FROM tbl_a LEFT SEMI JOIN tbl_b ON tbl_a._2 = tbl_b._1")) + + // Left anti join + checkSparkAnswer(sql( + "SELECT /*+ SHUFFLE_HASH(tbl_b) */ * FROM tbl_a LEFT ANTI JOIN tbl_b ON tbl_a._2 = tbl_b._1")) + } + } + } + } + + test("Grace HashJoin - with filter condition") { + withSQLConf(graceHashJoinConf: _*) { + withParquetTable((0 until 100).map(i => (i, i % 5)), "tbl_a") { + withParquetTable((0 until 100).map(i => (i % 10, i + 2)), "tbl_b") { + checkSparkAnswer( + sql("SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a JOIN tbl_b " + + "ON tbl_a._2 = tbl_b._1 AND tbl_a._1 > tbl_b._2")) + + checkSparkAnswer( + sql("SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a LEFT JOIN tbl_b " + + "ON tbl_a._2 = tbl_b._1 AND tbl_a._1 > tbl_b._2")) + + checkSparkAnswer( + sql("SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a FULL JOIN tbl_b " + + "ON tbl_a._2 = tbl_b._1 AND tbl_a._1 > tbl_b._2")) + } + } + } + } + + test("Grace HashJoin - various data types") { + withSQLConf(graceHashJoinConf: _*) { + // String keys + withParquetTable((0 until 50).map(i => (s"key_${i % 10}", i)), "str_a") { + withParquetTable((0 until 50).map(i => (s"key_${i % 5}", i * 2)), "str_b") { + checkSparkAnswer( + sql( + "SELECT /*+ SHUFFLE_HASH(str_a) */ * FROM str_a JOIN str_b ON str_a._1 = str_b._1")) + } + } + + // Decimal keys + withParquetTable((0 until 50).map(i => (Decimal(i % 10), i)), "dec_a") { + withParquetTable((0 until 50).map(i => (Decimal(i % 5), i * 2)), "dec_b") { + checkSparkAnswer( + sql( + "SELECT /*+ SHUFFLE_HASH(dec_a) */ * FROM dec_a JOIN dec_b ON dec_a._1 = dec_b._1")) + } + } + } + } + + test("Grace HashJoin - empty tables") { + withSQLConf(graceHashJoinConf: _*) { + withParquetTable(Seq.empty[(Int, Int)], "empty_a") { + withParquetTable((0 until 10).map(i => (i, i)), "nonempty_b") { + // Empty left side + checkSparkAnswer(sql( + "SELECT /*+ SHUFFLE_HASH(empty_a) */ * FROM empty_a JOIN nonempty_b ON empty_a._1 = nonempty_b._1")) + + // Empty left with left join + checkSparkAnswer(sql( + "SELECT /*+ SHUFFLE_HASH(empty_a) */ * FROM empty_a LEFT JOIN nonempty_b ON empty_a._1 = nonempty_b._1")) + + // Empty right side + checkSparkAnswer(sql( + "SELECT /*+ SHUFFLE_HASH(nonempty_b) */ * FROM nonempty_b JOIN empty_a ON nonempty_b._1 = empty_a._1")) + + // Empty right with right join + checkSparkAnswer(sql( + "SELECT /*+ SHUFFLE_HASH(nonempty_b) */ * FROM nonempty_b RIGHT JOIN empty_a ON nonempty_b._1 = empty_a._1")) + } + } + } + } + + test("Grace HashJoin - self join") { + withSQLConf(graceHashJoinConf: _*) { + withParquetTable((0 until 50).map(i => (i, i % 10)), "self_tbl") { + checkSparkAnswer( + sql("SELECT /*+ SHUFFLE_HASH(a) */ * FROM self_tbl a JOIN self_tbl b ON a._2 = b._2")) + } + } + } + + test("Grace HashJoin - build side selection") { + withSQLConf(graceHashJoinConf: _*) { + withParquetTable((0 until 100).map(i => (i, i % 5)), "tbl_a") { + withParquetTable((0 until 100).map(i => (i % 10, i + 2)), "tbl_b") { + // Build left (hint on left table) + checkSparkAnswer( + sql( + "SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a JOIN tbl_b ON tbl_a._2 = tbl_b._1")) + + // Build right (hint on right table) + checkSparkAnswer( + sql( + "SELECT /*+ SHUFFLE_HASH(tbl_b) */ * FROM tbl_a JOIN tbl_b ON tbl_a._2 = tbl_b._1")) + + // Left join build right + checkSparkAnswer(sql( + "SELECT /*+ SHUFFLE_HASH(tbl_b) */ * FROM tbl_a LEFT JOIN tbl_b ON tbl_a._2 = tbl_b._1")) + + // Right join build left + checkSparkAnswer(sql( + "SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a RIGHT JOIN tbl_b ON tbl_a._2 = tbl_b._1")) + } + } + } + } + + test("Grace HashJoin - plan shows CometGraceHashJoinExec") { + withSQLConf(graceHashJoinConf: _*) { + withParquetTable((0 until 50).map(i => (i, i % 5)), "tbl_a") { + withParquetTable((0 until 50).map(i => (i % 10, i + 2)), "tbl_b") { + val df = sql( + "SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a JOIN tbl_b ON tbl_a._2 = tbl_b._1") + checkSparkAnswerAndOperator(df, Seq(classOf[CometGraceHashJoinExec])) + } + } + } + } + + test("Grace HashJoin - multiple key columns") { + withSQLConf(graceHashJoinConf: _*) { + withParquetTable((0 until 50).map(i => (i, i % 5, i % 3)), "multi_a") { + withParquetTable((0 until 50).map(i => (i % 10, i % 5, i % 3)), "multi_b") { + checkSparkAnswer( + sql("SELECT /*+ SHUFFLE_HASH(multi_a) */ * FROM multi_a JOIN multi_b " + + "ON multi_a._2 = multi_b._2 AND multi_a._3 = multi_b._3")) + } + } + } + } + + // Schema with types that work well as join keys (no NaN/float issues) + private val fuzzJoinSchema = StructType( + Seq( + StructField("c_int", DataTypes.IntegerType), + StructField("c_long", DataTypes.LongType), + StructField("c_str", DataTypes.StringType), + StructField("c_date", DataTypes.DateType), + StructField("c_dec", DataTypes.createDecimalType(10, 2)), + StructField("c_short", DataTypes.ShortType), + StructField("c_bool", DataTypes.BooleanType))) + + private val joinTypes = + Seq("JOIN", "LEFT JOIN", "RIGHT JOIN", "FULL JOIN", "LEFT SEMI JOIN", "LEFT ANTI JOIN") + + test("Grace HashJoin fuzz - all join types with generated data") { + val dataGenOptions = + DataGenOptions(allowNull = true, generateNegativeZero = false, generateNaN = false) + + withSQLConf(graceHashJoinConf: _*) { + withTempPath { dir => + val path1 = s"${dir.getAbsolutePath}/fuzz_left" + val path2 = s"${dir.getAbsolutePath}/fuzz_right" + val random = new Random(42) + + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + ParquetGenerator + .makeParquetFile(random, spark, path1, fuzzJoinSchema, 200, dataGenOptions) + ParquetGenerator + .makeParquetFile(random, spark, path2, fuzzJoinSchema, 200, dataGenOptions) + } + + spark.read.parquet(path1).createOrReplaceTempView("fuzz_l") + spark.read.parquet(path2).createOrReplaceTempView("fuzz_r") + + for (jt <- joinTypes) { + // Join on int column + checkSparkAnswer(sql( + s"SELECT /*+ SHUFFLE_HASH(fuzz_l) */ * FROM fuzz_l $jt fuzz_r ON fuzz_l.c_int = fuzz_r.c_int")) + + // Join on string column + checkSparkAnswer(sql( + s"SELECT /*+ SHUFFLE_HASH(fuzz_l) */ * FROM fuzz_l $jt fuzz_r ON fuzz_l.c_str = fuzz_r.c_str")) + + // Join on decimal column + checkSparkAnswer(sql( + s"SELECT /*+ SHUFFLE_HASH(fuzz_l) */ * FROM fuzz_l $jt fuzz_r ON fuzz_l.c_dec = fuzz_r.c_dec")) + } + } + } + } + + test("Grace HashJoin fuzz - with spilling") { + val dataGenOptions = + DataGenOptions(allowNull = true, generateNegativeZero = false, generateNaN = false) + + // Use very small memory pool to force spilling + withSQLConf( + (graceHashJoinConf ++ Seq( + CometConf.COMET_ONHEAP_MEMORY_OVERHEAD.key -> "10000000", + CometConf.COMET_EXEC_GRACE_HASH_JOIN_NUM_PARTITIONS.key -> "8")): _*) { + withTempPath { dir => + val path1 = s"${dir.getAbsolutePath}/spill_left" + val path2 = s"${dir.getAbsolutePath}/spill_right" + val random = new Random(99) + + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + ParquetGenerator + .makeParquetFile(random, spark, path1, fuzzJoinSchema, 500, dataGenOptions) + ParquetGenerator + .makeParquetFile(random, spark, path2, fuzzJoinSchema, 500, dataGenOptions) + } + + spark.read.parquet(path1).createOrReplaceTempView("spill_l") + spark.read.parquet(path2).createOrReplaceTempView("spill_r") + + for (jt <- joinTypes) { + checkSparkAnswer(sql( + s"SELECT /*+ SHUFFLE_HASH(spill_l) */ * FROM spill_l $jt spill_r ON spill_l.c_int = spill_r.c_int")) + } + } + } + } } diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometJoinBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometJoinBenchmark.scala new file mode 100644 index 0000000000..67550e3970 --- /dev/null +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometJoinBenchmark.scala @@ -0,0 +1,203 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.benchmark + +import org.apache.spark.SparkConf +import org.apache.spark.benchmark.Benchmark +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.internal.SQLConf + +import org.apache.comet.{CometConf, CometSparkSessionExtensions} + +/** + * Benchmark to compare join implementations: Spark Sort Merge Join, Comet Sort Merge Join, Comet + * Hash Join, and Comet Grace Hash Join across all join types. + * + * To run this benchmark: + * {{{ + * SPARK_GENERATE_BENCHMARK_FILES=1 make \ + * benchmark-org.apache.spark.sql.benchmark.CometJoinBenchmark + * }}} + * + * Results will be written to "spark/benchmarks/CometJoinBenchmark-**results.txt". + */ +object CometJoinBenchmark extends CometBenchmarkBase { + override def getSparkSession: SparkSession = { + val conf = new SparkConf() + .setAppName("CometJoinBenchmark") + .set("spark.master", "local[5]") + .setIfMissing("spark.driver.memory", "3g") + .setIfMissing("spark.executor.memory", "3g") + .set("spark.executor.memoryOverhead", "10g") + .set( + "spark.shuffle.manager", + "org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager") + + val sparkSession = SparkSession.builder + .config(conf) + .withExtensions(new CometSparkSessionExtensions) + .getOrCreate() + + sparkSession.conf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, "true") + sparkSession.conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true") + sparkSession.conf.set(CometConf.COMET_ENABLED.key, "false") + sparkSession.conf.set(CometConf.COMET_EXEC_ENABLED.key, "false") + sparkSession.conf.set(CometConf.COMET_ONHEAP_MEMORY_OVERHEAD.key, "10g") + sparkSession.conf.set("parquet.enable.dictionary", "false") + sparkSession.conf.set("spark.sql.shuffle.partitions", "2") + + sparkSession + } + + /** Base Comet exec config — shuffle mode auto, no SMJ replacement by default. */ + private val cometBaseConf = Map( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_SHUFFLE_MODE.key -> "auto", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") + + private def prepareTwoTables(dir: java.io.File, rows: Int, keyCardinality: Int): Unit = { + val left = spark + .range(rows) + .selectExpr( + s"id % $keyCardinality as key", + "id as l_val1", + "cast(id * 1.5 as double) as l_val2") + prepareTable(dir, left) + spark.read.parquet(dir.getCanonicalPath + "/parquetV1").createOrReplaceTempView("left_table") + + val rightDir = new java.io.File(dir, "right") + rightDir.mkdirs() + val right = spark + .range(rows) + .selectExpr( + s"id % $keyCardinality as key", + "id as r_val1", + "cast(id * 2.5 as double) as r_val2") + right.write + .mode("overwrite") + .option("compression", "snappy") + .parquet(rightDir.getCanonicalPath) + spark.read.parquet(rightDir.getCanonicalPath).createOrReplaceTempView("right_table") + } + + private def addJoinCases(benchmark: Benchmark, query: String): Unit = { + // 1. Spark Sort Merge Join (baseline — no Comet) + benchmark.addCase("Spark Sort Merge Join") { _ => + withSQLConf( + CometConf.COMET_ENABLED.key -> "false", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + "spark.sql.join.preferSortMergeJoin" -> "true") { + spark.sql(query).noop() + } + } + + // 2. Comet Sort Merge Join (Spark plans SMJ, Comet executes it natively) + benchmark.addCase("Comet Sort Merge Join") { _ => + withSQLConf( + (cometBaseConf ++ Map( + CometConf.COMET_REPLACE_SMJ.key -> "false", + CometConf.COMET_EXEC_GRACE_HASH_JOIN_ENABLED.key -> "false", + "spark.sql.join.preferSortMergeJoin" -> "true")).toSeq: _*) { + spark.sql(query).noop() + } + } + + // 3. Comet Hash Join (replace SMJ with ShuffledHashJoin, Comet executes) + benchmark.addCase("Comet Hash Join") { _ => + withSQLConf( + (cometBaseConf ++ Map( + CometConf.COMET_REPLACE_SMJ.key -> "true", + CometConf.COMET_EXEC_GRACE_HASH_JOIN_ENABLED.key -> "false")).toSeq: _*) { + spark.sql(query).noop() + } + } + + // 4. Comet Grace Hash Join (replace SMJ, use grace hash join) + benchmark.addCase("Comet Grace Hash Join") { _ => + withSQLConf( + (cometBaseConf ++ Map( + CometConf.COMET_REPLACE_SMJ.key -> "true", + CometConf.COMET_EXEC_GRACE_HASH_JOIN_ENABLED.key -> "true")).toSeq: _*) { + spark.sql(query).noop() + } + } + } + + private def joinBenchmark(joinType: String, rows: Int, keyCardinality: Int): Unit = { + val joinClause = joinType match { + case "Inner" => "JOIN" + case "Left" => "LEFT JOIN" + case "Right" => "RIGHT JOIN" + case "Full" => "FULL OUTER JOIN" + case "LeftSemi" => "LEFT SEMI JOIN" + case "LeftAnti" => "LEFT ANTI JOIN" + } + + val selectCols = joinType match { + case "LeftSemi" | "LeftAnti" => "l.key, l.l_val1, l.l_val2" + case _ => "l.key, l.l_val1, r.r_val1" + } + + val query = + s"SELECT $selectCols FROM left_table l $joinClause right_table r ON l.key = r.key" + + val benchmark = + new Benchmark( + s"$joinType Join (rows=$rows, cardinality=$keyCardinality)", + rows, + output = output) + + addJoinCases(benchmark, query) + benchmark.run() + } + + private def joinWithFilterBenchmark(rows: Int, keyCardinality: Int): Unit = { + val query = + "SELECT l.key, l.l_val1, r.r_val1 FROM left_table l " + + "JOIN right_table r ON l.key = r.key WHERE l.l_val1 > r.r_val1" + + val benchmark = + new Benchmark( + s"Inner Join with Filter (rows=$rows, cardinality=$keyCardinality)", + rows, + output = output) + + addJoinCases(benchmark, query) + benchmark.run() + } + + override def runCometBenchmark(mainArgs: Array[String]): Unit = { + val rows = 1024 * 1024 * 2 + val keyCardinality = rows / 10 // ~10 matches per key + + withTempPath { dir => + prepareTwoTables(dir, rows, keyCardinality) + + runBenchmark("Join Benchmark") { + for (joinType <- Seq("Inner", "Left", "Right", "Full", "LeftSemi", "LeftAnti")) { + joinBenchmark(joinType, rows, keyCardinality) + } + joinWithFilterBenchmark(rows, keyCardinality) + } + } + } +}