diff --git a/dagster_polars/io_managers/base.py b/dagster_polars/io_managers/base.py index 36b7f9b..18bd401 100644 --- a/dagster_polars/io_managers/base.py +++ b/dagster_polars/io_managers/base.py @@ -269,32 +269,32 @@ def dump_to_path( else: assert obj is not None, "output should not be None if it's type is not Optional" if not annotation_for_storage_metadata(typing_type): - if typing_type == pl.DataFrame: + if typing_type in POLARS_EAGER_FRAME_ANNOTATIONS: obj = cast(pl.DataFrame, obj) df = obj self.write_df_to_path(context=context, df=df, path=path) - elif typing_type == pl.LazyFrame: + elif typing_type in POLARS_LAZY_FRAME_ANNOTATIONS: obj = cast(pl.LazyFrame, obj) df = obj self.sink_df_to_path(context=context, df=df, path=path) else: - raise NotImplementedError + raise NotImplementedError(f"dump_df_to_path for {typing_type} is not implemented") else: if not annotation_is_typing_optional(typing_type): frame_type = get_args(typing_type)[0] else: frame_type = get_args(get_args(typing_type)[0])[0] - if frame_type == pl.DataFrame: + if frame_type in POLARS_EAGER_FRAME_ANNOTATIONS: obj = cast(Tuple[pl.DataFrame, Dict[str, Any]], obj) df, metadata = obj self.write_df_to_path(context=context, df=df, path=path, metadata=metadata) - elif frame_type == pl.LazyFrame: + elif frame_type in POLARS_LAZY_FRAME_ANNOTATIONS: obj = cast(Tuple[pl.LazyFrame, Dict[str, Any]], obj) df, metadata = obj self.sink_df_to_path(context=context, df=df, path=path, metadata=metadata) else: - raise NotImplementedError + raise NotImplementedError(f"dump_df_to_path for {typing_type} is not implemented") def load_from_path( self, context: InputContext, path: "UPath", partition_key: Optional[str] = None