Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented PR change proposal #16

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions sdk/python/feast/infra/offline_stores/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,14 +335,20 @@ def _send_retrieve_remote(
return _call_get(client, command_descriptor)


def _call_get(client, command_descriptor):
def _call_get(client: fl.FlightClient, command_descriptor: fl.FlightDescriptor):
flight = client.get_flight_info(command_descriptor)
ticket = flight.endpoints[0].ticket
reader = client.do_get(ticket)
return reader.read_all()


def _call_put(api, api_parameters, client, entity_df, table):
def _call_put(
api: str,
api_parameters: Dict[str, Any],
client: fl.FlightClient,
entity_df: Union[pd.DataFrame, str],
table: pa.Table,
):
# Generate unique command identifier
command_id = str(uuid.uuid4())
command = {
Expand All @@ -364,7 +370,7 @@ def _call_put(api, api_parameters, client, entity_df, table):


def _put_parameters(
command_descriptor,
command_descriptor: fl.FlightDescriptor,
entity_df: Union[pd.DataFrame, str],
table: pa.Table,
client: fl.FlightClient,
Expand Down
43 changes: 26 additions & 17 deletions sdk/python/feast/offline_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,39 +27,46 @@ def __init__(self, store: FeatureStore, location: str, **kwargs):
self.offline_store = get_offline_store_from_config(store.config.offline_store)

@classmethod
def descriptor_to_key(self, descriptor):
def descriptor_to_key(self, descriptor: fl.FlightDescriptor):
return (
descriptor.descriptor_type.value,
descriptor.command,
tuple(descriptor.path or tuple()),
)

def _make_flight_info(self, key, descriptor, params):
def _make_flight_info(self, key: Any, descriptor: fl.FlightDescriptor):
endpoints = [fl.FlightEndpoint(repr(key), [self._location])]
# TODO calculate actual schema from the given features
schema = pa.schema([])

return fl.FlightInfo(schema, descriptor, endpoints, -1, -1)

def get_flight_info(self, context, descriptor):
def get_flight_info(
self, context: fl.ServerCallContext, descriptor: fl.FlightDescriptor
):
key = OfflineServer.descriptor_to_key(descriptor)
if key in self.flights:
params = self.flights[key]
return self._make_flight_info(key, descriptor, params)
return self._make_flight_info(key, descriptor)
raise KeyError("Flight not found.")

def list_flights(self, context, criteria):
def list_flights(self, context: fl.ServerCallContext, criteria: bytes):
for key, table in self.flights.items():
if key[1] is not None:
descriptor = fl.FlightDescriptor.for_command(key[1])
else:
descriptor = fl.FlightDescriptor.for_path(*key[2])

yield self._make_flight_info(key, descriptor, table)
yield self._make_flight_info(key, descriptor)

# Expects to receive request parameters and stores them in the flights dictionary
# Indexed by the unique command
def do_put(self, context, descriptor, reader, writer):
def do_put(
self,
context: fl.ServerCallContext,
descriptor: fl.FlightDescriptor,
reader: fl.MetadataRecordBatchReader,
writer: fl.FlightMetadataWriter,
):
key = OfflineServer.descriptor_to_key(descriptor)
command = json.loads(key[1])
if "api" in command:
Expand All @@ -71,7 +78,7 @@ def do_put(self, context, descriptor, reader, writer):
else:
logger.warning(f"No 'api' field in command: {command}")

def _call_api(self, command, key):
def _call_api(self, command: dict, key: str):
remove_data = False
try:
api = command["api"]
Expand Down Expand Up @@ -145,7 +152,7 @@ def list_feature_views_by_name(

# Extracts the API parameters from the flights dictionary, delegates the execution to the FeatureStore instance
# and returns the stream of data
def do_get(self, context, ticket):
def do_get(self, context: fl.ServerCallContext, ticket: fl.Ticket):
key = ast.literal_eval(ticket.ticket.decode())
if key not in self.flights:
logger.error(f"Unknown key {key}")
Expand Down Expand Up @@ -173,7 +180,7 @@ def do_get(self, context, ticket):
del self.flights[key]
return fl.RecordBatchStream(table)

def offline_write_batch(self, command, key):
def offline_write_batch(self, command: dict, key: str):
feature_view_names = command["feature_view_names"]
assert (
len(feature_view_names) == 1
Expand All @@ -193,12 +200,14 @@ def offline_write_batch(self, command, key):
self.store.config, feature_views[0], table, command["progress"]
)

def write_logged_features(self, command, key):
def write_logged_features(self, command: dict, key: str):
table = self.flights[key]
feature_service = self.store.get_feature_service(
command["feature_service_name"]
)

assert feature_service.logging_config is not None

self.offline_store.write_logged_features(
config=self.store.config,
data=table,
Expand All @@ -209,7 +218,7 @@ def write_logged_features(self, command, key):
registry=self.store.registry,
)

def pull_all_from_table_or_query(self, command):
def pull_all_from_table_or_query(self, command: dict):
return self.offline_store.pull_all_from_table_or_query(
self.store.config,
self.store.get_data_source(command["data_source_name"]),
Expand All @@ -220,7 +229,7 @@ def pull_all_from_table_or_query(self, command):
utils.make_tzaware(datetime.fromisoformat(command["end_date"])),
)

def pull_latest_from_table_or_query(self, command):
def pull_latest_from_table_or_query(self, command: dict):
return self.offline_store.pull_latest_from_table_or_query(
self.store.config,
self.store.get_data_source(command["data_source_name"]),
Expand Down Expand Up @@ -249,7 +258,7 @@ def list_actions(self, context):
),
]

def get_historical_features(self, command, key):
def get_historical_features(self, command: dict, key: str):
# Extract parameters from the internal flights dictionary
entity_df_value = self.flights[key]
entity_df = pa.Table.to_pandas(entity_df_value)
Expand All @@ -274,7 +283,7 @@ def get_historical_features(self, command, key):
)
return retJob

def persist(self, retrieve_func, command, key):
def persist(self, retrieve_func: str, command: dict, key: str):
try:
if retrieve_func == OfflineServer.get_historical_features.__name__:
ret_job = self.get_historical_features(command, key)
Expand All @@ -295,7 +304,7 @@ def persist(self, retrieve_func, command, key):
traceback.print_exc()
raise e

def do_action(self, context, action):
def do_action(self, context: fl.ServerCallContext, action: fl.Action):
pass

def do_drop_dataset(self, dataset):
Expand Down
Loading