Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Rewrite Spark materialization engine to use mapInPandas #3936

Merged
merged 2 commits into from
Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Callable, List, Literal, Optional, Sequence, Union, cast

import dill
import pandas
import pandas as pd
import pyarrow
from tqdm import tqdm
Expand Down Expand Up @@ -178,9 +179,9 @@ def _materialize_one(
self.repo_config.batch_engine.partitions
)

spark_df.foreachPartition(
lambda x: _process_by_partition(x, spark_serialized_artifacts)
)
spark_df.mapInPandas(
lambda x: _map_by_partition(x, spark_serialized_artifacts), "status int"
).count() # dummy action to force evaluation

return SparkMaterializationJob(
job_id=job_id, status=MaterializationJobStatus.SUCCEEDED
Expand Down Expand Up @@ -225,38 +226,40 @@ def unserialize(self):
return feature_view, online_store, repo_config


def _process_by_partition(rows, spark_serialized_artifacts: _SparkSerializedArtifacts):
"""Load pandas df to online store"""

# convert to pyarrow table
dicts = []
for row in rows:
dicts.append(row.asDict())
def _map_by_partition(iterator, spark_serialized_artifacts: _SparkSerializedArtifacts):
for pdf in iterator:
if pdf.shape[0] == 0:
print("Skipping")
return

df = pd.DataFrame.from_records(dicts)
if df.shape[0] == 0:
print("Skipping")
return
table = pyarrow.Table.from_pandas(pdf)

table = pyarrow.Table.from_pandas(df)
(
feature_view,
online_store,
repo_config,
) = spark_serialized_artifacts.unserialize()

if feature_view.batch_source.field_mapping is not None:
table = _run_pyarrow_field_mapping(
table, feature_view.batch_source.field_mapping
)

# unserialize artifacts
feature_view, online_store, repo_config = spark_serialized_artifacts.unserialize()
join_key_to_value_type = {
entity.name: entity.dtype.to_value_type()
for entity in feature_view.entity_columns
}

if feature_view.batch_source.field_mapping is not None:
table = _run_pyarrow_field_mapping(
table, feature_view.batch_source.field_mapping
rows_to_write = _convert_arrow_to_proto(
table, feature_view, join_key_to_value_type
)
online_store.online_write_batch(
repo_config,
feature_view,
rows_to_write,
lambda x: None,
)

join_key_to_value_type = {
entity.name: entity.dtype.to_value_type()
for entity in feature_view.entity_columns
}

rows_to_write = _convert_arrow_to_proto(table, feature_view, join_key_to_value_type)
online_store.online_write_batch(
repo_config,
feature_view,
rows_to_write,
lambda x: None,
)
yield pd.DataFrame(
[pd.Series(range(1, 2))]
) # dummy result because mapInPandas needs to return something
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@
"grpcio-testing>=1.56.2,<2",
"minio==7.1.0",
"mock==2.0.0",
"moto",
"moto<5",
"mypy>=0.981,<0.990",
"avro==1.10.0",
"fsspec<2023.10.0",
Expand Down
Loading