diff --git a/sdk/python/feast/infra/offline_stores/remote.py b/sdk/python/feast/infra/offline_stores/remote.py index 20e5a4bdf9..dc657017d9 100644 --- a/sdk/python/feast/infra/offline_stores/remote.py +++ b/sdk/python/feast/infra/offline_stores/remote.py @@ -335,14 +335,20 @@ def _send_retrieve_remote( return _call_get(client, command_descriptor) -def _call_get(client, command_descriptor): +def _call_get(client: fl.FlightClient, command_descriptor: fl.FlightDescriptor): flight = client.get_flight_info(command_descriptor) ticket = flight.endpoints[0].ticket reader = client.do_get(ticket) return reader.read_all() -def _call_put(api, api_parameters, client, entity_df, table): +def _call_put( + api: str, + api_parameters: Dict[str, Any], + client: fl.FlightClient, + entity_df: Union[pd.DataFrame, str], + table: pa.Table, +): # Generate unique command identifier command_id = str(uuid.uuid4()) command = { @@ -364,7 +370,7 @@ def _call_put(api, api_parameters, client, entity_df, table): def _put_parameters( - command_descriptor, + command_descriptor: fl.FlightDescriptor, entity_df: Union[pd.DataFrame, str], table: pa.Table, client: fl.FlightClient, diff --git a/sdk/python/feast/offline_server.py b/sdk/python/feast/offline_server.py index ff392cc44f..718da1b109 100644 --- a/sdk/python/feast/offline_server.py +++ b/sdk/python/feast/offline_server.py @@ -27,39 +27,46 @@ def __init__(self, store: FeatureStore, location: str, **kwargs): self.offline_store = get_offline_store_from_config(store.config.offline_store) @classmethod - def descriptor_to_key(self, descriptor): + def descriptor_to_key(self, descriptor: fl.FlightDescriptor): return ( descriptor.descriptor_type.value, descriptor.command, tuple(descriptor.path or tuple()), ) - def _make_flight_info(self, key, descriptor, params): + def _make_flight_info(self, key: Any, descriptor: fl.FlightDescriptor): endpoints = [fl.FlightEndpoint(repr(key), [self._location])] # TODO calculate actual schema from the given features schema = pa.schema([]) return fl.FlightInfo(schema, descriptor, endpoints, -1, -1) - def get_flight_info(self, context, descriptor): + def get_flight_info( + self, context: fl.ServerCallContext, descriptor: fl.FlightDescriptor + ): key = OfflineServer.descriptor_to_key(descriptor) if key in self.flights: - params = self.flights[key] - return self._make_flight_info(key, descriptor, params) + return self._make_flight_info(key, descriptor) raise KeyError("Flight not found.") - def list_flights(self, context, criteria): + def list_flights(self, context: fl.ServerCallContext, criteria: bytes): for key, table in self.flights.items(): if key[1] is not None: descriptor = fl.FlightDescriptor.for_command(key[1]) else: descriptor = fl.FlightDescriptor.for_path(*key[2]) - yield self._make_flight_info(key, descriptor, table) + yield self._make_flight_info(key, descriptor) # Expects to receive request parameters and stores them in the flights dictionary # Indexed by the unique command - def do_put(self, context, descriptor, reader, writer): + def do_put( + self, + context: fl.ServerCallContext, + descriptor: fl.FlightDescriptor, + reader: fl.MetadataRecordBatchReader, + writer: fl.FlightMetadataWriter, + ): key = OfflineServer.descriptor_to_key(descriptor) command = json.loads(key[1]) if "api" in command: @@ -71,7 +78,7 @@ def do_put(self, context, descriptor, reader, writer): else: logger.warning(f"No 'api' field in command: {command}") - def _call_api(self, command, key): + def _call_api(self, command: dict, key: str): remove_data = False try: api = command["api"] @@ -145,7 +152,7 @@ def list_feature_views_by_name( # Extracts the API parameters from the flights dictionary, delegates the execution to the FeatureStore instance # and returns the stream of data - def do_get(self, context, ticket): + def do_get(self, context: fl.ServerCallContext, ticket: fl.Ticket): key = ast.literal_eval(ticket.ticket.decode()) if key not in self.flights: logger.error(f"Unknown key {key}") @@ -173,7 +180,7 @@ def do_get(self, context, ticket): del self.flights[key] return fl.RecordBatchStream(table) - def offline_write_batch(self, command, key): + def offline_write_batch(self, command: dict, key: str): feature_view_names = command["feature_view_names"] assert ( len(feature_view_names) == 1 @@ -193,12 +200,14 @@ def offline_write_batch(self, command, key): self.store.config, feature_views[0], table, command["progress"] ) - def write_logged_features(self, command, key): + def write_logged_features(self, command: dict, key: str): table = self.flights[key] feature_service = self.store.get_feature_service( command["feature_service_name"] ) + assert feature_service.logging_config is not None + self.offline_store.write_logged_features( config=self.store.config, data=table, @@ -209,7 +218,7 @@ def write_logged_features(self, command, key): registry=self.store.registry, ) - def pull_all_from_table_or_query(self, command): + def pull_all_from_table_or_query(self, command: dict): return self.offline_store.pull_all_from_table_or_query( self.store.config, self.store.get_data_source(command["data_source_name"]), @@ -220,7 +229,7 @@ def pull_all_from_table_or_query(self, command): utils.make_tzaware(datetime.fromisoformat(command["end_date"])), ) - def pull_latest_from_table_or_query(self, command): + def pull_latest_from_table_or_query(self, command: dict): return self.offline_store.pull_latest_from_table_or_query( self.store.config, self.store.get_data_source(command["data_source_name"]), @@ -249,7 +258,7 @@ def list_actions(self, context): ), ] - def get_historical_features(self, command, key): + def get_historical_features(self, command: dict, key: str): # Extract parameters from the internal flights dictionary entity_df_value = self.flights[key] entity_df = pa.Table.to_pandas(entity_df_value) @@ -274,7 +283,7 @@ def get_historical_features(self, command, key): ) return retJob - def persist(self, retrieve_func, command, key): + def persist(self, retrieve_func: str, command: dict, key: str): try: if retrieve_func == OfflineServer.get_historical_features.__name__: ret_job = self.get_historical_features(command, key) @@ -295,7 +304,7 @@ def persist(self, retrieve_func, command, key): traceback.print_exc() raise e - def do_action(self, context, action): + def do_action(self, context: fl.ServerCallContext, action: fl.Action): pass def do_drop_dataset(self, dataset):