Skip to content

Commit

Permalink
[FEAT] sql image_decode (#2757)
Browse files Browse the repository at this point in the history
- adds sql `image_decode` function. 

- moves `image_decode` out of `daft-dsl` and into `daft-functions`. 

- some small changes to `SQLFunction` so that it can work with
`ScalarUDF`
  • Loading branch information
universalmind303 authored Sep 4, 2024
1 parent 60ebf82 commit 734c13f
Show file tree
Hide file tree
Showing 15 changed files with 364 additions and 227 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1146,7 +1146,6 @@ class PyExpr:
def utf8_to_date(self, format: str) -> PyExpr: ...
def utf8_to_datetime(self, format: str, timezone: str | None = None) -> PyExpr: ...
def utf8_normalize(self, remove_punct: bool, lowercase: bool, nfd_unicode: bool, white_space: bool) -> PyExpr: ...
def image_decode(self, raise_error_on_failure: bool, mode: ImageMode | None = None) -> PyExpr: ...
def image_encode(self, image_format: ImageFormat) -> PyExpr: ...
def image_resize(self, w: int, h: int) -> PyExpr: ...
def image_crop(self, bbox: PyExpr) -> PyExpr: ...
Expand Down Expand Up @@ -1246,6 +1245,7 @@ def utf8_count_matches(expr: PyExpr, patterns: PyExpr, whole_words: bool, case_s
def list_sort(expr: PyExpr, desc: PyExpr) -> PyExpr: ...
def cbrt(expr: PyExpr) -> PyExpr: ...
def to_struct(inputs: list[PyExpr]) -> PyExpr: ...
def image_decode(expr: PyExpr, raise_on_error: bool, mode: ImageMode | None = None) -> PyExpr: ...

class PyCatalog:
@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3141,7 +3141,7 @@ def decode(
mode = ImageMode.from_mode_string(mode.upper())
if not isinstance(mode, ImageMode):
raise ValueError(f"mode must be a string or ImageMode variant, but got: {mode}")
return Expression._from_pyexpr(self._expr.image_decode(raise_error_on_failure=raise_on_error, mode=mode))
return Expression._from_pyexpr(native.image_decode(self._expr, raise_on_error=raise_on_error, mode=mode))

def encode(self, image_format: str | ImageFormat) -> Expression:
"""
Expand Down
64 changes: 0 additions & 64 deletions src/daft-dsl/src/functions/image/decode.rs

This file was deleted.

31 changes: 3 additions & 28 deletions src/daft-dsl/src/functions/image/mod.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
mod crop;
mod decode;
mod encode;
mod resize;
mod to_mode;

use crop::CropEvaluator;
use decode::DecodeEvaluator;
use encode::EncodeEvaluator;
use resize::ResizeEvaluator;
use serde::{Deserialize, Serialize};
Expand All @@ -19,21 +17,10 @@ use super::FunctionEvaluator;

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub enum ImageExpr {
Decode {
raise_error_on_failure: bool,
mode: Option<ImageMode>,
},
Encode {
image_format: ImageFormat,
},
Resize {
w: u32,
h: u32,
},
Encode { image_format: ImageFormat },
Resize { w: u32, h: u32 },
Crop(),
ToMode {
mode: ImageMode,
},
ToMode { mode: ImageMode },
}

impl ImageExpr {
Expand All @@ -42,7 +29,6 @@ impl ImageExpr {
use ImageExpr::*;

match self {
Decode { .. } => &DecodeEvaluator {},
Encode { .. } => &EncodeEvaluator {},
Resize { .. } => &ResizeEvaluator {},
Crop { .. } => &CropEvaluator {},
Expand All @@ -51,17 +37,6 @@ impl ImageExpr {
}
}

pub fn decode(input: ExprRef, raise_error_on_failure: bool, mode: Option<ImageMode>) -> ExprRef {
Expr::Function {
func: super::FunctionExpr::Image(ImageExpr::Decode {
raise_error_on_failure,
mode,
}),
inputs: vec![input],
}
.into()
}

pub fn encode(input: ExprRef, image_format: ImageFormat) -> ExprRef {
Expr::Function {
func: super::FunctionExpr::Image(ImageExpr::Encode { image_format }),
Expand Down
9 changes: 0 additions & 9 deletions src/daft-dsl/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -797,15 +797,6 @@ impl PyExpr {
Ok(normalize(self.into(), opts).into())
}

pub fn image_decode(
&self,
raise_error_on_failure: bool,
mode: Option<ImageMode>,
) -> PyResult<Self> {
use crate::functions::image::decode;
Ok(decode(self.into(), raise_error_on_failure, mode).into())
}

pub fn image_encode(&self, image_format: ImageFormat) -> PyResult<Self> {
use crate::functions::image::encode;
Ok(encode(self.into(), image_format).into())
Expand Down
104 changes: 104 additions & 0 deletions src/daft-functions/src/image/decode.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
use daft_core::{
datatypes::{DataType, Field, ImageMode},
schema::Schema,
series::Series,
};

use common_error::{DaftError, DaftResult};
use daft_dsl::{
functions::{ScalarFunction, ScalarUDF},
ExprRef,
};
use serde::{Deserialize, Serialize};

/// Container for the keyword arguments for `image_decode`
/// ex:
/// ```text
/// image_decode(input)
/// image_decode(input, mode='RGB')
/// image_decode(input, mode='RGB', on_error='raise')
/// image_decode(input, on_error='null')
/// image_decode(input, on_error='null', mode='RGB')
/// image_decode(input, mode='RGB', on_error='null')
/// ```
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub struct ImageDecode {
pub mode: Option<ImageMode>,
pub raise_on_error: bool,
}

impl Default for ImageDecode {
fn default() -> Self {
Self {
mode: None,
raise_on_error: true,
}
}
}

#[typetag::serde]
impl ScalarUDF for ImageDecode {
fn as_any(&self) -> &dyn std::any::Any {
self
}

fn name(&self) -> &'static str {
"image_decode"
}

fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult<Field> {
match inputs {
[input] => {
let field = input.to_field(schema)?;
if !matches!(field.dtype, DataType::Binary) {
return Err(DaftError::TypeError(format!(
"ImageDecode can only decode BinaryArrays, got {}",
field
)));
}
Ok(Field::new(field.name, DataType::Image(self.mode)))
}
_ => Err(DaftError::SchemaMismatch(format!(
"Expected 1 input arg, got {}",
inputs.len()
))),
}
}

fn evaluate(&self, inputs: &[Series]) -> DaftResult<Series> {
let raise_error_on_failure = self.raise_on_error;
match inputs {
[input] => input.image_decode(raise_error_on_failure, self.mode),
_ => Err(DaftError::ValueError(format!(
"Expected 1 input arg, got {}",
inputs.len()
))),
}
}
}

pub fn decode(input: ExprRef, args: Option<ImageDecode>) -> ExprRef {
ScalarFunction::new(args.unwrap_or_default(), vec![input]).into()
}

#[cfg(feature = "python")]
use {
daft_dsl::python::PyExpr,
pyo3::{pyfunction, PyResult},
};

#[cfg(feature = "python")]
#[pyfunction]
#[pyo3(name = "image_decode")]
pub fn py_decode(
expr: PyExpr,
raise_on_error: Option<bool>,
mode: Option<ImageMode>,
) -> PyResult<PyExpr> {
let image_decode = ImageDecode {
mode,
raise_on_error: raise_on_error.unwrap_or(true),
};

Ok(decode(expr.into(), Some(image_decode)).into())
}
11 changes: 11 additions & 0 deletions src/daft-functions/src/image/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
pub mod decode;

#[cfg(feature = "python")]
use pyo3::prelude::*;

#[cfg(feature = "python")]
pub fn register_modules(_py: Python, parent: &PyModule) -> PyResult<()> {
parent.add_wrapped(wrap_pyfunction!(decode::py_decode))?;

Ok(())
}
3 changes: 3 additions & 0 deletions src/daft-functions/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
pub mod count_matches;
pub mod distance;
pub mod hash;
pub mod image;
pub mod list_sort;
pub mod minhash;
pub mod numeric;
Expand All @@ -28,6 +29,7 @@ pub fn register_modules(_py: Python, parent: &PyModule) -> PyResult<()> {
parent.add_wrapped(wrap_pyfunction!(tokenize::python::tokenize_encode))?;
parent.add_wrapped(wrap_pyfunction!(uri::python::url_download))?;
parent.add_wrapped(wrap_pyfunction!(uri::python::url_upload))?;
image::register_modules(_py, parent)?;
Ok(())
}

Expand All @@ -42,6 +44,7 @@ impl From<Error> for std::io::Error {
std::io::Error::new(std::io::ErrorKind::Other, err)
}
}

impl From<Error> for DaftError {
fn from(err: Error) -> DaftError {
DaftError::External(err.into())
Expand Down
3 changes: 2 additions & 1 deletion src/daft-sql/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ common-daft-config = {path = "../common/daft-config"}
common-error = {path = "../common/error"}
daft-core = {path = "../daft-core"}
daft-dsl = {path = "../daft-dsl"}
daft-functions = {path = "../daft-functions"}
daft-plan = {path = "../daft-plan"}
once_cell = {workspace = true}
pyo3 = {workspace = true, optional = true}
Expand All @@ -13,7 +14,7 @@ snafu.workspace = true
rstest = {workspace = true}

[features]
python = ["dep:pyo3", "common-error/python"]
python = ["dep:pyo3", "common-error/python", "daft-functions/python"]

[package]
name = "daft-sql"
Expand Down
Loading

0 comments on commit 734c13f

Please sign in to comment.