diff --git a/sdk/python/feast/infra/utils/snowflake_utils.py b/sdk/python/feast/infra/utils/snowflake_utils.py index 3513daa878..a467a9de42 100644 --- a/sdk/python/feast/infra/utils/snowflake_utils.py +++ b/sdk/python/feast/infra/utils/snowflake_utils.py @@ -4,9 +4,11 @@ import string from logging import getLogger from tempfile import TemporaryDirectory -from typing import Dict, Iterator, List, Optional, Tuple, cast +from typing import Any, Dict, Iterator, List, Optional, Tuple, cast import pandas as pd +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import serialization from tenacity import ( retry, retry_if_exception_type, @@ -40,18 +42,17 @@ def execute_snowflake_statement(conn: SnowflakeConnection, query) -> SnowflakeCu def get_snowflake_conn(config, autocommit=True) -> SnowflakeConnection: - if config.type == "snowflake.offline": - config_header = "connections.feast_offline_store" + assert config.type == "snowflake.offline" + config_header = "connections.feast_offline_store" config_dict = dict(config) # read config file config_reader = configparser.ConfigParser() config_reader.read([config_dict["config_path"]]) + kwargs: Dict[str, Any] = {} if config_reader.has_section(config_header): kwargs = dict(config_reader[config_header]) - else: - kwargs = {} if "schema" in kwargs: kwargs["schema_"] = kwargs.pop("schema") @@ -67,6 +68,13 @@ def get_snowflake_conn(config, autocommit=True) -> SnowflakeConnection: else: kwargs["schema"] = '"PUBLIC"' + # https://docs.snowflake.com/en/user-guide/python-connector-example.html#using-key-pair-authentication-key-pair-rotation + # https://docs.snowflake.com/en/user-guide/key-pair-auth.html#configuring-key-pair-authentication + if "private_key" in kwargs: + kwargs["private_key"] = parse_private_key_path( + kwargs["private_key"], kwargs["private_key_passphrase"] + ) + try: conn = snowflake.connector.connect( application="feast", autocommit=autocommit, **kwargs @@ -288,3 +296,21 @@ def chunk_helper(lst: pd.DataFrame, n: int) -> Iterator[Tuple[int, pd.DataFrame] """Helper generator to chunk a sequence efficiently with current index like if enumerate was called on sequence.""" for i in range(0, len(lst), n): yield int(i / n), lst[i : i + n] + + +def parse_private_key_path(key_path: str, private_key_passphrase: str) -> bytes: + + with open(key_path, "rb") as key: + p_key = serialization.load_pem_private_key( + key.read(), + password=private_key_passphrase.encode(), + backend=default_backend(), + ) + + pkb = p_key.private_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + + return pkb