Skip to content

Commit

Permalink
fix: Fix broken proto conversion methods for data sources (#2603)
Browse files Browse the repository at this point in the history
* Fix Snowflake proto conversion and add test

Signed-off-by: Felix Wang <[email protected]>

* Add proto conversion test for FileSource

Signed-off-by: Felix Wang <[email protected]>

* Fix Redshift proto conversion and add test

Signed-off-by: Felix Wang <[email protected]>

* Add proto conversion test for BigQuerySource

Signed-off-by: Felix Wang <[email protected]>

* Fix tests to use DataSource.from_proto

Signed-off-by: Felix Wang <[email protected]>

* Add proto conversion test for KafkaSource

Signed-off-by: Felix Wang <[email protected]>

* Add proto conversion test for KinesisSource

Signed-off-by: Felix Wang <[email protected]>

* Add proto conversion test for PushSource

Signed-off-by: Felix Wang <[email protected]>

* Add proto conversion test for PushSource

Signed-off-by: Felix Wang <[email protected]>

* Add name and other fixes

Signed-off-by: Felix Wang <[email protected]>

* Fix proto conversion tests

Signed-off-by: Felix Wang <[email protected]>

* Add tags to test

Signed-off-by: Felix Wang <[email protected]>

* Fix BigQuerySource bug

Signed-off-by: Felix Wang <[email protected]>

* Fix bug in RedshiftSource and TrinoSource

Signed-off-by: Felix Wang <[email protected]>

* Remove references to event_timestamp_column

Signed-off-by: Felix Wang <[email protected]>
  • Loading branch information
felixwang9817 authored Apr 24, 2022
1 parent c94a69c commit 00ed65a
Show file tree
Hide file tree
Showing 28 changed files with 313 additions and 412 deletions.
2 changes: 1 addition & 1 deletion go/cmd/server/logging/feature_repo/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# for more info.
driver_hourly_stats = FileSource(
path="driver_stats.parquet",
event_timestamp_column="event_timestamp",
timestamp_field="event_timestamp",
created_timestamp_column="created",
)

Expand Down
23 changes: 15 additions & 8 deletions sdk/python/feast/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def __init__(

if _message_format is None:
raise ValueError("Message format must be specified for Kafka source")
print("Asdfasdf")

super().__init__(
event_timestamp_column=_event_timestamp_column,
created_timestamp_column=created_timestamp_column,
Expand Down Expand Up @@ -467,7 +467,9 @@ def from_proto(data_source: DataSourceProto):
description=data_source.description,
tags=dict(data_source.tags),
owner=data_source.owner,
batch_source=DataSource.from_proto(data_source.batch_source),
batch_source=DataSource.from_proto(data_source.batch_source)
if data_source.batch_source
else None,
)

def to_proto(self) -> DataSourceProto:
Expand Down Expand Up @@ -500,17 +502,20 @@ class RequestSource(DataSource):
"""
RequestSource that can be used to provide input features for on demand transforms
Args:
Attributes:
name: Name of the request data source
schema Union[Dict[str, ValueType], List[Field]]: Schema mapping from the input feature name to a ValueType
description (optional): A human-readable description.
tags (optional): A dictionary of key-value pairs to store arbitrary metadata.
owner (optional): The owner of the request data source, typically the email of the primary
schema: Schema mapping from the input feature name to a ValueType
description: A human-readable description.
tags: A dictionary of key-value pairs to store arbitrary metadata.
owner: The owner of the request data source, typically the email of the primary
maintainer.
"""

name: str
schema: List[Field]
description: str
tags: Dict[str, str]
owner: str

def __init__(
self,
Expand Down Expand Up @@ -697,7 +702,9 @@ def from_proto(data_source: DataSourceProto):
description=data_source.description,
tags=dict(data_source.tags),
owner=data_source.owner,
batch_source=DataSource.from_proto(data_source.batch_source),
batch_source=DataSource.from_proto(data_source.batch_source)
if data_source.batch_source
else None,
)

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion sdk/python/feast/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def update_entities_with_inferred_types_from_feature_views(
def update_data_sources_with_inferred_event_timestamp_col(
data_sources: List[DataSource], config: RepoConfig
) -> None:
ERROR_MSG_PREFIX = "Unable to infer DataSource event_timestamp_column"
ERROR_MSG_PREFIX = "Unable to infer DataSource timestamp_field"

for data_source in data_sources:
if isinstance(data_source, RequestSource):
Expand Down
22 changes: 11 additions & 11 deletions sdk/python/feast/infra/offline_stores/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def pull_latest_from_table_or_query(
data_source: DataSource,
join_key_columns: List[str],
feature_name_columns: List[str],
event_timestamp_column: str,
timestamp_field: str,
created_timestamp_column: Optional[str],
start_date: datetime,
end_date: datetime,
Expand All @@ -96,7 +96,7 @@ def pull_latest_from_table_or_query(
partition_by_join_key_string = (
"PARTITION BY " + partition_by_join_key_string
)
timestamps = [event_timestamp_column]
timestamps = [timestamp_field]
if created_timestamp_column:
timestamps.append(created_timestamp_column)
timestamp_desc_string = " DESC, ".join(timestamps) + " DESC"
Expand All @@ -114,7 +114,7 @@ def pull_latest_from_table_or_query(
SELECT {field_string},
ROW_NUMBER() OVER({partition_by_join_key_string} ORDER BY {timestamp_desc_string}) AS _feast_row
FROM {from_expression}
WHERE {event_timestamp_column} BETWEEN TIMESTAMP('{start_date}') AND TIMESTAMP('{end_date}')
WHERE {timestamp_field} BETWEEN TIMESTAMP('{start_date}') AND TIMESTAMP('{end_date}')
)
WHERE _feast_row = 1
"""
Expand All @@ -131,7 +131,7 @@ def pull_all_from_table_or_query(
data_source: DataSource,
join_key_columns: List[str],
feature_name_columns: List[str],
event_timestamp_column: str,
timestamp_field: str,
start_date: datetime,
end_date: datetime,
) -> RetrievalJob:
Expand All @@ -143,12 +143,12 @@ def pull_all_from_table_or_query(
location=config.offline_store.location,
)
field_string = ", ".join(
join_key_columns + feature_name_columns + [event_timestamp_column]
join_key_columns + feature_name_columns + [timestamp_field]
)
query = f"""
SELECT {field_string}
FROM {from_expression}
WHERE {event_timestamp_column} BETWEEN TIMESTAMP('{start_date}') AND TIMESTAMP('{end_date}')
WHERE {timestamp_field} BETWEEN TIMESTAMP('{start_date}') AND TIMESTAMP('{end_date}')
"""
return BigQueryRetrievalJob(
query=query, client=client, config=config, full_feature_names=False,
Expand Down Expand Up @@ -583,9 +583,9 @@ def _get_bigquery_client(project: Optional[str] = None, location: Optional[str]
1. We first join the current feature_view to the entity dataframe that has been passed.
This JOIN has the following logic:
- For each row of the entity dataframe, only keep the rows where the `event_timestamp_column`
- For each row of the entity dataframe, only keep the rows where the `timestamp_field`
is less than the one provided in the entity dataframe
- If there a TTL for the current feature_view, also keep the rows where the `event_timestamp_column`
- If there a TTL for the current feature_view, also keep the rows where the `timestamp_field`
is higher the the one provided minus the TTL
- For each row, Join on the entity key and retrieve the `entity_row_unique_id` that has been
computed previously
Expand All @@ -596,16 +596,16 @@ def _get_bigquery_client(project: Optional[str] = None, location: Optional[str]
{{ featureview.name }}__subquery AS (
SELECT
{{ featureview.event_timestamp_column }} as event_timestamp,
{{ featureview.timestamp_field }} as event_timestamp,
{{ featureview.created_timestamp_column ~ ' as created_timestamp,' if featureview.created_timestamp_column else '' }}
{{ featureview.entity_selections | join(', ')}}{% if featureview.entity_selections %},{% else %}{% endif %}
{% for feature in featureview.features %}
{{ feature }} as {% if full_feature_names %}{{ featureview.name }}__{{featureview.field_mapping.get(feature, feature)}}{% else %}{{ featureview.field_mapping.get(feature, feature) }}{% endif %}{% if loop.last %}{% else %}, {% endif %}
{% endfor %}
FROM {{ featureview.table_subquery }}
WHERE {{ featureview.event_timestamp_column }} <= '{{ featureview.max_event_timestamp }}'
WHERE {{ featureview.timestamp_field }} <= '{{ featureview.max_event_timestamp }}'
{% if featureview.ttl == 0 %}{% else %}
AND {{ featureview.event_timestamp_column }} >= '{{ featureview.min_event_timestamp }}'
AND {{ featureview.timestamp_field }} >= '{{ featureview.min_event_timestamp }}'
{% endif %}
),
Expand Down
56 changes: 9 additions & 47 deletions sdk/python/feast/infra/offline_stores/bigquery_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,15 +99,9 @@ def __eq__(self, other):
)

return (
self.name == other.name
and self.bigquery_options.table == other.bigquery_options.table
and self.bigquery_options.query == other.bigquery_options.query
and self.timestamp_field == other.timestamp_field
and self.created_timestamp_column == other.created_timestamp_column
and self.field_mapping == other.field_mapping
and self.description == other.description
and self.tags == other.tags
and self.owner == other.owner
super().__eq__(other)
and self.table == other.table
and self.query == other.query
)

@property
Expand All @@ -120,7 +114,6 @@ def query(self):

@staticmethod
def from_proto(data_source: DataSourceProto):

assert data_source.HasField("bigquery_options")

return BigQuerySource(
Expand All @@ -144,11 +137,10 @@ def to_proto(self) -> DataSourceProto:
description=self.description,
tags=self.tags,
owner=self.owner,
timestamp_field=self.timestamp_field,
created_timestamp_column=self.created_timestamp_column,
)

data_source_proto.timestamp_field = self.timestamp_field
data_source_proto.created_timestamp_column = self.created_timestamp_column

return data_source_proto

def validate(self, config: RepoConfig):
Expand Down Expand Up @@ -179,7 +171,7 @@ def get_table_column_names_and_types(
from google.cloud import bigquery

client = bigquery.Client()
if self.table is not None:
if self.table:
schema = client.get_table(self.table).schema
if not isinstance(schema[0], bigquery.schema.SchemaField):
raise TypeError("Could not parse BigQuery table schema.")
Expand All @@ -200,42 +192,14 @@ def get_table_column_names_and_types(

class BigQueryOptions:
"""
DataSource BigQuery options used to source features from BigQuery query
Configuration options for a BigQuery data source.
"""

def __init__(
self, table: Optional[str], query: Optional[str],
):
self._table = table
self._query = query

@property
def query(self):
"""
Returns the BigQuery SQL query referenced by this source
"""
return self._query

@query.setter
def query(self, query):
"""
Sets the BigQuery SQL query referenced by this source
"""
self._query = query

@property
def table(self):
"""
Returns the table ref of this BQ table
"""
return self._table

@table.setter
def table(self, table):
"""
Sets the table ref of this BQ table
"""
self._table = table
self.table = table or ""
self.query = query or ""

@classmethod
def from_proto(cls, bigquery_options_proto: DataSourceProto.BigQueryOptions):
Expand All @@ -248,7 +212,6 @@ def from_proto(cls, bigquery_options_proto: DataSourceProto.BigQueryOptions):
Returns:
Returns a BigQueryOptions object based on the bigquery_options protobuf
"""

bigquery_options = cls(
table=bigquery_options_proto.table, query=bigquery_options_proto.query,
)
Expand All @@ -262,7 +225,6 @@ def to_proto(self) -> DataSourceProto.BigQueryOptions:
Returns:
BigQueryOptionsProto protobuf
"""

bigquery_options_proto = DataSourceProto.BigQueryOptions(
table=self.table, query=self.query,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def pull_latest_from_table_or_query(
data_source: DataSource,
join_key_columns: List[str],
feature_name_columns: List[str],
event_timestamp_column: str,
timestamp_field: str,
created_timestamp_column: Optional[str],
start_date: datetime,
end_date: datetime,
Expand All @@ -68,7 +68,7 @@ def pull_latest_from_table_or_query(
partition_by_join_key_string = (
"PARTITION BY " + partition_by_join_key_string
)
timestamps = [event_timestamp_column]
timestamps = [timestamp_field]
if created_timestamp_column:
timestamps.append(created_timestamp_column)
timestamp_desc_string = " DESC, ".join(_append_alias(timestamps, "a")) + " DESC"
Expand All @@ -87,7 +87,7 @@ def pull_latest_from_table_or_query(
SELECT {a_field_string},
ROW_NUMBER() OVER({partition_by_join_key_string} ORDER BY {timestamp_desc_string}) AS _feast_row
FROM ({from_expression}) a
WHERE a."{event_timestamp_column}" BETWEEN '{start_date}'::timestamptz AND '{end_date}'::timestamptz
WHERE a."{timestamp_field}" BETWEEN '{start_date}'::timestamptz AND '{end_date}'::timestamptz
) b
WHERE _feast_row = 1
"""
Expand Down Expand Up @@ -191,15 +191,15 @@ def pull_all_from_table_or_query(
data_source: DataSource,
join_key_columns: List[str],
feature_name_columns: List[str],
event_timestamp_column: str,
timestamp_field: str,
start_date: datetime,
end_date: datetime,
) -> RetrievalJob:
assert isinstance(data_source, PostgreSQLSource)
from_expression = data_source.get_table_query_string()

field_string = ", ".join(
join_key_columns + feature_name_columns + [event_timestamp_column]
join_key_columns + feature_name_columns + [timestamp_field]
)

start_date = start_date.astimezone(tz=utc)
Expand All @@ -208,7 +208,7 @@ def pull_all_from_table_or_query(
query = f"""
SELECT {field_string}
FROM {from_expression}
WHERE "{event_timestamp_column}" BETWEEN '{start_date}'::timestamptz AND '{end_date}'::timestamptz
WHERE "{timestamp_field}" BETWEEN '{start_date}'::timestamptz AND '{end_date}'::timestamptz
"""

return PostgreSQLRetrievalJob(
Expand Down Expand Up @@ -415,9 +415,9 @@ def build_point_in_time_query(
1. We first join the current feature_view to the entity dataframe that has been passed.
This JOIN has the following logic:
- For each row of the entity dataframe, only keep the rows where the `event_timestamp_column`
- For each row of the entity dataframe, only keep the rows where the `timestamp_field`
is less than the one provided in the entity dataframe
- If there a TTL for the current feature_view, also keep the rows where the `event_timestamp_column`
- If there a TTL for the current feature_view, also keep the rows where the `timestamp_field`
is higher the the one provided minus the TTL
- For each row, Join on the entity key and retrieve the `entity_row_unique_id` that has been
computed previously
Expand All @@ -428,16 +428,16 @@ def build_point_in_time_query(
"{{ featureview.name }}__subquery" AS (
SELECT
"{{ featureview.event_timestamp_column }}" as event_timestamp,
"{{ featureview.timestamp_field }}" as event_timestamp,
{{ '"' ~ featureview.created_timestamp_column ~ '" as created_timestamp,' if featureview.created_timestamp_column else '' }}
{{ featureview.entity_selections | join(', ')}}{% if featureview.entity_selections %},{% else %}{% endif %}
{% for feature in featureview.features %}
"{{ feature }}" as {% if full_feature_names %}"{{ featureview.name }}__{{featureview.field_mapping.get(feature, feature)}}"{% else %}"{{ featureview.field_mapping.get(feature, feature) }}"{% endif %}{% if loop.last %}{% else %}, {% endif %}
{% endfor %}
FROM {{ featureview.table_subquery }} AS sub
WHERE "{{ featureview.event_timestamp_column }}" <= (SELECT MAX(entity_timestamp) FROM entity_dataframe)
WHERE "{{ featureview.timestamp_field }}" <= (SELECT MAX(entity_timestamp) FROM entity_dataframe)
{% if featureview.ttl == 0 %}{% else %}
AND "{{ featureview.event_timestamp_column }}" >= (SELECT MIN(entity_timestamp) FROM entity_dataframe) - {{ featureview.ttl }} * interval '1' second
AND "{{ featureview.timestamp_field }}" >= (SELECT MIN(entity_timestamp) FROM entity_dataframe) - {{ featureview.ttl }} * interval '1' second
{% endif %}
),
Expand Down
Loading

0 comments on commit 00ed65a

Please sign in to comment.