Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion src/duckdb_py/pyrelation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -992,6 +992,19 @@ py::object DuckDBPyRelation::ToArrowCapsule(const py::object &requested_schema)
if (!rel) {
return py::none();
}
// The PyCapsule protocol doesn't allow custom parameters, so we use the same
// default batch size as fetch_arrow_table / fetch_record_batch.
idx_t batch_size = 1000000;
auto &config = ClientConfig::GetConfig(*rel->context->GetContext());
ScopedConfigSetting scoped_setting(
config,
[&batch_size](ClientConfig &config) {
config.get_result_collector = [&batch_size](ClientContext &context,
PreparedStatementData &data) -> PhysicalOperator & {
return PhysicalArrowCollector::Create(context, data, batch_size);
};
},
[](ClientConfig &config) { config.get_result_collector = nullptr; });
ExecuteOrThrow();
}
AssertResultOpen();
Expand All @@ -1003,7 +1016,8 @@ py::object DuckDBPyRelation::ToArrowCapsule(const py::object &requested_schema)
PolarsDataFrame DuckDBPyRelation::ToPolars(idx_t batch_size, bool lazy) {
if (!lazy) {
auto arrow = ToArrowTableInternal(batch_size, true);
return py::cast<PolarsDataFrame>(pybind11::module_::import("polars").attr("DataFrame")(arrow));
return py::cast<PolarsDataFrame>(
pybind11::module_::import("polars").attr("from_arrow")(arrow, py::arg("rechunk") = false));
}
auto &import_cache = *DuckDBPyConnection::ImportCache();
auto lazy_frame_produce = import_cache.duckdb.polars_io.duckdb_source();
Expand Down
96 changes: 96 additions & 0 deletions src/duckdb_py/pyresult.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,81 @@ duckdb::pyarrow::RecordBatchReader DuckDBPyResult::FetchRecordBatchReader(idx_t
return py::cast<duckdb::pyarrow::RecordBatchReader>(record_batch_reader);
}

// Wraps pre-built Arrow arrays from an ArrowQueryResult into an ArrowArrayStream.
// This avoids the double-materialization that happens when using ResultArrowArrayStreamWrapper
// with an ArrowQueryResult (which throws NotImplementedException from FetchInternal).
struct ArrowQueryResultStreamWrapper {
ArrowQueryResultStreamWrapper(unique_ptr<QueryResult> result_p) : result(std::move(result_p)), index(0) {
auto &arrow_result = result->Cast<ArrowQueryResult>();
arrays = arrow_result.ConsumeArrays();
types = result->types;
names = result->names;
client_properties = result->client_properties;

stream.private_data = this;
stream.get_schema = GetSchema;
stream.get_next = GetNext;
stream.release = Release;
stream.get_last_error = GetLastError;
}

static int GetSchema(ArrowArrayStream *stream, ArrowSchema *out) {
if (!stream->release) {
return -1;
}
auto self = reinterpret_cast<ArrowQueryResultStreamWrapper *>(stream->private_data);
out->release = nullptr;
try {
ArrowConverter::ToArrowSchema(out, self->types, self->names, self->client_properties);
} catch (std::runtime_error &e) {
self->last_error = e.what();
return -1;
}
return 0;
}

static int GetNext(ArrowArrayStream *stream, ArrowArray *out) {
if (!stream->release) {
return -1;
}
auto self = reinterpret_cast<ArrowQueryResultStreamWrapper *>(stream->private_data);
if (self->index >= self->arrays.size()) {
out->release = nullptr;
return 0;
}
*out = self->arrays[self->index]->arrow_array;
self->arrays[self->index]->arrow_array.release = nullptr;
self->index++;
return 0;
}

static void Release(ArrowArrayStream *stream) {
if (!stream || !stream->release) {
return;
}
stream->release = nullptr;
delete reinterpret_cast<ArrowQueryResultStreamWrapper *>(stream->private_data);
}

static const char *GetLastError(ArrowArrayStream *stream) {
if (!stream->release) {
return "stream was released";
}
auto self = reinterpret_cast<ArrowQueryResultStreamWrapper *>(stream->private_data);
return self->last_error.c_str();
}

ArrowArrayStream stream;
unique_ptr<QueryResult> result;
vector<unique_ptr<ArrowArrayWrapper>> arrays;
vector<LogicalType> types;
vector<string> names;
ClientProperties client_properties;
idx_t index;
string last_error;
};

// Destructor for capsules that own a heap-allocated ArrowArrayStream (slow path).
static void ArrowArrayStreamPyCapsuleDestructor(PyObject *object) {
auto data = PyCapsule_GetPointer(object, "arrow_array_stream");
if (!data) {
Expand All @@ -508,7 +583,28 @@ static void ArrowArrayStreamPyCapsuleDestructor(PyObject *object) {
delete stream;
}

// Destructor for capsules pointing at an embedded ArrowArrayStream (fast path).
// The stream is owned by an ArrowQueryResultStreamWrapper; Release() frees both.
static void ArrowArrayStreamEmbeddedPyCapsuleDestructor(PyObject *object) {
auto data = PyCapsule_GetPointer(object, "arrow_array_stream");
if (!data) {
return;
}
auto stream = reinterpret_cast<ArrowArrayStream *>(data);
if (stream->release) {
stream->release(stream);
}
}

py::object DuckDBPyResult::FetchArrowCapsule(idx_t rows_per_batch) {
if (result && result->type == QueryResultType::ARROW_RESULT) {
// Fast path: yield pre-built Arrow arrays directly.
// The wrapper is heap-allocated; Release() deletes it via private_data.
// The capsule points at the embedded stream field — no separate heap allocation needed.
auto wrapper = new ArrowQueryResultStreamWrapper(std::move(result));
return py::capsule(&wrapper->stream, "arrow_array_stream", ArrowArrayStreamEmbeddedPyCapsuleDestructor);
}
// Existing slow path for MaterializedQueryResult / StreamQueryResult
auto stream_p = FetchArrowArrayStream(rows_per_batch);
auto stream = new ArrowArrayStream();
*stream = stream_p;
Expand Down
68 changes: 68 additions & 0 deletions tests/fast/arrow/test_arrow_pycapsule.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import duckdb

pa = pytest.importorskip("pyarrow")
pl = pytest.importorskip("polars")


Expand All @@ -11,6 +12,73 @@ def polars_supports_capsule():
return Version(pl.__version__) >= Version("1.4.1")


class TestArrowPyCapsuleExport:
"""Tests for the PyCapsule export path (rel.__arrow_c_stream__).

Validates that the fast path (PhysicalArrowCollector + ArrowQueryResultStreamWrapper)
produces correct data, matching to_arrow_table() across types and edge cases.
"""

def test_capsule_matches_to_arrow_table(self):
"""Fast path produces identical data to to_arrow_table for various types."""
conn = duckdb.connect()
sql = """
SELECT
i AS int_col,
i::DOUBLE AS double_col,
'row_' || i::VARCHAR AS str_col,
i % 2 = 0 AS bool_col,
CASE WHEN i % 3 = 0 THEN NULL ELSE i END AS nullable_col
FROM range(1000) t(i)
"""
expected = conn.sql(sql).to_arrow_table()
actual = pa.table(conn.sql(sql))
assert actual.equals(expected)

def test_capsule_matches_to_arrow_table_nested_types(self):
"""Fast path handles nested types (struct, list, map)."""
conn = duckdb.connect()
sql = """
SELECT
{'x': i, 'y': i::VARCHAR} AS struct_col,
[i, i+1, i+2] AS list_col,
MAP {i::VARCHAR: i*10} AS map_col,
FROM range(100) t(i)
"""
expected = conn.sql(sql).to_arrow_table()
actual = pa.table(conn.sql(sql))
assert actual.equals(expected)

def test_capsule_multi_batch(self):
"""Data exceeding the 1M batch size produces multiple batches, all yielded correctly."""
conn = duckdb.connect()
sql = "SELECT i, i::DOUBLE AS d FROM range(1500000) t(i)"
expected = conn.sql(sql).to_arrow_table()
actual = pa.table(conn.sql(sql))
assert actual.num_rows == 1500000
assert actual.equals(expected)

def test_capsule_empty_result(self):
"""Empty result set produces a valid empty table with correct schema."""
conn = duckdb.connect()
sql = "SELECT i AS a, i::VARCHAR AS b FROM range(10) t(i) WHERE i < 0"
expected = conn.sql(sql).to_arrow_table()
actual = pa.table(conn.sql(sql))
assert actual.num_rows == 0
assert actual.schema.equals(expected.schema)

def test_capsule_slow_path_after_execute(self):
"""Pre-executed relation takes the slow path (MaterializedQueryResult) and still works."""
conn = duckdb.connect()
sql = "SELECT i, i::DOUBLE AS d FROM range(500) t(i)"
expected = conn.sql(sql).to_arrow_table()

rel = conn.sql(sql)
rel.execute() # forces MaterializedCollector, not PhysicalArrowCollector
actual = pa.table(rel)
assert actual.equals(expected)


@pytest.mark.skipif(
not polars_supports_capsule(), reason="Polars version does not support the Arrow PyCapsule interface"
)
Expand Down