diff --git a/Cargo.lock b/Cargo.lock index 2331e62ab..986e71c81 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5630,6 +5630,7 @@ name = "sedona" version = "0.4.0" dependencies = [ "arrow-array", + "arrow-buffer", "arrow-schema", "async-trait", "aws-config", @@ -5639,6 +5640,7 @@ dependencies = [ "datafusion-common", "datafusion-expr", "datafusion-ffi", + "datafusion-optimizer", "dirs", "futures", "geo-traits", @@ -5659,6 +5661,7 @@ dependencies = [ "sedona-pointcloud", "sedona-proj", "sedona-query-planner", + "sedona-raster", "sedona-raster-functions", "sedona-raster-gdal", "sedona-s2geography", @@ -6094,10 +6097,12 @@ dependencies = [ "arrow-buffer", "arrow-ipc", "arrow-schema", + "async-trait", "datafusion-common", "sedona-common", "sedona-schema", "sedona-testing", + "tokio", ] [[package]] @@ -6107,6 +6112,7 @@ dependencies = [ "arrow-array", "arrow-buffer", "arrow-schema", + "async-trait", "criterion", "datafusion-common", "datafusion-expr", @@ -6121,6 +6127,7 @@ dependencies = [ "sedona-testing", "sedona-tg", "serde_json", + "tokio", "wkb", ] @@ -6131,6 +6138,7 @@ dependencies = [ "arrow-array", "arrow-buffer", "arrow-schema", + "async-trait", "criterion", "datafusion-common", "datafusion-expr", @@ -6151,7 +6159,9 @@ name = "sedona-raster-zarr" version = "0.4.0" dependencies = [ "arrow-array", + "arrow-buffer", "arrow-schema", + "async-trait", "datafusion-common", "futures", "log", diff --git a/c/sedona-extension/src/scalar_kernel.rs b/c/sedona-extension/src/scalar_kernel.rs index 66f0f787c..78b5070b6 100644 --- a/c/sedona-extension/src/scalar_kernel.rs +++ b/c/sedona-extension/src/scalar_kernel.rs @@ -364,8 +364,8 @@ impl ExportedScalarKernel { /// when passed across a boundary. pub fn with_function_name(self, function_name: impl AsRef) -> Self { Self { - inner: self.inner, function_name: Some(CString::from_str(function_name.as_ref()).unwrap()), + ..self } } diff --git a/rust/sedona-expr/src/scalar_udf.rs b/rust/sedona-expr/src/scalar_udf.rs index de0a87f97..815f9d37f 100644 --- a/rust/sedona-expr/src/scalar_udf.rs +++ b/rust/sedona-expr/src/scalar_udf.rs @@ -14,7 +14,7 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. -use std::{any::Any, fmt::Debug, sync::Arc}; +use std::{any::Any, collections::HashMap, fmt::Debug, sync::Arc}; use arrow_schema::{DataType, FieldRef}; use datafusion_common::config::ConfigOptions; @@ -71,6 +71,15 @@ pub struct SedonaScalarUDF { signature: Signature, kernels: Vec, aliases: Vec, + /// Class-level, string-keyed metadata describing this UDF to the + /// planner. Flags are set via [`SedonaScalarUDF::with_metadata`] and + /// read back via [`SedonaScalarUDF::metadata`]; the planner keys off + /// well-known entries (e.g. the raster `"needs_pixels"` flag, whose + /// key is owned by `sedona-raster-functions`) without this crate + /// knowing their meaning. The map shape leaves room for further + /// planner-visible flags — and a future cross-cdylib FFI carrying + /// them — without a new field per flag. + metadata: HashMap, } impl PartialEq for SedonaScalarUDF { @@ -191,17 +200,32 @@ impl SedonaScalarUDF { signature, kernels, aliases: vec![], + metadata: HashMap::new(), } } /// Add aliases to an existing SedonaScalarUDF pub fn with_aliases(self, aliases: Vec) -> SedonaScalarUDF { - Self { - name: self.name, - signature: self.signature, - kernels: self.kernels, - aliases, - } + Self { aliases, ..self } + } + + /// Set a class-level metadata entry on this UDF, returning the + /// modified UDF. Metadata is planner-visible (e.g. the + /// `RS_EnsureLoaded` optimizer rule reads the raster `"needs_pixels"` + /// flag) and crosses the `sedona-extension` FFI boundary so + /// plugin-defined UDFs can declare it too. + pub fn with_metadata( + mut self, + key: impl Into, + value: impl Into, + ) -> SedonaScalarUDF { + self.metadata.insert(key.into(), value.into()); + self + } + + /// Class-level metadata map describing this UDF to the planner. + pub fn metadata(&self) -> &HashMap { + &self.metadata } /// Create a SedonaScalarUDF from a single kernel @@ -334,6 +358,30 @@ mod tests { use super::*; + #[test] + fn metadata_defaults_empty_and_set_via_builder() { + let udf = SedonaScalarUDF::new("u", vec![], Volatility::Immutable); + assert!(udf.metadata().get("a_flag").is_none()); + + let annotated = udf.with_metadata("a_flag", "true"); + assert_eq!( + annotated.metadata().get("a_flag").map(String::as_str), + Some("true") + ); + } + + #[test] + fn metadata_survives_with_aliases() { + let udf = SedonaScalarUDF::new("u", vec![], Volatility::Immutable) + .with_metadata("a_flag", "true") + .with_aliases(vec!["u_alias".to_string()]); + assert_eq!( + udf.metadata().get("a_flag").map(String::as_str), + Some("true") + ); + assert_eq!(udf.aliases(), &["u_alias".to_string()]); + } + #[test] fn udf_empty() -> Result<()> { // UDF with no implementations diff --git a/rust/sedona-query-planner/src/ensure_loaded.rs b/rust/sedona-query-planner/src/ensure_loaded.rs new file mode 100644 index 000000000..12901d490 --- /dev/null +++ b/rust/sedona-query-planner/src/ensure_loaded.rs @@ -0,0 +1,483 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Logical optimizer rule that wraps raster arguments of `needs_bytes` +//! UDFs with `RS_EnsureLoaded`, so OutDb byte materialisation happens +//! explicitly in the logical plan instead of as a hidden side effect +//! inside the kernel. +//! +//! After this rule, calls like `RS_Value(raster, x, y)` (where +//! `RS_Value` is annotated with the `needs_pixels` metadata flag) become +//! `RS_Value(RS_EnsureLoaded(raster), x, y)`. DataFusion's +//! `CommonSubexprEliminate` pass deduplicates identical +//! `RS_EnsureLoaded(col)` calls across multiple `needs_bytes` UDFs +//! sharing the same raster column — provided `RS_EnsureLoaded`'s +//! signature is `Volatility::Stable` (not `Volatile`). +//! +//! This is a logical optimizer rule (not an analyzer rule) so it can +//! look `RS_EnsureLoaded` up from the [`FunctionRegistry`] rather than +//! capturing an `Arc` at construction time. Because optimizer rules run +//! to a fixpoint, the rewrite is idempotent: an argument already wrapped +//! in `RS_EnsureLoaded` is left alone (see [`is_ensure_loaded_call`]). + +use std::sync::Arc; + +use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::{DFSchema, Result}; +use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::expr_schema::ExprSchemable; +use datafusion_expr::{Expr, LogicalPlan, ScalarUDF}; +use datafusion_optimizer::{ApplyOrder, OptimizerConfig, OptimizerRule}; +use sedona_common::sedona_internal_err; +use sedona_expr::scalar_udf::SedonaScalarUDF; +use sedona_schema::datatypes::SedonaType; + +/// `SedonaScalarUDF` metadata key marking a UDF whose kernels read raster +/// pixel bytes. Duplicated from `sedona_raster_functions` (the owner), +/// which this crate can't depend on — keep the literal in sync with +/// `sedona_raster_functions::rs_ensure_loaded::NEEDS_PIXELS_METADATA_KEY`. +const NEEDS_PIXELS_METADATA_KEY: &str = "needs_pixels"; + +/// Logical optimizer rule wrapping raster arguments of `needs_bytes` +/// UDFs with `RS_EnsureLoaded`. Stateless — the `RS_EnsureLoaded` UDF +/// is resolved from the session's [`FunctionRegistry`] at rewrite time. +#[derive(Default, Debug)] +pub struct EnsureLoadedOptimizerRule; + +impl OptimizerRule for EnsureLoadedOptimizerRule { + fn name(&self) -> &str { + "sedona.ensure_loaded" + } + + fn apply_order(&self) -> Option { + // Bottom-up so a nested `RS_X(RS_Y(rast))` is rewritten + // inside-out: the inner call's raster arg is wrapped first, then + // the outer call sees the (now-wrapped, still raster-typed) arg + // and the idempotency guard keeps it from double-wrapping. + Some(ApplyOrder::BottomUp) + } + + fn supports_rewrite(&self) -> bool { + true + } + + fn rewrite( + &self, + plan: LogicalPlan, + config: &dyn OptimizerConfig, + ) -> Result> { + // Resolve RS_EnsureLoaded from the registry. A context that never + // registered it (no raster support) has nothing to rewrite. + let Some(registry) = config.function_registry() else { + return Ok(Transformed::no(plan)); + }; + let Ok(ensure_loaded_udf) = registry.udf("rs_ensureloaded") else { + return Ok(Transformed::no(plan)); + }; + + // Type-check argument expressions against the merged schema of the + // node's INPUTS, not the node's own (output) schema. For a + // Projection the output schema holds the projected results + // (`rs_value(rast, …)`), not the input `rast` column the argument + // references, so `plan.schema()` would fail to recognise the raster + // arg and silently skip wrapping. Single-input nodes (Projection, + // Filter, …) use their one input; a Join's `filter` references + // left ⋈ right, so the merged schema resolves either side. Leaf + // nodes carry no wrappable expressions. + let inputs = plan.inputs(); + if inputs.is_empty() { + return Ok(Transformed::no(plan)); + } + let Some(schema) = merged_input_schema(&inputs) else { + // Schemas couldn't be merged (e.g. ambiguous duplicate + // qualifiers in a self-join). Skip this node rather than + // failing the query — a missed wrap surfaces later as a clear + // "raster bytes not loaded" error, not a wrong result. + return Ok(Transformed::no(plan)); + }; + drop(inputs); + + plan.map_expressions(|e| { + e.transform_up(|expr| rewrite_expr_node(expr, &schema, &ensure_loaded_udf)) + }) + } +} + +/// Merge the schemas of all inputs into one. Returns `None` if the merge +/// fails (DataFusion's [`DFSchema::join`] errors on ambiguous duplicate +/// qualified fields). +fn merged_input_schema(inputs: &[&LogicalPlan]) -> Option> { + let mut merged = inputs[0].schema().as_ref().clone(); + for input in &inputs[1..] { + merged = merged.join(input.schema()).ok()?; + } + Some(Arc::new(merged)) +} + +/// Single-step rewrite: if `expr` is a `needs_bytes` UDF call, wrap each +/// raster-typed arg with `RS_EnsureLoaded`. Two guards keep it correct: +/// it never wraps `RS_EnsureLoaded` itself (recursion), and it never +/// re-wraps an arg already wrapped in `RS_EnsureLoaded` (idempotency, +/// required because optimizer rules run to a fixpoint). +fn rewrite_expr_node( + expr: Expr, + schema: &Arc, + ensure_loaded_udf: &Arc, +) -> Result> { + let Expr::ScalarFunction(ref func_call) = expr else { + return Ok(Transformed::no(expr)); + }; + + // Recursion guard. + if func_call.func.name() == "rs_ensureloaded" { + return Ok(Transformed::no(expr)); + } + + // Only annotated SedonaScalarUDFs participate. DataFusion built-ins + // and unannotated UDFs pass through unchanged. + let needs_bytes = func_call + .func + .inner() + .as_any() + .downcast_ref::() + .map(|u| { + u.metadata() + .get(NEEDS_PIXELS_METADATA_KEY) + .map(String::as_str) + == Some("true") + }) + .unwrap_or(false); + if !needs_bytes { + return Ok(Transformed::no(expr)); + } + + // Structurally impossible: we matched `expr` as `Expr::ScalarFunction` + // a few lines up. Surface it as an internal error rather than a panic + // so a future refactor that breaks the invariant fails the query + // cleanly instead of crashing a worker. + let Expr::ScalarFunction(ScalarFunction { func, args }) = expr else { + return sedona_internal_err!( + "rewrite_expr_node: expected ScalarFunction after match, got a different Expr variant" + ); + }; + let mut changed = false; + let new_args: Vec = args + .into_iter() + .map(|arg| { + // Idempotency guard: a fixpoint re-run sees the wrapped arg + // (still raster-typed after RS_EnsureLoaded's identity + // `return_field`); don't wrap it again. + if is_ensure_loaded_call(&arg) { + return arg; + } + if expr_is_raster(&arg, schema) { + changed = true; + Expr::ScalarFunction(ScalarFunction { + func: Arc::clone(ensure_loaded_udf), + args: vec![arg], + }) + } else { + arg + } + }) + .collect(); + + let rewritten = Expr::ScalarFunction(ScalarFunction { + func, + args: new_args, + }); + if changed { + Ok(Transformed::yes(rewritten)) + } else { + Ok(Transformed::no(rewritten)) + } +} + +/// True if `expr` is a call to `RS_EnsureLoaded`. +fn is_ensure_loaded_call(expr: &Expr) -> bool { + matches!(expr, Expr::ScalarFunction(sf) if sf.func.name() == "rs_ensureloaded") +} + +/// True if `expr` evaluates to a `SedonaType::Raster` under the given +/// schema. Uses `to_field` (not `get_type`) so the Field's extension +/// metadata is available — `SedonaType::Raster` is identified by an +/// `"sedona.raster"` extension type, not by raw `DataType::Struct`. +fn expr_is_raster(expr: &Expr, schema: &Arc) -> bool { + let Ok((_, field)) = expr.to_field(schema.as_ref()) else { + return false; + }; + matches!( + SedonaType::from_storage_field(&field), + Ok(SedonaType::Raster) + ) +} + +#[cfg(test)] +mod tests { + use super::*; + + use arrow_schema::{DataType, Field, Schema}; + use datafusion_common::tree_node::TreeNodeRecursion; + use datafusion_expr::{col, ScalarUDF, Volatility}; + use sedona_expr::scalar_udf::{ScalarKernelRef, SedonaScalarUDF, SimpleSedonaScalarKernel}; + use sedona_schema::matchers::ArgMatcher; + + /// A stand-in `rs_ensureloaded` UDF. The rule keys off the name and + /// the `needs_bytes` marker, never the real async impl (which lives + /// in the `sedona` crate and can't be referenced here), so a plain + /// SedonaScalarUDF carrying the canonical name is sufficient. + fn fake_ensure_loaded_udf() -> Arc { + let kernel: ScalarKernelRef = SimpleSedonaScalarKernel::new_ref( + ArgMatcher::new(vec![ArgMatcher::is_raster()], SedonaType::Raster), + Arc::new(|_, _| unreachable!("stub kernel; rewrite never invokes it")), + ); + let udf = SedonaScalarUDF::new("rs_ensureloaded", vec![kernel], Volatility::Immutable); + Arc::new(ScalarUDF::new_from_impl(udf)) + } + + /// A `needs_bytes` UDF accepting a raster, returning Int32. + fn needs_bytes_udf(name: &str) -> Arc { + let kernel: ScalarKernelRef = SimpleSedonaScalarKernel::new_ref( + ArgMatcher::new( + vec![ArgMatcher::is_raster()], + SedonaType::Arrow(DataType::Int32), + ), + Arc::new(|_, _| unreachable!("stub kernel; not invoked at plan time")), + ); + let udf = SedonaScalarUDF::new(name, vec![kernel], Volatility::Immutable) + .with_metadata(NEEDS_PIXELS_METADATA_KEY, "true"); + Arc::new(ScalarUDF::new_from_impl(udf)) + } + + /// Same shape but without the `needs_bytes` annotation. + fn metadata_only_udf(name: &str) -> Arc { + let kernel: ScalarKernelRef = SimpleSedonaScalarKernel::new_ref( + ArgMatcher::new( + vec![ArgMatcher::is_raster()], + SedonaType::Arrow(DataType::Int32), + ), + Arc::new(|_, _| unreachable!("stub kernel; not invoked at plan time")), + ); + let udf = SedonaScalarUDF::new(name, vec![kernel], Volatility::Immutable); + Arc::new(ScalarUDF::new_from_impl(udf)) + } + + fn raster_schema_named(name: &str) -> Arc { + let field = SedonaType::Raster.to_storage_field(name, true).unwrap(); + let arrow_schema = Arc::new(Schema::new(vec![field])); + Arc::new(DFSchema::try_from(arrow_schema.as_ref().clone()).unwrap()) + } + + fn int_schema(name: &str) -> Arc { + let field = Field::new(name, DataType::Int64, true); + let arrow_schema = Arc::new(Schema::new(vec![field])); + Arc::new(DFSchema::try_from(arrow_schema.as_ref().clone()).unwrap()) + } + + fn count_ensure_loaded(expr: &Expr) -> usize { + let mut n = 0; + expr.apply(|e| { + if let Expr::ScalarFunction(sf) = e { + if sf.func.name() == "rs_ensureloaded" { + n += 1; + } + } + Ok(TreeNodeRecursion::Continue) + }) + .unwrap(); + n + } + + fn rewrite(expr: Expr, schema: &Arc, udf: &Arc) -> Expr { + rewrite_expr_node(expr, schema, udf).unwrap().data + } + + #[test] + fn wraps_raster_arg_of_needs_bytes_udf() { + let schema = raster_schema_named("rast"); + let udf = fake_ensure_loaded_udf(); + let call = Expr::ScalarFunction(ScalarFunction { + func: needs_bytes_udf("rs_mock"), + args: vec![col("rast")], + }); + let out = rewrite(call, &schema, &udf); + let Expr::ScalarFunction(ScalarFunction { args, .. }) = &out else { + panic!("expected ScalarFunction, got {out:?}"); + }; + assert!( + is_ensure_loaded_call(&args[0]), + "raster arg should be wrapped" + ); + } + + #[test] + fn leaves_non_raster_args_alone() { + let schema = int_schema("n"); + let udf = fake_ensure_loaded_udf(); + let call = Expr::ScalarFunction(ScalarFunction { + func: needs_bytes_udf("rs_mock"), + args: vec![col("n")], + }); + let out = rewrite(call, &schema, &udf); + assert_eq!(count_ensure_loaded(&out), 0); + } + + #[test] + fn leaves_metadata_only_udfs_alone() { + let schema = raster_schema_named("rast"); + let udf = fake_ensure_loaded_udf(); + let call = Expr::ScalarFunction(ScalarFunction { + func: metadata_only_udf("rs_meta"), + args: vec![col("rast")], + }); + let out = rewrite(call, &schema, &udf); + assert_eq!(count_ensure_loaded(&out), 0); + } + + #[test] + fn recursion_guard_does_not_wrap_rs_ensure_loaded_itself() { + let schema = raster_schema_named("rast"); + let udf = fake_ensure_loaded_udf(); + let call = Expr::ScalarFunction(ScalarFunction { + func: Arc::clone(&udf), + args: vec![col("rast")], + }); + let out = rewrite(call, &schema, &udf); + // Still exactly one — its raster arg is not itself wrapped. + assert_eq!(count_ensure_loaded(&out), 1); + } + + #[test] + fn idempotency_guard_does_not_rewrap_already_wrapped_arg() { + // Models the fixpoint re-run: the input already has + // rs_mock(rs_ensureloaded(rast)). A second pass must NOT produce + // rs_mock(rs_ensureloaded(rs_ensureloaded(rast))). + let schema = raster_schema_named("rast"); + let udf = fake_ensure_loaded_udf(); + let already_wrapped = Expr::ScalarFunction(ScalarFunction { + func: Arc::clone(&udf), + args: vec![col("rast")], + }); + let call = Expr::ScalarFunction(ScalarFunction { + func: needs_bytes_udf("rs_mock"), + args: vec![already_wrapped], + }); + let out = rewrite(call, &schema, &udf); + assert_eq!( + count_ensure_loaded(&out), + 1, + "already-wrapped arg must not be wrapped again: {out:?}" + ); + } + + #[test] + fn registers_immediately_before_cse() { + use crate::optimizer::register_ensure_loaded_optimizer; + use datafusion::execution::session_state::SessionStateBuilder; + + let builder = SessionStateBuilder::new().with_default_features(); + let mut builder = register_ensure_loaded_optimizer(builder).unwrap(); + + let rules = &builder.optimizer().as_ref().unwrap().rules; + let ours = rules + .iter() + .position(|r| r.name() == "sedona.ensure_loaded") + .expect("rule registered"); + let cse = rules + .iter() + .position(|r| r.name() == "common_sub_expression_eliminate") + .expect("CSE present in default optimizer"); + assert_eq!( + ours + 1, + cse, + "ensure_loaded must sit immediately before CSE so wraps dedupe in the same pass" + ); + } + + #[test] + fn merged_schema_resolves_raster_across_a_join() { + // Two single-raster inputs (left `a`, right `b`); the merged + // schema must see both so a join filter referencing either side's + // raster resolves and gets wrapped. + let left = LogicalPlan::EmptyRelation(datafusion_expr::EmptyRelation { + produce_one_row: false, + schema: raster_schema_named("a"), + }); + let right = LogicalPlan::EmptyRelation(datafusion_expr::EmptyRelation { + produce_one_row: false, + schema: raster_schema_named("b"), + }); + let inputs = [&left, &right]; + let merged = merged_input_schema(&inputs).expect("schemas merge"); + + let udf = fake_ensure_loaded_udf(); + // rs_mock(b) — the right side's raster, only resolvable via the + // merged schema. + let call = Expr::ScalarFunction(ScalarFunction { + func: needs_bytes_udf("rs_mock"), + args: vec![col("b")], + }); + let out = rewrite(call, &merged, &udf); + assert_eq!( + count_ensure_loaded(&out), + 1, + "raster arg from the right join input should be wrapped: {out:?}" + ); + } + + #[test] + fn rule_wraps_raster_arg_through_a_projection() { + // Drives the real `OptimizerRule::rewrite()` (not the + // `rewrite_expr_node` helper) on a Projection — `SELECT rs_mock(rast)`. + // The projection's OUTPUT schema holds the result column, not the + // input `rast`, so the rule must type-check against the INPUT schema + // to recognise and wrap the raster arg. A regression guard against + // switching to `plan.schema()`, which would silently skip wrapping + // here (the common single-projection case). + use datafusion::execution::session_state::SessionStateBuilder; + use datafusion_expr::registry::FunctionRegistry; + use datafusion_expr::{EmptyRelation, LogicalPlanBuilder}; + + // SessionState doubles as the OptimizerConfig and carries the + // function registry the rule resolves `rs_ensureloaded` from. + let mut state = SessionStateBuilder::new().with_default_features().build(); + state.register_udf(fake_ensure_loaded_udf()).unwrap(); + + let scan = LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: raster_schema_named("rast"), + }); + let proj = Expr::ScalarFunction(ScalarFunction { + func: needs_bytes_udf("rs_mock"), + args: vec![col("rast")], + }); + let plan = LogicalPlanBuilder::from(scan) + .project(vec![proj]) + .unwrap() + .build() + .unwrap(); + + let out = EnsureLoadedOptimizerRule.rewrite(plan, &state).unwrap(); + + let wrapped: usize = out.data.expressions().iter().map(count_ensure_loaded).sum(); + assert_eq!( + wrapped, 1, + "projection's raster arg should be wrapped via the input schema: {:?}", + out.data + ); + } +} diff --git a/rust/sedona-query-planner/src/lib.rs b/rust/sedona-query-planner/src/lib.rs index 4bec6d2b7..85b5190c0 100644 --- a/rust/sedona-query-planner/src/lib.rs +++ b/rust/sedona-query-planner/src/lib.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +pub mod ensure_loaded; mod logical_plan_node; pub mod optimizer; pub mod probe_shuffle_exec; diff --git a/rust/sedona-query-planner/src/optimizer.rs b/rust/sedona-query-planner/src/optimizer.rs index 9d30a07f1..c40bf5b96 100644 --- a/rust/sedona-query-planner/src/optimizer.rs +++ b/rust/sedona-query-planner/src/optimizer.rs @@ -16,6 +16,7 @@ // under the License. use std::sync::Arc; +use crate::ensure_loaded::EnsureLoadedOptimizerRule; use crate::logical_plan_node::SpatialJoinPlanNode; use crate::spatial_expr_utils::{ collect_spatial_predicate_names, find_knn_query_side, KNNJoinQuerySide, @@ -94,6 +95,33 @@ pub fn register_spatial_join_logical_optimizer( Ok(session_state_builder) } +/// Register the `RS_EnsureLoaded`-wrapping logical optimizer rule. +/// +/// Inserts [`EnsureLoadedOptimizerRule`] immediately before DataFusion's +/// `common_sub_expression_eliminate` so that, in the same optimizer +/// pass, CSE can dedupe the `RS_EnsureLoaded(col)` wraps this rule +/// injects across multiple `needs_bytes` UDFs sharing a raster column. +/// Falls back to appending if CSE isn't present. +pub fn register_ensure_loaded_optimizer( + mut session_state_builder: SessionStateBuilder, +) -> Result { + let optimizer = session_state_builder + .optimizer() + .get_or_insert_with(Optimizer::new); + + let rule = Arc::new(EnsureLoadedOptimizerRule); + match optimizer + .rules + .iter() + .position(|r| r.name() == "common_sub_expression_eliminate") + { + Some(cse_pos) => optimizer.rules.insert(cse_pos, rule), + None => optimizer.rules.push(rule), + } + + Ok(session_state_builder) +} + /// Early optimizer rule that converts KNN joins to `SpatialJoinPlanNode` extension nodes /// *before* DataFusion's `PushDownFilter` runs. /// diff --git a/rust/sedona-raster-functions/Cargo.toml b/rust/sedona-raster-functions/Cargo.toml index 36ab816fb..7d65c7a03 100644 --- a/rust/sedona-raster-functions/Cargo.toml +++ b/rust/sedona-raster-functions/Cargo.toml @@ -34,6 +34,7 @@ result_large_err = "allow" arrow-schema = { workspace = true } arrow-array = { workspace = true } arrow-buffer = { workspace = true } +async-trait = { workspace = true } datafusion-common = { workspace = true } datafusion-expr = { workspace = true } sedona-common = { workspace = true } @@ -52,6 +53,7 @@ geo-traits = { workspace = true } sedona-testing = { workspace = true, features = ["criterion"] } sedona-proj = { workspace = true, features = ["proj-sys"] } rstest = { workspace = true } +tokio = { workspace = true, features = ["macros", "rt-multi-thread"] } [[bench]] harness = false diff --git a/rust/sedona-raster-functions/src/lib.rs b/rust/sedona-raster-functions/src/lib.rs index dbc3749ed..1d9c10fd9 100644 --- a/rust/sedona-raster-functions/src/lib.rs +++ b/rust/sedona-raster-functions/src/lib.rs @@ -22,6 +22,7 @@ pub mod register; pub mod rs_band_accessors; pub mod rs_bandpath; pub mod rs_convexhull; +pub mod rs_ensure_loaded; pub mod rs_envelope; pub mod rs_example; pub mod rs_georeference; diff --git a/rust/sedona-raster-functions/src/rs_ensure_loaded.rs b/rust/sedona-raster-functions/src/rs_ensure_loaded.rs new file mode 100644 index 000000000..3fb786f18 --- /dev/null +++ b/rust/sedona-raster-functions/src/rs_ensure_loaded.rs @@ -0,0 +1,762 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! `RS_EnsureLoaded(raster) -> raster` — async UDF that materialises +//! the pixel bytes of any OutDb bands in the input raster column. +//! +//! Walks every input row, identifies bands whose `data` column is empty +//! (the schema-OutDb discriminator), groups them by `outdb_format`, +//! dispatches each via the [`RasterLoaderRegistry`] held on `SedonaContext`, +//! and assembles an output `RecordBatch` of the same row count whose +//! `data` columns are populated with the loaded bytes. InDb bands pass +//! through unchanged. Other band/raster metadata is preserved verbatim. + +use std::any::Any; +use std::hash::{Hash, Hasher}; +use std::sync::{Arc, RwLock}; + +use arrow_array::{Array, ArrayRef, StructArray}; +use arrow_buffer::Buffer; +use arrow_schema::{DataType, FieldRef}; +use async_trait::async_trait; +use datafusion_common::config::ConfigOptions; +use datafusion_common::{plan_err, Result}; +use datafusion_expr::async_udf::AsyncScalarUDFImpl; +use datafusion_expr::{ + ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use sedona_common::{sedona_internal_datafusion_err, sedona_internal_err}; +use sedona_raster::array::RasterStructArray; +use sedona_raster::builder::RasterBuilder; +use sedona_raster::raster_loader::{ + AsyncByteLoader, RasterLoadRequest, RasterLoaderConfig, RasterLoaderRegistry, +}; +use sedona_raster::traits::RasterRef; + +/// `SedonaScalarUDF` metadata key marking a UDF whose kernels read raster +/// pixel bytes. A raster function sets it (value `"true"`) via +/// `with_metadata`; the `RS_EnsureLoaded` optimizer rule keys off it to +/// decide whether to wrap raster arguments with byte materialisation. +/// +/// This crate owns the key. The optimizer rule lives in +/// `sedona-query-planner`, which can't depend on this crate, so it carries +/// a duplicate of the same string literal — keep the two in sync. +pub const NEEDS_PIXELS_METADATA_KEY: &str = "needs_pixels"; + +/// Async UDF that resolves OutDb bands by dispatching through the +/// [`RasterLoaderRegistry`] stashed in `ConfigOptions` as a +/// [`RasterLoaderConfig`] extension. The UDF instance itself is +/// session-agnostic — it pulls the registry handle out of +/// `args.config_options.extensions.get::()` at +/// dispatch time. This matches DataFusion's +/// `AsyncScalarUDFImpl::invoke_async_with_args` surface (only +/// `Arc` is reachable from the async fn) and mirrors how +/// `CrsProvider` flows through the session's options. +#[derive(Debug)] +pub struct RsEnsureLoaded { + signature: Signature, +} + +impl Default for RsEnsureLoaded { + fn default() -> Self { + Self::new() + } +} + +impl RsEnsureLoaded { + pub fn new() -> Self { + Self { + // `any(1, ...)` accepts whatever single-arg type the caller + // passes; we validate "argument is a Raster Struct" in + // `return_type` and at runtime. Using `Signature::any` (vs. + // `Signature::user_defined`) sidesteps DataFusion's + // `coerce_types` call path, which `AsyncScalarUDF` doesn't + // delegate to the inner impl. + // + // `Stable` (not `Volatile`) so DataFusion's CSE pass can + // deduplicate identical RS_EnsureLoaded(col) calls injected + // by the analyzer rule. Semantic: within a single query the + // byte materialisation is deterministic for fixed inputs; + // across queries the underlying storage may change, so the + // result isn't `Immutable`. + signature: Signature::any(1, Volatility::Stable), + } + } +} + +/// Pull the shared registry handle out of a `ConfigOptions`. Returns a +/// helpful error if the [`RasterLoaderConfig`] extension isn't installed +/// — that only happens if a caller bypasses `SedonaContext::new` to +/// build their own session, in which case naming the extension is the +/// right diagnostic. +fn registry_handle_from_config( + config: &ConfigOptions, +) -> Result>> { + config + .extensions + .get::() + .map(|cfg| cfg.registry.handle()) + .ok_or_else(|| { + sedona_internal_datafusion_err!( + "RasterLoaderConfig is not registered in this session's ConfigOptions; \ + RS_EnsureLoaded cannot dispatch without it. Use SedonaContext::new() \ + or insert the extension manually." + ) + }) +} + +fn lookup_loader( + registry: &Arc>, + format: &str, +) -> Result> { + let guard = registry.read().map_err(|e| { + sedona_internal_datafusion_err!("raster loader registry lock poisoned: {e}") + })?; + if let Some(loader) = guard.get(format) { + return Ok(loader); + } + // Build a diagnostic that lists registered formats so users know + // which loaders are registered. + let registered: Vec = guard.formats().map(String::from).collect(); + let registered_msg = if registered.is_empty() { + "no raster loaders are registered in this session".to_string() + } else { + format!("registered formats: {}", registered.join(", ")) + }; + plan_err!("no raster loader registered for format '{format}' — {registered_msg}") +} + +// One RsEnsureLoaded per session by construction — equality and hash +// are by identity (i.e. by name). DataFusion needs these to deduplicate +// `ScalarUDF` instances in the function registry; the struct holds no +// per-session state of its own. +impl PartialEq for RsEnsureLoaded { + fn eq(&self, _other: &Self) -> bool { + true + } +} +impl Eq for RsEnsureLoaded {} +impl Hash for RsEnsureLoaded { + fn hash(&self, state: &mut H) { + "rs_ensureloaded".hash(state); + } +} + +impl ScalarUDFImpl for RsEnsureLoaded { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "rs_ensureloaded" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + // Never called in practice — `return_field_from_args` below is the + // authoritative output-type source and carries the raster + // extension metadata that a bare `DataType` would drop. Provided + // only to satisfy the trait. + sedona_internal_err!( + "RS_EnsureLoaded::return_type should not be called; return_field_from_args is authoritative" + ) + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + // Identity on schema: the output raster has the same fields as the + // input — only the `data` column's bytes change. Return the input + // field verbatim so its `"sedona.raster"` extension metadata + // survives; building a fresh `Field` from the bare `DataType` + // (as the default `return_type`-based path does) would strip the + // extension and downstream code would stop recognising the column + // as a Raster. + if args.arg_fields.len() != 1 { + return plan_err!( + "RS_EnsureLoaded expects exactly one argument, got {}", + args.arg_fields.len() + ); + } + let field = &args.arg_fields[0]; + if !matches!(field.data_type(), DataType::Struct(_)) { + return plan_err!( + "RS_EnsureLoaded expects a Raster (Struct) argument, got {}", + field.data_type() + ); + } + Ok(Arc::clone(field)) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + // DataFusion routes async UDFs through `invoke_async_with_args` + // on the AsyncFuncExec node; this sync entry should never be + // called for an `AsyncScalarUDF`-wrapped impl. + sedona_internal_err!( + "RS_EnsureLoaded is async; AsyncFuncExec should have dispatched to invoke_async_with_args" + ) + } +} + +#[async_trait] +impl AsyncScalarUDFImpl for RsEnsureLoaded { + /// Materialising OutDb bytes is per-row I/O, so favour larger input + /// batches over DataFusion's default to amortise loader dispatch and + /// keep the async pipeline fed. + fn ideal_batch_size(&self) -> Option { + Some(1024) + } + + async fn invoke_async_with_args(&self, args: ScalarFunctionArgs) -> Result { + let input_array = match args.args.into_iter().next() { + Some(ColumnarValue::Array(arr)) => arr, + Some(ColumnarValue::Scalar(_)) => { + return sedona_internal_err!( + "RS_EnsureLoaded does not support scalar inputs; pass a column reference" + ) + } + None => return sedona_internal_err!("RS_EnsureLoaded received zero arguments"), + }; + + let registry = registry_handle_from_config(&args.config_options)?; + let output = ensure_loaded(&input_array, |format| lookup_loader(®istry, format)).await?; + + Ok(ColumnarValue::Array(output)) + } +} + +/// Sequentially resolve OutDb bands in `input` and return a new raster +/// StructArray with `data` populated. +/// +/// Sequential rather than `buffer_unordered` for the first cut: holding +/// borrows from the input across the `loader.load(...).await` point is +/// tricky enough with one outstanding future that we'd rather extract +/// owned metadata, dispatch, and move on. Parallel fan-out is a follow-up +/// optimisation that doesn't change the trait surface or the registry +/// contract. +async fn ensure_loaded(input_array: &ArrayRef, mut lookup: F) -> Result +where + F: FnMut(&str) -> Result>, +{ + let input_struct = input_array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + sedona_internal_datafusion_err!( + "RS_EnsureLoaded: expected StructArray input, got {:?}", + input_array.data_type() + ) + })?; + + let rasters = RasterStructArray::new(input_struct); + let mut builder = RasterBuilder::new(rasters.len()); + + for raster_idx in 0..rasters.len() { + if rasters.is_null(raster_idx) { + builder.append_null().map_err(|e| { + sedona_internal_datafusion_err!("RS_EnsureLoaded: append_null failed: {e}") + })?; + continue; + } + + let raster = rasters.get(raster_idx).map_err(|e| { + sedona_internal_datafusion_err!( + "RS_EnsureLoaded: bad input raster row {raster_idx}: {e}" + ) + })?; + + // Owned per-row metadata so the borrows don't span the per-band + // `await` points further down. + let transform: [f64; 6] = raster.transform().try_into().map_err(|_| { + sedona_internal_datafusion_err!( + "RS_EnsureLoaded: raster row {raster_idx} transform is not 6 elements" + ) + })?; + let spatial_dims_owned: Vec = raster + .spatial_dims() + .iter() + .map(|s| s.to_string()) + .collect(); + let spatial_dims: Vec<&str> = spatial_dims_owned.iter().map(String::as_str).collect(); + let spatial_shape: Vec = raster.spatial_shape().to_vec(); + let crs: Option = raster.crs().map(|s| s.to_string()); + + builder + .start_raster_nd(&transform, &spatial_dims, &spatial_shape, crs.as_deref()) + .map_err(|e| { + sedona_internal_datafusion_err!( + "RS_EnsureLoaded: start_raster_nd failed at row {raster_idx}: {e}" + ) + })?; + + let num_bands = raster.num_bands(); + for band_idx in 0..num_bands { + // Extract everything we need from the band as owned data + // before any `await`, so the future is straightforwardly Send. + let band_name = raster.band_name(band_idx).map(|s| s.to_string()); + let ( + dim_names_owned, + source_shape, + data_type, + nodata, + outdb_uri, + outdb_format, + indb_bytes, + ) = { + let band = raster.band(band_idx).map_err(|e| { + sedona_internal_datafusion_err!( + "RS_EnsureLoaded: bad input band ({raster_idx},{band_idx}): {e}" + ) + })?; + let dim_names_owned: Vec = + band.dim_names().iter().map(|s| s.to_string()).collect(); + let source_shape: Vec = band.raw_source_shape().to_vec(); + let data_type = band.data_type(); + let nodata: Option> = band.nodata().map(|b| b.to_vec()); + let outdb_uri: Option = band.outdb_uri().map(|s| s.to_string()); + let outdb_format: Option = band.outdb_format().map(|s| s.to_string()); + // For InDb bands, copy bytes into an owned Buffer. + // `Buffer::from_vec` is zero-copy ownership transfer of + // the Vec; the per-row clone of `band.data()` itself is + // the one InDb copy we accept for sequential simplicity. + // + // This passthrough copy (and the equivalent re-copy of + // freshly-loaded OutDb bytes through the builder) goes + // away once `RasterBuilder` gains a zero-copy band-data + // path that references an existing `arrow_buffer::Buffer` + // as a `BinaryViewArray` row. Tracked in + // https://github.com/apache/sedona-db/issues/894. + let indb_bytes: Option = if band.is_indb() { + Some(Buffer::from_vec(band.data().to_vec())) + } else { + None + }; + ( + dim_names_owned, + source_shape, + data_type, + nodata, + outdb_uri, + outdb_format, + indb_bytes, + ) + }; + + let dim_names: Vec<&str> = dim_names_owned.iter().map(String::as_str).collect(); + builder + .start_band_nd( + band_name.as_deref(), + &dim_names, + &source_shape, + data_type, + nodata.as_deref(), + outdb_uri.as_deref(), + outdb_format.as_deref(), + ) + .map_err(|e| { + sedona_internal_datafusion_err!( + "RS_EnsureLoaded: start_band_nd failed at ({raster_idx},{band_idx}): {e}" + ) + })?; + + // Resolve the bytes: InDb passes through; OutDb dispatches. + let resolved: Buffer = if let Some(buf) = indb_bytes { + buf + } else { + let format = outdb_format.as_deref().ok_or_else(|| { + sedona_internal_datafusion_err!( + "RS_EnsureLoaded: OutDb band ({raster_idx},{band_idx}) has empty data \ + but no outdb_format set" + ) + })?; + let uri = outdb_uri.as_deref().ok_or_else(|| { + sedona_internal_datafusion_err!( + "RS_EnsureLoaded: OutDb band ({raster_idx},{band_idx}) has empty data \ + but no outdb_uri set" + ) + })?; + let loader = lookup(format)?; + let req = RasterLoadRequest { + uri, + dim_names: &dim_names, + source_shape: &source_shape, + data_type, + }; + loader.load(&req).await.map_err(|e| { + sedona_internal_datafusion_err!( + "RS_EnsureLoaded: loader for format '{format}' failed on \ + band ({raster_idx},{band_idx}): {e}" + ) + })? + }; + + // Validate the resolved length so an under-sized loader output + // surfaces here, not as garbage bytes downstream. + let expected_bytes = source_shape + .iter() + .try_fold(1u64, |acc, &d| acc.checked_mul(d)) + .and_then(|elems| elems.checked_mul(data_type.byte_size() as u64)) + .ok_or_else(|| { + sedona_internal_datafusion_err!( + "RS_EnsureLoaded: band ({raster_idx},{band_idx}) byte count overflows u64" + ) + })?; + let got = resolved.len(); + if got as u64 != expected_bytes { + return sedona_internal_err!( + "RS_EnsureLoaded: band ({raster_idx},{band_idx}) expected {expected_bytes} \ + bytes but loader returned {got}" + ); + } + + // Follow up: ensure loaded bytes are not copied + // https://github.com/apache/sedona-db/issues/894. + builder.band_data_writer().append_value(resolved.as_slice()); + builder.finish_band().map_err(|e| { + sedona_internal_datafusion_err!( + "RS_EnsureLoaded: finish_band failed at ({raster_idx},{band_idx}): {e}" + ) + })?; + } + + builder.finish_raster().map_err(|e| { + sedona_internal_datafusion_err!( + "RS_EnsureLoaded: finish_raster failed at row {raster_idx}: {e}" + ) + })?; + } + + let output_struct = builder.finish().map_err(|e| { + sedona_internal_datafusion_err!("RS_EnsureLoaded: builder.finish failed: {e}") + })?; + Ok(Arc::new(output_struct) as ArrayRef) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::Mutex; + + use arrow_array::Array; + use sedona_raster::array::RasterStructArray; + use sedona_raster::builder::RasterBuilder; + use sedona_raster::raster_loader::RasterLoaderRegistry; + use sedona_raster::traits::RasterRef; + use sedona_schema::raster::BandDataType; + + /// Records load requests and returns a deterministic byte pattern. + #[derive(Debug, Default)] + struct RecordingLoader { + seen: Mutex, BandDataType)>>, + } + + #[async_trait] + impl AsyncByteLoader for RecordingLoader { + async fn load( + &self, + req: &RasterLoadRequest<'_>, + ) -> Result { + self.seen.lock().unwrap().push(( + req.uri.to_string(), + req.source_shape.to_vec(), + req.data_type, + )); + let elements: u64 = req.source_shape.iter().copied().product(); + let len = elements as usize * req.data_type.byte_size(); + // Fill with a recognisable pattern: byte i = (i % 251) as u8. + let bytes: Vec = (0..len).map(|i| (i % 251) as u8).collect(); + Ok(Buffer::from_vec(bytes)) + } + } + + /// Build a 1-row raster with one OutDb band ready for the loader to + /// materialise. + fn build_outdb_input(uri: &str, format: &str, source_shape: &[u64]) -> StructArray { + let mut b = RasterBuilder::new(1); + b.start_raster_nd( + &[0.0, 1.0, 0.0, 0.0, 0.0, -1.0], + &["y", "x"], + &source_shape.iter().map(|&v| v as i64).collect::>(), + None, + ) + .unwrap(); + b.start_band_nd( + Some("band0"), + &["y", "x"], + source_shape, + BandDataType::UInt8, + None, + Some(uri), + Some(format), + ) + .unwrap(); + // OutDb bands write empty data. + b.band_data_writer().append_value([0u8; 0]); + b.finish_band().unwrap(); + b.finish_raster().unwrap(); + b.finish().unwrap() + } + + /// Build a 1-row raster with one InDb band — bytes are inline, + /// `outdb_uri`/`outdb_format` are null. + fn build_indb_input(source_shape: &[u64], data: &[u8]) -> StructArray { + let mut b = RasterBuilder::new(1); + b.start_raster_nd( + &[0.0, 1.0, 0.0, 0.0, 0.0, -1.0], + &["y", "x"], + &source_shape.iter().map(|&v| v as i64).collect::>(), + None, + ) + .unwrap(); + b.start_band_nd( + Some("band0"), + &["y", "x"], + source_shape, + BandDataType::UInt8, + None, + None, + None, + ) + .unwrap(); + b.band_data_writer().append_value(data); + b.finish_band().unwrap(); + b.finish_raster().unwrap(); + b.finish().unwrap() + } + + fn registry_with( + format: &str, + loader: Arc, + ) -> Arc> { + let mut reg = RasterLoaderRegistry::new(); + reg.register(format, loader); + Arc::new(RwLock::new(reg)) + } + + /// Regression guard: `RS_EnsureLoaded`'s declared output field must + /// keep the `"sedona.raster"` extension metadata. If it ever reverts + /// to a bare-`DataType` return path the output column stops being + /// recognised as a Raster, and the analyzer rule (which wraps raster + /// args of needs_bytes UDFs) would both fail to detect already-wrapped + /// args and break downstream raster kernels reading the result. + #[test] + fn return_field_preserves_raster_extension() { + use datafusion_expr::ReturnFieldArgs; + use sedona_schema::datatypes::SedonaType; + + let raster_field = SedonaType::Raster.to_storage_field("rast", true).unwrap(); + let arg_fields = [Arc::new(raster_field)]; + let args = ReturnFieldArgs { + arg_fields: &arg_fields, + scalar_arguments: &[None], + }; + + let out = RsEnsureLoaded::new().return_field_from_args(args).unwrap(); + + // The output must round-trip back to SedonaType::Raster — proving + // the extension type survived, not just the raw Struct DataType. + assert!( + matches!(SedonaType::from_storage_field(&out), Ok(SedonaType::Raster)), + "output field lost its raster extension: {out:?}" + ); + } + + #[test] + fn return_field_rejects_non_raster_arg() { + use arrow_schema::{DataType, Field}; + use datafusion_expr::ReturnFieldArgs; + + let arg_fields = [Arc::new(Field::new("n", DataType::Int32, true))]; + let args = ReturnFieldArgs { + arg_fields: &arg_fields, + scalar_arguments: &[None], + }; + let err = RsEnsureLoaded::new() + .return_field_from_args(args) + .unwrap_err() + .to_string(); + assert!(err.contains("Raster"), "{err}"); + } + + #[tokio::test] + async fn ensure_loaded_populates_outdb_band_data() { + let input_struct = build_outdb_input("file:///tmp/foo.tif", "mock", &[2, 3]); + let input: ArrayRef = Arc::new(input_struct); + + let loader: Arc = Arc::new(RecordingLoader::default()); + let loader_dyn: Arc = loader.clone(); + let reg = registry_with("mock", loader_dyn); + + let out = ensure_loaded(&input, |fmt| { + reg.read() + .unwrap() + .get(fmt) + .ok_or_else(|| datafusion_common::DataFusionError::Plan(format!("no '{fmt}'"))) + }) + .await + .unwrap(); + + let out_struct = out.as_any().downcast_ref::().unwrap(); + let out_rasters = RasterStructArray::new(out_struct); + assert_eq!(out_rasters.len(), 1); + let r = out_rasters.get(0).unwrap(); + let band = r.band(0).unwrap(); + // Loader filled 6 bytes (2 × 3 × UInt8) with the (i % 251) pattern. + assert_eq!(band.data(), &[0, 1, 2, 3, 4, 5]); + // outdb_uri / outdb_format are preserved as provenance. + assert_eq!(band.outdb_uri(), Some("file:///tmp/foo.tif")); + assert_eq!(band.outdb_format(), Some("mock")); + + // Loader saw one request. + let seen = loader.seen.lock().unwrap(); + assert_eq!(seen.len(), 1); + assert_eq!(seen[0].0, "file:///tmp/foo.tif"); + assert_eq!(seen[0].1, vec![2, 3]); + assert_eq!(seen[0].2, BandDataType::UInt8); + } + + #[tokio::test] + async fn ensure_loaded_passes_through_indb_bands_without_calling_loader() { + let pixels: Vec = (10..16).collect(); // 6 bytes + let input_struct = build_indb_input(&[2, 3], &pixels); + let input: ArrayRef = Arc::new(input_struct); + + let loader: Arc = Arc::new(RecordingLoader::default()); + let loader_dyn: Arc = loader.clone(); + let reg = registry_with("mock", loader_dyn); + + let out = ensure_loaded(&input, |fmt| { + reg.read() + .unwrap() + .get(fmt) + .ok_or_else(|| datafusion_common::DataFusionError::Plan(format!("no '{fmt}'"))) + }) + .await + .unwrap(); + + let out_struct = out.as_any().downcast_ref::().unwrap(); + let out_rasters = RasterStructArray::new(out_struct); + let r = out_rasters.get(0).unwrap(); + let band = r.band(0).unwrap(); + assert_eq!(band.data(), &pixels); + + // Loader was never called. + assert!(loader.seen.lock().unwrap().is_empty()); + } + + #[tokio::test] + async fn ensure_loaded_errors_when_format_not_registered() { + let input_struct = build_outdb_input("s3://bucket/foo.zarr", "zarr", &[2, 3]); + let input: ArrayRef = Arc::new(input_struct); + + let reg: Arc> = + Arc::new(RwLock::new(RasterLoaderRegistry::new())); + + let err = ensure_loaded(&input, |fmt| { + reg.read().unwrap().get(fmt).ok_or_else(|| { + datafusion_common::DataFusionError::Plan(format!( + "no raster loader registered for format '{fmt}'" + )) + }) + }) + .await + .unwrap_err(); + let msg = err.to_string(); + assert!( + msg.contains("zarr"), + "expected error to mention missing format 'zarr', got: {msg}" + ); + } + + #[tokio::test] + async fn ensure_loaded_errors_on_undersized_loader_output() { + let input_struct = build_outdb_input("file:///tmp/foo.tif", "mock", &[2, 3]); + let input: ArrayRef = Arc::new(input_struct); + + #[derive(Debug, Default)] + struct ShortLoader; + + #[async_trait] + impl AsyncByteLoader for ShortLoader { + async fn load( + &self, + _req: &RasterLoadRequest<'_>, + ) -> Result { + // Return one too few bytes (5 instead of 6). + Ok(Buffer::from_vec(vec![0u8; 5])) + } + } + + let loader_dyn: Arc = Arc::new(ShortLoader); + let reg = registry_with("mock", loader_dyn); + + let err = ensure_loaded(&input, |fmt| { + reg.read() + .unwrap() + .get(fmt) + .ok_or_else(|| datafusion_common::DataFusionError::Plan(format!("no '{fmt}'"))) + }) + .await + .unwrap_err(); + let msg = err.to_string(); + assert!( + msg.contains("expected") && msg.contains("loader returned"), + "expected diagnostic about expected vs actual loader bytes, got: {msg}" + ); + } + + #[tokio::test] + async fn ensure_loaded_preserves_null_raster_rows() { + // Build a 2-row input: one OutDb band, one null raster row. + let mut b = RasterBuilder::new(2); + b.start_raster_nd(&[0.0, 1.0, 0.0, 0.0, 0.0, -1.0], &["y", "x"], &[2, 3], None) + .unwrap(); + b.start_band_nd( + Some("band0"), + &["y", "x"], + &[2, 3], + BandDataType::UInt8, + None, + Some("file:///tmp/foo.tif"), + Some("mock"), + ) + .unwrap(); + b.band_data_writer().append_value([0u8; 0]); + b.finish_band().unwrap(); + b.finish_raster().unwrap(); + b.append_null().unwrap(); + let input_struct = b.finish().unwrap(); + let input: ArrayRef = Arc::new(input_struct); + + let loader_dyn: Arc = Arc::new(RecordingLoader::default()); + let reg = registry_with("mock", loader_dyn); + + let out = ensure_loaded(&input, |fmt| { + reg.read() + .unwrap() + .get(fmt) + .ok_or_else(|| datafusion_common::DataFusionError::Plan(format!("no '{fmt}'"))) + }) + .await + .unwrap(); + + assert_eq!(out.len(), 2); + assert!(!out.is_null(0)); + assert!(out.is_null(1)); + } +} diff --git a/rust/sedona-raster-gdal/Cargo.toml b/rust/sedona-raster-gdal/Cargo.toml index d0fe9b9f7..e263b0ee3 100644 --- a/rust/sedona-raster-gdal/Cargo.toml +++ b/rust/sedona-raster-gdal/Cargo.toml @@ -34,6 +34,7 @@ result_large_err = "allow" arrow-array = { workspace = true } arrow-buffer = { workspace = true } arrow-schema = { workspace = true } +async-trait = { workspace = true } datafusion-common = { workspace = true } datafusion-expr = { workspace = true } lru = { workspace = true } @@ -43,13 +44,14 @@ sedona-functions = { workspace = true } sedona-gdal = { workspace = true } sedona-raster = { workspace = true } sedona-schema = { workspace = true } +tokio = { workspace = true } [dev-dependencies] criterion = { workspace = true } sedona-gdal = { workspace = true, features = ["gdal-sys"] } sedona-testing = { workspace = true } tempfile = { workspace = true } -tokio = { workspace = true, features = ["rt-multi-thread"] } +tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } [[bench]] harness = false diff --git a/rust/sedona-raster-gdal/src/gdal_dataset_provider.rs b/rust/sedona-raster-gdal/src/gdal_dataset_provider.rs index f8f878677..f635a5abb 100644 --- a/rust/sedona-raster-gdal/src/gdal_dataset_provider.rs +++ b/rust/sedona-raster-gdal/src/gdal_dataset_provider.rs @@ -188,7 +188,7 @@ impl GDALDatasetCache { }) } - fn get_or_create_outdb_source( + pub(crate) fn get_or_create_outdb_source( &self, gdal: &Gdal, path: &str, diff --git a/rust/sedona-raster-gdal/src/lib.rs b/rust/sedona-raster-gdal/src/lib.rs index 360320b56..faf006dc3 100644 --- a/rust/sedona-raster-gdal/src/lib.rs +++ b/rust/sedona-raster-gdal/src/lib.rs @@ -32,16 +32,16 @@ mod gdal_common; #[allow(dead_code)] mod gdal_dataset_provider; +mod raster_loader; mod rs_frompath; -mod utils; - -#[cfg(test)] mod source_uri; +mod utils; // Re-export main dataset conversion functions pub use gdal_common::{ band_data_type_to_gdal, bytes_to_f64, gdal_to_band_data_type, gdal_type_byte_size, nodata_bytes_to_f64, nodata_f64_to_bytes, }; +pub use raster_loader::{GdalLoader, GDAL_FORMAT}; pub use rs_frompath::rs_frompath_udf; pub use utils::{append_as_indb_raster, append_as_outdb_raster, dataset_to_indb_raster}; diff --git a/rust/sedona-raster-gdal/src/raster_loader.rs b/rust/sedona-raster-gdal/src/raster_loader.rs new file mode 100644 index 000000000..71127df04 --- /dev/null +++ b/rust/sedona-raster-gdal/src/raster_loader.rs @@ -0,0 +1,548 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! GDAL backend implementing [`sedona_raster::raster_loader::AsyncByteLoader`]. +//! +//! Reads OutDb raster bands identified by a `#band=N` URI fragment via +//! GDAL's blocking API. The blocking work runs inside +//! `tokio::task::spawn_blocking` so the caller's async runtime is not +//! stalled. Dataset opens are cached per-thread via the existing +//! `GDALDatasetCache` thread-local, so repeated queries against the +//! same file pay one open per worker thread. +//! +//! ## Cancellation +//! +//! Reads run as a loop of block-height-aligned strips +//! (`band.block_size().1` rows per iteration), with a cooperative +//! cancellation check between strips. When the outer async future is +//! dropped (e.g. a query is cancelled), a [`CancelOnDrop`] guard flips +//! a shared [`AtomicBool`]; the next iteration of the loop observes +//! the flag and returns a cancellation error rather than running to +//! completion. +//! +//! Cancellation granularity is the source's natural block height: +//! +//! * Strip GeoTIFF: typically 1–64 rows per check (fine-grained). +//! * Tile GeoTIFF (COG): typically 256 rows per check (fine-grained). +//! * PNG/JPEG and similar whole-image-block formats: the first read +//! forces full decompression in one call; subsequent in-call rows +//! hit GDAL's block cache. Effectively whole-image cancellation +//! granularity for these formats — the byte-cap below is the +//! primary safety net for them. +//! +//! ## Byte cap +//! +//! Requests are pre-validated against [`MAX_OUTDB_LOAD_BYTES`] (4 GiB) +//! before the blocking task is spawned. This catches runaway requests +//! (corrupt metadata, accidentally-huge bands) at the boundary so they +//! can't tie up a blocking-pool thread. +//! +//! Registered against the per-session +//! [`RasterLoaderRegistry`](sedona_raster::raster_loader::RasterLoaderRegistry) +//! under the format key `"gdal"`. The `sedona` crate constructs a +//! [`GdalLoader`] from `SedonaContext::new_from_context` and registers +//! it during session bootstrap. + +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; + +use arrow_buffer::Buffer; +use arrow_schema::ArrowError; +use async_trait::async_trait; +use datafusion_common::{DataFusionError, Result as DFResult}; +use sedona_gdal::raster::rasterband::RasterBand; +use sedona_raster::raster_loader::{AsyncByteLoader, RasterLoadRequest}; + +use crate::gdal_common::{convert_gdal_err, gdal_to_band_data_type, with_gdal}; +use crate::gdal_dataset_provider::thread_local_cache; +use crate::source_uri::parse_outdb_source; + +/// Format key the loader is registered under. Keep in sync with +/// `SedonaContext::new_from_context` and any band-builder code emitting +/// `outdb_format` values. +pub const GDAL_FORMAT: &str = "gdal"; + +/// Maximum bytes a single OutDb load request will produce. +/// +/// Requests with `Π source_shape × data_type.byte_size()` greater than +/// this value are rejected before spawning the blocking read task. +/// 4 GiB is intentionally conservative: typical satellite imagery bands +/// (Landsat, Sentinel-2, MODIS) are under 1 GiB; anything larger usually +/// indicates corrupt metadata or an accidentally-huge band claim. If we +/// ever want a tunable here, [`SedonaOptions`] is the natural home for +/// the override. +pub const MAX_OUTDB_LOAD_BYTES: u64 = 4 * 1024 * 1024 * 1024; + +/// GDAL-backed `AsyncByteLoader`. +/// +/// Stateless: the per-thread dataset cache lives in a thread-local owned +/// by `sedona-raster-gdal::gdal_dataset_provider`, so constructing a +/// `GdalLoader` is free and instances are interchangeable. +#[derive(Debug, Default, Clone, Copy)] +pub struct GdalLoader; + +impl GdalLoader { + pub fn new() -> Self { + Self + } +} + +/// Drop guard that flips an `AtomicBool` when the outer async future +/// is dropped. Paired with a `spawn_blocking` task that polls the same +/// flag between unit-of-work iterations: dropping the outer future +/// signals the blocking task to exit at the next checkpoint. +struct CancelOnDrop(Arc); + +impl Drop for CancelOnDrop { + fn drop(&mut self) { + self.0.store(true, Ordering::Release); + } +} + +#[async_trait] +impl AsyncByteLoader for GdalLoader { + async fn load(&self, req: &RasterLoadRequest<'_>) -> Result { + // Validate request shape synchronously, before spawning a blocking + // task — these are programming errors, no point queueing them + // onto a worker. + if req.source_shape.len() != 2 { + return Err(ArrowError::NotYetImplemented(format!( + "GDAL raster loader only supports 2-D bands; got source_shape with {} dims", + req.source_shape.len() + ))); + } + if req.dim_names != ["y", "x"] { + return Err(ArrowError::InvalidArgumentError(format!( + "GDAL raster loader requires dim_names=[\"y\", \"x\"]; got {:?}", + req.dim_names + ))); + } + let height = usize::try_from(req.source_shape[0]).map_err(|_| { + ArrowError::InvalidArgumentError(format!( + "GDAL OutDb source_shape[0]={} exceeds usize::MAX", + req.source_shape[0] + )) + })?; + let width = usize::try_from(req.source_shape[1]).map_err(|_| { + ArrowError::InvalidArgumentError(format!( + "GDAL OutDb source_shape[1]={} exceeds usize::MAX", + req.source_shape[1] + )) + })?; + let byte_size = req.data_type.byte_size(); + + // Byte-cap validation: compute Π source_shape × byte_size in u64 + // with checked arithmetic so a hostile request can't wrap to a + // small accept-value. Reject before allocating. + let expected_bytes_u64 = (req.source_shape[0]) + .checked_mul(req.source_shape[1]) + .and_then(|elems| elems.checked_mul(byte_size as u64)) + .ok_or_else(|| { + ArrowError::InvalidArgumentError(format!( + "GDAL OutDb request byte count overflows u64 for source_shape {:?} × byte_size {}", + req.source_shape, byte_size + )) + })?; + if expected_bytes_u64 > MAX_OUTDB_LOAD_BYTES { + return Err(ArrowError::InvalidArgumentError(format!( + "GDAL OutDb request exceeds MAX_OUTDB_LOAD_BYTES ({} > {}); \ + increase the cap or split the band into smaller reads", + expected_bytes_u64, MAX_OUTDB_LOAD_BYTES + ))); + } + let expected_bytes = expected_bytes_u64 as usize; + + // Take owned copies for the spawn_blocking closure (the closure + // must be `'static`). + let uri = req.uri.to_string(); + let expected_dtype = req.data_type; + + // Cancellation plumbing: the guard lives in this async fn's + // frame. On normal completion `_guard` drops after the await + // returns, flipping the flag on an already-finished blocking + // task (no-op). On cancellation (outer future dropped + // mid-await), the guard drops first and flips the flag; the + // blocking task observes it at the next strip boundary and + // returns a cancellation error. + let cancel: Arc = Arc::new(AtomicBool::new(false)); + let _guard = CancelOnDrop(Arc::clone(&cancel)); + + let buffer = tokio::task::spawn_blocking({ + let cancel = Arc::clone(&cancel); + move || -> Result { + with_gdal(|gdal| { + // `#band=N` fragment, with N defaulting to 1 if absent. + let (path, band_num) = parse_outdb_source(&uri)?; + let cache = thread_local_cache()?; + let dataset = cache.get_or_create_outdb_source(gdal, &path, None)?; + let band = dataset + .rasterband(band_num as usize) + .map_err(convert_gdal_err)?; + + // Verify the file's pixel type matches the band metadata's + // claim BEFORE reading. The bytes-out path doesn't convert; + // a mismatch would produce a 2x-or-N/2 byte count and the + // size check in `RS_EnsureLoaded` would mis-blame the + // loader for size rather than naming the dtype mismatch. + // Catch it cleanly here. + let file_dtype = gdal_to_band_data_type(band.band_type())?; + if file_dtype != expected_dtype { + return sedona_common::sedona_internal_err!( + "GDAL OutDb band metadata claims {:?} but file {} band {} is {:?}", + expected_dtype, + uri, + band_num, + file_dtype + ); + } + + // Pre-allocate the output buffer once; each strip read + // writes into a contiguous slice. + let mut output = vec![0u8; expected_bytes]; + read_band_blockwise(&band, &mut output, width, height, byte_size, &cancel)?; + Ok(Buffer::from_vec(output)) + }) + .map_err(|e| ArrowError::ExternalError(Box::new(e))) + } + }) + .await + .map_err(|e| { + ArrowError::ExternalError(Box::new(sedona_common::sedona_internal_datafusion_err!( + "GDAL raster loader task panicked or was cancelled: {e}" + ))) + })??; + + Ok(buffer) + } +} + +/// Read a band's full extent into `output` in row-major order, looping +/// over block-height-aligned horizontal strips. +/// +/// The cancellation flag is checked between strips. Each iteration +/// reads at most `block_h` rows via [`RasterBand::read_into_bytes`] +/// directly into the appropriate slice of `output`. For strip-layout +/// files, each iteration covers exactly one strip; for tile-layout +/// files, each iteration covers one row of tiles. GDAL's internal +/// block cache amortises decompression cost so the per-iteration +/// overhead is small. +/// +/// `output` must have length `width * height * byte_size`; assumed by +/// the caller. +fn read_band_blockwise( + band: &RasterBand<'_>, + output: &mut [u8], + width: usize, + height: usize, + byte_size: usize, + cancel: &AtomicBool, +) -> DFResult<()> { + let row_bytes = width.saturating_mul(byte_size); + // `block_size().1` is the band's natural strip / tile height. Edge + // bands sometimes report `0` for degenerate inputs; clamp to >=1 + // so the loop always makes progress. + let (_block_w, block_h) = band.block_size(); + let block_h = block_h.max(1); + + let mut y_start: usize = 0; + while y_start < height { + if cancel.load(Ordering::Acquire) { + return Err(cancelled_err(y_start, height)); + } + let chunk_h = (height - y_start).min(block_h); + let byte_off = y_start.saturating_mul(row_bytes); + let byte_end = byte_off.saturating_add(chunk_h.saturating_mul(row_bytes)); + // Sanity: should always hold given the caller's pre-allocated + // output slice; defensive in case of arithmetic surprises. + if byte_end > output.len() { + return sedona_common::sedona_internal_err!( + "GDAL OutDb read range [{}..{}) exceeds output buffer length {}", + byte_off, + byte_end, + output.len() + ); + } + band.read_into_bytes( + (0, y_start as isize), + (width, chunk_h), + (width, chunk_h), + &mut output[byte_off..byte_end], + None, + ) + .map_err(convert_gdal_err)?; + y_start += chunk_h; + } + Ok(()) +} + +fn cancelled_err(y_start: usize, height: usize) -> DataFusionError { + sedona_common::sedona_internal_datafusion_err!( + "GDAL OutDb load cancelled at row {y_start} of {height}" + ) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::gdal_common::with_gdal; + use sedona_gdal::raster::types::Buffer as GdalBuffer; + use sedona_schema::raster::BandDataType; + use tempfile::TempDir; + + /// Write a 2-row × 3-col UInt8 GeoTIFF and return its path. Pixels + /// `0..6` in row-major C-order. + fn write_uint8_geotiff(dir: &TempDir, name: &str) -> String { + let path = dir.path().join(name); + let path_str = path.to_string_lossy().to_string(); + with_gdal(|gdal| { + let driver = gdal.get_driver_by_name("GTiff").unwrap(); + let dataset = driver + .create_with_band_type::(&path_str, 3, 2, 1) + .unwrap(); + dataset + .set_geo_transform(&[0.0, 1.0, 0.0, 2.0, 0.0, -1.0]) + .unwrap(); + let band = dataset.rasterband(1).unwrap(); + let mut buffer = GdalBuffer::new((3, 2), (0..6u8).collect::>()); + band.write((0, 0), (3, 2), &mut buffer).unwrap(); + Ok(()) + }) + .unwrap(); + path_str + } + + /// Write a `width × height` UInt8 GeoTIFF where pixel `(x, y)` = `(y * width + x) as u8`. + /// Used to verify block-aligned strip reads assemble identical bytes + /// to a single bulk read. + fn write_pattern_geotiff(dir: &TempDir, name: &str, width: usize, height: usize) -> String { + let path = dir.path().join(name); + let path_str = path.to_string_lossy().to_string(); + with_gdal(|gdal| { + let driver = gdal.get_driver_by_name("GTiff").unwrap(); + let dataset = driver + .create_with_band_type::(&path_str, width, height, 1) + .unwrap(); + dataset + .set_geo_transform(&[0.0, 1.0, 0.0, height as f64, 0.0, -1.0]) + .unwrap(); + let band = dataset.rasterband(1).unwrap(); + let pixels: Vec = (0..width * height).map(|i| (i % 251) as u8).collect(); + let mut buffer = GdalBuffer::new((width, height), pixels); + band.write((0, 0), (width, height), &mut buffer).unwrap(); + Ok(()) + }) + .unwrap(); + path_str + } + + #[tokio::test] + async fn gdal_loader_reads_2d_uint8_geotiff() { + let tmp = TempDir::new().unwrap(); + let path = write_uint8_geotiff(&tmp, "fixture.tif"); + let uri = format!("{path}#band=1"); + + let loader = GdalLoader::new(); + let req = RasterLoadRequest { + uri: &uri, + dim_names: &["y", "x"], + source_shape: &[2, 3], + data_type: BandDataType::UInt8, + }; + + let buf = loader.load(&req).await.unwrap(); + assert_eq!(buf.len(), 6); + assert_eq!(buf.as_slice(), &[0u8, 1, 2, 3, 4, 5]); + } + + #[tokio::test] + async fn gdal_loader_defaults_to_band_1_when_fragment_missing() { + let tmp = TempDir::new().unwrap(); + let path = write_uint8_geotiff(&tmp, "no_fragment.tif"); + let uri = path; + + let loader = GdalLoader::new(); + let req = RasterLoadRequest { + uri: &uri, + dim_names: &["y", "x"], + source_shape: &[2, 3], + data_type: BandDataType::UInt8, + }; + let buf = loader.load(&req).await.unwrap(); + assert_eq!(buf.len(), 6); + } + + #[tokio::test] + async fn gdal_loader_rejects_non_2d_source_shape() { + let loader = GdalLoader::new(); + let req = RasterLoadRequest { + uri: "ignored", + dim_names: &["t", "y", "x"], + source_shape: &[2, 3, 4], + data_type: BandDataType::UInt8, + }; + let err = loader.load(&req).await.unwrap_err(); + assert!( + err.to_string().contains("2-D"), + "expected 2-D rejection diagnostic, got: {err}" + ); + } + + #[tokio::test] + async fn gdal_loader_rejects_non_yx_dim_names() { + let loader = GdalLoader::new(); + let req = RasterLoadRequest { + uri: "ignored", + dim_names: &["x", "y"], // transposed + source_shape: &[2, 3], + data_type: BandDataType::UInt8, + }; + let err = loader.load(&req).await.unwrap_err(); + assert!( + err.to_string().contains("dim_names"), + "expected dim_names rejection diagnostic, got: {err}" + ); + } + + #[tokio::test] + async fn gdal_loader_errors_when_dtype_disagrees_with_file() { + let tmp = TempDir::new().unwrap(); + let path = write_uint8_geotiff(&tmp, "dtype_mismatch.tif"); + let uri = format!("{path}#band=1"); + + let loader = GdalLoader::new(); + let req = RasterLoadRequest { + uri: &uri, + dim_names: &["y", "x"], + source_shape: &[2, 3], + // File is UInt8 but we claim Int16 — should fail with a + // clear dtype-mismatch message, not garbled bytes. + data_type: BandDataType::Int16, + }; + let err = loader.load(&req).await.unwrap_err(); + let msg = err.to_string(); + assert!( + msg.contains("metadata claims") && (msg.contains("UInt8") || msg.contains("Int16")), + "expected dtype-mismatch diagnostic, got: {msg}" + ); + } + + #[tokio::test] + async fn gdal_loader_errors_on_missing_file() { + let loader = GdalLoader::new(); + let req = RasterLoadRequest { + uri: "/nonexistent/path/to/file.tif#band=1", + dim_names: &["y", "x"], + source_shape: &[2, 3], + data_type: BandDataType::UInt8, + }; + let err = loader.load(&req).await.unwrap_err(); + // GDAL's "no such file" error message wraps through our convert. + assert!(err.to_string().to_lowercase().contains("nonexistent")); + } + + #[tokio::test] + async fn gdal_loader_errors_on_band_index_out_of_range() { + let tmp = TempDir::new().unwrap(); + let path = write_uint8_geotiff(&tmp, "oob_band.tif"); + // File has 1 band; ask for band 5. + let uri = format!("{path}#band=5"); + + let loader = GdalLoader::new(); + let req = RasterLoadRequest { + uri: &uri, + dim_names: &["y", "x"], + source_shape: &[2, 3], + data_type: BandDataType::UInt8, + }; + let err = loader.load(&req).await.unwrap_err(); + let msg = err.to_string(); + // GDAL surfaces this as a band-index error; just verify the + // dispatch went through and the error was propagated, not the + // exact GDAL phrasing. + assert!( + !msg.contains("dim_names") && !msg.contains("2-D"), + "expected a GDAL-layer error, not request-validation; got: {msg}" + ); + } + + #[tokio::test] + async fn gdal_loader_rejects_request_over_byte_cap() { + let loader = GdalLoader::new(); + // 2^31 elements × 4 bytes = 8 GiB, well over the 4 GiB cap. + // (Source shape values are u64, so this fits the request + // struct; only the cap should reject it.) + let req = RasterLoadRequest { + uri: "ignored", + dim_names: &["y", "x"], + source_shape: &[1 << 16, 1 << 16], + data_type: BandDataType::Float32, + }; + let err = loader.load(&req).await.unwrap_err(); + let msg = err.to_string(); + assert!( + msg.contains("MAX_OUTDB_LOAD_BYTES"), + "expected byte-cap diagnostic, got: {msg}" + ); + } + + #[tokio::test] + async fn gdal_loader_multi_strip_read_matches_single_call() { + // 64-row image: at default GTiff strip layout, this produces + // multiple strips, exercising the block-iter loop. + let tmp = TempDir::new().unwrap(); + let path = write_pattern_geotiff(&tmp, "multistrip.tif", 16, 64); + let uri = format!("{path}#band=1"); + + let loader = GdalLoader::new(); + let req = RasterLoadRequest { + uri: &uri, + dim_names: &["y", "x"], + source_shape: &[64, 16], + data_type: BandDataType::UInt8, + }; + let buf = loader.load(&req).await.unwrap(); + let expected: Vec = (0..16 * 64).map(|i| (i % 251) as u8).collect(); + assert_eq!(buf.as_slice(), expected.as_slice()); + } + + /// Pre-arm the cancellation flag, then drive `read_band_blockwise` + /// directly against a real band. The loop should bail before + /// reading anything. + #[test] + fn read_band_blockwise_honours_pre_cancelled_flag() { + let tmp = TempDir::new().unwrap(); + let path = write_pattern_geotiff(&tmp, "cancel.tif", 16, 64); + with_gdal(|gdal| { + let cache = thread_local_cache()?; + let dataset = cache.get_or_create_outdb_source(gdal, &path, None)?; + let band = dataset.rasterband(1).map_err(convert_gdal_err)?; + let cancel = AtomicBool::new(true); + let mut out = vec![0u8; 16 * 64]; + let err = read_band_blockwise(&band, &mut out, 16, 64, 1, &cancel) + .expect_err("pre-armed cancel flag should short-circuit the loop"); + let msg = err.to_string(); + assert!( + msg.contains("cancelled"), + "expected a cancellation diagnostic, got: {msg}" + ); + // Output buffer was never written into. + assert!(out.iter().all(|&b| b == 0)); + Ok(()) + }) + .unwrap(); + } +} diff --git a/rust/sedona-raster-zarr/Cargo.toml b/rust/sedona-raster-zarr/Cargo.toml index 93e946812..9023f833d 100644 --- a/rust/sedona-raster-zarr/Cargo.toml +++ b/rust/sedona-raster-zarr/Cargo.toml @@ -32,7 +32,9 @@ result_large_err = "allow" [dependencies] arrow-array = { workspace = true } +arrow-buffer = { workspace = true } arrow-schema = { workspace = true } +async-trait = { workspace = true } datafusion-common = { workspace = true } futures = { workspace = true } log = { workspace = true } @@ -41,6 +43,7 @@ sedona-raster = { workspace = true } sedona-schema = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } +tokio = { workspace = true } zarrs = { workspace = true, features = ["filesystem", "gzip", "zstd", "blosc", "crc32c", "sharding", "transpose"] } zarrs_filesystem = { workspace = true } diff --git a/rust/sedona-raster-zarr/src/lib.rs b/rust/sedona-raster-zarr/src/lib.rs index 1d35d29af..d6e653503 100644 --- a/rust/sedona-raster-zarr/src/lib.rs +++ b/rust/sedona-raster-zarr/src/lib.rs @@ -30,6 +30,8 @@ mod dtype; mod geozarr; mod loader; +mod raster_loader; mod source_uri; pub use loader::ZarrChunkReader; +pub use raster_loader::{ZarrLoader, ZARR_FORMAT}; diff --git a/rust/sedona-raster-zarr/src/loader.rs b/rust/sedona-raster-zarr/src/loader.rs index 631565df2..7a13b4c25 100644 --- a/rust/sedona-raster-zarr/src/loader.rs +++ b/rust/sedona-raster-zarr/src/loader.rs @@ -32,9 +32,7 @@ use sedona_common::sedona_internal_datafusion_err; use sedona_raster::builder::RasterBuilder; use sedona_schema::datatypes::SedonaType; use sedona_schema::raster::BandDataType; -use zarrs::array::Array; -#[cfg(test)] -use zarrs::array::ArrayBytes; +use zarrs::array::{Array, ArrayBytes}; use zarrs::group::Group; use zarrs_filesystem::FilesystemStore; @@ -52,7 +50,7 @@ use crate::source_uri::{build_chunk_anchor, group_uri_to_filesystem_path}; /// formatting. /// /// Rows always emit OutDb-style: `data` is empty, `outdb_uri` carries -/// a chunk anchor that the async OutDb resolver (registered separately) +/// a chunk anchor that the async raster byte loader (registered separately) /// resolves to bytes on demand. pub struct ZarrChunkReader { schema: SchemaRef, @@ -171,7 +169,7 @@ impl ZarrChunkReader { // Every band gets its chunk-anchor URI populated as // provenance metadata. `data.is_empty()` is the InDb/OutDb // discriminator; this reader always emits empty `data` and - // defers pixel-byte resolution to the OutDb resolver. + // defers pixel-byte resolution to the raster byte loader. let anchor = build_chunk_anchor(&self.group_uri, &info.path, &self.chunk_indices); builder.start_band_nd( Some(info.path.as_str()), @@ -653,14 +651,9 @@ fn advance_chunk_indices(chunk_indices: &mut [u64], chunk_grid_shape: &[u64]) -> /// types — those don't have a `BandDataType` counterpart anyway, so the /// dtype check in `collect_array_infos` rejects them upstream. /// -/// This is the only pixel-byte read primitive in the crate. The loader -/// itself never calls it today — it always emits OutDb anchors — but -/// the async `RS_EnsureLoaded` resolver (follow-up PR) will. Lives -/// behind `#[cfg(test)]` until the resolver lands; the unit test below -/// exercises it so the implementation doesn't bit-rot in the -/// meantime. -#[cfg(test)] -fn retrieve_chunk_bytes( +/// The only pixel-byte read primitive in the crate. Consumed by the +/// async raster byte loader (`raster_loader::ZarrLoader`). +pub(crate) fn retrieve_chunk_bytes( array: &Array, chunk_indices: &[u64], ) -> Result, ArrowError> { diff --git a/rust/sedona-raster-zarr/src/raster_loader.rs b/rust/sedona-raster-zarr/src/raster_loader.rs new file mode 100644 index 000000000..537e9cec9 --- /dev/null +++ b/rust/sedona-raster-zarr/src/raster_loader.rs @@ -0,0 +1,282 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Zarr backend implementing [`sedona_raster::raster_loader::AsyncByteLoader`]. +//! +//! Resolves a band's OutDb URI back into a Zarr chunk read: the URI is +//! a chunk anchor of the form +//! `#array=&chunk=,,...` (see +//! [`crate::source_uri::build_chunk_anchor`]). The loader parses the +//! anchor, opens the Zarr store and array, and retrieves the named +//! chunk's bytes via `zarrs` — all wrapped in +//! `tokio::task::spawn_blocking` so the caller's async runtime is not +//! stalled by Zarr's blocking decoder. +//! +//! Registered against the per-session +//! [`RasterLoaderRegistry`](sedona_raster::raster_loader::RasterLoaderRegistry) +//! under the format key [`ZARR_FORMAT`]. As an out-of-tree plugin, +//! `sedona-raster-zarr` does not depend on `sedona` — callers wire the +//! registration themselves from their `SedonaContext` setup: +//! +//! ```ignore +//! ctx.register_raster_loader( +//! sedona_raster_zarr::ZARR_FORMAT, +//! std::sync::Arc::new(sedona_raster_zarr::ZarrLoader::new()), +//! ); +//! ``` + +use std::sync::Arc; + +use arrow_buffer::Buffer; +use arrow_schema::ArrowError; +use async_trait::async_trait; +use sedona_common::sedona_internal_datafusion_err; +use sedona_raster::raster_loader::{AsyncByteLoader, RasterLoadRequest}; +use zarrs::group::Group; +use zarrs_filesystem::FilesystemStore; + +use crate::dtype::zarr_to_band_data_type; +use crate::loader::retrieve_chunk_bytes; +use crate::source_uri::{group_uri_to_filesystem_path, parse_chunk_anchor}; + +/// Format key the loader registers under. Keep in sync with +/// `outdb_format` values emitted by the Zarr reader's band builder +/// (see `crate::loader`). +pub const ZARR_FORMAT: &str = "zarr"; + +/// Async raster byte loader for Zarr-backed bands. +/// +/// Stateless: dataset opens use a fresh `FilesystemStore` per call. +/// Caching the open store per `(store_uri, array_path)` is a follow-up +/// optimisation that doesn't change the trait surface. +#[derive(Debug, Default, Clone, Copy)] +pub struct ZarrLoader; + +impl ZarrLoader { + pub fn new() -> Self { + Self + } +} + +#[async_trait] +impl AsyncByteLoader for ZarrLoader { + async fn load(&self, req: &RasterLoadRequest<'_>) -> Result { + // Take owned copies for the spawn_blocking closure. + let uri = req.uri.to_string(); + let expected_dtype = req.data_type; + + let buffer = tokio::task::spawn_blocking(move || -> Result { + let anchor = parse_chunk_anchor(&uri)?; + let fs_path = group_uri_to_filesystem_path(&anchor.store_uri)?; + let store = FilesystemStore::new(&fs_path).map_err(|e| { + ArrowError::ExternalError(Box::new(sedona_internal_datafusion_err!( + "failed to open Zarr store at {}: {e}", + fs_path.display() + ))) + })?; + let storage: Arc = Arc::new(store); + + // The group itself isn't strictly needed to open an array, + // but resolving the array path through it gives a clear + // diagnostic if the group root or the array path is wrong. + let group = Group::open(storage.clone(), "/").map_err(|e| { + ArrowError::ExternalError(Box::new(sedona_internal_datafusion_err!( + "failed to open Zarr group at {}: {e}", + fs_path.display() + ))) + })?; + let _ = group; // group handle dropped; array open uses storage directly. + + let array_path = if anchor.array_path.starts_with('/') { + anchor.array_path.clone() + } else { + format!("/{}", anchor.array_path) + }; + let array = zarrs::array::Array::open(storage, &array_path).map_err(|e| { + ArrowError::ExternalError(Box::new(sedona_internal_datafusion_err!( + "failed to open Zarr array {}: {e}", + array_path + ))) + })?; + + // Verify the Zarr array's dtype matches the band metadata's + // claim before reading. Mismatches catch silent byte-count + // surprises here rather than letting RS_EnsureLoaded's + // expected-byte-count check mis-blame the loader for size. + let file_dtype = zarr_to_band_data_type(array.data_type())?; + if file_dtype != expected_dtype { + return Err(ArrowError::ExternalError(Box::new( + sedona_internal_datafusion_err!( + "Zarr OutDb band metadata claims {:?} but array {} is {:?}", + expected_dtype, + array_path, + file_dtype + ), + ))); + } + + let bytes = retrieve_chunk_bytes(&array, &anchor.chunk_indices)?; + Ok(Buffer::from_vec(bytes)) + }) + .await + .map_err(|e| { + ArrowError::ExternalError(Box::new(sedona_internal_datafusion_err!( + "Zarr raster loader task panicked or was cancelled: {e}" + ))) + })??; + + Ok(buffer) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use sedona_schema::raster::BandDataType; + use tempfile::TempDir; + use zarrs::array::ArrayBuilder; + use zarrs::array::{data_type as zarr_dtype, FillValue}; + use zarrs::group::GroupBuilder; + + use crate::source_uri::build_chunk_anchor; + + /// Build a Zarr group at `/store.zarr` containing one array + /// `temperature` of UInt8 with shape [2, 3] and chunk shape [2, 3] + /// (one chunk). Returns the store URI and array path. + fn build_uint8_zarr(dir: &TempDir) -> (String, &'static str, Vec) { + let store_path = dir.path().join("store.zarr"); + let store = Arc::new(FilesystemStore::new(&store_path).unwrap()); + + // Root group metadata — Zarr v3 stores need this for + // `Group::open(store, "/")` to succeed. + GroupBuilder::new() + .build(store.clone(), "/") + .unwrap() + .store_metadata() + .unwrap(); + + let array = ArrayBuilder::new( + vec![2, 3], + vec![2, 3], + zarr_dtype::uint8(), + FillValue::from(0u8), + ) + .build(store.clone(), "/temperature") + .unwrap(); + array.store_metadata().unwrap(); + + let pixels: Vec = vec![10, 11, 12, 13, 14, 15]; + array.store_chunk(&[0, 0], pixels.clone()).unwrap(); + + let store_uri = format!("file://{}", store_path.display()); + (store_uri, "temperature", pixels) + } + + #[tokio::test] + async fn zarr_loader_reads_uint8_chunk() { + let tmp = TempDir::new().unwrap(); + let (store_uri, array_path, expected_pixels) = build_uint8_zarr(&tmp); + let uri = build_chunk_anchor(&store_uri, array_path, &[0, 0]); + + let loader = ZarrLoader::new(); + let req = RasterLoadRequest { + uri: &uri, + dim_names: &["y", "x"], + source_shape: &[2, 3], + data_type: BandDataType::UInt8, + }; + let buf = loader.load(&req).await.unwrap(); + assert_eq!(buf.as_slice(), expected_pixels.as_slice()); + } + + #[tokio::test] + async fn zarr_loader_errors_when_dtype_disagrees_with_array() { + let tmp = TempDir::new().unwrap(); + let (store_uri, array_path, _) = build_uint8_zarr(&tmp); + let uri = build_chunk_anchor(&store_uri, array_path, &[0, 0]); + + let loader = ZarrLoader::new(); + let req = RasterLoadRequest { + uri: &uri, + dim_names: &["y", "x"], + source_shape: &[2, 3], + // Array is UInt8 but the band claims Int16. + data_type: BandDataType::Int16, + }; + let err = loader.load(&req).await.unwrap_err(); + let msg = err.to_string(); + assert!( + msg.contains("metadata claims") && (msg.contains("UInt8") || msg.contains("Int16")), + "expected dtype-mismatch diagnostic, got: {msg}" + ); + } + + #[tokio::test] + async fn zarr_loader_errors_on_malformed_chunk_anchor_uri() { + let loader = ZarrLoader::new(); + let req = RasterLoadRequest { + uri: "file:///tmp/foo.zarr", // missing fragment + dim_names: &["y", "x"], + source_shape: &[2, 3], + data_type: BandDataType::UInt8, + }; + let err = loader.load(&req).await.unwrap_err(); + assert!( + err.to_string().contains("missing"), + "expected missing-fragment diagnostic, got: {err}" + ); + } + + #[tokio::test] + async fn zarr_loader_errors_on_missing_array_path() { + let tmp = TempDir::new().unwrap(); + let (store_uri, _, _) = build_uint8_zarr(&tmp); + // Anchor a chunk against a non-existent array. + let uri = build_chunk_anchor(&store_uri, "nonexistent", &[0, 0]); + + let loader = ZarrLoader::new(); + let req = RasterLoadRequest { + uri: &uri, + dim_names: &["y", "x"], + source_shape: &[2, 3], + data_type: BandDataType::UInt8, + }; + let err = loader.load(&req).await.unwrap_err(); + assert!( + err.to_string().contains("nonexistent") + || err.to_string().to_lowercase().contains("array"), + "expected diagnostic to name the missing array path, got: {err}" + ); + } + + #[tokio::test] + async fn zarr_loader_errors_on_cloud_scheme_until_supported() { + let loader = ZarrLoader::new(); + let uri = build_chunk_anchor("s3://bucket/foo.zarr", "temperature", &[0, 0]); + let req = RasterLoadRequest { + uri: &uri, + dim_names: &["y", "x"], + source_shape: &[2, 3], + data_type: BandDataType::UInt8, + }; + let err = loader.load(&req).await.unwrap_err(); + assert!( + err.to_string().contains("cloud") || err.to_string().contains("s3://"), + "expected cloud-scheme rejection, got: {err}" + ); + } +} diff --git a/rust/sedona-raster-zarr/src/source_uri.rs b/rust/sedona-raster-zarr/src/source_uri.rs index 6bd984854..e60c18a71 100644 --- a/rust/sedona-raster-zarr/src/source_uri.rs +++ b/rust/sedona-raster-zarr/src/source_uri.rs @@ -40,12 +40,8 @@ use arrow_schema::ArrowError; -/// Parts of a chunk-anchor URI. -/// -/// `#[cfg(test)]`: no production consumer yet. The async byte -/// resolver (separate follow-up) will parse `outdb_uri` values back -/// into this struct. -#[cfg(test)] +/// Parts of a chunk-anchor URI. The async raster byte loader parses +/// `outdb_uri` values back into this struct. #[derive(Debug, Clone, PartialEq, Eq)] pub struct ChunkAnchor { /// Original store URI for the *group* (e.g. `file:///tmp/foo.zarr`, @@ -80,10 +76,6 @@ pub fn build_chunk_anchor(store_uri: &str, array_path: &str, chunk_indices: &[u6 /// /// Strict: rejects URIs that don't carry both `array=` and `chunk=` /// fragment parameters with valid values. -/// -/// `#[cfg(test)]`: pairs with [`ChunkAnchor`]; resurrected when the -/// async byte resolver lands. -#[cfg(test)] pub fn parse_chunk_anchor(uri: &str) -> Result { let (store_uri, fragment) = uri.split_once('#').ok_or_else(|| { ArrowError::InvalidArgumentError(format!( diff --git a/rust/sedona-raster/Cargo.toml b/rust/sedona-raster/Cargo.toml index 37e7ecfbf..f6b692429 100644 --- a/rust/sedona-raster/Cargo.toml +++ b/rust/sedona-raster/Cargo.toml @@ -34,6 +34,7 @@ result_large_err = "allow" arrow-schema = { workspace = true } arrow-array = { workspace = true } arrow-buffer = { workspace = true } +async-trait = { workspace = true } datafusion-common = { workspace = true } sedona-common = { workspace = true } sedona-schema = { workspace = true } @@ -42,3 +43,4 @@ sedona-schema = { workspace = true } sedona-testing = { workspace = true } approx = { workspace = true } arrow-ipc = { workspace = true } +tokio = { workspace = true, features = ["macros", "rt"] } diff --git a/rust/sedona-raster/src/array.rs b/rust/sedona-raster/src/array.rs index bc7c2317b..57d1dbf16 100644 --- a/rust/sedona-raster/src/array.rs +++ b/rust/sedona-raster/src/array.rs @@ -247,6 +247,12 @@ impl<'a> RasterRef for RasterRefImpl<'a> { // path, which is not yet implemented. Surface it loudly here rather // than silently rejecting the band, so callers see the standardised // SedonaDB-internal-error framing. + // + // This rejection is also the guardrail keeping `RS_EnsureLoaded` + // correct: it drops `view()` on rebuild, so it would corrupt a + // viewed band. When this comes off (view composition), the loader + // request/response must round-trip the view — tracked in + // . if !self.band_view_list.is_null(band_row) { return Err(ArrowError::ExternalError(Box::new( sedona_common::sedona_internal_datafusion_err!( diff --git a/rust/sedona-raster/src/lib.rs b/rust/sedona-raster/src/lib.rs index 77db0c0dd..d5398804d 100644 --- a/rust/sedona-raster/src/lib.rs +++ b/rust/sedona-raster/src/lib.rs @@ -19,4 +19,5 @@ pub mod affine_transformation; pub mod array; pub mod builder; pub mod display; +pub mod raster_loader; pub mod traits; diff --git a/rust/sedona-raster/src/raster_loader.rs b/rust/sedona-raster/src/raster_loader.rs new file mode 100644 index 000000000..3242c2bb9 --- /dev/null +++ b/rust/sedona-raster/src/raster_loader.rs @@ -0,0 +1,411 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Async byte-loading for schema-OutDb raster bands. +//! +//! `sedona-raster` deliberately knows nothing about GDAL, Zarr, or any +//! other backend. Backends implement [`AsyncByteLoader`] and register +//! themselves with a format key against an [`RasterLoaderRegistry`]. The +//! `RS_EnsureLoaded` UDF in the `sedona` crate consumes the registry to +//! materialise OutDb bands at query time; band accessors +//! (`BandRef::nd_buffer()` / `contiguous_data()`) do **not** invoke the +//! loader transparently — they return whatever is in the `data` column +//! verbatim, surfacing a clear error when the column is empty. + +use std::collections::HashMap; +use std::sync::{Arc, RwLock}; + +use arrow_buffer::Buffer; +use arrow_schema::ArrowError; +use datafusion_common::config::{ + ConfigEntry, ConfigExtension, ConfigField, ExtensionOptions, Visit, +}; +use datafusion_common::{config_err, Result as DFResult}; +use sedona_schema::raster::BandDataType; + +/// Everything a backend needs to materialise a single OutDb band's bytes. +/// +/// Constructed by `RS_EnsureLoaded` once per row from the band's schema +/// metadata. The lifetime is the request's, not the loader's — borrowed +/// fields point into the input `RecordBatch` and stay valid for the +/// duration of the [`AsyncByteLoader::load`] future. +/// +/// This request carries no band *view*: it always asks for the full +/// `source_shape`, and `load` returns a bare `Buffer`. Once non-identity +/// views become constructible, the request must carry the desired view +/// and the response must report the realized `(buffer, source_shape, +/// view)` so a loader can range-read a sub-window and the caller can tell +/// what it got. Tracked in +/// . +#[derive(Debug, Clone, Copy)] +pub struct RasterLoadRequest<'a> { + /// Anchor URI from the band's `outdb_uri` column. Bare paths and + /// scheme'd URIs both allowed; backend is responsible for parsing. + pub uri: &'a str, + /// Per-axis names parallel to `source_shape`. Backends use this to + /// map their native axis order onto the band's. + pub dim_names: &'a [&'a str], + /// Raw source shape in `dim_names` order. The loader returns a + /// `Buffer` whose length equals `Π source_shape × data_type.byte_size()` + /// bytes, encoding pixels in C-order over `dim_names`. + pub source_shape: &'a [u64], + /// Pixel type the band claims. The loader returns bytes encoding this + /// type and errors if the source disagrees (e.g. file's dtype differs). + pub data_type: BandDataType, +} + +/// Backend trait. Implementers live in format-specific crates +/// (`sedona-raster-gdal`, `sedona-raster-zarr`, …) and are registered +/// against an [`RasterLoaderRegistry`] under a format key matching the +/// band's `outdb_format` column. +/// +/// Synchronous backends (e.g. GDAL) wrap their I/O in +/// `tokio::task::spawn_blocking` inside the impl — the trait itself stays +/// async-only so the dispatcher (`RS_EnsureLoaded`) can `buffer_unordered` +/// over many in-flight loads. The result type is +/// [`arrow_buffer::Buffer`] (not `Vec`) so backends that already +/// produce reference-counted bytes (e.g. `object_store` returning +/// `bytes::Bytes`) hand them off zero-copy, and so the dispatcher can +/// build the output `BinaryViewArray` directly from collected Buffers +/// without an extra copy through a `BinaryViewBuilder` block buffer. +#[async_trait::async_trait] +pub trait AsyncByteLoader: Send + Sync + std::fmt::Debug { + /// Fetch the band's bytes. The returned `Buffer` must contain exactly + /// `Π source_shape × data_type.byte_size()` bytes in C-order over + /// `dim_names`. Errors propagate to the caller of `RS_EnsureLoaded`. + async fn load(&self, req: &RasterLoadRequest<'_>) -> Result; +} + +/// Process-side registry mapping `outdb_format` keys to loader instances. +/// +/// One registry instance per `SedonaContext`. The owning context wraps it +/// in `Arc>` so extension crates (`sedona-raster-zarr`, future COG / +/// Icechunk / …) can register their loaders post-construction via a +/// public `SedonaContext::register_raster_loader` API. +#[derive(Debug, Default)] +pub struct RasterLoaderRegistry { + loaders: HashMap>, +} + +impl RasterLoaderRegistry { + /// Construct an empty registry. Compiled-in backends (`sedona-raster-gdal` + /// under the `gdal` feature) register themselves from `SedonaContext::new`; + /// extension backends register via `SedonaContext::register_raster_loader`. + pub fn new() -> Self { + Self::default() + } + + /// Register a loader under a format key. Later registrations for the + /// same key overwrite — registries are mutable for the lifetime of + /// the session and there's no value in locking down after first + /// registration (a process with runtime-registered backends may + /// legitimately swap implementations during setup). + pub fn register(&mut self, format: impl Into, loader: Arc) { + self.loaders.insert(format.into(), loader); + } + + /// Look up a loader by format key. Returns `None` for keys with no + /// registered backend. `RS_EnsureLoaded` surfaces the `None` case as a + /// query-time error that names the missing format and points users at + /// the install/register step. + pub fn get(&self, format: &str) -> Option> { + self.loaders.get(format).cloned() + } + + /// Iterate registered format keys. Useful for diagnostics ("no loader + /// for 'zarr'; registered formats are: gdal"). + pub fn formats(&self) -> impl Iterator { + self.loaders.keys().map(String::as_str) + } + + /// True if any loader is registered. + pub fn is_empty(&self) -> bool { + self.loaders.is_empty() + } +} + +/// `ConfigField`-shaped wrapper around the shared registry handle. +/// +/// Mirrors `sedona_common::option::CrsProviderOption` so the registry +/// can live inside a `ConfigOptions` extension and stay reachable from +/// `AsyncScalarUDFImpl::invoke_async_with_args` (which only receives +/// `Arc`). The inner `Arc>` is cloned +/// between the `SedonaContext` (mutable register API) and the config +/// extension (read at UDF dispatch time); both observe the same +/// underlying lock. +#[derive(Debug, Clone)] +pub struct RasterLoaderRegistryOption(Arc>); + +impl RasterLoaderRegistryOption { + /// Wrap an existing shared registry handle. + pub fn new(inner: Arc>) -> Self { + Self(inner) + } + + /// Clone the inner Arc for callers that need their own owning handle + /// (e.g. `SedonaContext::register_raster_loader` needs to write + /// through the same lock that the config extension exposes for + /// reads). + pub fn handle(&self) -> Arc> { + Arc::clone(&self.0) + } +} + +impl Default for RasterLoaderRegistryOption { + fn default() -> Self { + Self(Arc::new(RwLock::new(RasterLoaderRegistry::new()))) + } +} + +impl PartialEq for RasterLoaderRegistryOption { + fn eq(&self, other: &Self) -> bool { + Arc::ptr_eq(&self.0, &other.0) + } +} + +impl ConfigField for RasterLoaderRegistryOption { + fn visit(&self, v: &mut V, key: &str, description: &'static str) { + let snapshot = match self.0.read() { + Ok(g) => g.formats().map(String::from).collect::>(), + Err(p) => p + .into_inner() + .formats() + .map(String::from) + .collect::>(), + }; + v.some( + key, + format!("RasterLoaderRegistry {{ formats: {snapshot:?} }}"), + description, + ); + } + + fn set(&mut self, key: &str, _value: &str) -> DFResult<()> { + config_err!("Can't set {key} from SQL") + } +} + +/// `ConfigExtension` that stashes the per-session +/// [`RasterLoaderRegistry`] inside a `ConfigOptions`. Registered into +/// the session's `ConfigOptions` at `SedonaContext::new_from_context` +/// time; consumed by the `RS_EnsureLoaded` async UDF at dispatch time +/// via `args.config_options.extensions.get::()`. +/// +/// The PREFIX namespace is `sedona.raster_loader` — kept separate from +/// `sedona`'s main `SedonaOptions` extension because this lives in +/// `sedona-raster` (which is upstream of `sedona-common` in the +/// dependency graph) and adding a field to `SedonaOptions` would +/// require an undesirable circular dep. +#[derive(Debug, Default, Clone, PartialEq)] +pub struct RasterLoaderConfig { + pub registry: RasterLoaderRegistryOption, +} + +impl RasterLoaderConfig { + /// Build a config extension that closes over an existing shared + /// registry handle. Use this rather than `default()` when wiring + /// from `SedonaContext::new_from_context` so the context's mutable + /// `register_raster_loader` API writes to the same `RwLock` the + /// config extension exposes for reads. + pub fn from_handle(registry: Arc>) -> Self { + Self { + registry: RasterLoaderRegistryOption::new(registry), + } + } +} + +impl ConfigExtension for RasterLoaderConfig { + const PREFIX: &'static str = "sedona.raster_loader"; +} + +impl ConfigField for RasterLoaderConfig { + fn visit(&self, v: &mut V, key_prefix: &str, _description: &'static str) { + let key = if key_prefix.is_empty() { + "registry".to_string() + } else { + format!("{key_prefix}.registry") + }; + self.registry + .visit(v, &key, "Registered raster byte loaders"); + } + + fn set(&mut self, key: &str, _value: &str) -> DFResult<()> { + config_err!("Can't set {key} from SQL") + } +} + +impl ExtensionOptions for RasterLoaderConfig { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn as_any_mut(&mut self) -> &mut dyn std::any::Any { + self + } + + fn cloned(&self) -> Box { + Box::new(self.clone()) + } + + fn set(&mut self, key: &str, value: &str) -> DFResult<()> { + ::set(self, key, value) + } + + fn entries(&self) -> Vec { + struct EntryCollector(Vec); + impl Visit for EntryCollector { + fn some( + &mut self, + key: &str, + value: V, + description: &'static str, + ) { + self.0.push(ConfigEntry { + key: key.to_string(), + value: Some(value.to_string()), + description, + }); + } + fn none(&mut self, key: &str, description: &'static str) { + self.0.push(ConfigEntry { + key: key.to_string(), + value: None, + description, + }); + } + } + let mut collector = EntryCollector(vec![]); + self.visit(&mut collector, Self::PREFIX, ""); + collector.0 + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::Mutex; + + /// Minimal in-test loader: records the request and returns a buffer + /// of `Π source_shape × byte_size` zeros. + #[derive(Debug, Default)] + struct MockLoader { + seen: Mutex)>>, + } + + #[async_trait::async_trait] + impl AsyncByteLoader for MockLoader { + async fn load(&self, req: &RasterLoadRequest<'_>) -> Result { + self.seen + .lock() + .unwrap() + .push((req.uri.to_string(), req.source_shape.to_vec())); + let elements: u64 = req.source_shape.iter().copied().product(); + let len = elements as usize * req.data_type.byte_size(); + Ok(Buffer::from_vec(vec![0u8; len])) + } + } + + #[test] + fn registry_starts_empty_and_reports_no_formats() { + let r = RasterLoaderRegistry::new(); + assert!(r.is_empty()); + assert!(r.get("gdal").is_none()); + assert_eq!(r.formats().count(), 0); + } + + #[test] + fn registry_get_returns_registered_loader() { + let mut r = RasterLoaderRegistry::new(); + r.register("mock", Arc::new(MockLoader::default())); + assert!(!r.is_empty()); + assert!(r.get("mock").is_some()); + assert!(r.get("gdal").is_none()); + } + + #[test] + fn registry_register_overwrites_existing_key() { + let mut r = RasterLoaderRegistry::new(); + let first = Arc::new(MockLoader::default()); + let second = Arc::new(MockLoader::default()); + r.register("mock", first.clone()); + r.register("mock", second.clone()); + // Two distinct loaders pushed under the same key; the second wins. + let resolved = r.get("mock").unwrap(); + assert!(Arc::ptr_eq( + &(resolved as Arc), + &(second as Arc) + )); + } + + #[test] + fn registry_formats_lists_registered_keys() { + let mut r = RasterLoaderRegistry::new(); + r.register("gdal", Arc::new(MockLoader::default())); + r.register("zarr", Arc::new(MockLoader::default())); + let mut formats: Vec<&str> = r.formats().collect(); + formats.sort(); + assert_eq!(formats, vec!["gdal", "zarr"]); + } + + #[tokio::test] + async fn loader_load_returns_buffer_of_expected_size() { + let loader = MockLoader::default(); + let req = RasterLoadRequest { + uri: "file:///tmp/foo.tif", + dim_names: &["y", "x"], + source_shape: &[3, 4], + data_type: BandDataType::UInt8, + }; + let buf = loader.load(&req).await.unwrap(); + assert_eq!(buf.len(), 12); // 3 × 4 × 1 byte + let seen = loader.seen.lock().unwrap(); + assert_eq!(seen.len(), 1); + assert_eq!(seen[0].0, "file:///tmp/foo.tif"); + assert_eq!(seen[0].1, vec![3, 4]); + } + + #[tokio::test] + async fn loader_load_through_registry_dispatches_to_correct_backend() { + let mut r = RasterLoaderRegistry::new(); + let gdal = Arc::new(MockLoader::default()); + let zarr = Arc::new(MockLoader::default()); + r.register("gdal", gdal.clone()); + r.register("zarr", zarr.clone()); + + let req = RasterLoadRequest { + uri: "s3://bucket/cube.zarr", + dim_names: &["t", "y", "x"], + source_shape: &[2, 3, 4], + data_type: BandDataType::Float32, + }; + let loader = r.get("zarr").unwrap(); + let buf = loader.load(&req).await.unwrap(); + assert_eq!(buf.len(), 2 * 3 * 4 * 4); // Float32 = 4 bytes + + // Dispatched to zarr, not gdal. + assert_eq!(zarr.seen.lock().unwrap().len(), 1); + assert_eq!(gdal.seen.lock().unwrap().len(), 0); + } + + #[test] + fn registry_get_missing_format_returns_none_for_diagnostic_message() { + let r = RasterLoaderRegistry::new(); + // Caller (RS_EnsureLoaded) sees None and can build a diagnostic + // listing the registered formats. + assert!(r.get("nonexistent").is_none()); + } +} diff --git a/rust/sedona/Cargo.toml b/rust/sedona/Cargo.toml index ef5fb9e85..42b1a654d 100644 --- a/rust/sedona/Cargo.toml +++ b/rust/sedona/Cargo.toml @@ -54,6 +54,7 @@ rstest = { workspace = true } [dependencies] arrow-schema = { workspace = true } arrow-array = { workspace = true } +arrow-buffer = { workspace = true } async-trait = { workspace = true } aws-config = { version = "1.5.17", optional = true } aws-credential-types = { version = "1.2.0", optional = true } @@ -62,6 +63,7 @@ datafusion = { workspace = true, default_features = false, features = ["sql", "p datafusion-common = { workspace = true } datafusion-expr = { workspace = true } datafusion-ffi = { workspace = true } +datafusion-optimizer = { workspace = true } dirs = { workspace = true } futures = { workspace = true } geo-traits = { workspace = true } @@ -80,6 +82,7 @@ sedona-geos = { workspace = true, optional = true } sedona-pointcloud = { workspace = true, optional = true } sedona-proj = { workspace = true } sedona-gdal = { workspace = true } +sedona-raster = { workspace = true } sedona-raster-functions = { workspace = true } sedona-raster-gdal = { workspace = true } sedona-schema = { workspace = true } diff --git a/rust/sedona/src/context.rs b/rust/sedona/src/context.rs index 5cf4d653a..aa8d73a18 100644 --- a/rust/sedona/src/context.rs +++ b/rust/sedona/src/context.rs @@ -16,7 +16,7 @@ // under the License. use std::{ collections::{HashMap, VecDeque}, - sync::Arc, + sync::{Arc, RwLock}, }; use crate::exec::create_plan_from_sql; @@ -64,12 +64,15 @@ use sedona_pointcloud::las::{ format::{Extension, LasFormatFactory}, options::{GeometryEncoding, LasExtraBytes, LasOptions}, }; +use sedona_raster_functions::rs_ensure_loaded::RsEnsureLoaded; #[cfg(feature = "gpu")] use sedona_spatial_join_gpu::options::GpuOptions; use sedona_query_planner::{ - optimizer::register_spatial_join_logical_optimizer, query_planner::SedonaQueryPlanner, + optimizer::{register_ensure_loaded_optimizer, register_spatial_join_logical_optimizer}, + query_planner::SedonaQueryPlanner, }; +use sedona_raster::raster_loader::{AsyncByteLoader, RasterLoaderConfig, RasterLoaderRegistry}; /// Sedona SessionContext wrapper /// @@ -80,6 +83,12 @@ use sedona_query_planner::{ pub struct SedonaContext { pub ctx: SessionContext, pub functions: FunctionSet, + /// Per-session registry of async raster byte loaders, keyed by + /// `outdb_format`. Held behind an `Arc>` so the registered + /// `RS_EnsureLoaded` UDF instance and any extension crates' `register(&ctx)` + /// entry points observe the same map. See + /// [`SedonaContext::register_raster_loader`]. + raster_loader_registry: Arc>, } impl SedonaContext { @@ -186,6 +195,7 @@ impl SedonaContext { } state_builder = register_spatial_join_logical_optimizer(state_builder)?; + state_builder = register_ensure_loaded_optimizer(state_builder)?; state_builder = state_builder.with_query_planner(Arc::new(planner)); let mut state = state_builder.build(); @@ -225,8 +235,56 @@ impl SedonaContext { let mut out = Self { ctx, functions: FunctionSet::new(), + raster_loader_registry: Arc::new(RwLock::new(RasterLoaderRegistry::new())), }; + // Stash a clone of the shared registry handle inside + // `ConfigOptions` via the `RasterLoaderConfig` extension. The + // RS_EnsureLoaded async UDF reads from there at dispatch time — + // `AsyncScalarUDFImpl::invoke_async_with_args` only receives + // `Arc`, so this is the path that keeps the + // registry reachable at the UDF's invocation site. Mirrors how + // `CrsProviderOption` works inside `SedonaOptions`. + // + // Writes through `SedonaContext::register_raster_loader` (which + // mutates the Arc held in `out.raster_loader_registry`) are immediately + // visible to UDF reads through this config extension because + // both handles share the same `RwLock`. + out.ctx + .state_ref() + .write() + .config_mut() + .options_mut() + .extensions + .insert(RasterLoaderConfig::from_handle(Arc::clone( + &out.raster_loader_registry, + ))); + + // Register the RS_EnsureLoaded async UDF. It pulls the registry + // out of `args.config_options` at dispatch time, so it doesn't + // need to close over the Arc itself — the UDF instance is + // session-agnostic. The logical optimizer rule registered above + // (`register_ensure_loaded_optimizer`) resolves this UDF from the + // function registry at rewrite time, so it must be registered + // before any query is planned. + { + use datafusion_expr::async_udf::AsyncScalarUDF; + let udf = AsyncScalarUDF::new(Arc::new(RsEnsureLoaded::new())); + out.ctx.register_udf(udf.into_scalar_udf()); + }; + + // Register the GDAL raster byte loader. `sedona-raster-gdal` is a + // mandatory dep on `sedona`, but libgdal itself is dlopen'd + // lazily by `sedona-gdal` (workspace-default-features = false), + // so this registration is safe on systems without libgdal — + // the loader's `load()` call will surface a clean "libgdal not + // found" error when first invoked, but registration and import + // succeed regardless. + out.register_raster_loader( + sedona_raster_gdal::GDAL_FORMAT, + Arc::new(sedona_raster_gdal::GdalLoader::new()), + ); + // Register table functions out.ctx.register_udtf( "sd_random_geometry", @@ -289,6 +347,41 @@ impl SedonaContext { Ok(()) } + /// Register an async raster byte loader under a `format` key. + /// + /// `format` matches the band-level `outdb_format` column value + /// (`"gdal"`, `"zarr"`, …). At query time the `RS_EnsureLoaded` UDF + /// looks up the loader for each OutDb band's format and dispatches + /// the byte fetch through it. + /// + /// Used by both compiled-in backends (`sedona-raster-gdal::register` + /// called from `new_from_context` under `#[cfg(feature = "gdal")]`) + /// and out-of-tree extensions (`sedona-raster-zarr::register(&ctx)` + /// from user code after construction). Later registrations under the + /// same `format` key overwrite earlier ones. + pub fn register_raster_loader( + &self, + format: impl Into, + loader: Arc, + ) { + // Lock poisoning here would mean a previous registrant panicked + // mid-write — recover-by-ignoring matches how DataFusion handles + // session-state writes elsewhere. + if let Ok(mut guard) = self.raster_loader_registry.write() { + guard.register(format, loader); + } + } + + /// Returns a snapshot list of currently-registered OutDb format keys. + /// Useful for diagnostics (e.g., listing registered backends after + /// session setup). + pub fn registered_raster_formats(&self) -> Vec { + self.raster_loader_registry + .read() + .map(|g| g.formats().map(String::from).collect()) + .unwrap_or_default() + } + /// Register all functions in a [FunctionSet] with this context pub fn register_function_set(&mut self, function_set: FunctionSet) { for udf in function_set.scalar_udfs() { @@ -651,6 +744,52 @@ mod tests { use super::*; + #[tokio::test] + async fn outdb_registry_has_gdal_at_bootstrap_and_accepts_runtime_registration() { + use arrow_buffer::Buffer; + use sedona_raster::raster_loader::{AsyncByteLoader, RasterLoadRequest}; + + let ctx = SedonaContext::new(); + // GDAL is always registered at bootstrap (compiled-in backend). + // Extension backends like Zarr add themselves later via + // `register_raster_loader`. + let initial_formats = ctx.registered_raster_formats(); + assert!( + initial_formats.contains(&"gdal".to_string()), + "fresh SedonaContext should register the GDAL backend at bootstrap; got {initial_formats:?}" + ); + + // Mock runtime registration on top of the bootstrap state. + #[derive(Debug)] + struct MockLoader; + #[async_trait] + impl AsyncByteLoader for MockLoader { + async fn load( + &self, + _req: &RasterLoadRequest<'_>, + ) -> std::result::Result { + Ok(Buffer::from_vec(Vec::::new())) + } + } + ctx.register_raster_loader("mock", Arc::new(MockLoader)); + let mut after = ctx.registered_raster_formats(); + after.sort(); + assert!(after.contains(&"gdal".to_string())); + assert!(after.contains(&"mock".to_string())); + + // RS_EnsureLoaded is registered as a UDF at session bootstrap. + let udf = ctx + .ctx + .state() + .scalar_functions() + .get("rs_ensureloaded") + .cloned(); + assert!( + udf.is_some(), + "RS_EnsureLoaded should be registered by SedonaContext::new()" + ); + } + #[tokio::test] async fn basic_sql() -> Result<()> { let ctx = SedonaContext::new();