Skip to content

Commit

Permalink
fix: Update redshift api (#2479)
Browse files Browse the repository at this point in the history
* Fix

Signed-off-by: Kevin Zhang <[email protected]>

* Fix

Signed-off-by: Kevin Zhang <[email protected]>

* Remove warning

Signed-off-by: Kevin Zhang <[email protected]>
  • Loading branch information
kevjumba authored Apr 4, 2022
1 parent 4864252 commit 4fa73a9
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 9 deletions.
3 changes: 3 additions & 0 deletions protos/feast/core/DataSource.proto
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ message DataSource {

// Redshift schema name
string schema = 3;

// Redshift database name
string database = 4;
}

// Defines options for DataSource that sources features from a Snowflake Query
Expand Down
47 changes: 39 additions & 8 deletions sdk/python/feast/infra/offline_stores/redshift_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(
description: Optional[str] = "",
tags: Optional[Dict[str, str]] = None,
owner: Optional[str] = "",
database: Optional[str] = "",
):
"""
Creates a RedshiftSource object.
Expand All @@ -47,11 +48,12 @@ def __init__(
tags (optional): A dictionary of key-value pairs to store arbitrary metadata.
owner (optional): The owner of the redshift source, typically the email of the primary
maintainer.
database (optional): The Redshift database name.
"""
# The default Redshift schema is named "public".
_schema = "public" if table and not schema else schema
self.redshift_options = RedshiftOptions(
table=table, schema=_schema, query=query
table=table, schema=_schema, query=query, database=database
)

if table is None and query is None:
Expand Down Expand Up @@ -102,6 +104,7 @@ def from_proto(data_source: DataSourceProto):
description=data_source.description,
tags=dict(data_source.tags),
owner=data_source.owner,
database=data_source.redshift_options.database,
)

# Note: Python requires redefining hash in child classes that override __eq__
Expand All @@ -119,6 +122,7 @@ def __eq__(self, other):
and self.redshift_options.table == other.redshift_options.table
and self.redshift_options.schema == other.redshift_options.schema
and self.redshift_options.query == other.redshift_options.query
and self.redshift_options.database == other.redshift_options.database
and self.event_timestamp_column == other.event_timestamp_column
and self.created_timestamp_column == other.created_timestamp_column
and self.field_mapping == other.field_mapping
Expand All @@ -139,9 +143,14 @@ def schema(self):

@property
def query(self):
"""Returns the Redshift options of this Redshift source."""
"""Returns the Redshift query of this Redshift source."""
return self.redshift_options.query

@property
def database(self):
"""Returns the Redshift database of this Redshift source."""
return self.redshift_options.database

def to_proto(self) -> DataSourceProto:
"""
Converts a RedshiftSource object to its protobuf representation.
Expand Down Expand Up @@ -197,12 +206,15 @@ def get_table_column_names_and_types(
assert isinstance(config.offline_store, RedshiftOfflineStoreConfig)

client = aws_utils.get_redshift_data_client(config.offline_store.region)

if self.table is not None:
try:
table = client.describe_table(
ClusterIdentifier=config.offline_store.cluster_id,
Database=config.offline_store.database,
Database=(
self.database
if self.database
else config.offline_store.database
),
DbUser=config.offline_store.user,
Table=self.table,
Schema=self.schema,
Expand All @@ -221,7 +233,7 @@ def get_table_column_names_and_types(
statement_id = aws_utils.execute_redshift_statement(
client,
config.offline_store.cluster_id,
config.offline_store.database,
self.database if self.database else config.offline_store.database,
config.offline_store.user,
f"SELECT * FROM ({self.query}) LIMIT 1",
)
Expand All @@ -238,11 +250,16 @@ class RedshiftOptions:
"""

def __init__(
self, table: Optional[str], schema: Optional[str], query: Optional[str]
self,
table: Optional[str],
schema: Optional[str],
query: Optional[str],
database: Optional[str],
):
self._table = table
self._schema = schema
self._query = query
self._database = database

@property
def query(self):
Expand Down Expand Up @@ -274,6 +291,16 @@ def schema(self, schema):
"""Sets the schema of this Redshift table."""
self._schema = schema

@property
def database(self):
"""Returns the schema name of this Redshift table."""
return self._database

@database.setter
def database(self, database):
"""Sets the database name of this Redshift table."""
self._database = database

@classmethod
def from_proto(cls, redshift_options_proto: DataSourceProto.RedshiftOptions):
"""
Expand All @@ -289,6 +316,7 @@ def from_proto(cls, redshift_options_proto: DataSourceProto.RedshiftOptions):
table=redshift_options_proto.table,
schema=redshift_options_proto.schema,
query=redshift_options_proto.query,
database=redshift_options_proto.database,
)

return redshift_options
Expand All @@ -301,7 +329,10 @@ def to_proto(self) -> DataSourceProto.RedshiftOptions:
A RedshiftOptionsProto protobuf.
"""
redshift_options_proto = DataSourceProto.RedshiftOptions(
table=self.table, schema=self.schema, query=self.query,
table=self.table,
schema=self.schema,
query=self.query,
database=self.database,
)

return redshift_options_proto
Expand All @@ -314,7 +345,7 @@ class SavedDatasetRedshiftStorage(SavedDatasetStorage):

def __init__(self, table_ref: str):
self.redshift_options = RedshiftOptions(
table=table_ref, schema=None, query=None
table=table_ref, schema=None, query=None, database=None
)

@staticmethod
Expand Down
5 changes: 4 additions & 1 deletion sdk/python/feast/templates/aws/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,17 @@ def bootstrap():

repo_path = pathlib.Path(__file__).parent.absolute()
config_file = repo_path / "feature_store.yaml"
driver_file = repo_path / "driver_repo.py"

replace_str_in_file(config_file, "%AWS_REGION%", aws_region)
replace_str_in_file(config_file, "%REDSHIFT_CLUSTER_ID%", cluster_id)
replace_str_in_file(config_file, "%REDSHIFT_DATABASE%", database)
replace_str_in_file(driver_file, "%REDSHIFT_DATABASE%", database)
replace_str_in_file(config_file, "%REDSHIFT_USER%", user)
replace_str_in_file(
config_file, "%REDSHIFT_S3_STAGING_LOCATION%", s3_staging_location
driver_file, config_file, "%REDSHIFT_S3_STAGING_LOCATION%", s3_staging_location
)
replace_str_in_file(config_file,)
replace_str_in_file(config_file, "%REDSHIFT_IAM_ROLE%", iam_role)


Expand Down
2 changes: 2 additions & 0 deletions sdk/python/feast/templates/aws/driver_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
# The (optional) created timestamp is used to ensure there are no duplicate
# feature rows in the offline store or when building training datasets
created_timestamp_column="created",
# Database to redshift source.
database="%REDSHIFT_DATABASE%",
)

# Feature views are a grouping based on how features are stored in either the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def create_data_source(
created_timestamp_column=created_timestamp_column,
date_partition_column="",
field_mapping=field_mapping or {"ts_1": "ts"},
database=self.offline_store_config.database,
)

def create_saved_dataset_destination(self) -> SavedDatasetRedshiftStorage:
Expand Down

0 comments on commit 4fa73a9

Please sign in to comment.