Skip to content

Commit

Permalink
Support user-supplied profiles (#390)
Browse files Browse the repository at this point in the history
## Description

<!-- Add a brief but complete description of the change. -->

This PR extends the `ProfileConfig` interface to support two modes:
1. the profile mapping, which takes an Airflow connection and translates
it to a dbt profile
2. a user-supplied profile file

When using a profile mapping, you can do one of the following:

1. Use the AutomaticProfileMapping class to auto-select a profile
mapping

```python
ProfileConfig(
    profile_name="default",
    target_name="dev",
    profile_mapping=get_automatic_profile_mapping(
        conn_id="airflow_db",
        args={"schema": "public"},
    ),
)
```

2. Specify a specific profile mapping

```python
ProfileConfig(
    profile_name="default",
    target_name="dev",
    profile_=PostgresUserPasswordProfileMapping(
        conn_id="airflow_db",
        profile_args={"schema": "public"},
    ),
)
```

When using your own profile, you can supply the `path_to_profiles_yml`:

```python
ProfileConfig(
    profile_name="default",
    target_name="dev",
    path_to_profiles_yml=DBT_ROOT_PATH / "jaffle_shop" / "profiles.yml",
)
```

## Related Issue(s)

<!-- If this PR closes an issue, you can use a keyword to auto-close.
-->
<!-- i.e. "closes #0000" -->

closes #336
closes #269

## Breaking Change?

<!-- If this introduces a breaking change, specify that here. -->

This changes the `ProfileConfig` class, which isn't yet user-facing.

## Checklist

- [ ] I have made corresponding changes to the documentation (if
required)
- [x] I have added tests that prove my fix is effective or that my
feature works

---------

Co-authored-by: Harel Shein <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Harel Shein <[email protected]>
  • Loading branch information
4 people authored Jul 26, 2023
1 parent e1ff962 commit 3185e22
Show file tree
Hide file tree
Showing 36 changed files with 808 additions and 522 deletions.
72 changes: 65 additions & 7 deletions cosmos/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,21 @@
from __future__ import annotations

import shutil
import contextlib
import tempfile
from dataclasses import dataclass, field
from pathlib import Path
from logging import getLogger
from typing import Iterator

from cosmos.constants import TestBehavior, ExecutionMode, LoadMode
from cosmos.exceptions import CosmosValueError
from cosmos.profiles import BaseProfileMapping

logger = getLogger(__name__)

DEFAULT_PROFILES_FILE_NAME = "profiles.yml"


@dataclass
class RenderConfig:
Expand Down Expand Up @@ -66,13 +72,13 @@ def __post_init__(self) -> None:

def validate_project(self) -> None:
"Validates that the project, models, and seeds directories exist."
project_yml_path = self.dbt_project_path / "dbt_project.yml"
project_yml_path = Path(self.dbt_project_path) / "dbt_project.yml"
mandatory_paths = {
"dbt_project.yml": project_yml_path,
"models directory ": self.models_relative_path,
}
for name, path in mandatory_paths.items():
if path is None or not path.exists():
if path is None or not Path(path).exists():
raise CosmosValueError(f"Could not find {name} at {project_yml_path}")

def is_manifest_available(self) -> bool:
Expand All @@ -87,24 +93,76 @@ def is_manifest_available(self) -> bool:
@property
def project_name(self) -> str:
"The name of the dbt project."
return self.dbt_project_path.stem
return Path(self.dbt_project_path).stem


@dataclass
class ProfileConfig:
"""
Class for setting profile config.
Class for setting profile config. Supports two modes of operation:
1. Using a user-supplied profiles.yml file. If using this mode, set profiles_yml_filepath to the
path to the file.
2. Using cosmos to map Airflow connections to dbt profiles. If using this mode, set
profile_mapping to a subclass of BaseProfileMapping.
:param profile_name: The name of the dbt profile to use.
:param target_name: The name of the dbt target to use.
:param conn_id: The Airflow connection ID to use.
:param profiles_yml_filepath: The path to a profiles.yml file to use.
:param profile_mapping: A mapping of Airflow connections to dbt profiles.
"""

# should always be set to be explicit
profile_name: str
target_name: str
conn_id: str
profile_args: dict[str, str] = field(default_factory=dict)

# should be set if using a user-supplied profiles.yml
profiles_yml_filepath: Path | None = None

# should be set if using cosmos to map Airflow connections to dbt profiles
profile_mapping: BaseProfileMapping | None = None

def __post_init__(self) -> None:
"Validates that we have enough information to render a profile."
# if using a user-supplied profiles.yml, validate that it exists
if self.profiles_yml_filepath and not self.profiles_yml_filepath.exists():
raise CosmosValueError(f"The file {self.profiles_yml_filepath} does not exist.")

def validate_profile(self) -> None:
"Validates that we have enough information to render a profile."
if not self.profiles_yml_filepath and not self.profile_mapping:
raise CosmosValueError("Either profiles_yml_filepath or profile_mapping must be set to render a profile")

@contextlib.contextmanager
def ensure_profile(self, desired_profile_path: Path | None = None) -> Iterator[tuple[Path, dict[str, str]]]:
"Context manager to ensure that there is a profile. If not, create one."
if self.profiles_yml_filepath:
logger.info("Using user-supplied profiles.yml at %s", self.profiles_yml_filepath)
yield Path(self.profiles_yml_filepath), {}

elif self.profile_mapping:
profile_contents = self.profile_mapping.get_profile_file_contents(
profile_name=self.profile_name, target_name=self.target_name
)

if desired_profile_path:
logger.info(
"Writing profile to %s with the following contents:\n%s",
desired_profile_path,
profile_contents,
)
# write profile_contents to desired_profile_path using yaml library
desired_profile_path.write_text(profile_contents)
yield desired_profile_path, self.profile_mapping.env_vars
else:
with tempfile.TemporaryDirectory() as temp_dir:
temp_file = Path(temp_dir) / DEFAULT_PROFILES_FILE_NAME
logger.info(
"Creating temporary profiles.yml at %s with the following contents:\n%s",
temp_file,
profile_contents,
)
temp_file.write_text(profile_contents)
yield temp_file, self.profile_mapping.env_vars


@dataclass
Expand Down
18 changes: 10 additions & 8 deletions cosmos/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,6 @@ def __init__(
) -> None:
project_config.validate_project()

conn_id = profile_config.conn_id
profile_args = profile_config.profile_args
profile_name_override = profile_config.profile_name
target_name_override = profile_config.target_name
emit_datasets = render_config.emit_datasets
dbt_root_path = project_config.dbt_project_path.parent
dbt_project_name = project_config.dbt_project_path.name
Expand All @@ -123,6 +119,14 @@ def __init__(
manifest_path = project_config.parsed_manifest_path
dbt_executable_path = execution_config.dbt_executable_path

conn_id = "unknown"
if profile_config and profile_config.profile_mapping:
conn_id = profile_config.profile_mapping.conn_id

profile_args = {}
if profile_config.profile_mapping:
profile_args = profile_config.profile_mapping.profile_args

if not operator_args:
operator_args = {}

Expand All @@ -140,17 +144,15 @@ def __init__(
exclude=exclude,
select=select,
dbt_cmd=dbt_executable_path,
profile_config=profile_config,
)
dbt_graph.load(method=load_mode, execution_mode=execution_mode)

task_args = {
**operator_args,
"profile_args": profile_args,
"profile_name": profile_name_override,
"target_name": target_name_override,
# the following args may be only needed for local / venv:
"project_dir": dbt_project.dir,
"conn_id": conn_id,
"profile_config": profile_config,
}

validate_arguments(select, exclude, profile_args, task_args)
Expand Down
50 changes: 40 additions & 10 deletions cosmos/dbt/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import json
import logging
import os
import shutil
from dataclasses import dataclass, field
from pathlib import Path
from subprocess import Popen, PIPE
from typing import Any

from cosmos.config import ProfileConfig
from cosmos.constants import DbtResourceType, ExecutionMode, LoadMode
from cosmos.dbt.executable import get_system_dbt
from cosmos.dbt.parser.project import DbtProject as LegacyDbtProject
Expand Down Expand Up @@ -69,15 +71,21 @@ def __init__(
exclude: list[str] | None = None,
select: list[str] | None = None,
dbt_cmd: str = get_system_dbt(),
profile_config: ProfileConfig | None = None,
):
self.project = project
self.exclude = exclude or []
self.select = select or []
self.profile_config = profile_config

# specific to loading using ls
self.dbt_cmd = dbt_cmd

def load(self, method: LoadMode = LoadMode.AUTOMATIC, execution_mode: ExecutionMode = ExecutionMode.LOCAL) -> None:
def load(
self,
method: LoadMode = LoadMode.AUTOMATIC,
execution_mode: ExecutionMode = ExecutionMode.LOCAL,
) -> None:
"""
Load a `dbt` project into a `DbtGraph`, setting `nodes` and `filtered_nodes` accordingly.
Expand Down Expand Up @@ -123,27 +131,49 @@ def load_via_dbt_ls(self) -> None:
* self.filtered_nodes
"""
logger.info("Trying to parse the dbt project using dbt ls...")
command = [self.dbt_cmd, "ls", "--output", "json", "--profiles-dir", self.project.dir]

if not self.profile_config:
raise CosmosLoadDbtException("Unable to load dbt project without a profile config")

if not shutil.which(self.dbt_cmd):
raise CosmosLoadDbtException(f"Unable to find the dbt executable: {self.dbt_cmd}")

command = [self.dbt_cmd, "ls", "--output", "json"]

if self.exclude:
command.extend(["--exclude", *self.exclude])

if self.select:
command.extend(["--select", *self.select])
logger.info(f"Running command: {command}")
try:

with self.profile_config.ensure_profile() as (profile_path, env_vars):
command.extend(
[
"--profiles-dir",
str(profile_path.parent),
"--profile",
self.profile_config.profile_name,
"--target",
self.profile_config.target_name,
]
)

logger.info("Running command: `%s`", " ".join(command))
process = Popen(
command, # type: ignore[arg-type]
command,
stdout=PIPE,
stderr=PIPE,
cwd=self.project.dir,
universal_newlines=True,
env=os.environ,
env={
**os.environ,
**env_vars,
},
)
except FileNotFoundError as exception:
raise CosmosLoadDbtException(f"Unable to run the command due to the error:\n{exception}")

stdout, stderr = process.communicate()
stdout, stderr = process.communicate()

logger.debug(f"Output: {stdout}")
logger.debug("Output: %s", stdout)

if stderr or "Runtime Error" in stdout:
details = stderr or stdout
Expand Down
2 changes: 1 addition & 1 deletion cosmos/operators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class DbtBaseOperator(BaseOperator): # type: ignore[misc] # ignores subclass My
def __init__(
self,
project_dir: str,
conn_id: str,
conn_id: str | None = None,
base_cmd: list[str] | None = None,
select: str | None = None,
exclude: str | None = None,
Expand Down
Loading

0 comments on commit 3185e22

Please sign in to comment.