diff --git a/core/dbt/config/project.py b/core/dbt/config/project.py index 76a3980e26a..811023a2b08 100644 --- a/core/dbt/config/project.py +++ b/core/dbt/config/project.py @@ -37,7 +37,7 @@ Project as ProjectContract, SemverString, ) -from dbt.contracts.project import PackageConfig, ProjectPackageMetadata +from dbt.contracts.project import PackageConfig, ModuleConfig, ProjectPackageMetadata from dbt.dataclass_schema import ValidationError from .renderer import DbtProjectYamlRenderer, PackageRenderer from .selectors import ( @@ -75,6 +75,14 @@ {error} """ +MALFORMED_MODULE_ERROR = """\ +The modules.yml file in this project is malformed. Please double check +the contents of this file and fix any errors before retrying. + +Validator Error: +{error} +""" + MISSING_DBT_PROJECT_ERROR = """\ No dbt_project.yml found at expected path {path} Verify that each entry within packages.yml (and their transitive dependencies) contains a file named dbt_project.yml @@ -103,6 +111,16 @@ def package_data_from_root(project_root): return packages_dict +def module_data_from_root(project_root): + module_filepath = resolve_path_from_base("modules.yml", project_root) + + if path_exists(module_filepath): + modules_dict = _load_yaml(module_filepath) + else: + modules_dict = None + return modules_dict + + def package_config_from_data(packages_data: Dict[str, Any]): if not packages_data: packages_data = {"packages": []} @@ -115,6 +133,18 @@ def package_config_from_data(packages_data: Dict[str, Any]): return packages +def module_config_from_data(modules_data: Dict[str, Any]): + if not modules_data: + modules_data = {"modules": []} + + try: + ModuleConfig.validate(modules_data) + modules = ModuleConfig.from_dict(modules_data) + except ValidationError as e: + raise DbtProjectError(MALFORMED_MODULE_ERROR.format(error=str(e.message))) from e + return modules + + def _parse_versions(versions: Union[List[str], str]) -> List[VersionSpecifier]: """Parse multiple versions as read from disk. The versions value may be any one of: @@ -239,6 +269,7 @@ def _get_required_version( class RenderComponents: project_dict: Dict[str, Any] = field(metadata=dict(description="The project dictionary")) packages_dict: Dict[str, Any] = field(metadata=dict(description="The packages dictionary")) + modules_dict: Dict[str, Any] = field(metadata=dict(description="The modules dictionary")) selectors_dict: Dict[str, Any] = field(metadata=dict(description="The selectors dictionary")) @@ -273,11 +304,13 @@ def get_rendered( rendered_project = renderer.render_project(self.project_dict, self.project_root) rendered_packages = renderer.render_packages(self.packages_dict) + rendered_modules = renderer.render_modules(self.modules_dict) rendered_selectors = renderer.render_selectors(self.selectors_dict) return RenderComponents( project_dict=rendered_project, packages_dict=rendered_packages, + modules_dict=rendered_modules, selectors_dict=rendered_selectors, ) @@ -324,6 +357,7 @@ def create_project(self, rendered: RenderComponents) -> "Project": unrendered = RenderComponents( project_dict=self.project_dict, packages_dict=self.packages_dict, + modules_dict=self.modules_dict, selectors_dict=self.selectors_dict, ) dbt_version = _get_required_version( @@ -425,6 +459,7 @@ def create_project(self, rendered: RenderComponents) -> "Project": query_comment = _query_comment_from_cfg(cfg.query_comment) packages = package_config_from_data(rendered.packages_dict) + modules = module_config_from_data(rendered.modules_dict) selectors = selector_config_from_data(rendered.selectors_dict) manifest_selectors: Dict[str, Any] = {} if rendered.selectors_dict and rendered.selectors_dict["selectors"]: @@ -459,6 +494,7 @@ def create_project(self, rendered: RenderComponents) -> "Project": snapshots=snapshots, dbt_version=dbt_version, packages=packages, + modules=modules, manifest_selectors=manifest_selectors, selectors=selectors, query_comment=query_comment, @@ -481,6 +517,7 @@ def from_dicts( project_root: str, project_dict: Dict[str, Any], packages_dict: Dict[str, Any], + modules_dict: Dict[str, Any], selectors_dict: Dict[str, Any], *, verify_version: bool = False, @@ -495,6 +532,7 @@ def from_dicts( project_root=project_root, project_dict=project_dict, packages_dict=packages_dict, + modules_dict=modules_dict, selectors_dict=selectors_dict, verify_version=verify_version, ) @@ -506,12 +544,14 @@ def from_project_root( project_root = os.path.normpath(project_root) project_dict = load_raw_project(project_root) packages_dict = package_data_from_root(project_root) + modules_dict = module_data_from_root(project_root) selectors_dict = selector_data_from_root(project_root) return cls.from_dicts( project_root=project_root, project_dict=project_dict, selectors_dict=selectors_dict, packages_dict=packages_dict, + modules_dict=modules_dict, verify_version=verify_version, ) @@ -566,6 +606,7 @@ class Project: vars: VarProvider dbt_version: List[VersionSpecifier] packages: Dict[str, Any] + modules: Dict[str, Any] manifest_selectors: Dict[str, Any] selectors: SelectorConfig query_comment: QueryComment diff --git a/core/dbt/config/renderer.py b/core/dbt/config/renderer.py index 69361da18b7..c132a19307b 100644 --- a/core/dbt/config/renderer.py +++ b/core/dbt/config/renderer.py @@ -136,6 +136,10 @@ def render_packages(self, packages: Dict[str, Any]): package_renderer = self.get_package_renderer() return package_renderer.render_data(packages) + def render_modules(self, modules: Dict[str, Any]): + """Render the given modules dict""" + return self.render_data(modules) + def render_selectors(self, selectors: Dict[str, Any]): return self.render_data(selectors) diff --git a/core/dbt/config/runtime.py b/core/dbt/config/runtime.py index d58a9009922..1477786780e 100644 --- a/core/dbt/config/runtime.py +++ b/core/dbt/config/runtime.py @@ -159,6 +159,7 @@ def from_parts( snapshots=project.snapshots, dbt_version=project.dbt_version, packages=project.packages, + modules=project.modules, manifest_selectors=project.manifest_selectors, selectors=project.selectors, query_comment=project.query_comment, diff --git a/core/dbt/context/base.py b/core/dbt/context/base.py index 1ac9fc239f0..d8f2b79bb90 100644 --- a/core/dbt/context/base.py +++ b/core/dbt/context/base.py @@ -1,5 +1,7 @@ +import importlib import json import os +import sys from typing import Any, Dict, NoReturn, Optional, Mapping, Iterable, Set, List from dbt.flags import get_flags @@ -8,7 +10,9 @@ from dbt import utils from dbt.clients.jinja import get_rendered from dbt.clients.yaml_helper import yaml, safe_load, SafeLoader, Loader, Dumper # noqa: F401 +# from dbt.config.runtime import RuntimeConfig from dbt.constants import SECRET_ENV_PREFIX, DEFAULT_ENV_PLACEHOLDER +from dbt.contracts.project import Module from dbt.contracts.graph.nodes import Resource from dbt.exceptions import ( SecretEnvVarLocationError, @@ -75,7 +79,7 @@ def get_itertools_module_context() -> Dict[str, Any]: return {name: getattr(itertools, name) for name in context_exports} -def get_context_modules() -> Dict[str, Dict[str, Any]]: +def get_default_context_modules() -> Dict[str, Dict[str, Any]]: return { "pytz": get_pytz_module_context(), "datetime": get_datetime_module_context(), @@ -84,6 +88,30 @@ def get_context_modules() -> Dict[str, Dict[str, Any]]: } +def get_module_context(module: Module) -> Dict[str, Any]: + if module.location is not None: + sys.path.append(module.location) + + py_module = importlib.import_module(module.package) + + return {name: getattr(py_module, name) for name in module.exports} + + +def get_context_modules(modules: List[Module]) -> Dict[str, Dict[str, Any]]: + default_modules = get_default_context_modules() + custom_modules = { + module.package: get_module_context(module=module) + for module in modules + } + + # Overwrite default modules with custom modules if there are any + # conflicts, with the defaults kept for backwards compatibility + return { + **default_modules, + **custom_modules, + } + + class ContextMember: def __init__(self, value, name=None): self.name = name @@ -619,7 +647,10 @@ def modules(self) -> Dict[str, Any]: {% set dt_local = modules.pytz.timezone('US/Eastern').localize(dt) %} {{ dt_local }} """ # noqa - return get_context_modules() + if type(self).__name__ not in ("SecretContext", "TargetContext"): + return get_context_modules(modules=self.config.modules.modules) + + return get_default_context_modules() @contextproperty def flags(self) -> Any: diff --git a/core/dbt/contracts/project.py b/core/dbt/contracts/project.py index 581932e5888..0c0e276b59b 100644 --- a/core/dbt/contracts/project.py +++ b/core/dbt/contracts/project.py @@ -113,6 +113,22 @@ def validate(cls, data): super().validate(data) +@dataclass +class Module(Replaceable, HyphenatedDbtClassMixin): + package: str + exports: List[str] # __all__ is not allowed since not all modules implement this + location: Optional[str] = None # For local modules only (since we need to add them to the path) + + +@dataclass +class ModuleConfig(dbtClassMixin, Replaceable): + modules: List[Module] + + @classmethod + def validate(cls, *args, **kwargs): + pass + + @dataclass class ProjectPackageMetadata: name: str