Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix adapter reset race condition in lib.py #5921

Merged
merged 3 commits into from
Sep 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .changes/unreleased/Fixes-20220923-174504.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
kind: Fixes
body: Fix race condition when invoking dbt via lib.py concurrently
time: 2022-09-23T17:45:04.405026-04:00
custom:
Author: drewbanin
Issue: "5919"
PR: "5921"
61 changes: 34 additions & 27 deletions core/dbt/lib.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
# TODO: this file is one big TODO
import os
from dbt.exceptions import RuntimeException
from dbt import flags
from collections import namedtuple
from dataclasses import dataclass

RuntimeArgs = namedtuple("RuntimeArgs", "project_dir profiles_dir single_threaded profile target")

@dataclass
class RuntimeArgs:
project_dir: str
profiles_dir: str
single_threaded: bool
profile: str
target: str


def get_dbt_config(project_dir, args=None, single_threaded=False):
Expand All @@ -17,27 +23,30 @@ def get_dbt_config(project_dir, args=None, single_threaded=False):
else:
profiles_dir = flags.DEFAULT_PROFILES_DIR

profile = args.profile if hasattr(args, "profile") else None
target = args.target if hasattr(args, "target") else None

# Construct a phony config
config = RuntimeConfig.from_args(
RuntimeArgs(project_dir, profiles_dir, single_threaded, profile, target)
runtime_args = RuntimeArgs(
project_dir=project_dir,
profiles_dir=profiles_dir,
single_threaded=single_threaded,
profile=getattr(args, "profile", None),
target=getattr(args, "target", None),
)
# Clear previously registered adapters--
# this fixes cacheing behavior on the dbt-server

# Construct a RuntimeConfig from phony args
config = RuntimeConfig.from_args(runtime_args)

# Set global flags from arguments
flags.set_from_args(args, config)
dbt.adapters.factory.reset_adapters()
# Load the relevant adapter

# This is idempotent, so we can call it repeatedly
dbt.adapters.factory.register_adapter(config)
# Set invocation id

# Make sure we have a valid invocation_id
dbt.events.functions.set_invocation_id()

return config


def get_task_by_type(type):
# TODO: we need to tell dbt-server what tasks are available
from dbt.task.run import RunTask
from dbt.task.list import ListTask
from dbt.task.seed import SeedTask
Expand Down Expand Up @@ -70,16 +79,13 @@ def create_task(type, args, manifest, config):
def no_op(*args, **kwargs):
pass

# TODO: yuck, let's rethink tasks a little
task = task(args, config)

# Wow! We can monkeypatch taskCls.load_manifest to return _our_ manifest
task.load_manifest = no_op
task.manifest = manifest
return task


def _get_operation_node(manifest, project_path, sql):
def _get_operation_node(manifest, project_path, sql, node_name):
from dbt.parser.manifest import process_node
from dbt.parser.sql import SqlBlockParser
import dbt.adapters.factory
Expand All @@ -92,26 +98,28 @@ def _get_operation_node(manifest, project_path, sql):
)

adapter = dbt.adapters.factory.get_adapter(config)
# TODO : This needs a real name?
sql_node = block_parser.parse_remote(sql, "name")
sql_node = block_parser.parse_remote(sql, node_name)
process_node(config, manifest, sql_node)
return config, sql_node, adapter


def compile_sql(manifest, project_path, sql):
def compile_sql(manifest, project_path, sql, node_name="query"):
from dbt.task.sql import SqlCompileRunner

config, node, adapter = _get_operation_node(manifest, project_path, sql)
config, node, adapter = _get_operation_node(manifest, project_path, sql, node_name)

runner = SqlCompileRunner(config, adapter, node, 1, 1)

return runner.safe_run(manifest)


def execute_sql(manifest, project_path, sql):
def execute_sql(manifest, project_path, sql, node_name="query"):
from dbt.task.sql import SqlExecuteRunner

config, node, adapter = _get_operation_node(manifest, project_path, sql)
config, node, adapter = _get_operation_node(manifest, project_path, sql, node_name)

runner = SqlExecuteRunner(config, adapter, node, 1, 1)
# TODO: use same interface for runner

return runner.safe_run(manifest)


Expand All @@ -128,5 +136,4 @@ def deserialize_manifest(manifest_msgpack):


def serialize_manifest(manifest):
# TODO: what should this take as an arg?
return manifest.to_msgpack()