Skip to content

Commit

Permalink
Organise code better (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
ismailsimsek authored Jul 9, 2024
1 parent 732693a commit f1762f2
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 130 deletions.
1 change: 0 additions & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ jobs:
runs-on: macos-latest
strategy:
matrix:
# commented out to save from github limits
python-version: [ "3.8", "3.9", "3.10", "3.11" ]
dbt-version: [ "1.7", "1.8" ]
steps:
Expand Down
138 changes: 12 additions & 126 deletions opendbt/client.py
Original file line number Diff line number Diff line change
@@ -1,144 +1,30 @@
import importlib
from multiprocessing.context import SpawnContext
from typing import Optional

import dbt
from dbt.adapters.base.plugin import AdapterPlugin
from dbt.adapters.factory import FACTORY, Adapter
from dbt.cli.main import dbtRunner as DbtCliRunner
from dbt.cli.main import dbtRunnerResult
from dbt.contracts.results import RunResult
from dbt.exceptions import DbtRuntimeError
from dbt.version import get_installed_version as get_dbt_version
from packaging.version import Version

DBT_CUSTOM_ADAPTER_VAR = 'dbt_custom_adapter'
DBT_VERISON = get_dbt_version()

if Version(DBT_VERISON.to_version_string(skip_matcher=True)) > Version("1.8.0"):
try:
from dbt.adapters.contracts.connection import AdapterRequiredConfig
from dbt.adapters.events.types import (
AdapterRegistered,
)
from dbt_common.events.base_types import EventLevel
from dbt_common.events.functions import fire_event
except ImportError:
pass
else:
try:
from dbt.events.base_types import EventLevel
from dbt.events.functions import fire_event
from dbt.events.types import AdapterRegistered
from importlib import import_module
from dbt.events.functions import fire_event
from dbt.events.types import AdapterRegistered
from dbt.semver import VersionSpecifier
except ImportError:
pass


def get_custom_adapter_config_value(self, config: 'AdapterRequiredConfig') -> str:
# FIRST: it's set as cli value: dbt run --vars {'dbt_custom_adapter': 'custom_adapters.DuckDBAdapterV1Custom'}
if hasattr(config, 'cli_vars') and DBT_CUSTOM_ADAPTER_VAR in config.cli_vars:
custom_adapter_class_name: str = config.cli_vars[DBT_CUSTOM_ADAPTER_VAR]
if custom_adapter_class_name and custom_adapter_class_name.strip():
return custom_adapter_class_name
# SECOND: it's set inside dbt_project.yml
if hasattr(config, 'vars') and DBT_CUSTOM_ADAPTER_VAR in config.vars.to_dict():
custom_adapter_class_name: str = config.vars.to_dict()[DBT_CUSTOM_ADAPTER_VAR]
if custom_adapter_class_name and custom_adapter_class_name.strip():
return custom_adapter_class_name

return None


def get_custom_adapter_class_by_name(self, custom_adapter_class_name: str):
if "." not in custom_adapter_class_name:
raise ValueError(f"Unexpected adapter class name: `{custom_adapter_class_name}` ,"
f"Expecting something like:`my.sample.library.MyAdapterClass`")

__module, __class = custom_adapter_class_name.rsplit('.', 1)
try:
user_adapter_module = importlib.import_module(__module)
user_adapter_class = getattr(user_adapter_module, __class)
return user_adapter_class
except ModuleNotFoundError as mnfe:
raise Exception(f"Module of provided adapter not found, provided: {custom_adapter_class_name}") from mnfe


DBT_VERSION = get_dbt_version()
# ================================================================================================================
# Add further extension below, extend dbt using Monkey Patching!
# Monkey Patching! Override dbt lib AdapterContainer.register_adapter method with new one above
# ================================================================================================================
# dbt < 1.8
def register_adapter_v1(self, config: 'AdapterRequiredConfig') -> None:
# ==== CUSTOM CODE ====
# ==== END CUSTOM CODE ====
adapter_name = config.credentials.type
adapter_type = self.get_adapter_class_by_name(adapter_name)
adapter_version = import_module(f".{adapter_name}.__version__", "dbt.adapters").version
# ==== CUSTOM CODE ====
custom_adapter_class_name: str = self.get_custom_adapter_config_value(config)
if custom_adapter_class_name and custom_adapter_class_name.strip():
# OVERRIDE DEFAULT ADAPTER BY USER GIVEN ADAPTER CLASS
adapter_type = self.get_custom_adapter_class_by_name(custom_adapter_class_name)
# ==== END CUSTOM CODE ====
adapter_version_specifier = VersionSpecifier.from_version_string(
adapter_version
).to_version_string()
fire_event(
AdapterRegistered(adapter_name=adapter_name, adapter_version=adapter_version_specifier)
)
with self.lock:
if adapter_name in self.adapters:
# this shouldn't really happen...
return

adapter: Adapter = adapter_type(config) # type: ignore
self.adapters[adapter_name] = adapter


# dbt >=1.8
def register_adapter_v2(
self,
config: 'AdapterRequiredConfig',
mp_context: SpawnContext,
adapter_registered_log_level: Optional[EventLevel] = EventLevel.INFO,
) -> None:
adapter_name = config.credentials.type
adapter_type = self.get_adapter_class_by_name(adapter_name)
adapter_version = self._adapter_version(adapter_name)
# ==== CUSTOM CODE ====
custom_adapter_class_name: str = self.get_custom_adapter_config_value(config)
if custom_adapter_class_name and custom_adapter_class_name.strip():
# OVERRIDE DEFAULT ADAPTER BY USER GIVEN ADAPTER CLASS
adapter_type = self.get_custom_adapter_class_by_name(custom_adapter_class_name)
# ==== END CUSTOM CODE ====
fire_event(
AdapterRegistered(adapter_name=adapter_name, adapter_version=adapter_version),
level=adapter_registered_log_level,
)
with self.lock:
if adapter_name in self.adapters:
# this shouldn't really happen...
return
from opendbt import dbtcommon

adapter: Adapter = adapter_type(config, mp_context) # type: ignore
self.adapters[adapter_name] = adapter
# STEP-1 add new methods
dbt.adapters.factory.AdapterContainer.get_custom_adapter_config_value = dbtcommon.get_custom_adapter_config_value
dbt.adapters.factory.AdapterContainer.get_custom_adapter_class_by_name = dbtcommon.get_custom_adapter_class_by_name
# # STEP-2 override existing method
if Version(DBT_VERSION.to_version_string(skip_matcher=True)) > Version("1.8.0"):
from opendbt import dbt18


# ================================================================================================================
# Monkey Patching! Override dbt lib AdapterContainer.register_adapter method with new one above
# ================================================================================================================
# add new methods
dbt.adapters.factory.AdapterContainer.get_custom_adapter_config_value = get_custom_adapter_config_value
dbt.adapters.factory.AdapterContainer.get_custom_adapter_class_by_name = get_custom_adapter_class_by_name
# override existing method
if Version(DBT_VERISON.to_version_string(skip_matcher=True)) > Version("1.8.0"):
dbt.adapters.factory.AdapterContainer.register_adapter = register_adapter_v2
dbt.adapters.factory.AdapterContainer.register_adapter = dbt18.register_adapter
else:
dbt.adapters.factory.AdapterContainer.register_adapter = register_adapter_v1
from opendbt import dbt17

dbt.adapters.factory.AdapterContainer.register_adapter = dbt17.register_adapter

class OpenDbtCli:

Expand Down
33 changes: 33 additions & 0 deletions opendbt/dbt17.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from importlib import import_module

from dbt.adapters.factory import Adapter
from dbt.events.functions import fire_event
from dbt.events.types import AdapterRegistered
from dbt.semver import VersionSpecifier


def register_adapter(self, config: 'AdapterRequiredConfig') -> None:
# ==== CUSTOM CODE ====
# ==== END CUSTOM CODE ====
adapter_name = config.credentials.type
adapter_type = self.get_adapter_class_by_name(adapter_name)
adapter_version = import_module(f".{adapter_name}.__version__", "dbt.adapters").version
# ==== CUSTOM CODE ====
custom_adapter_class_name: str = self.get_custom_adapter_config_value(config)
if custom_adapter_class_name and custom_adapter_class_name.strip():
# OVERRIDE DEFAULT ADAPTER BY USER GIVEN ADAPTER CLASS
adapter_type = self.get_custom_adapter_class_by_name(custom_adapter_class_name)
# ==== END CUSTOM CODE ====
adapter_version_specifier = VersionSpecifier.from_version_string(
adapter_version
).to_version_string()
fire_event(
AdapterRegistered(adapter_name=adapter_name, adapter_version=adapter_version_specifier)
)
with self.lock:
if adapter_name in self.adapters:
# this shouldn't really happen...
return

adapter: Adapter = adapter_type(config) # type: ignore
self.adapters[adapter_name] = adapter
38 changes: 38 additions & 0 deletions opendbt/dbt18.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from multiprocessing.context import SpawnContext
from typing import Optional

from dbt.adapters.contracts.connection import AdapterRequiredConfig
from dbt.adapters.events.types import (
AdapterRegistered,
)
from dbt.adapters.factory import Adapter
from dbt_common.events.base_types import EventLevel
from dbt_common.events.functions import fire_event


def register_adapter(
self,
config: 'AdapterRequiredConfig',
mp_context: SpawnContext,
adapter_registered_log_level: Optional[EventLevel] = EventLevel.INFO,
) -> None:
adapter_name = config.credentials.type
adapter_type = self.get_adapter_class_by_name(adapter_name)
adapter_version = self._adapter_version(adapter_name)
# ==== CUSTOM CODE ====
custom_adapter_class_name: str = self.get_custom_adapter_config_value(config)
if custom_adapter_class_name and custom_adapter_class_name.strip():
# OVERRIDE DEFAULT ADAPTER BY USER GIVEN ADAPTER CLASS
adapter_type = self.get_custom_adapter_class_by_name(custom_adapter_class_name)
# ==== END CUSTOM CODE ====
fire_event(
AdapterRegistered(adapter_name=adapter_name, adapter_version=adapter_version),
level=adapter_registered_log_level,
)
with self.lock:
if adapter_name in self.adapters:
# this shouldn't really happen...
return

adapter: Adapter = adapter_type(config, mp_context) # type: ignore
self.adapters[adapter_name] = adapter
32 changes: 32 additions & 0 deletions opendbt/dbtcommon.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import importlib

DBT_CUSTOM_ADAPTER_VAR = 'dbt_custom_adapter'


def get_custom_adapter_config_value(self, config: 'AdapterRequiredConfig') -> str:
# FIRST: it's set as cli value: dbt run --vars {'dbt_custom_adapter': 'custom_adapters.DuckDBAdapterV1Custom'}
if hasattr(config, 'cli_vars') and DBT_CUSTOM_ADAPTER_VAR in config.cli_vars:
custom_adapter_class_name: str = config.cli_vars[DBT_CUSTOM_ADAPTER_VAR]
if custom_adapter_class_name and custom_adapter_class_name.strip():
return custom_adapter_class_name
# SECOND: it's set inside dbt_project.yml
if hasattr(config, 'vars') and DBT_CUSTOM_ADAPTER_VAR in config.vars.to_dict():
custom_adapter_class_name: str = config.vars.to_dict()[DBT_CUSTOM_ADAPTER_VAR]
if custom_adapter_class_name and custom_adapter_class_name.strip():
return custom_adapter_class_name

return None


def get_custom_adapter_class_by_name(self, custom_adapter_class_name: str):
if "." not in custom_adapter_class_name:
raise ValueError(f"Unexpected adapter class name: `{custom_adapter_class_name}` ,"
f"Expecting something like:`my.sample.library.MyAdapterClass`")

__module, __class = custom_adapter_class_name.rsplit('.', 1)
try:
user_adapter_module = importlib.import_module(__module)
user_adapter_class = getattr(user_adapter_module, __class)
return user_adapter_class
except ModuleNotFoundError as mnfe:
raise Exception(f"Module of provided adapter not found, provided: {custom_adapter_class_name}") from mnfe
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
'opendbt = opendbt:main',
],
},
version='0.2.0',
version='0.3.0',
packages=find_packages(),
author="Memiiso Organization",
description='Python opendbt',
Expand Down
4 changes: 2 additions & 2 deletions tests/test_custom_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
from packaging.version import Version

from opendbt import OpenDbtProject
from opendbt.client import DBT_VERISON
from opendbt.client import DBT_VERSION


class TestOpenDbtProject(TestCase):
RESOURCES_DIR = Path(__file__).parent.joinpath("resources")
DBTTEST_DIR = RESOURCES_DIR.joinpath("dbttest")

def test_run_with_custom_adapter(self):
if Version(DBT_VERISON.to_version_string(skip_matcher=True)) > Version("1.8.0"):
if Version(DBT_VERSION.to_version_string(skip_matcher=True)) > Version("1.8.0"):
dbt_custom_adapter = 'opendbt.examples.DuckDBAdapterV1Custom_afer_dbt18'
else:
dbt_custom_adapter = 'opendbt.examples.DuckDBAdapterV1Custom_before_dbt18'
Expand Down

0 comments on commit f1762f2

Please sign in to comment.