Skip to content

Commit

Permalink
Feast SDK integration for historical feature retrieval using Spark (#…
Browse files Browse the repository at this point in the history
…1054)

* Feast SDK integration for historical feature retrieval using Spark

Signed-off-by: Khor Shu Heng <[email protected]>

* Downgrade pyspark dependencies, add more tests

Signed-off-by: Khor Shu Heng <[email protected]>

* Don't expose the historical feature output config directly to the user

Signed-off-by: Khor Shu Heng <[email protected]>

Co-authored-by: Khor Shu Heng <[email protected]>
  • Loading branch information
khorshuheng and khorshuheng authored Oct 15, 2020
1 parent 5e9a717 commit e8b24bb
Show file tree
Hide file tree
Showing 17 changed files with 1,305 additions and 591 deletions.
122 changes: 119 additions & 3 deletions sdk/python/feast/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import logging
import multiprocessing
import shutil
import uuid
from itertools import groupby
from typing import Any, Dict, List, Optional, Union

import grpc
Expand All @@ -30,6 +32,8 @@
CONFIG_SERVING_ENABLE_SSL_KEY,
CONFIG_SERVING_SERVER_SSL_CERT_KEY,
CONFIG_SERVING_URL_KEY,
CONFIG_SPARK_HISTORICAL_FEATURE_OUTPUT_FORMAT,
CONFIG_SPARK_HISTORICAL_FEATURE_OUTPUT_LOCATION,
FEAST_DEFAULT_OPTIONS,
)
from feast.core.CoreService_pb2 import (
Expand Down Expand Up @@ -70,6 +74,11 @@
_write_partitioned_table_from_source,
)
from feast.online_response import OnlineResponse, _infer_online_entity_rows
from feast.pyspark.abc import RetrievalJob
from feast.pyspark.launcher import (
start_historical_feature_retrieval_job,
start_historical_feature_retrieval_spark_session,
)
from feast.serving.ServingService_pb2 import (
GetFeastServingInfoRequest,
GetOnlineFeaturesRequestV2,
Expand Down Expand Up @@ -723,7 +732,6 @@ def get_online_features(
) -> OnlineResponse:
"""
Retrieves the latest online feature data from Feast Serving.
Args:
feature_refs: List of feature references that will be returned for each entity.
Each feature reference should have the following format:
Expand All @@ -733,12 +741,10 @@ def get_online_features(
entity_rows: A list of dictionaries where each key-value is an entity-name, entity-value pair.
project: Optionally specify the the project override. If specified, uses given project for retrieval.
Overrides the projects specified in Feature References if also are specified.
Returns:
GetOnlineFeaturesResponse containing the feature data in records.
Each EntityRow provided will yield one record, which contains
data fields with data value and field status metadata (if included).
Examples:
>>> from feast import Client
>>>
Expand Down Expand Up @@ -767,3 +773,113 @@ def get_online_features(

response = OnlineResponse(response)
return response

def get_historical_features(
self,
feature_refs: List[str],
entity_source: Union[FileSource, BigQuerySource],
project: str = None,
) -> RetrievalJob:
"""
Launch a historical feature retrieval job.
Args:
feature_refs: List of feature references that will be returned for each entity.
Each feature reference should have the following format:
"feature_table:feature" where "feature_table" & "feature" refer to
the feature and feature table names respectively.
entity_source (Union[FileSource, BigQuerySource]): Source for the entity rows.
The user needs to make sure that the source is accessible from the Spark cluster
that will be used for the retrieval job.
project: Specifies the project that contains the feature tables
which the requested features belong to.
Returns:
Returns a retrieval job object that can be used to monitor retrieval
progress asynchronously, and can be used to materialize the
results.
Examples:
>>> from feast import Client
>>> from datetime import datetime
>>> feast_client = Client(core_url="localhost:6565")
>>> feature_refs = ["bookings:bookings_7d", "bookings:booking_14d"]
>>> entity_source = FileSource("event_timestamp", "parquet", "gs://some-bucket/customer")
>>> feature_retrieval_job = feast_client.get_historical_features(
>>> feature_refs, entity_source, project="my_project")
>>> output_file_uri = feature_retrieval_job.get_output_file_uri()
"gs://some-bucket/output/
"""
feature_tables = self._get_feature_tables_from_feature_refs(
feature_refs, project
)
output_location = self._config.get(
CONFIG_SPARK_HISTORICAL_FEATURE_OUTPUT_LOCATION
)
output_format = self._config.get(CONFIG_SPARK_HISTORICAL_FEATURE_OUTPUT_FORMAT)
job_id = f"historical-feature-{str(uuid.uuid4())}"

return start_historical_feature_retrieval_job(
self, entity_source, feature_tables, output_format, output_location, job_id
)

def get_historical_features_df(
self,
feature_refs: List[str],
entity_source: Union[FileSource, BigQuerySource],
project: str = None,
):
"""
Launch a historical feature retrieval job.
Args:
feature_refs: List of feature references that will be returned for each entity.
Each feature reference should have the following format:
"feature_table:feature" where "feature_table" & "feature" refer to
the feature and feature table names respectively.
entity_source (Union[FileSource, BigQuerySource]): Source for the entity rows.
The user needs to make sure that the source is accessible from the Spark cluster
that will be used for the retrieval job.
project: Specifies the project that contains the feature tables
which the requested features belong to.
Returns:
Returns the historical feature retrieval result in the form of Spark dataframe.
Examples:
>>> from feast import Client
>>> from datetime import datetime
>>> from pyspark.sql import SparkSession
>>> spark = SparkSession.builder.getOrCreate()
>>> feast_client = Client(core_url="localhost:6565")
>>> feature_refs = ["bookings:bookings_7d", "bookings:booking_14d"]
>>> entity_source = FileSource("event_timestamp", "parquet", "gs://some-bucket/customer")
>>> df = feast_client.get_historical_features(
>>> feature_refs, entity_source, project="my_project")
"""
feature_tables = self._get_feature_tables_from_feature_refs(
feature_refs, project
)
return start_historical_feature_retrieval_spark_session(
self, entity_source, feature_tables
)

def _get_feature_tables_from_feature_refs(
self, feature_refs: List[str], project: Optional[str]
):
feature_refs_grouped_by_table = [
(feature_table_name, list(grouped_feature_refs))
for feature_table_name, grouped_feature_refs in groupby(
feature_refs, lambda x: x.split(":")[0]
)
]

feature_tables = []
for feature_table_name, grouped_feature_refs in feature_refs_grouped_by_table:
feature_table = self.get_feature_table(feature_table_name, project)
feature_names = [f.split(":")[-1] for f in grouped_feature_refs]
feature_table.features = [
f for f in feature_table.features if f.name in feature_names
]
feature_tables.append(feature_table)
return feature_tables
14 changes: 14 additions & 0 deletions sdk/python/feast/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,20 @@ class AuthProvider(Enum):
CONFIG_TIMEOUT_KEY = "timeout"
CONFIG_MAX_WAIT_INTERVAL_KEY = "max_wait_interval"

# Spark Job Config
CONFIG_SPARK_LAUNCHER = "spark_launcher" # standalone, dataproc, emr

CONFIG_SPARK_STANDALONE_MASTER = "spark_standalone_master"

CONFIG_SPARK_DATAPROC_CLUSTER_NAME = "dataproc_cluster_name"
CONFIG_SPARK_DATAPROC_PROJECT = "dataproc_project"
CONFIG_SPARK_DATAPROC_REGION = "dataproc_region"
CONFIG_SPARK_DATAPROC_STAGING_LOCATION = "dataproc_staging_location"

CONFIG_SPARK_HISTORICAL_FEATURE_OUTPUT_FORMAT = "historical_feature_output_format"
CONFIG_SPARK_HISTORICAL_FEATURE_OUTPUT_LOCATION = "historical_feature_output_location"


# Configuration option default values
FEAST_DEFAULT_OPTIONS = {
# Default Feast project to use
Expand Down
193 changes: 193 additions & 0 deletions sdk/python/feast/pyspark/abc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
import abc
from typing import Dict, List


class SparkJobFailure(Exception):
"""
Job submission failed, encountered error during execution, or timeout
"""

pass


class SparkJob(abc.ABC):
"""
Base class for all spark jobs
"""

@abc.abstractmethod
def get_id(self) -> str:
"""
Getter for the job id. The job id must be unique for each spark job submission.
Returns:
str: Job id.
"""
raise NotImplementedError


class RetrievalJob(SparkJob):
"""
Container for the historical feature retrieval job result
"""

@abc.abstractmethod
def get_output_file_uri(self, timeout_sec=None):
"""
Get output file uri to the result file. This method will block until the
job succeeded, or if the job didn't execute successfully within timeout.
Args:
timeout_sec (int):
Max no of seconds to wait until job is done. If "timeout_sec"
is exceeded or if the job fails, an exception will be raised.
Raises:
SparkJobFailure:
The spark job submission failed, encountered error during execution,
or timeout.
Returns:
str: file uri to the result file.
"""
raise NotImplementedError


class IngestionJob(SparkJob):
pass


class JobLauncher(abc.ABC):
"""
Submits spark jobs to a spark cluster. Currently supports only historical feature retrieval jobs.
"""

@abc.abstractmethod
def historical_feature_retrieval(
self,
pyspark_script: str,
entity_source_conf: Dict,
feature_tables_sources_conf: List[Dict],
feature_tables_conf: List[Dict],
destination_conf: Dict,
job_id: str,
**kwargs,
) -> RetrievalJob:
"""
Submits a historical feature retrieval job to a Spark cluster.
Args:
pyspark_script (str): Local file path to the pyspark script for historical feature
retrieval.
entity_source_conf (Dict): Entity data source configuration.
feature_tables_sources_conf (List[Dict]): List of feature tables data sources configurations.
feature_tables_conf (List[Dict]): List of feature table specification.
The order of the feature table must correspond to that of feature_tables_sources.
destination_conf (Dict): Retrieval job output destination.
job_id (str): A job id that is unique for each job submission.
Raises:
SparkJobFailure: The spark job submission failed, encountered error
during execution, or timeout.
Examples:
>>> # Entity source from file
>>> entity_source_conf = {
"file": {
"format": "parquet",
"path": "gs://some-gcs-bucket/customer",
"event_timestamp_column": "event_timestamp",
"options": {
"mergeSchema": "true"
} # Optional. Options to be passed to Spark while reading the dataframe from source.
"field_mapping": {
"id": "customer_id"
} # Optional. Map the columns, where the key is the original column name and the value is the new column name.
}
}
>>> # Entity source from BigQuery
>>> entity_source_conf = {
"bq": {
"project": "gcp_project_id",
"dataset": "bq_dataset",
"table": "customer",
"event_timestamp_column": "event_timestamp",
}
}
>>> feature_table_sources_conf = [
{
"bq": {
"project": "gcp_project_id",
"dataset": "bq_dataset",
"table": "customer_transactions",
"event_timestamp_column": "event_timestamp",
"created_timestamp_column": "created_timestamp" # This field is mandatory for feature tables.
}
},
{
"file": {
"format": "parquet",
"path": "gs://some-gcs-bucket/customer_profile",
"event_timestamp_column": "event_timestamp",
"created_timestamp_column": "created_timestamp",
"options": {
"mergeSchema": "true"
}
}
},
]
>>> feature_tables_conf = [
{
"name": "customer_transactions",
"entities": [
{
"name": "customer
"type": "int32"
}
],
"features": [
{
"name": "total_transactions"
"type": "double"
},
{
"name": "total_discounts"
"type": "double"
}
],
"max_age": 86400 # In seconds.
},
{
"name": "customer_profile",
"entities": [
{
"name": "customer
"type": "int32"
}
],
"features": [
{
"name": "is_vip"
"type": "bool"
}
],
}
]
>>> destination_conf = {
"format": "parquet",
"path": "gs://some-gcs-bucket/retrieval_output"
}
Returns:
str: file uri to the result file.
"""
raise NotImplementedError
Loading

0 comments on commit e8b24bb

Please sign in to comment.