From decb72b0f0ad28ab56da7bcde4cd701e8e0944fc Mon Sep 17 00:00:00 2001 From: Hal Ali Date: Thu, 19 Sep 2024 15:35:39 -0400 Subject: [PATCH] feat: Add new config parameter `private_key` (#260) # Description Add a config parameter `private_key`. This allows a user of target-snowflake to specify the `private_key` directly. This is useful in cases where it is difficult to save the private_key to a file but easy to pass in an environment variable (e.g. a container) --- README.md | 3 +- target_snowflake/connector.py | 80 +++++++++++++++++++++++++---------- target_snowflake/target.py | 15 ++++++- 3 files changed, 73 insertions(+), 25 deletions(-) diff --git a/README.md b/README.md index 36e22a6..028c6fc 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,8 @@ Built with the [Meltano Singer SDK](https://sdk.meltano.com). |:---------------------------|:---------|:------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| | user | True | None | The login name for your Snowflake user. | | password | False | None | The password for your Snowflake user. | -| private_key_path | False | None | Path to file containing private key. | +| private_key | False | None | The private key contents. For KeyPair authentication either private_key or private_key_path must be provided. | +| private_key_path | False | None | Path to file containing private key. For KeyPair authentication either private_key or private_key_path must be provided. | | private_key_passphrase | False | None | Passphrase to decrypt private key if encrypted. | | account | True | None | Your account identifier. See [Account Identifiers](https://docs.snowflake.com/en/user-guide/admin-account-identifier.html). | | database | True | None | The initial database for the Snowflake session. | diff --git a/target_snowflake/connector.py b/target_snowflake/connector.py index e2c4858..fcaadd5 100644 --- a/target_snowflake/connector.py +++ b/target_snowflake/connector.py @@ -1,6 +1,9 @@ from __future__ import annotations +from enum import Enum +from functools import cached_property from operator import contains, eq +from pathlib import Path from typing import TYPE_CHECKING, Any, Iterable, Sequence, cast import snowflake.sqlalchemy.custom_types as sct @@ -10,6 +13,7 @@ from singer_sdk import typing as th from singer_sdk.connectors import SQLConnector from singer_sdk.connectors.sql import FullyQualifiedName +from singer_sdk.exceptions import ConfigValidationError from snowflake.sqlalchemy import URL from snowflake.sqlalchemy.base import SnowflakeIdentifierPreparer from snowflake.sqlalchemy.snowdialect import SnowflakeDialect @@ -62,6 +66,14 @@ def prepare_part(self, part: str) -> str: return self.dialect.identifier_preparer.quote(part) +class SnowflakeAuthMethod(Enum): + """Supported methods to authenticate to snowflake""" + + BROWSER = 1 + PASSWORD = 2 + KEY_PAIR = 3 + + class SnowflakeConnector(SQLConnector): """Snowflake Target Connector. @@ -124,6 +136,47 @@ def _convert_type(sql_type): # noqa: ANN205, ANN001 return sql_type + def get_private_key(self): + """Get private key from the right location.""" + phrase = self.config.get("private_key_passphrase") + encoded_passphrase = phrase.encode() if phrase else None + if "private_key_path" in self.config: + with Path.open(self.config["private_key_path"], "rb") as key: + key_content = key.read() + else: + key_content = self.config["private_key"].encode() + + p_key = serialization.load_pem_private_key( + key_content, + password=encoded_passphrase, + backend=default_backend(), + ) + + return p_key.private_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + + @cached_property + def auth_method(self) -> SnowflakeAuthMethod: + """Validate & return the authentication method based on config.""" + if self.config.get("use_browser_authentication"): + return SnowflakeAuthMethod.BROWSER + + valid_auth_methods = {"private_key", "private_key_path", "password"} + config_auth_methods = [x for x in self.config if x in valid_auth_methods] + if len(config_auth_methods) != 1: + msg = ( + "Neither password nor private key was provided for " + "authentication. For password-less browser authentication via SSO, " + "set use_browser_authentication config option to True." + ) + raise ConfigValidationError(msg) + if config_auth_methods[0] in ["private_key", "private_key_path"]: + return SnowflakeAuthMethod.KEY_PAIR + return SnowflakeAuthMethod.PASSWORD + def get_sqlalchemy_url(self, config: dict) -> str: """Generates a SQLAlchemy URL for Snowflake. @@ -136,17 +189,10 @@ def get_sqlalchemy_url(self, config: dict) -> str: "database": config["database"], } - if config.get("use_browser_authentication"): + if self.auth_method == SnowflakeAuthMethod.BROWSER: params["authenticator"] = "externalbrowser" - elif "password" in config: + elif self.auth_method == SnowflakeAuthMethod.PASSWORD: params["password"] = config["password"] - elif "private_key_path" not in config: - msg = ( - "Neither password nor private_key_path was provided for " - "authentication. For password-less browser authentication via SSO, " - "set use_browser_authentication config option to True." - ) - raise Exception(msg) # noqa: TRY002 for option in ["warehouse", "role"]: if config.get(option): @@ -173,20 +219,8 @@ def create_engine(self) -> Engine: "QUOTED_IDENTIFIERS_IGNORE_CASE": "TRUE", }, } - if "private_key_path" in self.config: - with open(self.config["private_key_path"], "rb") as private_key_file: # noqa: PTH123 - private_key = serialization.load_pem_private_key( - private_key_file.read(), - password=self.config["private_key_passphrase"].encode() - if "private_key_passphrase" in self.config - else None, - backend=default_backend(), - ) - connect_args["private_key"] = private_key.private_bytes( - encoding=serialization.Encoding.DER, - format=serialization.PrivateFormat.PKCS8, - encryption_algorithm=serialization.NoEncryption(), - ) + if self.auth_method == SnowflakeAuthMethod.KEY_PAIR: + connect_args["private_key"] = self.get_private_key() engine = sqlalchemy.create_engine( self.sqlalchemy_url, connect_args=connect_args, diff --git a/target_snowflake/target.py b/target_snowflake/target.py index 79716a6..eca5d9c 100644 --- a/target_snowflake/target.py +++ b/target_snowflake/target.py @@ -30,11 +30,24 @@ class TargetSnowflake(SQLTarget): required=False, description="The password for your Snowflake user.", ), + th.Property( + "private_key", + th.StringType, + required=False, + secret=True, + description=( + "The private key contents. For KeyPair authentication either " + "private_key or private_key_path must be provided." + ), + ), th.Property( "private_key_path", th.StringType, required=False, - description="Path to file containing private key.", + description=( + "Path to file containing private key. For KeyPair authentication either " + "private_key or private_key_path must be provided." + ), ), th.Property( "private_key_passphrase",