Skip to content

Commit

Permalink
fix: Prevent error when extractor can't be imported
Browse files Browse the repository at this point in the history
Signed-off-by: Kacper Muda <[email protected]>
  • Loading branch information
kacpermuda committed May 21, 2024
1 parent a81504e commit 4d6c095
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 10 deletions.
21 changes: 11 additions & 10 deletions airflow/providers/openlineage/extractors/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,18 @@
# under the License.
from __future__ import annotations

from contextlib import suppress
from typing import TYPE_CHECKING, Iterator

from airflow.providers.openlineage import conf
from airflow.providers.openlineage.extractors import BaseExtractor, OperatorLineage
from airflow.providers.openlineage.extractors.base import DefaultExtractor
from airflow.providers.openlineage.extractors.bash import BashExtractor
from airflow.providers.openlineage.extractors.python import PythonExtractor
from airflow.providers.openlineage.utils.utils import get_unknown_source_attribute_run_facet
from airflow.providers.openlineage.utils.utils import (
get_unknown_source_attribute_run_facet,
try_import_from_string,
)
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.module_loading import import_string

if TYPE_CHECKING:
from openlineage.client.run import Dataset
Expand All @@ -35,11 +36,6 @@
from airflow.models import Operator


def try_import_from_string(string):
with suppress(ImportError):
return import_string(string)


def _iter_extractor_types() -> Iterator[type[BaseExtractor]]:
if PythonExtractor is not None:
yield PythonExtractor
Expand All @@ -61,10 +57,15 @@ def __init__(self):
self.extractors[operator_class] = extractor

for extractor_path in conf.custom_extractors():
extractor: type[BaseExtractor] = try_import_from_string(extractor_path)
extractor: type[BaseExtractor] | None = try_import_from_string(extractor_path)
if not extractor:
self.log.warning(
"OpenLineage is unable to import custom extractor `%s`; will ignore it.", extractor_path
)
continue
for operator_class in extractor.get_operator_classnames():
if operator_class in self.extractors:
self.log.debug(
self.log.warning(
"Duplicate OpenLineage custom extractor found for `%s`. "
"`%s` will be used instead of `%s`",
operator_class,
Expand Down
6 changes: 6 additions & 0 deletions airflow/providers/openlineage/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
)
from airflow.utils.context import AirflowContextDeprecationWarning
from airflow.utils.log.secrets_masker import Redactable, Redacted, SecretsMasker, should_hide_value_for_key
from airflow.utils.module_loading import import_string

if TYPE_CHECKING:
from airflow.models import DagRun, TaskInstance
Expand All @@ -52,6 +53,11 @@
_NOMINAL_TIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%fZ"


def try_import_from_string(string: str) -> Any:
with suppress(ImportError):
return import_string(string)


def get_operator_class(task: BaseOperator) -> type:
if task.__class__.__name__ in ("DecoratedMappedOperator", "MappedOperator"):
return task.operator_class
Expand Down

0 comments on commit 4d6c095

Please sign in to comment.