diff --git a/src/duckdb_py/pyrelation.cpp b/src/duckdb_py/pyrelation.cpp index 1a711562..8a70f0d2 100644 --- a/src/duckdb_py/pyrelation.cpp +++ b/src/duckdb_py/pyrelation.cpp @@ -992,6 +992,19 @@ py::object DuckDBPyRelation::ToArrowCapsule(const py::object &requested_schema) if (!rel) { return py::none(); } + // The PyCapsule protocol doesn't allow custom parameters, so we use the same + // default batch size as fetch_arrow_table / fetch_record_batch. + idx_t batch_size = 1000000; + auto &config = ClientConfig::GetConfig(*rel->context->GetContext()); + ScopedConfigSetting scoped_setting( + config, + [&batch_size](ClientConfig &config) { + config.get_result_collector = [&batch_size](ClientContext &context, + PreparedStatementData &data) -> PhysicalOperator & { + return PhysicalArrowCollector::Create(context, data, batch_size); + }; + }, + [](ClientConfig &config) { config.get_result_collector = nullptr; }); ExecuteOrThrow(); } AssertResultOpen(); @@ -1003,7 +1016,8 @@ py::object DuckDBPyRelation::ToArrowCapsule(const py::object &requested_schema) PolarsDataFrame DuckDBPyRelation::ToPolars(idx_t batch_size, bool lazy) { if (!lazy) { auto arrow = ToArrowTableInternal(batch_size, true); - return py::cast(pybind11::module_::import("polars").attr("DataFrame")(arrow)); + return py::cast( + pybind11::module_::import("polars").attr("from_arrow")(arrow, py::arg("rechunk") = false)); } auto &import_cache = *DuckDBPyConnection::ImportCache(); auto lazy_frame_produce = import_cache.duckdb.polars_io.duckdb_source(); diff --git a/src/duckdb_py/pyresult.cpp b/src/duckdb_py/pyresult.cpp index 644d6393..b4ed691b 100644 --- a/src/duckdb_py/pyresult.cpp +++ b/src/duckdb_py/pyresult.cpp @@ -496,6 +496,81 @@ duckdb::pyarrow::RecordBatchReader DuckDBPyResult::FetchRecordBatchReader(idx_t return py::cast(record_batch_reader); } +// Wraps pre-built Arrow arrays from an ArrowQueryResult into an ArrowArrayStream. +// This avoids the double-materialization that happens when using ResultArrowArrayStreamWrapper +// with an ArrowQueryResult (which throws NotImplementedException from FetchInternal). +struct ArrowQueryResultStreamWrapper { + ArrowQueryResultStreamWrapper(unique_ptr result_p) : result(std::move(result_p)), index(0) { + auto &arrow_result = result->Cast(); + arrays = arrow_result.ConsumeArrays(); + types = result->types; + names = result->names; + client_properties = result->client_properties; + + stream.private_data = this; + stream.get_schema = GetSchema; + stream.get_next = GetNext; + stream.release = Release; + stream.get_last_error = GetLastError; + } + + static int GetSchema(ArrowArrayStream *stream, ArrowSchema *out) { + if (!stream->release) { + return -1; + } + auto self = reinterpret_cast(stream->private_data); + out->release = nullptr; + try { + ArrowConverter::ToArrowSchema(out, self->types, self->names, self->client_properties); + } catch (std::runtime_error &e) { + self->last_error = e.what(); + return -1; + } + return 0; + } + + static int GetNext(ArrowArrayStream *stream, ArrowArray *out) { + if (!stream->release) { + return -1; + } + auto self = reinterpret_cast(stream->private_data); + if (self->index >= self->arrays.size()) { + out->release = nullptr; + return 0; + } + *out = self->arrays[self->index]->arrow_array; + self->arrays[self->index]->arrow_array.release = nullptr; + self->index++; + return 0; + } + + static void Release(ArrowArrayStream *stream) { + if (!stream || !stream->release) { + return; + } + stream->release = nullptr; + delete reinterpret_cast(stream->private_data); + } + + static const char *GetLastError(ArrowArrayStream *stream) { + if (!stream->release) { + return "stream was released"; + } + auto self = reinterpret_cast(stream->private_data); + return self->last_error.c_str(); + } + + ArrowArrayStream stream; + unique_ptr result; + vector> arrays; + vector types; + vector names; + ClientProperties client_properties; + idx_t index; + string last_error; +}; + +// Destructor for capsules that own a heap-allocated ArrowArrayStream (slow path). static void ArrowArrayStreamPyCapsuleDestructor(PyObject *object) { auto data = PyCapsule_GetPointer(object, "arrow_array_stream"); if (!data) { @@ -508,7 +583,28 @@ static void ArrowArrayStreamPyCapsuleDestructor(PyObject *object) { delete stream; } +// Destructor for capsules pointing at an embedded ArrowArrayStream (fast path). +// The stream is owned by an ArrowQueryResultStreamWrapper; Release() frees both. +static void ArrowArrayStreamEmbeddedPyCapsuleDestructor(PyObject *object) { + auto data = PyCapsule_GetPointer(object, "arrow_array_stream"); + if (!data) { + return; + } + auto stream = reinterpret_cast(data); + if (stream->release) { + stream->release(stream); + } +} + py::object DuckDBPyResult::FetchArrowCapsule(idx_t rows_per_batch) { + if (result && result->type == QueryResultType::ARROW_RESULT) { + // Fast path: yield pre-built Arrow arrays directly. + // The wrapper is heap-allocated; Release() deletes it via private_data. + // The capsule points at the embedded stream field — no separate heap allocation needed. + auto wrapper = new ArrowQueryResultStreamWrapper(std::move(result)); + return py::capsule(&wrapper->stream, "arrow_array_stream", ArrowArrayStreamEmbeddedPyCapsuleDestructor); + } + // Existing slow path for MaterializedQueryResult / StreamQueryResult auto stream_p = FetchArrowArrayStream(rows_per_batch); auto stream = new ArrowArrayStream(); *stream = stream_p; diff --git a/tests/fast/arrow/test_arrow_pycapsule.py b/tests/fast/arrow/test_arrow_pycapsule.py index 1f825799..6d0319f4 100644 --- a/tests/fast/arrow/test_arrow_pycapsule.py +++ b/tests/fast/arrow/test_arrow_pycapsule.py @@ -2,6 +2,7 @@ import duckdb +pa = pytest.importorskip("pyarrow") pl = pytest.importorskip("polars") @@ -11,6 +12,73 @@ def polars_supports_capsule(): return Version(pl.__version__) >= Version("1.4.1") +class TestArrowPyCapsuleExport: + """Tests for the PyCapsule export path (rel.__arrow_c_stream__). + + Validates that the fast path (PhysicalArrowCollector + ArrowQueryResultStreamWrapper) + produces correct data, matching to_arrow_table() across types and edge cases. + """ + + def test_capsule_matches_to_arrow_table(self): + """Fast path produces identical data to to_arrow_table for various types.""" + conn = duckdb.connect() + sql = """ + SELECT + i AS int_col, + i::DOUBLE AS double_col, + 'row_' || i::VARCHAR AS str_col, + i % 2 = 0 AS bool_col, + CASE WHEN i % 3 = 0 THEN NULL ELSE i END AS nullable_col + FROM range(1000) t(i) + """ + expected = conn.sql(sql).to_arrow_table() + actual = pa.table(conn.sql(sql)) + assert actual.equals(expected) + + def test_capsule_matches_to_arrow_table_nested_types(self): + """Fast path handles nested types (struct, list, map).""" + conn = duckdb.connect() + sql = """ + SELECT + {'x': i, 'y': i::VARCHAR} AS struct_col, + [i, i+1, i+2] AS list_col, + MAP {i::VARCHAR: i*10} AS map_col, + FROM range(100) t(i) + """ + expected = conn.sql(sql).to_arrow_table() + actual = pa.table(conn.sql(sql)) + assert actual.equals(expected) + + def test_capsule_multi_batch(self): + """Data exceeding the 1M batch size produces multiple batches, all yielded correctly.""" + conn = duckdb.connect() + sql = "SELECT i, i::DOUBLE AS d FROM range(1500000) t(i)" + expected = conn.sql(sql).to_arrow_table() + actual = pa.table(conn.sql(sql)) + assert actual.num_rows == 1500000 + assert actual.equals(expected) + + def test_capsule_empty_result(self): + """Empty result set produces a valid empty table with correct schema.""" + conn = duckdb.connect() + sql = "SELECT i AS a, i::VARCHAR AS b FROM range(10) t(i) WHERE i < 0" + expected = conn.sql(sql).to_arrow_table() + actual = pa.table(conn.sql(sql)) + assert actual.num_rows == 0 + assert actual.schema.equals(expected.schema) + + def test_capsule_slow_path_after_execute(self): + """Pre-executed relation takes the slow path (MaterializedQueryResult) and still works.""" + conn = duckdb.connect() + sql = "SELECT i, i::DOUBLE AS d FROM range(500) t(i)" + expected = conn.sql(sql).to_arrow_table() + + rel = conn.sql(sql) + rel.execute() # forces MaterializedCollector, not PhysicalArrowCollector + actual = pa.table(rel) + assert actual.equals(expected) + + @pytest.mark.skipif( not polars_supports_capsule(), reason="Polars version does not support the Arrow PyCapsule interface" )