Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolve MyPy errors in Cosmos pre-commit #377

Merged
merged 5 commits into from
Jul 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,13 @@ repos:
- id: blacken-docs
alias: black
additional_dependencies: [black>=22.10.0]
#- repo: https://github.com/pre-commit/mirrors-mypy
# rev: 'v1.3.0'
# hooks:
# - id: mypy
# name: mypy-python-sdk
# additional_dependencies: [types-PyYAML, types-attrs, attrs, types-requests, types-python-dateutil]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: 'v1.3.0'
hooks:
- id: mypy
name: mypy-python-sdk
additional_dependencies: [types-PyYAML, types-attrs, attrs, types-requests, types-python-dateutil]
files: ^cosmos

ci:
autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks
Expand Down
9 changes: 3 additions & 6 deletions cosmos/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
# type: ignore # ignores "Cannot assign to a type" MyPy error

"""
Astronomer Cosmos is a library for rendering dbt workflows in Airflow.

Contains dags, task groups, and operators.
"""

__version__ = "0.7.5"

from cosmos.airflow.dag import DbtDag
from cosmos.airflow.task_group import DbtTaskGroup
from cosmos.constants import LoadMode, TestBehavior, ExecutionMode
from cosmos.dataset import get_dbt_dataset

from cosmos.operators.lazy_load import MissingPackage

from cosmos.operators.local import (
DbtDepsLocalOperator,
Expand All @@ -32,8 +33,6 @@
DbtTestDockerOperator,
)
except ImportError:
from cosmos.operators.lazy_load import MissingPackage

DbtLSDockerOperator = MissingPackage("cosmos.operators.docker.DbtLSDockerOperator", "docker")
DbtRunDockerOperator = MissingPackage("cosmos.operators.docker.DbtRunDockerOperator", "docker")
DbtRunOperationDockerOperator = MissingPackage(
Expand All @@ -54,8 +53,6 @@
DbtTestKubernetesOperator,
)
except ImportError:
from cosmos.operators.lazy_load import MissingPackage

DbtLSKubernetesOperator = MissingPackage(
"cosmos.operators.kubernetes.DbtLSKubernetesOperator",
"kubernetes",
Expand Down
5 changes: 3 additions & 2 deletions cosmos/airflow/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from cosmos.converter import airflow_kwargs, specific_kwargs, DbtToAirflowConverter


class DbtDag(DAG, DbtToAirflowConverter):
class DbtDag(DAG, DbtToAirflowConverter): # type: ignore[misc] # ignores subclass MyPy error
"""
Render a dbt project as an Airflow DAG.
"""
Expand All @@ -21,4 +21,5 @@ def __init__(
**kwargs: Any,
) -> None:
DAG.__init__(self, *args, **airflow_kwargs(**kwargs))
DbtToAirflowConverter.__init__(self, *args, dag=self, **specific_kwargs(**kwargs))
kwargs["dag"] = self
DbtToAirflowConverter.__init__(self, *args, **specific_kwargs(**kwargs))
18 changes: 10 additions & 8 deletions cosmos/airflow/graph.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import logging
from typing import Callable
from typing import Any, Callable

from airflow.models.dag import DAG
from airflow.utils.task_group import TaskGroup
Expand Down Expand Up @@ -42,15 +42,16 @@ def calculate_leaves(tasks_ids: list[str], nodes: dict[str, DbtNode]) -> list[st
parents = []
leaves = []
materialized_nodes = [node for node in nodes.values() if node.unique_id in tasks_ids]
[parents.extend(node.depends_on) for node in materialized_nodes]
for node in materialized_nodes:
parents.extend(node.depends_on)
parents_ids = set(parents)
for node in materialized_nodes:
if node.unique_id not in parents_ids:
leaves.append(node.unique_id)
return leaves


def create_task_metadata(node: DbtNode, execution_mode: ExecutionMode, args: dict) -> TaskMetadata:
def create_task_metadata(node: DbtNode, execution_mode: ExecutionMode, args: dict[str, Any]) -> TaskMetadata | None:
"""
Create the metadata that will be used to instantiate the Airflow Task used to run the Dbt node.

Expand Down Expand Up @@ -80,13 +81,14 @@ def create_task_metadata(node: DbtNode, execution_mode: ExecutionMode, args: dic
return task_metadata
else:
logger.error(f"Unsupported resource type {node.resource_type} (node {node.unique_id}).")
return None


def create_test_task_metadata(
test_task_name: str,
execution_mode: ExecutionMode,
task_args: dict,
on_warning_callback: callable,
task_args: dict[str, Any],
on_warning_callback: Callable[..., Any] | None = None,
model_name: str | None = None,
) -> TaskMetadata:
"""
Expand Down Expand Up @@ -118,12 +120,12 @@ def build_airflow_graph(
nodes: dict[str, DbtNode],
dag: DAG, # Airflow-specific - parent DAG where to associate tasks and (optional) task groups
execution_mode: ExecutionMode, # Cosmos-specific - decide what which class to use
task_args: dict[str, str], # Cosmos/DBT - used to instantiate tasks
task_args: dict[str, Any], # Cosmos/DBT - used to instantiate tasks
test_behavior: TestBehavior, # Cosmos-specific: how to inject tests to Airflow DAG
dbt_project_name: str, # DBT / Cosmos - used to name test task if mode is after_all,
conn_id: str, # Cosmos, dataset URI
task_group: TaskGroup | None = None,
on_warning_callback: Callable | None = None, # argument specific to the DBT test command
on_warning_callback: Callable[..., Any] | None = None, # argument specific to the DBT test command
emit_datasets: bool = True, # Cosmos
) -> None:
"""
Expand Down Expand Up @@ -191,7 +193,7 @@ def build_airflow_graph(
f"{dbt_project_name}_test", execution_mode, task_args=task_args, on_warning_callback=on_warning_callback
)
test_task = create_airflow_task(test_meta, dag, task_group=task_group)
leaves_ids = calculate_leaves(tasks_ids=tasks_map.keys(), nodes=nodes)
leaves_ids = calculate_leaves(tasks_ids=list(tasks_map.keys()), nodes=nodes)
for leaf_node_id in leaves_ids:
tasks_map[leaf_node_id] >> test_task

Expand Down
5 changes: 3 additions & 2 deletions cosmos/airflow/task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from cosmos.converter import airflow_kwargs, specific_kwargs, DbtToAirflowConverter


class DbtTaskGroup(TaskGroup, DbtToAirflowConverter):
class DbtTaskGroup(TaskGroup, DbtToAirflowConverter): # type: ignore[misc] # ignores subclass MyPy error
"""
Render a dbt project as an Airflow Task Group.
"""
Expand All @@ -21,4 +21,5 @@ def __init__(
) -> None:
group_id = kwargs.get("group_id", kwargs.get("dbt_project_name", "dbt_task_group"))
TaskGroup.__init__(self, group_id, *args, **airflow_kwargs(**kwargs))
DbtToAirflowConverter.__init__(self, *args, task_group=self, **specific_kwargs(**kwargs))
kwargs["task_group"] = self
DbtToAirflowConverter.__init__(self, *args, **specific_kwargs(**kwargs))
28 changes: 19 additions & 9 deletions cosmos/converter.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
# mypy: ignore-errors
# ignoring enum Mypy errors

from __future__ import annotations
from enum import Enum

import inspect
import logging
import pathlib
from enum import Enum
from typing import Any, Callable, Optional
from typing import Any, Callable

from airflow.exceptions import AirflowException
from airflow.models.dag import DAG
from airflow.utils.task_group import TaskGroup
from pathlib import Path

from cosmos.airflow.graph import build_airflow_graph
from cosmos.constants import ExecutionMode, LoadMode, TestBehavior
Expand Down Expand Up @@ -142,8 +145,8 @@ def __init__(
exclude: list[str] | None = None,
execution_mode: str | ExecutionMode = ExecutionMode.LOCAL,
load_mode: str | LoadMode = LoadMode.AUTOMATIC,
manifest_path: str | pathlib.Path | None = None,
on_warning_callback: Optional[Callable] = None,
manifest_path: str | Path | None = None,
on_warning_callback: Callable[..., Any] | None = None,
*args: Any,
**kwargs: Any,
) -> None:
Expand All @@ -154,12 +157,19 @@ def __init__(
execution_mode = convert_value_to_enum(execution_mode, ExecutionMode)
load_mode = convert_value_to_enum(load_mode, LoadMode)

test_behavior = convert_value_to_enum(test_behavior, TestBehavior)
execution_mode = convert_value_to_enum(execution_mode, ExecutionMode)
load_mode = convert_value_to_enum(load_mode, LoadMode)

if type(manifest_path) == str:
manifest_path = Path(manifest_path)

dbt_project = DbtProject(
name=dbt_project_name,
root_dir=dbt_root_path,
models_dir=dbt_models_dir,
seeds_dir=dbt_seeds_dir,
snapshots_dir=dbt_snapshots_dir,
root_dir=Path(dbt_root_path),
models_dir=Path(dbt_models_dir) if dbt_models_dir else None,
seeds_dir=Path(dbt_seeds_dir) if dbt_seeds_dir else None,
snapshots_dir=Path(dbt_snapshots_dir) if dbt_snapshots_dir else None,
manifest_path=manifest_path,
)

Expand Down
3 changes: 1 addition & 2 deletions cosmos/core/airflow.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import importlib
import logging
from typing import Optional

from airflow.models import BaseOperator
from airflow.models.dag import DAG
Expand All @@ -11,7 +10,7 @@
logger = logging.getLogger(__name__)


def get_airflow_task(task: Task, dag: DAG, task_group: Optional[TaskGroup] = None) -> BaseOperator:
def get_airflow_task(task: Task, dag: DAG, task_group: "TaskGroup | None" = None) -> BaseOperator:
"""
Get the Airflow Operator class for a Task.

Expand Down
15 changes: 9 additions & 6 deletions cosmos/dataset.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
from typing import Any, Tuple


try:
from airflow.datasets import Dataset
except ImportError:
except (ImportError, ModuleNotFoundError):
from logging import getLogger

logger = getLogger(__name__)

class Dataset:
class Dataset: # type: ignore[no-redef]
cosmos_override = True

def __init__(self, id: str, *args, **kwargs):
def __init__(self, id: str, *args: Tuple[Any], **kwargs: str):
self.id = id
logger.warning("Datasets are not supported in Airflow < 2.5.0")

def __eq__(self, other) -> bool:
return self.id == other.id
def __eq__(self, other: "Dataset") -> bool:
return bool(self.id == other.id)


def get_dbt_dataset(connection_id: str, project_name: str, model_name: str):
def get_dbt_dataset(connection_id: str, project_name: str, model_name: str) -> Dataset:
return Dataset(f"DBT://{connection_id.upper()}/{project_name.upper()}/{model_name.upper()}")


Expand Down
32 changes: 19 additions & 13 deletions cosmos/dbt/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import os
from dataclasses import dataclass, field
from pathlib import Path
from subprocess import Popen, PIPE
from typing import Any

Expand Down Expand Up @@ -34,9 +35,9 @@ class DbtNode:

name: str
unique_id: str
resource_type: str
resource_type: DbtResourceType
depends_on: list[str]
file_path: str
file_path: Path
tags: list[str] = field(default_factory=lambda: [])
config: dict[str, Any] = field(default_factory=lambda: {})

Expand Down Expand Up @@ -66,7 +67,7 @@ def __init__(
self,
project: DbtProject,
exclude: list[str] | None = None,
select: list[str] = None,
select: list[str] | None = None,
dbt_cmd: str = get_system_dbt(),
):
self.project = project
Expand Down Expand Up @@ -112,7 +113,7 @@ def load(self, method: LoadMode = LoadMode.AUTOMATIC, execution_mode: ExecutionM

load_method[method]()

def load_via_dbt_ls(self):
def load_via_dbt_ls(self) -> None:
"""
This is the most accurate way of loading `dbt` projects and filtering them out, since it uses the `dbt` command
line for both parsing and filtering the nodes.
Expand All @@ -130,7 +131,12 @@ def load_via_dbt_ls(self):
logger.info(f"Running command: {command}")
try:
process = Popen(
command, stdout=PIPE, stderr=PIPE, cwd=self.project.dir, universal_newlines=True, env=os.environ
command, # type: ignore[arg-type]
stdout=PIPE,
stderr=PIPE,
cwd=self.project.dir,
universal_newlines=True,
env=os.environ,
)
except FileNotFoundError as exception:
raise CosmosLoadDbtException(f"Unable to run the command due to the error:\n{exception}")
Expand Down Expand Up @@ -164,7 +170,7 @@ def load_via_dbt_ls(self):
self.nodes = nodes
self.filtered_nodes = nodes

def load_via_custom_parser(self):
def load_via_custom_parser(self) -> None:
"""
This is the least accurate way of loading `dbt` projects and filtering them out, since it uses custom Cosmos
logic, which is usually a subset of what is available in `dbt`.
Expand All @@ -177,11 +183,11 @@ def load_via_custom_parser(self):
* self.filtered_nodes
"""
logger.info("Trying to parse the dbt project using a custom Cosmos method...")

project = LegacyDbtProject(
dbt_root_path=self.project.root_dir,
dbt_models_dir=self.project.models_dir.stem,
dbt_snapshots_dir=self.project.snapshots_dir.stem,
dbt_seeds_dir=self.project.seeds_dir.stem,
dbt_root_path=str(self.project.root_dir),
dbt_models_dir=self.project.models_dir.stem if self.project.models_dir else None,
dbt_seeds_dir=self.project.seeds_dir.stem if self.project.seeds_dir else None,
project_name=self.project.name,
)
nodes = {}
Expand All @@ -192,7 +198,7 @@ def load_via_custom_parser(self):
name=model_name,
unique_id=model_name,
resource_type=DbtResourceType(model.type.value),
depends_on=model.config.upstream_models,
depends_on=list(model.config.upstream_models),
file_path=model.path,
tags=[],
config=config,
Expand All @@ -204,7 +210,7 @@ def load_via_custom_parser(self):
project_dir=self.project.dir, nodes=nodes, select=self.select, exclude=self.exclude
)

def load_from_dbt_manifest(self):
def load_from_dbt_manifest(self) -> None:
"""
This approach accurately loads `dbt` projects using the `manifest.yml` file.

Expand All @@ -217,7 +223,7 @@ def load_from_dbt_manifest(self):
"""
logger.info("Trying to parse the dbt project using a dbt manifest...")
nodes = {}
with open(self.project.manifest_path) as fp:
with open(self.project.manifest_path) as fp: # type: ignore[arg-type]
manifest = json.load(fp)

for unique_id, node_dict in manifest.get("nodes", {}).items():
Expand Down
Loading