Skip to content

Commit

Permalink
Introduce support to override the profile name (#314)
Browse files Browse the repository at this point in the history
Allow a user to specify a `profile_name_override`. When specified, the profile Cosmos generates will be called the value passed.
This argument is available at the `DAG`, `TaskGroup` and `Operator` levels.
Closes: #266
  • Loading branch information
jlaneve authored Jun 9, 2023
1 parent 4e38f7b commit e976a4c
Show file tree
Hide file tree
Showing 8 changed files with 119 additions and 15 deletions.
1 change: 1 addition & 0 deletions cosmos/providers/dbt/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from pathlib import Path

DBT_PROFILE_PATH = Path(os.path.expanduser("~")).joinpath(".dbt/profiles.yml")
DEFAULT_DBT_PROFILE_NAME = "cosmos_profile"
54 changes: 39 additions & 15 deletions cosmos/providers/dbt/core/operators/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from airflow.utils.session import NEW_SESSION, provide_session
from sqlalchemy.orm import Session

from cosmos.providers.dbt.constants import DEFAULT_DBT_PROFILE_NAME
from cosmos.providers.dbt.core.operators.base import DbtBaseOperator
from cosmos.providers.dbt.core.profiles import get_profile_mapping
from cosmos.providers.dbt.core.utils.adapted_subprocesshook import (
Expand All @@ -35,6 +36,8 @@ class DbtLocalBaseOperator(DbtBaseOperator):
:param profile_args: Arguments to pass to the profile. See
:py:class:`cosmos.providers.dbt.core.profiles.BaseProfileMapping`.
:param profile_name: A name to use for the dbt profile. If not provided, and no profile target is found
in your project's dbt_project.yml, "cosmos_profile" is used.
:param install_deps: If true, install dependencies before running the command
:param callback: A callback function called on after a dbt run with a path to the dbt project directory.
"""
Expand All @@ -49,11 +52,13 @@ def __init__(
install_deps: bool = False,
callback: Optional[Callable[[str], None]] = None,
profile_args: dict[str, str] = {},
profile_name: str | None = None,
**kwargs,
) -> None:
self.install_deps = install_deps
self.profile_args = profile_args
self.callback = callback
self.profile_name = profile_name
self.compiled_sql = ""
super().__init__(**kwargs)

Expand Down Expand Up @@ -112,6 +117,38 @@ def store_compiled_sql(self, tmp_project_dir: str, context: Context, session: Se
def run_subprocess(self, *args, **kwargs):
return self.subprocess_hook.run_command(*args, **kwargs)

def get_profile_name(self, project_dir: str) -> str:
"""
Returns the profile name to use. Precedence is:
1. The profile name passed in to the operator
2. The profile name in the dbt_project.yml file
3. "cosmos_profile"
"""
if self.profile_name:
return self.profile_name

# get the profile name from the dbt_project.yml file
dbt_project_path = os.path.join(project_dir, "dbt_project.yml")

# if there's no dbt_project.yml file, we're not in a dbt project
# and need to raise an error
if not os.path.exists(dbt_project_path):
raise AirflowException(f"dbt project directory {self.project_dir} does not contain a dbt_project.yml file.")

# get file contents using path
dbt_project = yaml.safe_load(Path(dbt_project_path).read_text(encoding="utf-8")) or {}

profile_name = dbt_project.get("profile", DEFAULT_DBT_PROFILE_NAME)

if not isinstance(profile_name, str):
raise AirflowException(
f"dbt project directory {self.project_dir} contains a dbt_project.yml file, but the profile "
f"specified in the file is not a string. Please specify a string profile name, or pass a profile "
f"name to the operator."
)

return profile_name

def run_command(
self,
cmd: list[str],
Expand All @@ -135,20 +172,7 @@ def run_command(
tmp_project_dir,
)

# get the profile name from the dbt_project.yml file
dbt_project_path = os.path.join(tmp_project_dir, "dbt_project.yml")

# if there's no dbt_project.yml file, we're not in a dbt project
# and need to raise an error
if not os.path.exists(dbt_project_path):
raise AirflowException(
f"dbt project directory {self.project_dir} does not contain a dbt_project.yml file."
)

with open(dbt_project_path, encoding="utf-8") as f:
dbt_project = yaml.safe_load(f)

profile_name = dbt_project.get("profile")
profile_name = self.get_profile_name(tmp_project_dir)

# need to write the profile to a file because dbt requires a profile file
# and doesn't accept a profile as a string
Expand Down Expand Up @@ -178,7 +202,7 @@ def run_command(
logger.info("Trying to run the command:\n %s\nFrom %s", cmd, tmp_project_dir)

result = self.run_subprocess(
command=cmd,
command=cmd + ["--profile", profile_name],
env=env,
output_encoding=self.output_encoding,
cwd=tmp_project_dir,
Expand Down
6 changes: 6 additions & 0 deletions cosmos/providers/dbt/dag.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""
This module contains a function to render a dbt project as an Airflow DAG.
"""
from __future__ import annotations

try:
from typing import Literal
except ImportError:
Expand All @@ -24,6 +26,8 @@ class DbtDag(CosmosDag):
:param dbt_seeds_dir: The path to the dbt seeds directory within the project
:param conn_id: The Airflow connection ID to use for the dbt profile
:param profile_args: Arguments to pass to the dbt profile
:param profile_name_override: A name to use for the dbt profile. If not provided, and no profile target is found
in your project's dbt_project.yml, "cosmos_profile" is used.
:param dbt_args: Parameters to pass to the underlying dbt operators, can include dbt_executable_path to utilize venv
:param operator_args: Parameters to pass to the underlying operators, can include KubernetesPodOperator
or DockerOperator parameters
Expand All @@ -45,6 +49,7 @@ def __init__(
conn_id: str,
profile_args: Dict[str, str] = {},
dbt_args: Dict[str, Any] = {},
profile_name_override: str | None = None,
operator_args: Dict[str, Any] = {},
emit_datasets: bool = True,
dbt_root_path: str = "/usr/local/airflow/dags/dbt",
Expand Down Expand Up @@ -76,6 +81,7 @@ def __init__(
emit_datasets=emit_datasets,
conn_id=conn_id,
profile_args=profile_args,
profile_name=profile_name_override,
select=select,
exclude=exclude,
execution_mode=execution_mode,
Expand Down
7 changes: 7 additions & 0 deletions cosmos/providers/dbt/render.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""
This module contains a function to render a dbt project into Cosmos entities.
"""
from __future__ import annotations

import itertools
import logging

Expand Down Expand Up @@ -40,6 +42,7 @@ def render_project(
emit_datasets: bool = True,
conn_id: str = "default_conn_id",
profile_args: Dict[str, str] = {},
profile_name: str | None = None,
select: Dict[str, List[str]] = {},
exclude: Dict[str, List[str]] = {},
execution_mode: Literal["local", "docker", "kubernetes"] = "local",
Expand All @@ -58,6 +61,8 @@ def render_project(
:param emit_datasets: If enabled test nodes emit Airflow Datasets for downstream cross-DAG dependencies
:param conn_id: The Airflow connection ID to use
:param profile_args: Arguments to pass to the dbt profile
:param profile_name: A name to use for the dbt profile. If not provided, and no profile target is found
in your project's dbt_project.yml, "cosmos_profile" is used.
:param select: A dict of dbt selector arguments (i.e., {"tags": ["tag_1", "tag_2"]})
:param exclude: A dict of dbt exclude arguments (i.e., {"tags": ["tag_1", "tag_2]}})
:param execution_mode: The execution mode in which the dbt project should be run.
Expand Down Expand Up @@ -132,12 +137,14 @@ def render_project(
**operator_args,
"models": model_name,
"profile_args": profile_args,
"profile_name": profile_name,
}
test_args: Dict[str, Any] = {
**task_args,
**operator_args,
"models": model_name,
"profile_args": profile_args,
"profile_name": profile_name,
}
# DbtTestOperator specific arg
test_args["on_warning_callback"] = on_warning_callback
Expand Down
6 changes: 6 additions & 0 deletions cosmos/providers/dbt/task_group.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""
This module contains a function to render a dbt project as an Airflow Task Group.
"""
from __future__ import annotations

try:
from typing import Literal
except ImportError:
Expand All @@ -25,6 +27,8 @@ class DbtTaskGroup(CosmosTaskGroup):
:param dbt_seeds_dir: The path to the dbt seeds directory within the project
:param conn_id: The Airflow connection ID to use for the dbt profile
:param profile_args: Arguments to pass to the dbt profile
:param profile_name_override: A name to use for the dbt profile. If not provided, and no profile target is found
in your project's dbt_project.yml, "cosmos_profile" is used.
:param dbt_args: Parameters to pass to the underlying dbt operators, can include dbt_executable_path to utilize venv
:param operator_args: Parameters to pass to the underlying operators, can include KubernetesPodOperator
or DockerOperator parameters
Expand All @@ -45,6 +49,7 @@ def __init__(
dbt_project_name: str,
conn_id: str,
profile_args: Dict[str, str] = {},
profile_name_override: Optional[str] = None,
dbt_args: Dict[str, Any] = {},
operator_args: Dict[str, Any] = {},
emit_datasets: bool = True,
Expand Down Expand Up @@ -79,6 +84,7 @@ def __init__(
emit_datasets=emit_datasets,
conn_id=conn_id,
profile_args=profile_args,
profile_name=profile_name_override,
select=select,
exclude=exclude,
execution_mode=execution_mode,
Expand Down
1 change: 1 addition & 0 deletions dev/dags/basic_cosmos_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
profile_args={
"schema": "public",
},
profile_name_override="airflow",
# normal dag parameters
schedule_interval="@daily",
start_date=datetime(2023, 1, 1),
Expand Down
15 changes: 15 additions & 0 deletions docs/dbt/connections-profiles.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,21 @@ secret values are passed to dbt as environment variables with the following nami
For example, a Snowflake password field would be passed to dbt as an environment variable with the name
``COSMOS_CONN_SNOWFLAKE_PASSWORD``.

Profile Name
------------

By default, Cosmos will use the dbt profile name specified in your project's dbt_project.yml file. However, you can
override this by passing in a ``profile_name_override`` parameter to either ``DbtDag`` or ``DbtTaskGroup``. This is useful
if you have macros or other code that depends on the profile name. For example, to ensure we always use the profile name
``my_profile_name`` in the following example, we can pass in a ``profile_name_override`` parameter to ``DbtDag``:

.. code-block:: python
dag = DbtDag(profile_name_override="my_profile_name", ...)
If no profile name is specified, and there's no profile target in the dbt_project.yml file, Cosmos will use the
default profile name ``cosmos_profile``.


Available Profile Mappings
==========================
Expand Down
44 changes: 44 additions & 0 deletions tests/providers/dbt/core/operators/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,47 @@ def test_dbt_base_operator_get_env(p_context_to_airflow_vars: MagicMock) -> None
"START_DATE": "2023-02-15 12:30:00",
}
assert env == expected_env


@patch("os.path.exists")
def test_get_profile_name(mock_os_path_exists) -> None:
mock_os_path_exists.return_value = True

# check that a user-specified profile name is returned when specified
dbt_base_operator = DbtLocalBaseOperator(
conn_id="my_airflow_connection",
task_id="my-task",
project_dir="my/dir",
profile_name="default",
)
assert dbt_base_operator.get_profile_name("path/to/dir") == "default"

# check that the dbt_project profile name is returned when no user-specified profile name is specified
dbt_base_operator = DbtLocalBaseOperator(
conn_id="my_airflow_connection",
task_id="my-task",
project_dir="my/dir",
)
with patch("pathlib.Path.read_text") as mock_read_text:
mock_read_text.return_value = "profile: default"
assert dbt_base_operator.get_profile_name("path/to/dir") == "default"

# check that the default profile name is returned when no user-specified profile name is specified and no
# dbt_project profile name is specified
with patch("pathlib.Path.read_text") as mock_read_text:
mock_read_text.return_value = ""
assert dbt_base_operator.get_profile_name("path/to/dir") == "cosmos_profile"

mock_read_text.return_value = "other_config: other_value"
assert dbt_base_operator.get_profile_name("path/to/dir") == "cosmos_profile"

# test that we raise an AirflowException if the profile argument is not a string
with patch("pathlib.Path.read_text") as mock_read_text:
mock_read_text.return_value = "profile:\n my_key: my_value"
with pytest.raises(AirflowException):
dbt_base_operator.get_profile_name("path/to/dir")

# test that we raise an AirflowException if there's no dbt_project.yml file
mock_os_path_exists.return_value = False
with pytest.raises(AirflowException):
dbt_base_operator.get_profile_name("path/to/dir")

0 comments on commit e976a4c

Please sign in to comment.