Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/deltalake/_internal.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ class RawDeltaTable:
commit_properties: CommitProperties | None,
post_commithook_properties: PostCommitHookProperties | None,
) -> None: ...
def __datafusion_table_provider__(self) -> Any: ...
def __datafusion_table_provider__(self, ffi_codec: Any) -> Any: ...
def write(
self,
data: RecordBatchReader,
Expand Down
4 changes: 2 additions & 2 deletions python/deltalake/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1166,7 +1166,7 @@ def create_write_transaction(
post_commithook_properties=post_commithook_properties,
)

def __datafusion_table_provider__(self) -> Any:
def __datafusion_table_provider__(self, ffi_codec: Any) -> Any:
"""Return the DataFusion table provider PyCapsule interface.

To support DataFusion features such as push down filtering, this function will return a PyCapsule
Expand Down Expand Up @@ -1201,7 +1201,7 @@ def __datafusion_table_provider__(self) -> Any:
+----+----+----+
```
"""
return self._table.__datafusion_table_provider__()
return self._table.__datafusion_table_provider__(ffi_codec)


class TableMerger:
Expand Down
44 changes: 44 additions & 0 deletions python/src/datafusion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use datafusion::physical_plan::limit::GlobalLimitExec;
use datafusion::physical_plan::memory::{LazyBatchGenerator, LazyMemoryExec};
use datafusion::physical_plan::projection::ProjectionExec;
use datafusion::physical_plan::{ExecutionPlan, Statistics};
use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec;
use deltalake::datafusion::catalog::{Session, TableProvider};
use deltalake::datafusion::common::{Column, DFSchema, Result as DataFusionResult};
use deltalake::datafusion::datasource::TableType;
Expand All @@ -19,6 +20,10 @@ use deltalake::datafusion::prelude::Expr;
use deltalake::delta_datafusion::DeltaScanNext;
use deltalake::{datafusion, DeltaResult, DeltaTableError};
use parking_lot::RwLock;
use pyo3::{Bound, PyAny, PyResult};
use pyo3::exceptions::PyValueError;
use pyo3::prelude::{PyAnyMethods, PyCapsuleMethods};
use pyo3::types::PyCapsule;
use tokio::runtime::Handle;

#[derive(Debug)]
Expand Down Expand Up @@ -174,6 +179,45 @@ impl TableProvider for TokioDeltaScan {
}
}

pub(crate) fn ffi_logical_codec_from_pycapsule(
obj: &Bound<PyAny>,
) -> PyResult<FFI_LogicalExtensionCodec> {
let attr_name = "__datafusion_logical_extension_codec__";

if obj.hasattr(attr_name)? {
let capsule = obj.getattr(attr_name)?.call0()?;
let capsule = capsule.downcast::<PyCapsule>()?;
validate_pycapsule(capsule, "datafusion_logical_extension_codec")?;

let provider = unsafe { capsule.reference::<FFI_LogicalExtensionCodec>() };

Ok(provider.clone())
} else {
Err(PyValueError::new_err(
"Expected PyCapsule object for FFI_LogicalExtensionCodec, but attribute does not exist",
))
}
}

pub(crate) fn validate_pycapsule(capsule: &Bound<PyCapsule>, name: &str) -> PyResult<()> {
let capsule_name = capsule.name()?;
if capsule_name.is_none() {
return Err(PyValueError::new_err(format!(
"Expected {name} PyCapsule to have name set."
)));
}

let capsule_name = capsule_name.unwrap().to_str()?;
if capsule_name != name {
return Err(PyValueError::new_err(format!(
"Expected name '{name}' in PyCapsule, instead got '{capsule_name}'"
)));
}

Ok(())
}


#[cfg(test)]
mod tests {
use super::*;
Expand Down
11 changes: 4 additions & 7 deletions python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ use deltalake::datafusion::execution::TaskContextProvider;
#[repr(C)]
struct FFITableProviderCapsuleData {
provider: FFI_TableProvider,
_ctx: Arc<SessionContext>,
}
use delta_kernel::expressions::Scalar;
use delta_kernel::schema::{MetadataValue, StructField};
Expand Down Expand Up @@ -78,7 +77,7 @@ use uuid::Uuid;

use writer::maybe_lazy_cast_reader;

use crate::datafusion::TokioDeltaScan;
use crate::datafusion::{ffi_logical_codec_from_pycapsule, TokioDeltaScan};
use crate::error::{to_rt_err, DeltaError, DeltaProtocolError, PythonError};
use crate::features::TableFeatures;
use crate::filesystem::FsConfig;
Expand Down Expand Up @@ -1847,6 +1846,7 @@ impl RawDeltaTable {
fn __datafusion_table_provider__<'py>(
&self,
py: Python<'py>,
ffi_codec: &Bound<'py, PyAny>
) -> PyResult<Bound<'py, PyCapsule>> {
let handle = rt().handle().clone();
let name = CString::new("datafusion_table_provider").unwrap();
Expand All @@ -1862,15 +1862,12 @@ impl RawDeltaTable {
let tokio_scan =
Arc::new(TokioDeltaScan::new(scan, handle.clone())) as Arc<dyn TableProvider>;

let ctx = Arc::new(SessionContext::new());
let task_ctx_provider = Arc::clone(&ctx) as Arc<dyn TaskContextProvider>;
let ffi_task_ctx = FFI_TaskContextProvider::from(&task_ctx_provider);
let codec = ffi_logical_codec_from_pycapsule(ffi_codec)?;
let provider =
FFI_TableProvider::new(tokio_scan, false, Some(handle.clone()), ffi_task_ctx, None);
FFI_TableProvider::new_with_ffi_codec(tokio_scan, false, Some(handle.clone()), codec);

let capsule_data = FFITableProviderCapsuleData {
provider,
_ctx: ctx,
};
PyCapsule::new(py, capsule_data, Some(name))
}
Expand Down
Loading