diff --git a/python/deltalake/_internal.pyi b/python/deltalake/_internal.pyi index cbc1a5da5..7f9df6028 100644 --- a/python/deltalake/_internal.pyi +++ b/python/deltalake/_internal.pyi @@ -265,7 +265,7 @@ class RawDeltaTable: commit_properties: CommitProperties | None, post_commithook_properties: PostCommitHookProperties | None, ) -> None: ... - def __datafusion_table_provider__(self) -> Any: ... + def __datafusion_table_provider__(self, ffi_codec: Any) -> Any: ... def write( self, data: RecordBatchReader, diff --git a/python/deltalake/table.py b/python/deltalake/table.py index 98a4fbee9..8b1a32df8 100644 --- a/python/deltalake/table.py +++ b/python/deltalake/table.py @@ -1166,7 +1166,7 @@ def create_write_transaction( post_commithook_properties=post_commithook_properties, ) - def __datafusion_table_provider__(self) -> Any: + def __datafusion_table_provider__(self, ffi_codec: Any) -> Any: """Return the DataFusion table provider PyCapsule interface. To support DataFusion features such as push down filtering, this function will return a PyCapsule @@ -1201,7 +1201,7 @@ def __datafusion_table_provider__(self) -> Any: +----+----+----+ ``` """ - return self._table.__datafusion_table_provider__() + return self._table.__datafusion_table_provider__(ffi_codec) class TableMerger: diff --git a/python/src/datafusion.rs b/python/src/datafusion.rs index 0ad0274d4..a886534f0 100644 --- a/python/src/datafusion.rs +++ b/python/src/datafusion.rs @@ -11,6 +11,7 @@ use datafusion::physical_plan::limit::GlobalLimitExec; use datafusion::physical_plan::memory::{LazyBatchGenerator, LazyMemoryExec}; use datafusion::physical_plan::projection::ProjectionExec; use datafusion::physical_plan::{ExecutionPlan, Statistics}; +use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec; use deltalake::datafusion::catalog::{Session, TableProvider}; use deltalake::datafusion::common::{Column, DFSchema, Result as DataFusionResult}; use deltalake::datafusion::datasource::TableType; @@ -19,6 +20,10 @@ use deltalake::datafusion::prelude::Expr; use deltalake::delta_datafusion::DeltaScanNext; use deltalake::{datafusion, DeltaResult, DeltaTableError}; use parking_lot::RwLock; +use pyo3::{Bound, PyAny, PyResult}; +use pyo3::exceptions::PyValueError; +use pyo3::prelude::{PyAnyMethods, PyCapsuleMethods}; +use pyo3::types::PyCapsule; use tokio::runtime::Handle; #[derive(Debug)] @@ -174,6 +179,45 @@ impl TableProvider for TokioDeltaScan { } } +pub(crate) fn ffi_logical_codec_from_pycapsule( + obj: &Bound, +) -> PyResult { + let attr_name = "__datafusion_logical_extension_codec__"; + + if obj.hasattr(attr_name)? { + let capsule = obj.getattr(attr_name)?.call0()?; + let capsule = capsule.downcast::()?; + validate_pycapsule(capsule, "datafusion_logical_extension_codec")?; + + let provider = unsafe { capsule.reference::() }; + + Ok(provider.clone()) + } else { + Err(PyValueError::new_err( + "Expected PyCapsule object for FFI_LogicalExtensionCodec, but attribute does not exist", + )) + } +} + +pub(crate) fn validate_pycapsule(capsule: &Bound, name: &str) -> PyResult<()> { + let capsule_name = capsule.name()?; + if capsule_name.is_none() { + return Err(PyValueError::new_err(format!( + "Expected {name} PyCapsule to have name set." + ))); + } + + let capsule_name = capsule_name.unwrap().to_str()?; + if capsule_name != name { + return Err(PyValueError::new_err(format!( + "Expected name '{name}' in PyCapsule, instead got '{capsule_name}'" + ))); + } + + Ok(()) +} + + #[cfg(test)] mod tests { use super::*; diff --git a/python/src/lib.rs b/python/src/lib.rs index 17fa9ff5e..5587d4055 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -22,7 +22,6 @@ use deltalake::datafusion::execution::TaskContextProvider; #[repr(C)] struct FFITableProviderCapsuleData { provider: FFI_TableProvider, - _ctx: Arc, } use delta_kernel::expressions::Scalar; use delta_kernel::schema::{MetadataValue, StructField}; @@ -78,7 +77,7 @@ use uuid::Uuid; use writer::maybe_lazy_cast_reader; -use crate::datafusion::TokioDeltaScan; +use crate::datafusion::{ffi_logical_codec_from_pycapsule, TokioDeltaScan}; use crate::error::{to_rt_err, DeltaError, DeltaProtocolError, PythonError}; use crate::features::TableFeatures; use crate::filesystem::FsConfig; @@ -1847,6 +1846,7 @@ impl RawDeltaTable { fn __datafusion_table_provider__<'py>( &self, py: Python<'py>, + ffi_codec: &Bound<'py, PyAny> ) -> PyResult> { let handle = rt().handle().clone(); let name = CString::new("datafusion_table_provider").unwrap(); @@ -1862,15 +1862,12 @@ impl RawDeltaTable { let tokio_scan = Arc::new(TokioDeltaScan::new(scan, handle.clone())) as Arc; - let ctx = Arc::new(SessionContext::new()); - let task_ctx_provider = Arc::clone(&ctx) as Arc; - let ffi_task_ctx = FFI_TaskContextProvider::from(&task_ctx_provider); + let codec = ffi_logical_codec_from_pycapsule(ffi_codec)?; let provider = - FFI_TableProvider::new(tokio_scan, false, Some(handle.clone()), ffi_task_ctx, None); + FFI_TableProvider::new_with_ffi_codec(tokio_scan, false, Some(handle.clone()), codec); let capsule_data = FFITableProviderCapsuleData { provider, - _ctx: ctx, }; PyCapsule::new(py, capsule_data, Some(name)) }