-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(ingest/snowflake): Okta OAuth support; update docs #8157
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from enum import Enum | ||
from typing import List, Optional | ||
|
||
from pydantic import Field, SecretStr | ||
|
||
from datahub.configuration import ConfigModel | ||
|
||
|
||
class OAuthIdentityProvider(Enum): | ||
MICROSOFT = "microsoft" | ||
OKTA = "okta" | ||
|
||
|
||
class OAuthConfiguration(ConfigModel): | ||
provider: OAuthIdentityProvider = Field( | ||
description="Identity provider for oauth." | ||
"Supported providers are microsoft and okta." | ||
) | ||
authority_url: str = Field(description="Authority url of your identity provider") | ||
client_id: str = Field(description="client id of your registered application") | ||
scopes: List[str] = Field(description="scopes required to connect to snowflake") | ||
use_certificate: bool = Field( | ||
description="Do you want to use certificate and private key to authenticate using oauth", | ||
default=False, | ||
) | ||
client_secret: Optional[SecretStr] = Field( | ||
description="client secret of the application if use_certificate = false" | ||
) | ||
encoded_oauth_public_key: Optional[str] = Field( | ||
description="base64 encoded certificate content if use_certificate = true" | ||
) | ||
encoded_oauth_private_key: Optional[str] = Field( | ||
description="base64 encoded private key content if use_certificate = true" | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,14 +12,15 @@ | |
OAUTH_AUTHENTICATOR, | ||
) | ||
|
||
from datahub.configuration.common import AllowDenyPattern, OauthConfiguration | ||
from datahub.configuration.common import AllowDenyPattern | ||
from datahub.configuration.oauth import OAuthConfiguration, OAuthIdentityProvider | ||
from datahub.configuration.time_window_config import BaseTimeWindowConfig | ||
from datahub.configuration.validate_field_rename import pydantic_renamed_field | ||
from datahub.ingestion.source.snowflake.constants import ( | ||
CLIENT_PREFETCH_THREADS, | ||
CLIENT_SESSION_KEEP_ALIVE, | ||
) | ||
from datahub.ingestion.source.sql.oauth_generator import OauthTokenGenerator | ||
from datahub.ingestion.source.sql.oauth_generator import OAuthTokenGenerator | ||
from datahub.ingestion.source.sql.sql_config import ( | ||
SQLAlchemyConfig, | ||
make_sqlalchemy_uri, | ||
|
@@ -69,7 +70,7 @@ class BaseSnowflakeConfig(BaseTimeWindowConfig): | |
description="Password for your private key. Required if using key pair authentication with encrypted private key.", | ||
) | ||
|
||
oauth_config: Optional[OauthConfiguration] = pydantic.Field( | ||
oauth_config: Optional[OAuthConfiguration] = pydantic.Field( | ||
default=None, | ||
description="oauth configuration - https://docs.snowflake.com/en/user-guide/python-connector-example.html#connecting-with-oauth", | ||
) | ||
|
@@ -137,48 +138,36 @@ def authenticator_type_is_valid(cls, v, values, field): | |
f"At least one should be set when using {v} authentication" | ||
) | ||
elif v == "OAUTH_AUTHENTICATOR": | ||
if values.get("oauth_config") is None: | ||
raise ValueError( | ||
f"'oauth_config' is none but should be set when using {v} authentication" | ||
) | ||
if values.get("oauth_config").provider is None: | ||
raise ValueError( | ||
f"'oauth_config.provider' is none " | ||
f"but should be set when using {v} authentication" | ||
) | ||
if values.get("oauth_config").client_id is None: | ||
raise ValueError( | ||
f"'oauth_config.client_id' is none " | ||
f"but should be set when using {v} authentication" | ||
) | ||
if values.get("oauth_config").scopes is None: | ||
cls._check_oauth_config(values.get("oauth_config")) | ||
logger.info(f"using authenticator type '{v}'") | ||
return v | ||
|
||
@staticmethod | ||
def _check_oauth_config(oauth_config: Optional[OAuthConfiguration]) -> None: | ||
if oauth_config is None: | ||
raise ValueError( | ||
"'oauth_config' is none but should be set when using OAUTH_AUTHENTICATOR authentication" | ||
) | ||
if oauth_config.use_certificate is True: | ||
if oauth_config.provider == OAuthIdentityProvider.OKTA.value: | ||
raise ValueError( | ||
f"'oauth_config.scopes' was none " | ||
f"but should be set when using {v} authentication" | ||
"Certificate authentication is not supported for Okta." | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If this is always true for okta, it would help to move this to a validator in OAuthConfiguration . There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Pretty sure you can set up okta auth with public / private keys, I just couldn't get it to work easily with snowflake |
||
) | ||
if values.get("oauth_config").authority_url is None: | ||
if oauth_config.encoded_oauth_private_key is None: | ||
raise ValueError( | ||
f"'oauth_config.authority_url' was none " | ||
f"but should be set when using {v} authentication" | ||
"'base64_encoded_oauth_private_key' was none " | ||
"but should be set when using certificate for oauth_config" | ||
) | ||
if values.get("oauth_config").use_certificate is True: | ||
if values.get("oauth_config").encoded_oauth_private_key is None: | ||
raise ValueError( | ||
"'base64_encoded_oauth_private_key' was none " | ||
"but should be set when using certificate for oauth_config" | ||
) | ||
if values.get("oauth").encoded_oauth_public_key is None: | ||
raise ValueError( | ||
"'base64_encoded_oauth_public_key' was none" | ||
"but should be set when using use_certificate true for oauth_config" | ||
) | ||
elif values.get("oauth_config").client_secret is None: | ||
if oauth_config.encoded_oauth_public_key is None: | ||
raise ValueError( | ||
"'oauth_config.client_secret' was none " | ||
"but should be set when using use_certificate false for oauth_config" | ||
"'base64_encoded_oauth_public_key' was none" | ||
"but should be set when using use_certificate true for oauth_config" | ||
) | ||
logger.info(f"using authenticator type '{v}'") | ||
return v | ||
elif oauth_config.client_secret is None: | ||
raise ValueError( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can some of this validation logic live on the OAuthConfiguration object? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Probably, but I'm not sure what will be snowflake specific and what won't just yet. |
||
"'oauth_config.client_secret' was none " | ||
"but should be set when using use_certificate false for oauth_config" | ||
) | ||
|
||
@pydantic.validator("include_view_lineage") | ||
def validate_include_view_lineage(cls, v, values): | ||
|
@@ -297,14 +286,16 @@ def get_options(self) -> dict: | |
self.options["connect_args"] = options_connect_args | ||
return self.options | ||
|
||
def get_oauth_connection(self): | ||
def get_oauth_connection(self) -> snowflake.connector.SnowflakeConnection: | ||
assert ( | ||
self.oauth_config | ||
), "oauth_config should be provided if using oauth based authentication" | ||
generator = OauthTokenGenerator( | ||
self.oauth_config.client_id, | ||
self.oauth_config.authority_url, | ||
self.oauth_config.provider, | ||
generator = OAuthTokenGenerator( | ||
client_id=self.oauth_config.client_id, | ||
authority_url=self.oauth_config.authority_url, | ||
provider=self.oauth_config.provider, | ||
username=self.username, | ||
password=self.password, | ||
) | ||
if self.oauth_config.use_certificate: | ||
response = generator.get_token_with_certificate( | ||
|
@@ -313,11 +304,18 @@ def get_oauth_connection(self): | |
scopes=self.oauth_config.scopes, | ||
) | ||
else: | ||
assert self.oauth_config.client_secret | ||
response = generator.get_token_with_secret( | ||
secret=str(self.oauth_config.client_secret), | ||
secret=str(self.oauth_config.client_secret.get_secret_value()), | ||
scopes=self.oauth_config.scopes, | ||
) | ||
token = response["access_token"] | ||
try: | ||
token = response["access_token"] | ||
except KeyError: | ||
raise ValueError( | ||
f"access_token not found in response {response}. " | ||
"Please check your OAuth configuration." | ||
) | ||
connect_args = self.get_options()["connect_args"] | ||
return snowflake.connector.connect( | ||
user=self.username, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for updating these docs