From de5b0eb8e4922f16b7a8f36ed6373490f8b2da8d Mon Sep 17 00:00:00 2001 From: Theodor Mihalache <84387487+tmihalac@users.noreply.github.com> Date: Thu, 20 Jun 2024 10:27:09 -0400 Subject: [PATCH] refactor: Add parameters validation to OfflineServer (#4289) Add parameters validation to OfflineServer Signed-off-by: Theodor Mihalache --- sdk/python/feast/offline_server.py | 96 +++++++++++++++++++++++++++--- 1 file changed, 88 insertions(+), 8 deletions(-) diff --git a/sdk/python/feast/offline_server.py b/sdk/python/feast/offline_server.py index 718da1b109..be92620d68 100644 --- a/sdk/python/feast/offline_server.py +++ b/sdk/python/feast/offline_server.py @@ -74,14 +74,15 @@ def do_put( logger.debug(f"do_put: command is{command}, data is {data}") self.flights[key] = data - self._call_api(command, key) + self._call_api(command["api"], command, key) else: logger.warning(f"No 'api' field in command: {command}") - def _call_api(self, command: dict, key: str): + def _call_api(self, api: str, command: dict, key: str): + assert api is not None, "api can not be empty" + remove_data = False try: - api = command["api"] if api == OfflineServer.offline_write_batch.__name__: self.offline_write_batch(command, key) remove_data = True @@ -89,7 +90,7 @@ def _call_api(self, command: dict, key: str): self.write_logged_features(command, key) remove_data = True elif api == OfflineServer.persist.__name__: - self.persist(command["retrieve_func"], command, key) + self.persist(command, key) remove_data = True except Exception as e: remove_data = True @@ -150,6 +151,9 @@ def list_feature_views_by_name( for index, fv_name in enumerate(feature_view_names) ] + def _validate_do_get_parameters(self, command: dict): + assert "api" in command, "api parameter is mandatory" + # Extracts the API parameters from the flights dictionary, delegates the execution to the FeatureStore instance # and returns the stream of data def do_get(self, context: fl.ServerCallContext, ticket: fl.Ticket): @@ -159,6 +163,9 @@ def do_get(self, context: fl.ServerCallContext, ticket: fl.Ticket): return None command = json.loads(key[1]) + + self._validate_do_get_parameters(command) + api = command["api"] logger.debug(f"get command is {command}") logger.debug(f"requested api is {api}") @@ -180,13 +187,26 @@ def do_get(self, context: fl.ServerCallContext, ticket: fl.Ticket): del self.flights[key] return fl.RecordBatchStream(table) - def offline_write_batch(self, command: dict, key: str): + def _validate_offline_write_batch_parameters(self, command: dict): + assert ( + "feature_view_names" in command + ), "feature_view_names is a mandatory parameter" + assert "name_aliases" in command, "name_aliases is a mandatory parameter" + feature_view_names = command["feature_view_names"] assert ( len(feature_view_names) == 1 ), "feature_view_names list should only have one item" + name_aliases = command["name_aliases"] assert len(name_aliases) == 1, "name_aliases list should only have one item" + + def offline_write_batch(self, command: dict, key: str): + self._validate_offline_write_batch_parameters(command) + + feature_view_names = command["feature_view_names"] + name_aliases = command["name_aliases"] + project = self.store.config.project feature_views = self.list_feature_views_by_name( feature_view_names=feature_view_names, @@ -194,19 +214,25 @@ def offline_write_batch(self, command: dict, key: str): project=project, ) - assert len(feature_views) == 1 + assert len(feature_views) == 1, "incorrect feature view" table = self.flights[key] self.offline_store.offline_write_batch( self.store.config, feature_views[0], table, command["progress"] ) + def _validate_write_logged_features_parameters(self, command: dict): + assert "feature_service_name" in command + def write_logged_features(self, command: dict, key: str): + self._validate_write_logged_features_parameters(command) table = self.flights[key] feature_service = self.store.get_feature_service( command["feature_service_name"] ) - assert feature_service.logging_config is not None + assert ( + feature_service.logging_config is not None + ), "feature service must have logging_config set" self.offline_store.write_logged_features( config=self.store.config, @@ -218,7 +244,23 @@ def write_logged_features(self, command: dict, key: str): registry=self.store.registry, ) + def _validate_pull_all_from_table_or_query_parameters(self, command: dict): + assert ( + "data_source_name" in command + ), "data_source_name is a mandatory parameter" + assert ( + "join_key_columns" in command + ), "join_key_columns is a mandatory parameter" + assert ( + "feature_name_columns" in command + ), "feature_name_columns is a mandatory parameter" + assert "timestamp_field" in command, "timestamp_field is a mandatory parameter" + assert "start_date" in command, "start_date is a mandatory parameter" + assert "end_date" in command, "end_date is a mandatory parameter" + def pull_all_from_table_or_query(self, command: dict): + self._validate_pull_all_from_table_or_query_parameters(command) + return self.offline_store.pull_all_from_table_or_query( self.store.config, self.store.get_data_source(command["data_source_name"]), @@ -229,7 +271,23 @@ def pull_all_from_table_or_query(self, command: dict): utils.make_tzaware(datetime.fromisoformat(command["end_date"])), ) + def _validate_pull_latest_from_table_or_query_parameters(self, command: dict): + assert ( + "data_source_name" in command + ), "data_source_name is a mandatory parameter" + assert ( + "join_key_columns" in command + ), "join_key_columns is a mandatory parameter" + assert ( + "feature_name_columns" in command + ), "feature_name_columns is a mandatory parameter" + assert "timestamp_field" in command, "timestamp_field is a mandatory parameter" + assert "start_date" in command, "start_date is a mandatory parameter" + assert "end_date" in command, "end_date is a mandatory parameter" + def pull_latest_from_table_or_query(self, command: dict): + self._validate_pull_latest_from_table_or_query_parameters(command) + return self.offline_store.pull_latest_from_table_or_query( self.store.config, self.store.get_data_source(command["data_source_name"]), @@ -258,20 +316,33 @@ def list_actions(self, context): ), ] + def _validate_get_historical_features_parameters(self, command: dict, key: str): + assert key in self.flights, f"missing key={key}" + assert "feature_view_names" in command, "feature_view_names is mandatory" + assert "name_aliases" in command, "name_aliases is mandatory" + assert "feature_refs" in command, "feature_refs is mandatory" + assert "project" in command, "project is mandatory" + assert "full_feature_names" in command, "full_feature_names is mandatory" + def get_historical_features(self, command: dict, key: str): + self._validate_get_historical_features_parameters(command, key) + # Extract parameters from the internal flights dictionary entity_df_value = self.flights[key] entity_df = pa.Table.to_pandas(entity_df_value) + feature_view_names = command["feature_view_names"] name_aliases = command["name_aliases"] feature_refs = command["feature_refs"] project = command["project"] full_feature_names = command["full_feature_names"] + feature_views = self.list_feature_views_by_name( feature_view_names=feature_view_names, name_aliases=name_aliases, project=project, ) + retJob = self.offline_store.get_historical_features( config=self.store.config, feature_views=feature_views, @@ -281,10 +352,19 @@ def get_historical_features(self, command: dict, key: str): project=project, full_feature_names=full_feature_names, ) + return retJob - def persist(self, retrieve_func: str, command: dict, key: str): + def _validate_persist_parameters(self, command: dict): + assert "retrieve_func" in command, "retrieve_func is mandatory" + assert "data_source_name" in command, "data_source_name is mandatory" + assert "allow_overwrite" in command, "allow_overwrite is mandatory" + + def persist(self, command: dict, key: str): + self._validate_persist_parameters(command) + try: + retrieve_func = command["retrieve_func"] if retrieve_func == OfflineServer.get_historical_features.__name__: ret_job = self.get_historical_features(command, key) elif (