Skip to content

Commit

Permalink
expose table name in proto extension codec (#11139)
Browse files Browse the repository at this point in the history
  • Loading branch information
leoyvens authored Jun 28, 2024
1 parent 838e0f7 commit 3bd7200
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 6 deletions.
12 changes: 11 additions & 1 deletion datafusion/proto/src/logical_plan/file_formats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use datafusion::{
},
prelude::SessionContext,
};
use datafusion_common::not_impl_err;
use datafusion_common::{not_impl_err, TableReference};

use super::LogicalExtensionCodec;

Expand Down Expand Up @@ -53,6 +53,7 @@ impl LogicalExtensionCodec for CsvLogicalExtensionCodec {
fn try_decode_table_provider(
&self,
_buf: &[u8],
_table_ref: &TableReference,
_schema: arrow::datatypes::SchemaRef,
_ctx: &datafusion::prelude::SessionContext,
) -> datafusion_common::Result<
Expand All @@ -63,6 +64,7 @@ impl LogicalExtensionCodec for CsvLogicalExtensionCodec {

fn try_encode_table_provider(
&self,
_table_ref: &TableReference,
_node: std::sync::Arc<dyn datafusion::datasource::TableProvider>,
_buf: &mut Vec<u8>,
) -> datafusion_common::Result<()> {
Expand Down Expand Up @@ -127,6 +129,7 @@ impl LogicalExtensionCodec for JsonLogicalExtensionCodec {
fn try_decode_table_provider(
&self,
_buf: &[u8],
_table_ref: &TableReference,
_schema: arrow::datatypes::SchemaRef,
_ctx: &datafusion::prelude::SessionContext,
) -> datafusion_common::Result<
Expand All @@ -137,6 +140,7 @@ impl LogicalExtensionCodec for JsonLogicalExtensionCodec {

fn try_encode_table_provider(
&self,
_table_ref: &TableReference,
_node: std::sync::Arc<dyn datafusion::datasource::TableProvider>,
_buf: &mut Vec<u8>,
) -> datafusion_common::Result<()> {
Expand Down Expand Up @@ -201,6 +205,7 @@ impl LogicalExtensionCodec for ParquetLogicalExtensionCodec {
fn try_decode_table_provider(
&self,
_buf: &[u8],
_table_ref: &TableReference,
_schema: arrow::datatypes::SchemaRef,
_ctx: &datafusion::prelude::SessionContext,
) -> datafusion_common::Result<
Expand All @@ -211,6 +216,7 @@ impl LogicalExtensionCodec for ParquetLogicalExtensionCodec {

fn try_encode_table_provider(
&self,
_table_ref: &TableReference,
_node: std::sync::Arc<dyn datafusion::datasource::TableProvider>,
_buf: &mut Vec<u8>,
) -> datafusion_common::Result<()> {
Expand Down Expand Up @@ -275,6 +281,7 @@ impl LogicalExtensionCodec for ArrowLogicalExtensionCodec {
fn try_decode_table_provider(
&self,
_buf: &[u8],
_table_ref: &TableReference,
_schema: arrow::datatypes::SchemaRef,
_ctx: &datafusion::prelude::SessionContext,
) -> datafusion_common::Result<
Expand All @@ -285,6 +292,7 @@ impl LogicalExtensionCodec for ArrowLogicalExtensionCodec {

fn try_encode_table_provider(
&self,
_table_ref: &TableReference,
_node: std::sync::Arc<dyn datafusion::datasource::TableProvider>,
_buf: &mut Vec<u8>,
) -> datafusion_common::Result<()> {
Expand Down Expand Up @@ -349,6 +357,7 @@ impl LogicalExtensionCodec for AvroLogicalExtensionCodec {
fn try_decode_table_provider(
&self,
_buf: &[u8],
_table_ref: &TableReference,
_schema: arrow::datatypes::SchemaRef,
_cts: &datafusion::prelude::SessionContext,
) -> datafusion_common::Result<
Expand All @@ -359,6 +368,7 @@ impl LogicalExtensionCodec for AvroLogicalExtensionCodec {

fn try_encode_table_provider(
&self,
_table_ref: &TableReference,
_node: std::sync::Arc<dyn datafusion::datasource::TableProvider>,
_buf: &mut Vec<u8>,
) -> datafusion_common::Result<()> {
Expand Down
14 changes: 10 additions & 4 deletions datafusion/proto/src/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,14 @@ pub trait LogicalExtensionCodec: Debug + Send + Sync {
fn try_decode_table_provider(
&self,
buf: &[u8],
table_ref: &TableReference,
schema: SchemaRef,
ctx: &SessionContext,
) -> Result<Arc<dyn TableProvider>>;

fn try_encode_table_provider(
&self,
table_ref: &TableReference,
node: Arc<dyn TableProvider>,
buf: &mut Vec<u8>,
) -> Result<()>;
Expand Down Expand Up @@ -164,6 +166,7 @@ impl LogicalExtensionCodec for DefaultLogicalExtensionCodec {
fn try_decode_table_provider(
&self,
_buf: &[u8],
_table_ref: &TableReference,
_schema: SchemaRef,
_ctx: &SessionContext,
) -> Result<Arc<dyn TableProvider>> {
Expand All @@ -172,6 +175,7 @@ impl LogicalExtensionCodec for DefaultLogicalExtensionCodec {

fn try_encode_table_provider(
&self,
_table_ref: &TableReference,
_node: Arc<dyn TableProvider>,
_buf: &mut Vec<u8>,
) -> Result<()> {
Expand Down Expand Up @@ -445,15 +449,17 @@ impl AsLogicalPlan for LogicalPlanNode {
.iter()
.map(|expr| from_proto::parse_expr(expr, ctx, extension_codec))
.collect::<Result<Vec<_>, _>>()?;

let table_name =
from_table_reference(scan.table_name.as_ref(), "CustomScan")?;

let provider = extension_codec.try_decode_table_provider(
&scan.custom_table_data,
&table_name,
schema,
ctx,
)?;

let table_name =
from_table_reference(scan.table_name.as_ref(), "CustomScan")?;

LogicalPlanBuilder::scan_with_filters(
table_name,
provider_as_source(provider),
Expand Down Expand Up @@ -1048,7 +1054,7 @@ impl AsLogicalPlan for LogicalPlanNode {
} else {
let mut bytes = vec![];
extension_codec
.try_encode_table_provider(provider, &mut bytes)
.try_encode_table_provider(table_name, provider, &mut bytes)
.map_err(|e| context!("Error serializing custom table", e))?;
let scan = CustomScan(CustomTableScanNode {
table_name: Some(table_name.clone().into()),
Expand Down
13 changes: 12 additions & 1 deletion datafusion/proto/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ use datafusion_common::config::TableOptions;
use datafusion_common::scalar::ScalarStructBuilder;
use datafusion_common::{
internal_datafusion_err, internal_err, not_impl_err, plan_err, DFSchema, DFSchemaRef,
DataFusionError, Result, ScalarValue,
DataFusionError, Result, ScalarValue, TableReference,
};
use datafusion_expr::dml::CopyTo;
use datafusion_expr::expr::{
Expand Down Expand Up @@ -134,6 +134,9 @@ pub struct TestTableProto {
/// URL of the table root
#[prost(string, tag = "1")]
pub url: String,
/// Qualified table name
#[prost(string, tag = "2")]
pub table_name: String,
}

#[derive(Debug)]
Expand All @@ -156,12 +159,14 @@ impl LogicalExtensionCodec for TestTableProviderCodec {
fn try_decode_table_provider(
&self,
buf: &[u8],
table_ref: &TableReference,
schema: SchemaRef,
_ctx: &SessionContext,
) -> Result<Arc<dyn TableProvider>> {
let msg = TestTableProto::decode(buf).map_err(|_| {
DataFusionError::Internal("Error decoding test table".to_string())
})?;
assert_eq!(msg.table_name, table_ref.to_string());
let provider = TestTableProvider {
url: msg.url,
schema,
Expand All @@ -171,6 +176,7 @@ impl LogicalExtensionCodec for TestTableProviderCodec {

fn try_encode_table_provider(
&self,
table_ref: &TableReference,
node: Arc<dyn TableProvider>,
buf: &mut Vec<u8>,
) -> Result<()> {
Expand All @@ -181,6 +187,7 @@ impl LogicalExtensionCodec for TestTableProviderCodec {
.expect("Can't encode non-test tables");
let msg = TestTableProto {
url: table.url.clone(),
table_name: table_ref.to_string(),
};
msg.encode(buf).map_err(|_| {
DataFusionError::Internal("Error encoding test table".to_string())
Expand Down Expand Up @@ -866,6 +873,7 @@ impl LogicalExtensionCodec for TopKExtensionCodec {
fn try_decode_table_provider(
&self,
_buf: &[u8],
_table_ref: &TableReference,
_schema: SchemaRef,
_ctx: &SessionContext,
) -> Result<Arc<dyn TableProvider>> {
Expand All @@ -874,6 +882,7 @@ impl LogicalExtensionCodec for TopKExtensionCodec {

fn try_encode_table_provider(
&self,
_table_ref: &TableReference,
_node: Arc<dyn TableProvider>,
_buf: &mut Vec<u8>,
) -> Result<()> {
Expand Down Expand Up @@ -943,6 +952,7 @@ impl LogicalExtensionCodec for ScalarUDFExtensionCodec {
fn try_decode_table_provider(
&self,
_buf: &[u8],
_table_ref: &TableReference,
_schema: SchemaRef,
_ctx: &SessionContext,
) -> Result<Arc<dyn TableProvider>> {
Expand All @@ -951,6 +961,7 @@ impl LogicalExtensionCodec for ScalarUDFExtensionCodec {

fn try_encode_table_provider(
&self,
_table_ref: &TableReference,
_node: Arc<dyn TableProvider>,
_buf: &mut Vec<u8>,
) -> Result<()> {
Expand Down

0 comments on commit 3bd7200

Please sign in to comment.