diff --git a/cosmos/profiles/base.py b/cosmos/profiles/base.py index 7ef17e57e..2db741132 100644 --- a/cosmos/profiles/base.py +++ b/cosmos/profiles/base.py @@ -58,21 +58,23 @@ def can_claim_connection(self) -> bool: if self.conn.conn_type != self.airflow_connection_type: return False - logger.info(dir(self.conn)) - logger.info(self.conn.__dict__) + generated_profile = self.profile for field in self.required_fields: - try: - if not getattr(self, field): + # if it's a secret field, check if we can get it + if field in self.secret_fields: + if not self.get_dbt_value(field): logger.info( - "1 Not using mapping %s because %s is not set", + "Not using mapping %s because %s is not set", self.__class__.__name__, field, ) return False - except AttributeError: + + # otherwise, check if it's in the generated profile + if not generated_profile.get(field): logger.info( - "2 Not using mapping %s because %s is not set", + "Not using mapping %s because %s is not set", self.__class__.__name__, field, ) @@ -159,23 +161,15 @@ def get_dbt_value(self, name: str) -> Any: # otherwise, we don't have it - return None return None - def __getattr__(self, name: str) -> Any: - "If the attribute doesn't exist, try to get it from profile_args or the Airflow connection." - logger.info( - "Couldn't find attribute %s on %s. Trying to get it from `profile_args` or the Airflow connection.", - name, - self.__class__.__name__, - ) - - # if it doesn't exist, try to get it from profile_args or the Airflow connection - attempted_value = self.get_dbt_value(name) - if attempted_value is not None: - return attempted_value - - raise AttributeError( - f"{self.__class__.__name__} has no attribute {name}. If this is a dbt profile field, " - f"ensure it's set either in the profile_args or the Airflow connection." - ) + @property + def mapped_params(self) -> dict[str, Any]: + "Turns the self.airflow_param_mapping into a dictionary of dbt fields and their values." + mapped_params = {} + + for dbt_field in self.airflow_param_mapping: + mapped_params[dbt_field] = self.get_dbt_value(dbt_field) + + return mapped_params @classmethod def filter_null(cls, args: dict[str, Any]) -> dict[str, Any]: diff --git a/cosmos/profiles/bigquery/service_account_file.py b/cosmos/profiles/bigquery/service_account_file.py index 980f07328..699ee0e11 100644 --- a/cosmos/profiles/bigquery/service_account_file.py +++ b/cosmos/profiles/bigquery/service_account_file.py @@ -32,11 +32,9 @@ class GoogleCloudServiceAccountFileProfileMapping(BaseProfileMapping): def profile(self) -> dict[str, Any | None]: "Generates profile. Defaults `threads` to 1." return { + **self.mapped_params, "type": "bigquery", "method": "service-account", - "project": self.project, - "dataset": self.dataset, - "threads": self.profile_args.get("threads") or 1, - "keyfile": self.keyfile, + "threads": 1, **self.profile_args, } diff --git a/cosmos/profiles/bigquery/service_account_keyfile_dict.py b/cosmos/profiles/bigquery/service_account_keyfile_dict.py index f5bd87e20..acf3f6c3f 100644 --- a/cosmos/profiles/bigquery/service_account_keyfile_dict.py +++ b/cosmos/profiles/bigquery/service_account_keyfile_dict.py @@ -19,15 +19,15 @@ class GoogleCloudServiceAccountDictProfileMapping(BaseProfileMapping): required_fields = [ "project", "dataset", - "keyfile_dict", + "keyfile_json", ] airflow_param_mapping = { "project": "extra.project", # multiple options for dataset because of older Airflow versions - "dataset": ["extra.dataset", "dataset"], + "dataset": "extra.dataset", # multiple options for keyfile_dict param name because of older Airflow versions - "keyfile_dict": ["extra.keyfile_dict", "keyfile_dict", "extra__google_cloud_platform__keyfile_dict"], + "keyfile_json": ["extra.keyfile_dict", "keyfile_dict", "extra__google_cloud_platform__keyfile_dict"], } @property @@ -38,11 +38,9 @@ def profile(self) -> dict[str, Any | None]: we generate a temporary file and the DBT profile uses it. """ return { + **self.mapped_params, "type": "bigquery", "method": "service-account-json", - "project": self.project, - "dataset": self.dataset, - "threads": self.profile_args.get("threads") or 1, - "keyfile_json": self.keyfile_dict, + "threads": 1, **self.profile_args, } diff --git a/cosmos/profiles/databricks/token.py b/cosmos/profiles/databricks/token.py index 75ed80e5c..ed057de52 100644 --- a/cosmos/profiles/databricks/token.py +++ b/cosmos/profiles/databricks/token.py @@ -38,10 +38,8 @@ class DatabricksTokenProfileMapping(BaseProfileMapping): def profile(self) -> dict[str, Any | None]: "Generates profile. The token is stored in an environment variable." return { + **self.mapped_params, "type": "databricks", - "schema": self.schema, - "host": self.host, - "http_path": self.http_path, **self.profile_args, # token should always get set as env var "token": self.get_env_var_format("token"), diff --git a/cosmos/profiles/exasol/user_pass.py b/cosmos/profiles/exasol/user_pass.py index eb89d2a6d..b951137c6 100644 --- a/cosmos/profiles/exasol/user_pass.py +++ b/cosmos/profiles/exasol/user_pass.py @@ -46,17 +46,8 @@ class ExasolUserPasswordProfileMapping(BaseProfileMapping): def profile(self) -> dict[str, Any | None]: "Gets profile. The password is stored in an environment variable." profile_vars = { + **self.mapped_params, "type": "exasol", - "threads": self.threads, - "dsn": self.dsn, - "user": self.user, - "dbname": self.dbname, - "schema": self.schema, - "encryption": self.conn.extra_dejson.get("encryption"), - "compression": self.conn.extra_dejson.get("compression"), - "connect_timeout": self.conn.extra_dejson.get("connection_timeout"), - "socket_timeout": self.conn.extra_dejson.get("socket_timeout"), - "protocol_version": self.conn.extra_dejson.get("protocol_version"), **self.profile_args, # password should always get set as env var "password": self.get_env_var_format("password"), diff --git a/cosmos/profiles/postgres/user_pass.py b/cosmos/profiles/postgres/user_pass.py index 983999900..28b2052c6 100644 --- a/cosmos/profiles/postgres/user_pass.py +++ b/cosmos/profiles/postgres/user_pass.py @@ -14,13 +14,11 @@ class PostgresUserPasswordProfileMapping(BaseProfileMapping): """ airflow_connection_type: str = "postgres" - default_port = 5432 required_fields = [ "host", "user", "password", - "port", "dbname", "schema", ] @@ -41,14 +39,9 @@ class PostgresUserPasswordProfileMapping(BaseProfileMapping): def profile(self) -> dict[str, Any | None]: "Gets profile. The password is stored in an environment variable." profile = { + **self.mapped_params, "type": "postgres", - "host": self.conn.host, - "user": self.conn.login, - "port": self.conn.port or self.default_port, - "dbname": self.dbname, - "schema": self.schema, - "keepalives_idle": self.conn.extra_dejson.get("keepalives_idle"), - "sslmode": self.conn.extra_dejson.get("sslmode"), + "port": 5432, **self.profile_args, # password should always get set as env var "password": self.get_env_var_format("password"), diff --git a/cosmos/profiles/redshift/user_pass.py b/cosmos/profiles/redshift/user_pass.py index 6547e8351..565896055 100644 --- a/cosmos/profiles/redshift/user_pass.py +++ b/cosmos/profiles/redshift/user_pass.py @@ -14,7 +14,6 @@ class RedshiftUserPasswordProfileMapping(BaseProfileMapping): """ airflow_connection_type: str = "redshift" - default_port = 5432 required_fields = [ "host", @@ -41,17 +40,12 @@ class RedshiftUserPasswordProfileMapping(BaseProfileMapping): def profile(self) -> dict[str, Any | None]: "Gets profile." profile = { + **self.mapped_params, "type": "redshift", - "host": self.host, - "user": self.user, - "password": self.get_env_var_format("password"), - "port": self.port or self.default_port, - "dbname": self.dbname, - "schema": self.schema, - "connection_timeout": self.conn.extra_dejson.get("timeout"), - "sslmode": self.conn.extra_dejson.get("sslmode"), - "region": self.conn.extra_dejson.get("region"), + "port": 5439, **self.profile_args, + # password should always get set as env var + "password": self.get_env_var_format("password"), } return self.filter_null(profile) diff --git a/cosmos/profiles/snowflake/user_pass.py b/cosmos/profiles/snowflake/user_pass.py index 306fa2547..ba2f0ea14 100644 --- a/cosmos/profiles/snowflake/user_pass.py +++ b/cosmos/profiles/snowflake/user_pass.py @@ -63,13 +63,8 @@ def conn(self) -> Connection: def profile(self) -> dict[str, Any | None]: "Gets profile." profile_vars = { + **self.mapped_params, "type": "snowflake", - "account": self.account, - "user": self.user, - "schema": self.schema, - "database": self.database, - "role": self.conn.extra_dejson.get("role"), - "warehouse": self.conn.extra_dejson.get("warehouse"), **self.profile_args, # password should always get set as env var "password": self.get_env_var_format("password"), diff --git a/cosmos/profiles/snowflake/user_privatekey.py b/cosmos/profiles/snowflake/user_privatekey.py index c41a3e63e..d5b7c4b7b 100644 --- a/cosmos/profiles/snowflake/user_privatekey.py +++ b/cosmos/profiles/snowflake/user_privatekey.py @@ -64,13 +64,8 @@ def conn(self) -> Connection: def profile(self) -> dict[str, Any | None]: "Gets profile." profile_vars = { + **self.mapped_params, "type": "snowflake", - "account": self.account, - "user": self.user, - "schema": self.schema, - "database": self.database, - "role": self.conn.extra_dejson.get("role"), - "warehouse": self.conn.extra_dejson.get("warehouse"), **self.profile_args, # private_key should always get set as env var "private_key_content": self.get_env_var_format("private_key_content"), diff --git a/cosmos/profiles/spark/thrift.py b/cosmos/profiles/spark/thrift.py index 942188247..2bf26f23d 100644 --- a/cosmos/profiles/spark/thrift.py +++ b/cosmos/profiles/spark/thrift.py @@ -23,6 +23,7 @@ class SparkThriftProfileMapping(BaseProfileMapping): airflow_param_mapping = { "host": "host", + "port": "port", } @property @@ -31,11 +32,9 @@ def profile(self) -> dict[str, Any | None]: Return a dbt Spark profile based on the Airflow Spark connection. """ profile_vars = { + **self.mapped_params, "type": "spark", "method": "thrift", - "schema": self.schema, - "host": self.host, - "port": self.conn.port, **self.profile_args, } diff --git a/cosmos/profiles/trino/base.py b/cosmos/profiles/trino/base.py index 8d4bf5572..929cf1788 100644 --- a/cosmos/profiles/trino/base.py +++ b/cosmos/profiles/trino/base.py @@ -26,15 +26,11 @@ class TrinoBaseProfileMapping(BaseProfileMapping): } @property - def profile(self) -> dict[str, Any | None]: + def profile(self) -> dict[str, Any]: "Gets profile." profile_vars = { + **self.mapped_params, "type": "trino", - "host": self.host, - "port": self.port, - "database": self.database, - "schema": self.schema, - "session_properties": self.conn.extra_dejson.get("session_properties"), **self.profile_args, } diff --git a/cosmos/profiles/trino/certificate.py b/cosmos/profiles/trino/certificate.py index eec934827..38ec7e45c 100644 --- a/cosmos/profiles/trino/certificate.py +++ b/cosmos/profiles/trino/certificate.py @@ -29,10 +29,9 @@ def profile(self) -> dict[str, Any | None]: "Gets profile." common_profile_vars = super().profile profile_vars = { + **self.mapped_params, **common_profile_vars, "method": "certificate", - "client_certificate": self.client_certificate, - "client_private_key": self.client_private_key, **self.profile_args, } diff --git a/cosmos/profiles/trino/ldap.py b/cosmos/profiles/trino/ldap.py index 3b6d4ff83..0c01908a0 100644 --- a/cosmos/profiles/trino/ldap.py +++ b/cosmos/profiles/trino/ldap.py @@ -34,11 +34,12 @@ def profile(self) -> dict[str, Any | None]: """ common_profile_vars = super().profile profile_vars = { + **self.mapped_params, **common_profile_vars, "method": "ldap", - "user": self.user, - "password": self.get_env_var_format("password"), **self.profile_args, + # password should always get set as env var + "password": self.get_env_var_format("password"), } # remove any null values diff --git a/tests/profiles/exasol/test_exasol_user_pass.py b/tests/profiles/exasol/test_exasol_user_pass.py index b229f9836..b2880c222 100644 --- a/tests/profiles/exasol/test_exasol_user_pass.py +++ b/tests/profiles/exasol/test_exasol_user_pass.py @@ -193,7 +193,7 @@ def test_dsn_formatting() -> None: with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): profile_mapping = ExasolUserPasswordProfileMapping(conn, {"schema": "my_schema", "threads": 1}) - assert profile_mapping.dsn == "my_host:1000" + assert profile_mapping.get_dbt_value("dsn") == "my_host:1000" # next, test with a host that doesn't include a port conn = Connection( @@ -207,7 +207,7 @@ def test_dsn_formatting() -> None: with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): profile_mapping = ExasolUserPasswordProfileMapping(conn, {"schema": "my_schema", "threads": 1}) - assert profile_mapping.dsn == "my_host:8563" # should default to 8563 + assert profile_mapping.get_dbt_value("dsn") == "my_host:8563" # should default to 8563 # lastly, test with a port override conn = Connection( @@ -222,4 +222,4 @@ def test_dsn_formatting() -> None: with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): profile_mapping = ExasolUserPasswordProfileMapping(conn, {"schema": "my_schema", "threads": 1}) - assert profile_mapping.dsn == "my_host:1000" + assert profile_mapping.get_dbt_value("dsn") == "my_host:1000" diff --git a/tests/profiles/postgres/test_pg_user_pass.py b/tests/profiles/postgres/test_pg_user_pass.py index 67cd541f7..db1a45701 100644 --- a/tests/profiles/postgres/test_pg_user_pass.py +++ b/tests/profiles/postgres/test_pg_user_pass.py @@ -48,7 +48,6 @@ def test_connection_claiming() -> None: "host": "my_host", "login": "my_user", "password": "my_password", - "port": 5432, "schema": "my_database", }