From 6c87bed66b63c4ffbac4dc44eda6276fdc5b2e6b Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Mon, 14 Oct 2019 08:41:41 -0600 Subject: [PATCH 01/13] Add manifest.expect method --- core/dbt/contracts/graph/manifest.py | 9 +++++++++ core/dbt/linker.py | 19 +++++++++---------- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/core/dbt/contracts/graph/manifest.py b/core/dbt/contracts/graph/manifest.py index c9f880c407c..374ebb0882f 100644 --- a/core/dbt/contracts/graph/manifest.py +++ b/core/dbt/contracts/graph/manifest.py @@ -415,6 +415,7 @@ def patch_nodes(self, patches): 'not found or is disabled').format(patch.name) ) + # TODO: why is this here? def __getattr__(self, name): raise AttributeError("'{}' object has no attribute '{}'".format( type(self).__name__, name) @@ -483,6 +484,14 @@ def to_dict(self, omit_none=True, validate=False): def write(self, path): self.writable_manifest().write(path) + def expect(self, unique_id: str) -> CompileResultNode: + if unique_id not in self.nodes: + # something terrible has happened + raise dbt.exceptions.InternalException( + 'Expected node {} not found in manifest'.format(unique_id) + ) + return self.nodes[unique_id] + @dataclass class WritableManifest(JsonSchemaMixin, Writable): diff --git a/core/dbt/linker.py b/core/dbt/linker.py index 3fe610c2c20..4a67b639719 100644 --- a/core/dbt/linker.py +++ b/core/dbt/linker.py @@ -1,8 +1,10 @@ -import networkx as nx from queue import PriorityQueue +from typing import Iterable, Optional +import networkx as nx import threading +from dbt.contracts.graph.manifest import Manifest from dbt.node_types import NodeType @@ -42,11 +44,8 @@ def __init__(self, graph, manifest): # populate the initial queue self._find_new_additions() - def get_node(self, node_id): - return self.manifest.nodes[node_id] - def _include_in_cost(self, node_id): - node = self.get_node(node_id) + node = self.manifest.expect(node_id) if not is_blocking_dependency(node): return False if node.get_materialization() == 'ephemeral': @@ -95,7 +94,7 @@ def get(self, block=True, timeout=None): _, node_id = self.inner.get(block=block, timeout=timeout) with self.lock: self._mark_in_progress(node_id) - return self.get_node(node_id) + return self.manifest.expect(node_id) def __len__(self): """The length of the queue is the number of tasks left for the queue to @@ -167,7 +166,6 @@ def join(self): """ self.inner.join() - def _subset_graph(graph, include_nodes): """Create and return a new graph that is a shallow copy of graph but with only the nodes in include_nodes. Transitive edges across removed nodes are @@ -189,7 +187,6 @@ def _subset_graph(graph, include_nodes): ) return new_graph - class Linker: def __init__(self, data=None): if data is None: @@ -216,7 +213,9 @@ def find_cycles(self): return None - def as_graph_queue(self, manifest, limit_to=None): + def as_graph_queue( + self, manifest: Manifest, limit_to: Optional[Iterable[str]] = None + ) -> GraphQueue: """Returns a queue over nodes in the graph that tracks progress of dependecies. """ @@ -259,6 +258,6 @@ def read_graph(self, infile): def _updated_graph(graph, manifest): graph = graph.copy() for node_id in graph.nodes(): - data = manifest.nodes[node_id].to_dict() + data = manifest.expect(node_id).to_dict() graph.add_node(node_id, **data) return graph From d86092ae78e147e04fdc3807ee1de4ac9e00e2d1 Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Mon, 14 Oct 2019 08:53:29 -0600 Subject: [PATCH 02/13] loader.GraphLoader -> parser.manifest.ManifestLoader Create new helper function dbt.perf_utils.get_full_manifest Update task.runnable accordingly Update RPC server accordingly --- core/dbt/adapters/base/impl.py | 4 ++-- core/dbt/compilation.py | 1 - core/dbt/linker.py | 2 ++ core/dbt/{loader.py => parser/manifest.py} | 13 +++++++++++- core/dbt/perf_utils.py | 19 +++++++++++++++++ core/dbt/task/generate.py | 3 +-- core/dbt/task/rpc_server.py | 14 ++++++------- core/dbt/task/runnable.py | 21 ++++--------------- test/unit/test_graph.py | 8 +++---- ...{test_loader.py => test_parse_manifest.py} | 6 +++--- test/unit/test_postgres_adapter.py | 3 ++- test/unit/test_snowflake_adapter.py | 3 ++- tox.ini | 2 +- 13 files changed, 58 insertions(+), 41 deletions(-) rename core/dbt/{loader.py => parser/manifest.py} (97%) create mode 100644 core/dbt/perf_utils.py rename test/unit/{test_loader.py => test_parse_manifest.py} (96%) diff --git a/core/dbt/adapters/base/impl.py b/core/dbt/adapters/base/impl.py index bdac9ea40f8..0eea68de714 100644 --- a/core/dbt/adapters/base/impl.py +++ b/core/dbt/adapters/base/impl.py @@ -16,7 +16,7 @@ from dbt.config import RuntimeConfig from dbt.contracts.graph.manifest import Manifest from dbt.node_types import NodeType -from dbt.loader import GraphLoader +from dbt.parser.manifest import load_internal_manifest from dbt.logger import GLOBAL_LOGGER as logger from dbt.utils import filter_null_values @@ -280,7 +280,7 @@ def check_internal_manifest(self) -> Optional[Manifest]: def load_internal_manifest(self) -> Manifest: if self._internal_manifest_lazy is None: - manifest = GraphLoader.load_internal(self.config) + manifest = load_internal_manifest(self.config) self._internal_manifest_lazy = manifest return self._internal_manifest_lazy diff --git a/core/dbt/compilation.py b/core/dbt/compilation.py index 1a582ac3f92..d603754c2db 100644 --- a/core/dbt/compilation.py +++ b/core/dbt/compilation.py @@ -13,7 +13,6 @@ import dbt.contracts.project import dbt.exceptions import dbt.flags -import dbt.loader import dbt.config from dbt.contracts.graph.compiled import InjectedCTE, COMPILED_TYPES from dbt.contracts.graph.parsed import ParsedNode diff --git a/core/dbt/linker.py b/core/dbt/linker.py index 4a67b639719..fca7bdd270e 100644 --- a/core/dbt/linker.py +++ b/core/dbt/linker.py @@ -166,6 +166,7 @@ def join(self): """ self.inner.join() + def _subset_graph(graph, include_nodes): """Create and return a new graph that is a shallow copy of graph but with only the nodes in include_nodes. Transitive edges across removed nodes are @@ -187,6 +188,7 @@ def _subset_graph(graph, include_nodes): ) return new_graph + class Linker: def __init__(self, data=None): if data is None: diff --git a/core/dbt/loader.py b/core/dbt/parser/manifest.py similarity index 97% rename from core/dbt/loader.py rename to core/dbt/parser/manifest.py index 112ce124212..4324a54626f 100644 --- a/core/dbt/loader.py +++ b/core/dbt/parser/manifest.py @@ -82,7 +82,7 @@ def make_parse_result( ) -class GraphLoader: +class ManifestLoader: def __init__( self, root_project: RuntimeConfig, all_projects: Mapping[str, Project] ) -> None: @@ -292,6 +292,7 @@ def load_all( loader.write_parse_results() manifest = loader.create_manifest() _check_manifest(manifest, root_config) + manifest.build_flat_graph() return manifest @classmethod @@ -388,3 +389,13 @@ def load_all_projects(config) -> Mapping[str, Project]: def load_internal_projects(config): return dict(_load_projects(config, internal_project_names())) + + +def load_internal_manifest(config: RuntimeConfig) -> Manifest: + return ManifestLoader.load_internal(config) + + +def load_manifest( + config: RuntimeConfig, internal_manifest: Optional[Manifest] +) -> Manifest: + return ManifestLoader.load_all(config, internal_manifest) diff --git a/core/dbt/perf_utils.py b/core/dbt/perf_utils.py new file mode 100644 index 00000000000..273c9a1a3dc --- /dev/null +++ b/core/dbt/perf_utils.py @@ -0,0 +1,19 @@ +"""A collection of performance-enhancing functions that have to know just a +little bit too much to go anywhere else. +""" +from dbt.adapters.factory import get_adapter +from dbt.parser.manifest import load_manifest +from dbt.contracts.graph.manifest import Manifest +from dbt.config import RuntimeConfig + + +def get_full_manifest(config: RuntimeConfig) -> Manifest: + """Load the full manifest, using the adapter's internal manifest if it + exists to skip parsing internal (dbt + plugins) macros a second time. + + Also, make sure that we force-laod the adapter's manifest, so it gets + attached to the adapter for any methods that need it. + """ + adapter = get_adapter(config) # type: ignore + internal: Manifest = adapter.load_internal_manifest() + return load_manifest(config, internal) diff --git a/core/dbt/task/generate.py b/core/dbt/task/generate.py index 13a7a801dee..4d4ced1043c 100644 --- a/core/dbt/task/generate.py +++ b/core/dbt/task/generate.py @@ -179,7 +179,6 @@ def _coerce_decimal(value): class GenerateTask(CompileTask): def _get_manifest(self) -> Manifest: - # manifest = dbt.loader.GraphLoader.load_all(self.config) return self.manifest def run(self): @@ -215,7 +214,7 @@ def run(self): path = os.path.join(self.config.target_path, CATALOG_FILENAME) results.write(path) - write_manifest(self.manifest, self.config) + write_manifest(self.config, self.manifest) dbt.ui.printer.print_timestamped_line( 'Catalog written to {}'.format(os.path.abspath(path)) diff --git a/core/dbt/task/rpc_server.py b/core/dbt/task/rpc_server.py index e43eaf2f0cd..a7e4c4dea41 100644 --- a/core/dbt/task/rpc_server.py +++ b/core/dbt/task/rpc_server.py @@ -16,11 +16,12 @@ log_manager, ) from dbt.task.base import ConfiguredTask -from dbt.task.compile import CompileTask from dbt.task.remote import RPCTask from dbt.utils import ForgivingJSONEncoder, env_set_truthy -from dbt import rpc +from dbt.rpc.response_manager import ResponseManager +from dbt.rpc.task_manager import TaskManager from dbt.rpc.logger import ServerContext, HTTPRequest, RPCResponse +from dbt.perf_utils import get_full_manifest SINGLE_THREADED_WEBSERVER = env_set_truthy('DBT_SINGLE_THREADED_WEBSERVER') @@ -34,11 +35,8 @@ def reload_manager(task_manager, tasks): logs = [] try: - compile_task = CompileTask(task_manager.args, task_manager.config) with list_handler(logs): - compile_task.run() - manifest = compile_task.manifest - manifest.build_flat_graph() + manifest = get_full_manifest(task_manager.config) for cls in tasks: task_manager.add_task_handler(cls, manifest) @@ -99,7 +97,7 @@ def __init__(self, args, config, tasks=None): ) super().__init__(args, config) self._tasks = tasks or self._default_tasks() - self.task_manager = rpc.TaskManager(self.args, self.config) + self.task_manager = TaskManager(self.args, self.config) self._reloader = None self._reload_task_manager() signal.signal(signal.SIGHUP, self._sighup_handler) @@ -189,7 +187,7 @@ def run(self): @Request.application def handle_jsonrpc_request(self, request): with HTTPRequest(request): - jsonrpc_response = rpc.ResponseManager.handle( + jsonrpc_response = ResponseManager.handle( request, self.task_manager ) json_data = json.dumps( diff --git a/core/dbt/task/runnable.py b/core/dbt/task/runnable.py index 50ad7881b1f..855e6ba9ed5 100644 --- a/core/dbt/task/runnable.py +++ b/core/dbt/task/runnable.py @@ -17,7 +17,7 @@ ) from dbt.compilation import compile_manifest from dbt.contracts.results import ExecutionResult -from dbt.loader import GraphLoader +from dbt.perf_utils import get_full_manifest import dbt.exceptions import dbt.flags @@ -31,25 +31,11 @@ RUNNING_STATE = DbtProcessState('running') -def write_manifest(manifest, config): +def write_manifest(config, manifest): if dbt.flags.WRITE_JSON: manifest.write(os.path.join(config.target_path, MANIFEST_FILE_NAME)) -def load_manifest(config): - # performance trick: if the adapter has a manifest loaded, use that to - # avoid parsing internal macros twice. Also, when loading the adapter's - # manifest, load the internal manifest to avoid running the graph laoder - # twice. - adapter = get_adapter(config) - - internal = adapter.load_internal_manifest() - manifest = GraphLoader.load_all(config, internal_manifest=internal) - - write_manifest(manifest, config) - return manifest - - class ManifestTask(ConfiguredTask): def __init__(self, args, config): super().__init__(args, config) @@ -57,7 +43,8 @@ def __init__(self, args, config): self.linker = None def load_manifest(self): - self.manifest = load_manifest(self.config) + self.manifest = get_full_manifest(self.config) + write_manifest(self.config, self.manifest) def compile_manifest(self): self.linker = compile_manifest(self.config, self.manifest) diff --git a/test/unit/test_graph.py b/test/unit/test_graph.py index 519c150cd2c..157030d225f 100644 --- a/test/unit/test_graph.py +++ b/test/unit/test_graph.py @@ -10,7 +10,7 @@ import dbt.parser import dbt.config import dbt.utils -import dbt.loader +import dbt.parser.manifest from dbt.contracts.graph.manifest import FilePath, SourceFile, FileHash from dbt.parser.results import ParseResult from dbt.parser.base import BaseParser @@ -43,7 +43,7 @@ def setUp(self): self.graph_result = None self.write_gpickle_patcher = patch('networkx.write_gpickle') - self.load_projects_patcher = patch('dbt.loader._load_projects') + self.load_projects_patcher = patch('dbt.parser.manifest._load_projects') self.file_system_patcher = patch.object( dbt.parser.search.FilesystemSearcher, '__new__' ) @@ -81,7 +81,7 @@ def _load_projects(config, paths): self.mock_models = [] - self.load_patch = patch('dbt.loader.make_parse_result') + self.load_patch = patch('dbt.parser.manifest.make_parse_result') self.mock_parse_result = self.load_patch.start() self.mock_parse_result.return_value = ParseResult.rpc() @@ -140,7 +140,7 @@ def use_models(self, models): self.mock_models.append(source_file) def load_manifest(self, config): - loader = dbt.loader.GraphLoader(config, {config.project_name: config}) + loader = dbt.parser.manifest.ManifestLoader(config, {config.project_name: config}) loader.load() return loader.create_manifest() diff --git a/test/unit/test_loader.py b/test/unit/test_parse_manifest.py similarity index 96% rename from test/unit/test_loader.py rename to test/unit/test_parse_manifest.py index 079dde1c62f..0e5fb6f1b5f 100644 --- a/test/unit/test_loader.py +++ b/test/unit/test_parse_manifest.py @@ -3,10 +3,10 @@ from .utils import config_from_parts_or_dicts, normalize -from dbt import loader from dbt.contracts.graph.manifest import FileHash, FilePath, SourceFile from dbt.parser import ParseResult from dbt.parser.search import FileBlock +from dbt.parser import manifest class MatchingHash(FileHash): @@ -56,10 +56,10 @@ def setUp(self): cli_vars='{"test_schema_name": "foo"}' ) self.parser = mock.MagicMock() - self.patched_result_builder = mock.patch('dbt.loader.make_parse_result') + self.patched_result_builder = mock.patch('dbt.parser.manifest.make_parse_result') self.mock_result_builder = self.patched_result_builder.start() self.patched_result_builder.return_value = self._new_results() - self.loader = loader.GraphLoader( + self.loader = manifest.ManifestLoader( self.root_project_config, {'root': self.root_project_config} ) diff --git a/test/unit/test_postgres_adapter.py b/test/unit/test_postgres_adapter.py index 634f011b54f..e14ecd5b304 100644 --- a/test/unit/test_postgres_adapter.py +++ b/test/unit/test_postgres_adapter.py @@ -2,6 +2,7 @@ from unittest import mock import dbt.flags as flags +import dbt.parser.manifest from dbt.task.debug import DebugTask from dbt.adapters.postgres import PostgresAdapter @@ -235,7 +236,7 @@ def setUp(self): self.adapter.acquire_connection() inject_adapter(self.adapter) - self.load_patch = mock.patch('dbt.loader.make_parse_result') + self.load_patch = mock.patch('dbt.parser.manifest.make_parse_result') self.mock_parse_result = self.load_patch.start() self.mock_parse_result.return_value = ParseResult.rpc() diff --git a/test/unit/test_snowflake_adapter.py b/test/unit/test_snowflake_adapter.py index d4bfde13e5f..61ce6e1747e 100644 --- a/test/unit/test_snowflake_adapter.py +++ b/test/unit/test_snowflake_adapter.py @@ -4,6 +4,7 @@ import dbt.flags as flags +import dbt.parser.manifest from dbt.adapters.snowflake import SnowflakeAdapter from dbt.logger import GLOBAL_LOGGER as logger # noqa from dbt.parser.results import ParseResult @@ -51,7 +52,7 @@ def setUp(self): ) self.snowflake = self.patcher.start() - self.load_patch = mock.patch('dbt.loader.make_parse_result') + self.load_patch = mock.patch('dbt.parser.manifest.make_parse_result') self.mock_parse_result = self.load_patch.start() self.mock_parse_result.return_value = ParseResult.rpc() diff --git a/tox.ini b/tox.ini index 278f59265a8..94dbb13fa20 100644 --- a/tox.ini +++ b/tox.ini @@ -25,12 +25,12 @@ commands = /bin/bash -c '$(which mypy) \ core/dbt/hooks.py \ core/dbt/include \ core/dbt/links.py \ - core/dbt/loader.py \ core/dbt/logger.py \ core/dbt/main.py \ core/dbt/node_runners.py \ core/dbt/node_types.py \ core/dbt/parser \ + core/dbt/perf_utils.py \ core/dbt/profiler.py \ core/dbt/py.typed \ core/dbt/rpc \ From ef16a99f88910fe8bf07347eede9666ef28f0d6e Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Mon, 14 Oct 2019 10:04:21 -0600 Subject: [PATCH 03/13] Refactors for mypy: initial refactoring of adapter factory stuff Move HasCredentials protocol into connection contract and use that in the base connection --- core/dbt/adapters/base/connections.py | 7 +- core/dbt/adapters/factory.py | 168 ++++++++++++++++---------- core/dbt/contracts/connection.py | 5 + core/dbt/task/base.py | 5 + 4 files changed, 115 insertions(+), 70 deletions(-) diff --git a/core/dbt/adapters/base/connections.py b/core/dbt/adapters/base/connections.py index 9f2cedf4616..dbe6873dde2 100644 --- a/core/dbt/adapters/base/connections.py +++ b/core/dbt/adapters/base/connections.py @@ -10,8 +10,9 @@ import dbt.exceptions import dbt.flags -from dbt.config import Profile -from dbt.contracts.connection import Connection, Identifier, ConnectionState +from dbt.contracts.connection import ( + Connection, Identifier, ConnectionState, HasCredentials +) from dbt.logger import GLOBAL_LOGGER as logger @@ -30,7 +31,7 @@ class BaseConnectionManager(metaclass=abc.ABCMeta): """ TYPE: str = NotImplemented - def __init__(self, profile: Profile): + def __init__(self, profile: HasCredentials): self.profile = profile self.thread_connections: Dict[Hashable, Connection] = {} self.lock = multiprocessing.RLock() diff --git a/core/dbt/adapters/factory.py b/core/dbt/adapters/factory.py index 9c7a78b2e6a..c0157b382fd 100644 --- a/core/dbt/adapters/factory.py +++ b/core/dbt/adapters/factory.py @@ -1,108 +1,142 @@ import threading from importlib import import_module -from typing import Type, Dict, TypeVar +from typing import Type, Dict, Any from dbt.exceptions import RuntimeException from dbt.include.global_project import PACKAGES from dbt.logger import GLOBAL_LOGGER as logger -from dbt.contracts.connection import Credentials +from dbt.contracts.connection import Credentials, HasCredentials +from dbt.adapters.base.impl import BaseAdapter +from dbt.adapters.base.plugin import AdapterPlugin # TODO: we can't import these because they cause an import cycle. -# currently RuntimeConfig needs to figure out default quoting for its adapter. -# We should push that elsewhere when we fixup project/profile stuff -# Instead here are some import loop avoiding-hacks right now. And Profile has -# to call into load_plugin to get credentials, so adapter/relation don't work -RuntimeConfig = TypeVar('RuntimeConfig') -BaseAdapter = TypeVar('BaseAdapter') -BaseRelation = TypeVar('BaseRelation') +# Profile has to call into load_plugin to get credentials, so adapter/relation +# don't work +BaseRelation = Any -ADAPTER_TYPES: Dict[str, Type[BaseAdapter]] = {} -_ADAPTERS: Dict[str, BaseAdapter] = {} -_ADAPTER_LOCK = threading.Lock() +Adapter = BaseAdapter -def get_adapter_class_by_name(adapter_name: str) -> Type[BaseAdapter]: - with _ADAPTER_LOCK: - if adapter_name in ADAPTER_TYPES: - return ADAPTER_TYPES[adapter_name] +class AdpaterContainer: + def __init__(self): + self.lock = threading.Lock() + self.adapters: Dict[str, Adapter] = {} + self.adapter_types: Dict[str, Type[Adapter]] = {} - adapter_names = ", ".join(ADAPTER_TYPES.keys()) + def get_adapter_class_by_name(self, name: str) -> Type[Adapter]: + with self.lock: + if name in self.adapter_types: + return self.adapter_types[name] - message = "Invalid adapter type {}! Must be one of {}" - formatted_message = message.format(adapter_name, adapter_names) - raise RuntimeException(formatted_message) + names = ", ".join(self.adapter_types.keys()) + message = f"Invalid adapter type {name}! Must be one of {names}" + raise RuntimeException(message) -def get_relation_class_by_name(adapter_name: str) -> Type[BaseRelation]: - adapter = get_adapter_class_by_name(adapter_name) - return adapter.Relation + def get_relation_class_by_name(self, name: str) -> Type[BaseRelation]: + adapter = self.get_adapter_class_by_name(name) + return adapter.Relation + def load_plugin(self, name: str) -> Type[Credentials]: + # this doesn't need a lock: in the worst case we'll overwrite PACKAGES + # and adapter_type entries with the same value, as they're all + # singletons + try: + mod = import_module('.' + name, 'dbt.adapters') + except ImportError as e: + logger.info("Error importing adapter: {}".format(e)) + raise RuntimeException( + "Could not find adapter type {}!".format(name) + ) + if not hasattr(mod, 'Plugin'): + raise RuntimeException( + f'Could not find plugin in {name} plugin module' + ) + plugin: AdapterPlugin = mod.Plugin # type: ignore + plugin_type = plugin.adapter.type() -def load_plugin(adapter_name: str) -> Credentials: - # this doesn't need a lock: in the worst case we'll overwrite PACKAGES and - # _ADAPTER_TYPE entries with the same value, as they're all singletons - try: - mod = import_module('.' + adapter_name, 'dbt.adapters') - except ImportError as e: - logger.info("Error importing adapter: {}".format(e)) - raise RuntimeException( - "Could not find adapter type {}!".format(adapter_name) - ) - plugin = mod.Plugin + if plugin_type != name: + raise RuntimeException( + f'Expected to find adapter with type named {name}, got ' + f'adapter with type {plugin_type}' + ) - if plugin.adapter.type() != adapter_name: - raise RuntimeException( - 'Expected to find adapter with type named {}, got adapter with ' - 'type {}' - .format(adapter_name, plugin.adapter.type()) - ) + with self.lock: + # things do hold the lock to iterate over it so we need it to add + self.adapter_types[name] = plugin.adapter - with _ADAPTER_LOCK: - # things do hold the lock to iterate over it so we need ot to add stuff - ADAPTER_TYPES[adapter_name] = plugin.adapter + PACKAGES[plugin.project_name] = plugin.include_path - PACKAGES[plugin.project_name] = plugin.include_path + for dep in plugin.dependencies: + self.load_plugin(dep) - for dep in plugin.dependencies: - load_plugin(dep) + return plugin.credentials - return plugin.credentials + def register_adapter(self, config: HasCredentials) -> Adapter: + adapter_name = config.credentials.type + adapter_type = self.get_adapter_class_by_name(adapter_name) + with self.lock: + if adapter_name in self.adapters: + # this shouldn't really happen... + return -def get_adapter(config: RuntimeConfig) -> BaseAdapter: - adapter_name = config.credentials.type + adapter: Adapter = adapter_type(config) # type: ignore + self.adapters[adapter_name] = adapter - # Atomically check to see if we already have an adapter - if adapter_name in _ADAPTERS: - return _ADAPTERS[adapter_name] + def lookup_adapter(self, adapter_name: str) -> Adapter: + return self.adapters[adapter_name] - adapter_type = get_adapter_class_by_name(adapter_name) + def reset_adapters(self): + """Clear the adapters. This is useful for tests, which change configs. + """ + with self.lock: + for adapter in self.adapters.values(): + adapter.cleanup_connections() + self.adapters.clear() - with _ADAPTER_LOCK: - # check again, in case something was setting it before - if adapter_name in _ADAPTERS: - return _ADAPTERS[adapter_name] + def cleanup_connections(self): + """Only clean up the adapter connections list without resetting the actual + adapters. + """ + with self.lock: + for adapter in self.adapters.values(): + adapter.cleanup_connections() - adapter = adapter_type(config) - _ADAPTERS[adapter_name] = adapter - return adapter + +FACTORY: AdpaterContainer = AdpaterContainer() + + +def register_adapter(config: HasCredentials) -> None: + return FACTORY.register_adapter(config) + + +def get_adapter(config: HasCredentials): + return FACTORY.lookup_adapter(config.credentials.type) def reset_adapters(): """Clear the adapters. This is useful for tests, which change configs. """ - with _ADAPTER_LOCK: - for adapter in _ADAPTERS.values(): - adapter.cleanup_connections() - _ADAPTERS.clear() + FACTORY.reset_adapters() def cleanup_connections(): """Only clean up the adapter connections list without resetting the actual adapters. """ - with _ADAPTER_LOCK: - for adapter in _ADAPTERS.values(): - adapter.cleanup_connections() + FACTORY.cleanup_connections() + + +def get_adapter_class_by_name(name: str) -> Type[BaseAdapter]: + return FACTORY.get_adapter_class_by_name(name) + + +def get_relation_class_by_name(name: str) -> Type[BaseRelation]: + return FACTORY.get_relation_class_by_name(name) + + +def load_plugin(name: str) -> Type[Credentials]: + return FACTORY.load_plugin(name) diff --git a/core/dbt/contracts/connection.py b/core/dbt/contracts/connection.py index 56b4be0adf2..e0e1fe112aa 100644 --- a/core/dbt/contracts/connection.py +++ b/core/dbt/contracts/connection.py @@ -3,6 +3,7 @@ from typing import ( Any, ClassVar, Dict, Tuple, Iterable, Optional, NewType ) +from typing_extensions import Protocol from hologram import JsonSchemaMixin from hologram.helpers import ( @@ -116,3 +117,7 @@ def to_dict(self, omit_none=True, validate=False, with_aliases=False): if canonical_name in serialized }) return serialized + + +class HasCredentials(Protocol): + credentials: Credentials diff --git a/core/dbt/task/base.py b/core/dbt/task/base.py index 2e5f3e7a1c9..99722681bc0 100644 --- a/core/dbt/task/base.py +++ b/core/dbt/task/base.py @@ -2,6 +2,7 @@ from abc import ABCMeta, abstractmethod from typing import Type, Union +from dbt.adapters.factory import register_adapter from dbt.config import RuntimeConfig, Project from dbt.config.profile import read_profile, PROFILES_DIR from dbt import tracking @@ -137,6 +138,10 @@ def from_args(cls, args): class ConfiguredTask(RequiresProjectTask): ConfigType = RuntimeConfig + def __init__(self, args, config): + super().__init__(args, config) + register_adapter(self.config) + class ProjectOnlyTask(RequiresProjectTask): ConfigType = Project From 773c979955f61694f272198a7f90ed0694f91284 Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Mon, 14 Oct 2019 09:42:15 -0600 Subject: [PATCH 04/13] Refactor internal RPC logic to support just getting parsed manifests RemoteCallableResult -> RPCResult RemoteCallable -> RemoteMethod - move some things from RPCTask -> RemoteMethod - recursive_subclasses classmethod things in core/dbt/rpc now are all based on RemoteMethods, not RPCTasks --- core/dbt/parser/search.py | 2 + core/dbt/rpc/__init__.py | 7 -- core/dbt/rpc/logger.py | 8 +-- core/dbt/rpc/{task.py => method.py} | 99 +++++++---------------------- core/dbt/rpc/response_manager.py | 4 +- core/dbt/rpc/task_handler.py | 10 +-- core/dbt/rpc/task_manager.py | 11 ++-- core/dbt/task/remote.py | 63 +++++++++++++++++- 8 files changed, 104 insertions(+), 100 deletions(-) rename core/dbt/rpc/{task.py => method.py} (50%) diff --git a/core/dbt/parser/search.py b/core/dbt/parser/search.py index b343cc98189..08384d7bd2d 100644 --- a/core/dbt/parser/search.py +++ b/core/dbt/parser/search.py @@ -32,6 +32,7 @@ def path(self): @dataclass class BlockContents(FileBlock): + file: SourceFile # if you remove this, mypy will get upset block: BlockTag @property @@ -45,6 +46,7 @@ def contents(self): @dataclass class FullBlock(FileBlock): + file: SourceFile # if you remove this, mypy will get upset block: BlockTag @property diff --git a/core/dbt/rpc/__init__.py b/core/dbt/rpc/__init__.py index 67c92e7928b..2d82f709ad2 100644 --- a/core/dbt/rpc/__init__.py +++ b/core/dbt/rpc/__init__.py @@ -34,10 +34,3 @@ - `kills` all processes (triggering the end of all processes, right!?) - exits (all remaining threads should die here!) """ -from dbt.rpc.error import ( # noqa - dbt_error, server_error, invalid_params, RPCException -) -from dbt.rpc.logger import RemoteCallableResult # noqa -from dbt.rpc.task import RemoteCallable # noqa -from dbt.rpc.task_manager import TaskManager # noqa -from dbt.rpc.response_manager import ResponseManager # noqa diff --git a/core/dbt/rpc/logger.py b/core/dbt/rpc/logger.py index 2bd06e84f19..508ad27772b 100644 --- a/core/dbt/rpc/logger.py +++ b/core/dbt/rpc/logger.py @@ -16,7 +16,7 @@ from dbt.utils import restrict_to -RemoteCallableResult = Union[ +RPCResult = Union[ RemoteCompileResult, RemoteExecutionResult, RemoteCatalogResults, @@ -73,10 +73,10 @@ class QueueResultMessage(QueueMessage): message_type: QueueMessageType = field( metadata=restrict_to(QueueMessageType.Result) ) - result: RemoteCallableResult + result: RPCResult @classmethod - def from_result(cls, result: RemoteCallableResult): + def from_result(cls, result: RPCResult): return cls( message_type=QueueMessageType.Result, result=result, @@ -101,7 +101,7 @@ def emit(self, record: logbook.LogRecord): def emit_error(self, error: JSONRPCError): self.queue.put_nowait(QueueErrorMessage.from_error(error)) - def emit_result(self, result: RemoteCallableResult): + def emit_result(self, result: RPCResult): self.queue.put_nowait(QueueResultMessage.from_result(result)) diff --git a/core/dbt/rpc/task.py b/core/dbt/rpc/method.py similarity index 50% rename from core/dbt/rpc/task.py rename to core/dbt/rpc/method.py index e698db1b955..39f3b24ee95 100644 --- a/core/dbt/rpc/task.py +++ b/core/dbt/rpc/method.py @@ -1,23 +1,30 @@ -import base64 import inspect from abc import abstractmethod -from typing import Union, List, Optional, Type, TypeVar, Generic +from typing import List, Optional, Type, TypeVar, Generic from dbt.contracts.rpc import RPCParameters from dbt.exceptions import NotImplementedException, InternalException -from dbt.rpc.logger import RemoteCallableResult, RemoteExecutionResult -from dbt.rpc.error import invalid_params -from dbt.task.compile import CompileTask +from dbt.rpc.logger import RPCResult Parameters = TypeVar('Parameters', bound=RPCParameters) -Result = TypeVar('Result', bound=RemoteCallableResult) +Result = TypeVar('Result', bound=RPCResult) -class RemoteCallable(Generic[Parameters, Result]): +# If you call recursive_subclasses on a subclass of RemoteMethod, it should +# only return subtypes of the given subclass. +T = TypeVar('T', bound='RemoteMethod') + + +class RemoteMethod(Generic[Parameters, Result]): METHOD_NAME: Optional[str] = None is_async = False + def __init__(self, args, config, manifest): + self.args = args + self.config = config + self.manifest = manifest.deepcopy(config=config) + @classmethod def get_parameters(cls) -> Type[Parameters]: argspec = inspect.getfullargspec(cls.set_args) @@ -41,66 +48,6 @@ def get_parameters(cls) -> Type[Parameters]: ) return params_type - @abstractmethod - def set_args(self, params: Parameters): - raise NotImplementedException( - 'set_args not implemented' - ) - - @abstractmethod - def handle_request(self) -> Result: - raise NotImplementedException( - 'handle_request not implemented' - ) - - @staticmethod - def _listify( - value: Optional[Union[str, List[str]]] - ) -> Optional[List[str]]: - if value is None: - return None - elif isinstance(value, str): - return [value] - else: - return value - - def decode_sql(self, sql: str) -> str: - """Base64 decode a string. This should only be used for sql in calls. - - :param str sql: The base64 encoded form of the original utf-8 string - :return str: The decoded utf-8 string - """ - # JSON is defined as using "unicode", we'll go a step further and - # mandate utf-8 (though for the base64 part, it doesn't really matter!) - base64_sql_bytes = str(sql).encode('utf-8') - - try: - sql_bytes = base64.b64decode(base64_sql_bytes, validate=True) - except ValueError: - self.raise_invalid_base64(sql) - - return sql_bytes.decode('utf-8') - - @staticmethod - def raise_invalid_base64(sql): - raise invalid_params( - data={ - 'message': 'invalid base64-encoded sql input', - 'sql': str(sql), - } - ) - - -# If you call recursive_subclasses on a subclass of RPCTask, it should only -# return subtypes of the given subclass. -T = TypeVar('T', bound='RPCTask') - - -class RPCTask(CompileTask, RemoteCallable[Parameters, RemoteExecutionResult]): - def __init__(self, args, config, manifest): - super().__init__(args, config) - self._base_manifest = manifest.deepcopy(config=config) - @classmethod def recursive_subclasses( cls: Type[T], named_only: bool = True @@ -116,12 +63,14 @@ def recursive_subclasses( classes = [c for c in classes if c.METHOD_NAME is not None] return classes - def get_result( - self, results, elapsed_time, generated_at - ) -> RemoteExecutionResult: - return RemoteExecutionResult( - results=results, - elapsed_time=elapsed_time, - generated_at=generated_at, - logs=[], + @abstractmethod + def set_args(self, params: Parameters): + raise NotImplementedException( + 'set_args not implemented' + ) + + @abstractmethod + def handle_request(self) -> Result: + raise NotImplementedException( + 'handle_request not implemented' ) diff --git a/core/dbt/rpc/response_manager.py b/core/dbt/rpc/response_manager.py index 2d4a4252b57..a53a101564f 100644 --- a/core/dbt/rpc/response_manager.py +++ b/core/dbt/rpc/response_manager.py @@ -17,7 +17,7 @@ from dbt.logger import GLOBAL_LOGGER as logger from dbt.rpc.logger import RequestContext from dbt.rpc.task_handler import RequestTaskHandler -from dbt.rpc.task import RemoteCallable +from dbt.rpc.method import RemoteMethod from dbt.rpc.task_manager import TaskManager @@ -50,7 +50,7 @@ def __getitem__(self, key) -> Callable[..., Dict[str, Any]]: ) if handler is None: raise KeyError(key) - elif isinstance(handler, RemoteCallable): + elif isinstance(handler, RemoteMethod): # the handler must be a task. Wrap it in a task handler so it can # go async return RequestTaskHandler( diff --git a/core/dbt/rpc/task_handler.py b/core/dbt/rpc/task_handler.py index 9d260607929..a260fea256d 100644 --- a/core/dbt/rpc/task_handler.py +++ b/core/dbt/rpc/task_handler.py @@ -22,14 +22,14 @@ timeout_error, ) from dbt.rpc.logger import ( - RemoteCallableResult, + RPCResult, QueueSubscriber, QueueLogHandler, QueueErrorMessage, QueueResultMessage, QueueTimeoutMessage, ) -from dbt.rpc.task import RPCTask +from dbt.rpc.method import RemoteMethod from dbt.utils import env_set_truthy # we use this in typing only... @@ -95,7 +95,7 @@ def sigterm_handler(signum, frame): def _task_bootstrap( - task: RPCTask, + task: RemoteMethod, queue, # typing: Queue[Tuple[QueueMessageType, Any]] params: JsonSchemaMixin, ) -> None: @@ -241,7 +241,7 @@ def tags(self) -> Optional[Dict[str, Any]]: return None return self.task_params.task_tags - def _wait_for_results(self) -> RemoteCallableResult: + def _wait_for_results(self) -> RPCResult: """Wait for results off the queue. If there is an exception raised, raise an appropriate RPC exception. @@ -279,7 +279,7 @@ def _wait_for_results(self) -> RemoteCallableResult: 'Invalid message type {} (result={})'.format(msg) ) - def get_result(self) -> RemoteCallableResult: + def get_result(self) -> RPCResult: if self.process is None: raise dbt.exceptions.InternalException( 'get_result() called before handle()' diff --git a/core/dbt/rpc/task_manager.py b/core/dbt/rpc/task_manager.py index a709d9395e0..7f3ee92c790 100644 --- a/core/dbt/rpc/task_manager.py +++ b/core/dbt/rpc/task_manager.py @@ -24,7 +24,8 @@ from dbt.logger import LogMessage from dbt.rpc.error import dbt_error, RPCException from dbt.rpc.task_handler import TaskHandlerState, RequestTaskHandler -from dbt.rpc.task import RemoteCallable, RPCTask +from dbt.rpc.method import RemoteMethod + from dbt.utils import restrict_to # import this to make sure our timedelta encoder is registered @@ -329,7 +330,7 @@ def __init__(self, args, config): self.args = args self.config = config self.tasks: Dict[uuid.UUID, RequestTaskHandler] = {} - self._rpc_task_map: Dict[str, RPCTask] = {} + self._rpc_task_map: Dict[str, RemoteMethod] = {} self._builtins: Dict[str, UnmanagedHandler] = {} self.last_compile = LastCompile(status=ManifestStatus.Init) self._lock = multiprocessing.Lock() @@ -343,7 +344,7 @@ def add_request(self, request_handler): def reserve_handler(self, task): self._rpc_task_map[task.METHOD_NAME] = None - def _assert_unique_task(self, task_type: Type[RPCTask]): + def _assert_unique_task(self, task_type: Type[RemoteMethod]): method = task_type.METHOD_NAME if method not in self._rpc_task_map: # this is weird, but hey whatever @@ -357,7 +358,7 @@ def _assert_unique_task(self, task_type: Type[RPCTask]): 'should be unique'.format(task_type, other_task) ) - def add_task_handler(self, task: Type[RPCTask], manifest: Manifest): + def add_task_handler(self, task: Type[RemoteMethod], manifest: Manifest): if task.METHOD_NAME is None: raise dbt.exceptions.InternalException( 'Task {} has no method name, cannot add it'.format(task) @@ -547,7 +548,7 @@ def compilation_error(self, *args, **kwargs): def get_handler( self, method, http_request, json_rpc_request - ) -> Optional[Union[WrappedHandler, RemoteCallable]]: + ) -> Optional[Union[WrappedHandler, RemoteMethod]]: # get_handler triggers a GC check. TODO: does this go somewhere else? self.gc_as_required() # the dispatcher's keys are method names and its values are functions diff --git a/core/dbt/task/remote.py b/core/dbt/task/remote.py index 1de95a13c52..d46b3eeb626 100644 --- a/core/dbt/task/remote.py +++ b/core/dbt/task/remote.py @@ -1,8 +1,9 @@ +import base64 import shlex import signal import threading from datetime import datetime -from typing import Type +from typing import Type, Optional, Union, List import dbt.exceptions import dbt.ui.printer @@ -23,17 +24,38 @@ from dbt.parser.rpc import RPCCallParser, RPCMacroParser from dbt.parser.util import ParserUtils from dbt.logger import GLOBAL_LOGGER as logger +from dbt.rpc.error import invalid_params from dbt.rpc.node_runners import ( RPCCompileRunner, RPCExecuteRunner ) -from dbt.rpc.task import RPCTask, Parameters +from dbt.rpc.method import RemoteMethod, Parameters +from dbt.task.runnable import GraphRunnableTask from dbt.task.generate import GenerateTask from dbt.task.run import RunTask from dbt.task.seed import SeedTask from dbt.task.test import TestTask +class RPCTask( + GraphRunnableTask, + RemoteMethod[Parameters, RemoteExecutionResult] +): + def __init__(self, args, config, manifest): + super().__init__(args, config) + RemoteMethod.__init__(self, args, config, manifest) + + def get_result( + self, results, elapsed_time, generated_at + ) -> RemoteExecutionResult: + return RemoteExecutionResult( + results=results, + elapsed_time=elapsed_time, + generated_at=generated_at, + logs=[], + ) + + class _RPCExecTask(RPCTask[RPCExecParameters]): def runtime_cleanup(self, selected_uids): """Do some pre-run cleanup that is usually performed in Task __init__. @@ -45,6 +67,32 @@ def runtime_cleanup(self, selected_uids): self._skipped_children = {} self._raise_next_tick = None + def decode_sql(self, sql: str) -> str: + """Base64 decode a string. This should only be used for sql in calls. + + :param str sql: The base64 encoded form of the original utf-8 string + :return str: The decoded utf-8 string + """ + # JSON is defined as using "unicode", we'll go a step further and + # mandate utf-8 (though for the base64 part, it doesn't really matter!) + base64_sql_bytes = str(sql).encode('utf-8') + + try: + sql_bytes = base64.b64decode(base64_sql_bytes, validate=True) + except ValueError: + self.raise_invalid_base64(sql) + + return sql_bytes.decode('utf-8') + + @staticmethod + def raise_invalid_base64(sql): + raise invalid_params( + data={ + 'message': 'invalid base64-encoded sql input', + 'sql': str(sql), + } + ) + def _extract_request_data(self, data): data = self.decode_sql(data) macro_blocks = [] @@ -179,6 +227,17 @@ def __init__(self, args, config, manifest): super().__init__(args, config, manifest) self.manifest = self._base_manifest + @staticmethod + def _listify( + value: Optional[Union[str, List[str]]] + ) -> Optional[List[str]]: + if value is None: + return None + elif isinstance(value, str): + return [value] + else: + return value + def load_manifest(self): # we started out with a manifest! pass From 7206c202bf501c14fb799426e639e5fa1ff7c283 Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Mon, 14 Oct 2019 09:49:45 -0600 Subject: [PATCH 05/13] Manifests are only parsed, not compiled at SIGHUP/startup time The _sql tasks now compile any ref'ed CTE chains at RPC call time Give RPC tasks their own folder - task/rpc_server -> task/rpc/server - task/remote -> task/rpc/{project_commands,sql_commands,base} Linker enhancements: - Expose subset graph building so multiple methods can use it - Expose a way for the linker to provide an interable of the ephemeral ancestors of a node - it's guaranteed to be ordered (so nested CTEs behave) --- core/dbt/linker.py | 84 +++-- core/dbt/main.py | 2 +- core/dbt/rpc/__init__.py | 2 +- core/dbt/task/remote.py | 331 ------------------ core/dbt/task/rpc/base.py | 22 ++ core/dbt/task/rpc/project_commands.py | 129 +++++++ .../dbt/task/{rpc_server.py => rpc/server.py} | 5 +- core/dbt/task/rpc/sql_commands.py | 196 +++++++++++ 8 files changed, 413 insertions(+), 358 deletions(-) delete mode 100644 core/dbt/task/remote.py create mode 100644 core/dbt/task/rpc/base.py create mode 100644 core/dbt/task/rpc/project_commands.py rename core/dbt/task/{rpc_server.py => rpc/server.py} (98%) create mode 100644 core/dbt/task/rpc/sql_commands.py diff --git a/core/dbt/linker.py b/core/dbt/linker.py index fca7bdd270e..479f33a1853 100644 --- a/core/dbt/linker.py +++ b/core/dbt/linker.py @@ -1,5 +1,5 @@ from queue import PriorityQueue -from typing import Iterable, Optional +from typing import Iterable, Set, Optional import networkx as nx import threading @@ -19,6 +19,11 @@ def is_blocking_dependency(node): return node.resource_type == NodeType.Model +def is_ephemeral_dependency(node): + return (node.resource_type == NodeType.Model and + node.get_materialization() == 'ephemeral') + + class GraphQueue: """A fancy queue that is backed by the dependency graph. Note: this will mutate input! @@ -167,28 +172,6 @@ def join(self): self.inner.join() -def _subset_graph(graph, include_nodes): - """Create and return a new graph that is a shallow copy of graph but with - only the nodes in include_nodes. Transitive edges across removed nodes are - preserved as explicit new edges. - """ - new_graph = nx.algorithms.transitive_closure(graph) - - include_nodes = set(include_nodes) - - for node in graph.nodes(): - if node not in include_nodes: - new_graph.remove_node(node) - - for node in include_nodes: - if node not in new_graph: - raise RuntimeError( - "Couldn't find model '{}' -- does it exist or is " - "it disabled?".format(node) - ) - return new_graph - - class Linker: def __init__(self, data=None): if data is None: @@ -215,6 +198,27 @@ def find_cycles(self): return None + def build_subset_graph(self, include_nodes: Iterable[str]): + """Create and return a new graph that is a shallow copy of the graph, + but with only the nodes in include_nodes. Transitive edges across + removed nodes are preserved as explicit new edges. + """ + new_graph = nx.algorithms.transitive_closure(self.graph) + + include_nodes = set(include_nodes) + + for node in self.graph.nodes(): + if node not in include_nodes: + new_graph.remove_node(node) + + for node in include_nodes: + if node not in new_graph: + raise RuntimeError( + "Couldn't find model '{}' -- does it exist or is " + "it disabled?".format(node) + ) + return new_graph + def as_graph_queue( self, manifest: Manifest, limit_to: Optional[Iterable[str]] = None ) -> GraphQueue: @@ -226,9 +230,41 @@ def as_graph_queue( else: graph_nodes = limit_to - new_graph = _subset_graph(self.graph, graph_nodes) + new_graph = self.build_subset_graph(graph_nodes) return GraphQueue(new_graph, manifest) + def sorted_ephemeral_ancestors( + self, manifest: Manifest, unique_id: str + ) -> Iterable[str]: + """Get the ephemeral ancestors of unique_id, stopping at the first + non-ephemeral node in each chain, in graph-topological order. + """ + to_check: Set[str] = {unique_id} + ephemerals: Set[str] = set() + visited: Set[str] = set() + + while to_check: + # note that this avoids collecting unique_id itself + nextval = to_check.pop() + for pred in self.graph.predecessors(nextval): + if pred in visited: + continue + visited.add(pred) + node = manifest.expect(pred) + + if node.resource_type != NodeType.Model: + continue + if node.get_materialization() != 'ephemeral': + continue + # this is an ephemeral model! We have to find everything it + # refs and do it all over again until we exhaust them all + ephemerals.add(pred) + to_check.add(pred) + + ephemeral_graph = self.build_subset_graph(ephemerals) + # we can just topo sort this because we know there are no cycles. + return nx.topological_sort(ephemeral_graph) + def get_dependent_nodes(self, node): return nx.descendants(self.graph, node) diff --git a/core/dbt/main.py b/core/dbt/main.py index fb57053b046..d6341d73403 100644 --- a/core/dbt/main.py +++ b/core/dbt/main.py @@ -22,7 +22,7 @@ import dbt.task.freshness as freshness_task import dbt.task.run_operation as run_operation_task from dbt.task.list import ListTask -from dbt.task.rpc_server import RPCServerTask +from dbt.task.rpc.server import RPCServerTask from dbt.adapters.factory import reset_adapters, cleanup_connections import dbt.tracking diff --git a/core/dbt/rpc/__init__.py b/core/dbt/rpc/__init__.py index 2d82f709ad2..f34604bac88 100644 --- a/core/dbt/rpc/__init__.py +++ b/core/dbt/rpc/__init__.py @@ -1,6 +1,6 @@ """The `rpc` package handles most aspects of the actual execution of dbt's RPC server (except for the server itself and the client tasks, which are defined in -the `task.remote` and `task.rpc_server` modules). +the `task.remote` package). The general idea from a thread/process management perspective (ignoring the --single-threaded flag!) is as follows: diff --git a/core/dbt/task/remote.py b/core/dbt/task/remote.py deleted file mode 100644 index d46b3eeb626..00000000000 --- a/core/dbt/task/remote.py +++ /dev/null @@ -1,331 +0,0 @@ -import base64 -import shlex -import signal -import threading -from datetime import datetime -from typing import Type, Optional, Union, List - -import dbt.exceptions -import dbt.ui.printer -from dbt.adapters.factory import get_adapter -from dbt.clients.jinja import extract_toplevel_blocks -from dbt.compilation import compile_manifest -from dbt.contracts.rpc import ( - RPCExecParameters, - RPCCompileParameters, - RPCTestParameters, - RPCSeedParameters, - RPCDocsGenerateParameters, - RPCCliParameters, - RemoteCatalogResults, - RemoteExecutionResult, -) -from dbt.parser.results import ParseResult -from dbt.parser.rpc import RPCCallParser, RPCMacroParser -from dbt.parser.util import ParserUtils -from dbt.logger import GLOBAL_LOGGER as logger -from dbt.rpc.error import invalid_params -from dbt.rpc.node_runners import ( - RPCCompileRunner, RPCExecuteRunner -) -from dbt.rpc.method import RemoteMethod, Parameters - -from dbt.task.runnable import GraphRunnableTask -from dbt.task.generate import GenerateTask -from dbt.task.run import RunTask -from dbt.task.seed import SeedTask -from dbt.task.test import TestTask - - -class RPCTask( - GraphRunnableTask, - RemoteMethod[Parameters, RemoteExecutionResult] -): - def __init__(self, args, config, manifest): - super().__init__(args, config) - RemoteMethod.__init__(self, args, config, manifest) - - def get_result( - self, results, elapsed_time, generated_at - ) -> RemoteExecutionResult: - return RemoteExecutionResult( - results=results, - elapsed_time=elapsed_time, - generated_at=generated_at, - logs=[], - ) - - -class _RPCExecTask(RPCTask[RPCExecParameters]): - def runtime_cleanup(self, selected_uids): - """Do some pre-run cleanup that is usually performed in Task __init__. - """ - self.run_count = 0 - self.num_nodes = len(selected_uids) - self.node_results = [] - self._skipped_children = {} - self._skipped_children = {} - self._raise_next_tick = None - - def decode_sql(self, sql: str) -> str: - """Base64 decode a string. This should only be used for sql in calls. - - :param str sql: The base64 encoded form of the original utf-8 string - :return str: The decoded utf-8 string - """ - # JSON is defined as using "unicode", we'll go a step further and - # mandate utf-8 (though for the base64 part, it doesn't really matter!) - base64_sql_bytes = str(sql).encode('utf-8') - - try: - sql_bytes = base64.b64decode(base64_sql_bytes, validate=True) - except ValueError: - self.raise_invalid_base64(sql) - - return sql_bytes.decode('utf-8') - - @staticmethod - def raise_invalid_base64(sql): - raise invalid_params( - data={ - 'message': 'invalid base64-encoded sql input', - 'sql': str(sql), - } - ) - - def _extract_request_data(self, data): - data = self.decode_sql(data) - macro_blocks = [] - data_chunks = [] - for block in extract_toplevel_blocks(data): - if block.block_type_name == 'macro': - macro_blocks.append(block.full_block) - else: - data_chunks.append(block.full_block) - macros = '\n'.join(macro_blocks) - sql = ''.join(data_chunks) - return sql, macros - - def _get_exec_node(self): - results = ParseResult.rpc() - macro_overrides = {} - macros = self.args.macros - sql, macros = self._extract_request_data(self.args.sql) - - if macros: - macro_parser = RPCMacroParser(results, self.config) - for node in macro_parser.parse_remote(macros): - macro_overrides[node.unique_id] = node - - self._base_manifest.macros.update(macro_overrides) - rpc_parser = RPCCallParser( - results=results, - project=self.config, - root_project=self.config, - macro_manifest=self._base_manifest, - ) - node = rpc_parser.parse_remote(sql, self.args.name) - self.manifest = ParserUtils.add_new_refs( - manifest=self._base_manifest, - current_project=self.config, - node=node, - macros=macro_overrides - ) - - # don't write our new, weird manifest! - self.linker = compile_manifest(self.config, self.manifest, write=False) - return node - - def _raise_set_error(self): - if self._raise_next_tick is not None: - raise self._raise_next_tick - - def _in_thread(self, node, thread_done): - runner = self.get_runner(node) - try: - self.node_results.append(runner.safe_run(self.manifest)) - except Exception as exc: - logger.debug('Got exception {}'.format(exc), exc_info=True) - self._raise_next_tick = exc - finally: - thread_done.set() - - def set_args(self, params: RPCExecParameters): - self.args.name = params.name - self.args.sql = params.sql - self.args.macros = params.macros - - def handle_request(self) -> RemoteExecutionResult: - # we could get a ctrl+c at any time, including during parsing. - thread = None - started = datetime.utcnow() - try: - node = self._get_exec_node() - - selected_uids = [node.unique_id] - self.runtime_cleanup(selected_uids) - - thread_done = threading.Event() - thread = threading.Thread(target=self._in_thread, - args=(node, thread_done)) - thread.start() - thread_done.wait() - except KeyboardInterrupt: - adapter = get_adapter(self.config) # type: ignore - if adapter.is_cancelable(): - - for conn_name in adapter.cancel_open_connections(): - logger.debug('canceled query {}'.format(conn_name)) - if thread: - thread.join() - else: - msg = ("The {} adapter does not support query " - "cancellation. Some queries may still be " - "running!".format(adapter.type())) - - logger.debug(msg) - - raise dbt.exceptions.RPCKilledException(signal.SIGINT) - - self._raise_set_error() - - ended = datetime.utcnow() - elapsed = (ended - started).total_seconds() - return self.get_result( - results=self.node_results, - elapsed_time=elapsed, - generated_at=ended, - ) - - -class RemoteCompileTask(_RPCExecTask): - METHOD_NAME = 'compile_sql' - - def handle_request(self) -> RemoteExecutionResult: - # TODO: annotate that this is a RemoteExecutionResult of - # RemoteCompileResults. - return super().handle_request() - - def get_runner_type(self): - return RPCCompileRunner - - -class RemoteRunTask(_RPCExecTask, RunTask): - METHOD_NAME = 'run_sql' - - def handle_request(self) -> RemoteExecutionResult: - # TODO: annotate that this is a RemoteExecutionResult of - # RemoteRunResult. - return super().handle_request() - - def get_runner_type(self): - return RPCExecuteRunner - - -class _RPCCommandTask(RPCTask[Parameters]): - def __init__(self, args, config, manifest): - super().__init__(args, config, manifest) - self.manifest = self._base_manifest - - @staticmethod - def _listify( - value: Optional[Union[str, List[str]]] - ) -> Optional[List[str]]: - if value is None: - return None - elif isinstance(value, str): - return [value] - else: - return value - - def load_manifest(self): - # we started out with a manifest! - pass - - def handle_request(self) -> RemoteExecutionResult: - return self.run() - - -class RemoteCompileProjectTask(_RPCCommandTask[RPCCompileParameters]): - METHOD_NAME = 'compile' - - def set_args(self, params: RPCCompileParameters) -> None: - self.args.models = self._listify(params.models) - self.args.exclude = self._listify(params.exclude) - - -class RemoteRunProjectTask(_RPCCommandTask[RPCCompileParameters], RunTask): - METHOD_NAME = 'run' - - def set_args(self, params: RPCCompileParameters) -> None: - self.args.models = self._listify(params.models) - self.args.exclude = self._listify(params.exclude) - - -class RemoteSeedProjectTask(_RPCCommandTask[RPCSeedParameters], SeedTask): - METHOD_NAME = 'seed' - - def set_args(self, params: RPCSeedParameters) -> None: - self.args.show = params.show - - -class RemoteTestProjectTask(_RPCCommandTask[RPCTestParameters], TestTask): - METHOD_NAME = 'test' - - def set_args(self, params: RPCTestParameters) -> None: - self.args.models = self._listify(params.models) - self.args.exclude = self._listify(params.exclude) - self.args.data = params.data - self.args.schema = params.schema - - -class RemoteDocsGenerateProjectTask( - _RPCCommandTask[RPCDocsGenerateParameters], - GenerateTask, -): - METHOD_NAME = 'docs.generate' - - def set_args(self, params: RPCDocsGenerateParameters) -> None: - self.args.models = None - self.args.exclude = None - self.args.compile = params.compile - - def get_catalog_results( - self, nodes, generated_at, compile_results - ) -> RemoteCatalogResults: - return RemoteCatalogResults( - nodes=nodes, - generated_at=datetime.utcnow(), - _compile_results=compile_results, - logs=[], - ) - - -class RemoteRPCParameters(_RPCCommandTask[RPCCliParameters]): - METHOD_NAME = 'cli_args' - - def set_args(self, params: RPCCliParameters) -> None: - # more import cycles :( - from dbt.main import parse_args, RPCArgumentParser - split = shlex.split(params.cli) - self.args = parse_args(split, RPCArgumentParser) - - def get_rpc_task_cls(self) -> Type[_RPCCommandTask]: - # This is obnoxious, but we don't have actual access to the TaskManager - # so instead we get to dig through all the subclasses of RPCTask - # (recursively!) looking for a matching METHOD_NAME - candidate: Type[_RPCCommandTask] - for candidate in _RPCCommandTask.recursive_subclasses(): - if candidate.METHOD_NAME == self.args.rpc_method: - return candidate - # this shouldn't happen - raise dbt.exceptions.InternalException( - 'No matching handler found for rpc method {} (which={})' - .format(self.args.rpc_method, self.args.which) - ) - - def handle_request(self) -> RemoteExecutionResult: - cls = self.get_rpc_task_cls() - # we parsed args from the cli, so we're set on that front - task = cls(self.args, self.config, self.manifest) - return task.handle_request() diff --git a/core/dbt/task/rpc/base.py b/core/dbt/task/rpc/base.py new file mode 100644 index 00000000000..7187724d53b --- /dev/null +++ b/core/dbt/task/rpc/base.py @@ -0,0 +1,22 @@ +from dbt.contracts.rpc import RemoteExecutionResult +from dbt.task.runnable import GraphRunnableTask +from dbt.rpc.method import RemoteMethod, Parameters + + +class RPCTask( + GraphRunnableTask, + RemoteMethod[Parameters, RemoteExecutionResult] +): + def __init__(self, args, config, manifest): + super().__init__(args, config) + RemoteMethod.__init__(self, args, config, manifest) + + def get_result( + self, results, elapsed_time, generated_at + ) -> RemoteExecutionResult: + return RemoteExecutionResult( + results=results, + elapsed_time=elapsed_time, + generated_at=generated_at, + logs=[], + ) diff --git a/core/dbt/task/rpc/project_commands.py b/core/dbt/task/rpc/project_commands.py new file mode 100644 index 00000000000..b8ad3d688d4 --- /dev/null +++ b/core/dbt/task/rpc/project_commands.py @@ -0,0 +1,129 @@ +from datetime import datetime +import shlex +from typing import Type, List, Optional, Union + + +from dbt.contracts.rpc import ( + RPCCliParameters, + RPCCompileParameters, + RPCDocsGenerateParameters, + RPCSeedParameters, + RPCTestParameters, + RemoteCatalogResults, + RemoteExecutionResult, +) +from dbt.exceptions import InternalException +from dbt.task.compile import CompileTask +from dbt.task.generate import GenerateTask +from dbt.task.run import RunTask +from dbt.task.seed import SeedTask +from dbt.task.test import TestTask + +from .base import RPCTask, Parameters + + +class RPCCommandTask(RPCTask[Parameters]): + @staticmethod + def _listify( + value: Optional[Union[str, List[str]]] + ) -> Optional[List[str]]: + if value is None: + return None + elif isinstance(value, str): + return [value] + else: + return value + + def load_manifest(self): + # we started out with a manifest! + pass + + def handle_request(self) -> RemoteExecutionResult: + return self.run() + + +class RemoteCompileProjectTask( + RPCCommandTask[RPCCompileParameters], CompileTask +): + METHOD_NAME = 'compile' + + def set_args(self, params: RPCCompileParameters) -> None: + self.args.models = self._listify(params.models) + self.args.exclude = self._listify(params.exclude) + + +class RemoteRunProjectTask(RPCCommandTask[RPCCompileParameters], RunTask): + METHOD_NAME = 'run' + + def set_args(self, params: RPCCompileParameters) -> None: + self.args.models = self._listify(params.models) + self.args.exclude = self._listify(params.exclude) + + +class RemoteSeedProjectTask(RPCCommandTask[RPCSeedParameters], SeedTask): + METHOD_NAME = 'seed' + + def set_args(self, params: RPCSeedParameters) -> None: + self.args.show = params.show + + +class RemoteTestProjectTask(RPCCommandTask[RPCTestParameters], TestTask): + METHOD_NAME = 'test' + + def set_args(self, params: RPCTestParameters) -> None: + self.args.models = self._listify(params.models) + self.args.exclude = self._listify(params.exclude) + self.args.data = params.data + self.args.schema = params.schema + + +class RemoteDocsGenerateProjectTask( + RPCCommandTask[RPCDocsGenerateParameters], + GenerateTask, +): + METHOD_NAME = 'docs.generate' + + def set_args(self, params: RPCDocsGenerateParameters) -> None: + self.args.models = None + self.args.exclude = None + self.args.compile = params.compile + + def get_catalog_results( + self, nodes, generated_at, compile_results + ) -> RemoteCatalogResults: + return RemoteCatalogResults( + nodes=nodes, + generated_at=datetime.utcnow(), + _compile_results=compile_results, + logs=[], + ) + + +class RemoteRPCParameters(RPCCommandTask[RPCCliParameters]): + METHOD_NAME = 'cli_args' + + def set_args(self, params: RPCCliParameters) -> None: + # more import cycles :( + from dbt.main import parse_args, RPCArgumentParser + split = shlex.split(params.cli) + self.args = parse_args(split, RPCArgumentParser) + + def get_rpc_task_cls(self) -> Type[RPCCommandTask]: + # This is obnoxious, but we don't have actual access to the TaskManager + # so instead we get to dig through all the subclasses of RPCTask + # (recursively!) looking for a matching METHOD_NAME + candidate: Type[RPCCommandTask] + for candidate in RPCCommandTask.recursive_subclasses(): + if candidate.METHOD_NAME == self.args.rpc_method: + return candidate + # this shouldn't happen + raise InternalException( + 'No matching handler found for rpc method {} (which={})' + .format(self.args.rpc_method, self.args.which) + ) + + def handle_request(self) -> RemoteExecutionResult: + cls = self.get_rpc_task_cls() + # we parsed args from the cli, so we're set on that front + task = cls(self.args, self.config, self.manifest) + return task.handle_request() diff --git a/core/dbt/task/rpc_server.py b/core/dbt/task/rpc/server.py similarity index 98% rename from core/dbt/task/rpc_server.py rename to core/dbt/task/rpc/server.py index a7e4c4dea41..026b8b2283c 100644 --- a/core/dbt/task/rpc_server.py +++ b/core/dbt/task/rpc/server.py @@ -1,3 +1,7 @@ +# import these so we can find them +from . import sql_commands # noqa +from . import project_commands # noqa +from .base import RPCTask import json import os import signal @@ -16,7 +20,6 @@ log_manager, ) from dbt.task.base import ConfiguredTask -from dbt.task.remote import RPCTask from dbt.utils import ForgivingJSONEncoder, env_set_truthy from dbt.rpc.response_manager import ResponseManager from dbt.rpc.task_manager import TaskManager diff --git a/core/dbt/task/rpc/sql_commands.py b/core/dbt/task/rpc/sql_commands.py new file mode 100644 index 00000000000..855a9d705d4 --- /dev/null +++ b/core/dbt/task/rpc/sql_commands.py @@ -0,0 +1,196 @@ +import base64 +from datetime import datetime +import signal +import threading + +from dbt.adapters.factory import get_adapter +from dbt.clients.jinja import extract_toplevel_blocks +from dbt.compilation import compile_manifest, compile_node +from dbt.contracts.rpc import RPCExecParameters +from dbt.contracts.rpc import RemoteExecutionResult +from dbt.exceptions import RPCKilledException +from dbt.logger import GLOBAL_LOGGER as logger +from dbt.parser.results import ParseResult +from dbt.parser.rpc import RPCCallParser, RPCMacroParser +from dbt.parser.util import ParserUtils +from dbt.rpc.error import invalid_params +from dbt.rpc.node_runners import RPCCompileRunner, RPCExecuteRunner +from dbt.task.compile import CompileTask +from dbt.task.run import RunTask + +from .base import RPCTask + + +class RemoteRunSQLTask(RPCTask[RPCExecParameters]): + def runtime_cleanup(self, selected_uids): + """Do some pre-run cleanup that is usually performed in Task __init__. + """ + self.run_count = 0 + self.num_nodes = len(selected_uids) + self.node_results = [] + self._skipped_children = {} + self._skipped_children = {} + self._raise_next_tick = None + + def decode_sql(self, sql: str) -> str: + """Base64 decode a string. This should only be used for sql in calls. + + :param str sql: The base64 encoded form of the original utf-8 string + :return str: The decoded utf-8 string + """ + # JSON is defined as using "unicode", we'll go a step further and + # mandate utf-8 (though for the base64 part, it doesn't really matter!) + base64_sql_bytes = str(sql).encode('utf-8') + + try: + sql_bytes = base64.b64decode(base64_sql_bytes, validate=True) + except ValueError: + self.raise_invalid_base64(sql) + + return sql_bytes.decode('utf-8') + + @staticmethod + def raise_invalid_base64(sql): + raise invalid_params( + data={ + 'message': 'invalid base64-encoded sql input', + 'sql': str(sql), + } + ) + + def _extract_request_data(self, data): + data = self.decode_sql(data) + macro_blocks = [] + data_chunks = [] + for block in extract_toplevel_blocks(data): + if block.block_type_name == 'macro': + macro_blocks.append(block.full_block) + else: + data_chunks.append(block.full_block) + macros = '\n'.join(macro_blocks) + sql = ''.join(data_chunks) + return sql, macros + + def _compile_ancestors(self, unique_id: str): + # this just gets a transitive closure of the nodes. We could build a + # special GraphQueue around this, but we do them all in the main thread + # so we only care about preserving dependency order anyway + sorted_ancestors = self.linker.sorted_ephemeral_ancestors( + self.manifest, + unique_id, + ) + # We're just compiling, so we don't need to use a graph queue + adapter = get_adapter(self.config) # type: ignore + + for unique_id in sorted_ancestors: + # for each node, compile it + overwrite it + parsed = self.manifest.expect(unique_id) + self.manifest.nodes[unique_id] = compile_node( + adapter, self.config, parsed, self.manifest, {}, write=False + ) + + def _get_exec_node(self): + results = ParseResult.rpc() + macro_overrides = {} + macros = self.args.macros + sql, macros = self._extract_request_data(self.args.sql) + + if macros: + macro_parser = RPCMacroParser(results, self.config) + for node in macro_parser.parse_remote(macros): + macro_overrides[node.unique_id] = node + + self.manifest.macros.update(macro_overrides) + rpc_parser = RPCCallParser( + results=results, + project=self.config, + root_project=self.config, + macro_manifest=self.manifest, + ) + node = rpc_parser.parse_remote(sql, self.args.name) + self.manifest = ParserUtils.add_new_refs( + manifest=self.manifest, + current_project=self.config, + node=node, + macros=macro_overrides + ) + + # don't write our new, weird manifest! + self.linker = compile_manifest(self.config, self.manifest, write=False) + self._compile_ancestors(node.unique_id) + return node + + def _raise_set_error(self): + if self._raise_next_tick is not None: + raise self._raise_next_tick + + def _in_thread(self, node, thread_done): + runner = self.get_runner(node) + try: + self.node_results.append(runner.safe_run(self.manifest)) + except Exception as exc: + logger.debug('Got exception {}'.format(exc), exc_info=True) + self._raise_next_tick = exc + finally: + thread_done.set() + + def set_args(self, params: RPCExecParameters): + self.args.name = params.name + self.args.sql = params.sql + self.args.macros = params.macros + + def handle_request(self) -> RemoteExecutionResult: + # we could get a ctrl+c at any time, including during parsing. + thread = None + started = datetime.utcnow() + try: + node = self._get_exec_node() + + selected_uids = [node.unique_id] + self.runtime_cleanup(selected_uids) + + thread_done = threading.Event() + thread = threading.Thread(target=self._in_thread, + args=(node, thread_done)) + thread.start() + thread_done.wait() + except KeyboardInterrupt: + adapter = get_adapter(self.config) # type: ignore + if adapter.is_cancelable(): + + for conn_name in adapter.cancel_open_connections(): + logger.debug('canceled query {}'.format(conn_name)) + if thread: + thread.join() + else: + msg = ("The {} adapter does not support query " + "cancellation. Some queries may still be " + "running!".format(adapter.type())) + + logger.debug(msg) + + raise RPCKilledException(signal.SIGINT) + + self._raise_set_error() + + ended = datetime.utcnow() + elapsed = (ended - started).total_seconds() + return self.get_result( + results=self.node_results, + elapsed_time=elapsed, + generated_at=ended, + ) + + +class RemoteCompileTask(RemoteRunSQLTask, CompileTask): + METHOD_NAME = 'compile_sql' + + def get_runner_type(self): + return RPCCompileRunner + + +class RemoteRunTask(RemoteRunSQLTask, RunTask): + METHOD_NAME = 'run_sql' + + def get_runner_type(self): + return RPCExecuteRunner From 66ff79dfbdd9a19ddd63227d8bc9d14914497687 Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Mon, 14 Oct 2019 08:54:17 -0600 Subject: [PATCH 06/13] Refactor RPC tests --- test/integration/048_rpc_test/test_rpc.py | 500 +++++++++++----------- test/integration/base.py | 1 - 2 files changed, 259 insertions(+), 242 deletions(-) diff --git a/test/integration/048_rpc_test/test_rpc.py b/test/integration/048_rpc_test/test_rpc.py index 8c3ede85f8b..9dc377d3833 100644 --- a/test/integration/048_rpc_test/test_rpc.py +++ b/test/integration/048_rpc_test/test_rpc.py @@ -276,11 +276,95 @@ def assertSuccessfulRunResult(self, data, raw_sql, compiled_sql=None, table=None self.assertEqual(result['table'], table) self.assertResultHasTimings(result, 'compile', 'execute') + def assertHasErrorData(self, error, expected_error_data): + self.assertIn('data', error) + error_data = error['data'] + for key, value in expected_error_data.items(): + self.assertIn(key, error_data) + self.assertEqual(error_data[key], value) + return error_data + + def assertRunning(self, sleepers): + sleeper_ps_result = self.query('ps', completed=False, active=True).json() + result = self.assertIsResult(sleeper_ps_result) + self.assertEqual(len(result['rows']), len(sleepers)) + result_map = {rd['request_id']: rd for rd in result['rows']} + for _, request_id in sleepers: + found = result_map[request_id] + self.assertEqual(found['request_id'], request_id) + self.assertEqual(found['method'], 'run_sql') + self.assertEqual(found['state'], 'running') + self.assertEqual(found['timeout'], None) + + def kill_and_assert(self, request_token, request_id): + kill_response = self.query('kill', task_id=request_token).json() + result = self.assertIsResult(kill_response) + self.assertEqual(result['status'], 'killed') + + poll_id = 90891 + + poll_response = self.poll_for_result(request_token, poll_id).json() + error = self.assertIsErrorWithCode(poll_response, 10009, poll_id) + self.assertEqual(error['message'], 'RPC process killed') + self.assertIn('data', error) + error_data = error['data'] + self.assertEqual(error_data['signum'], 2) + self.assertEqual(error_data['message'], 'RPC process killed by signal 2') + self.assertIn('logs', error_data) + return error_data + + def get_sleep_query(self, duration=15, request_id=90890): + sleep_query = self.query( + 'run_sql', + 'select * from pg_sleep({})'.format(duration), + name='sleeper', + _test_request_id=request_id + ).json() + result = self.assertIsResult(sleep_query, id_=request_id) + self.assertIn('request_token', result) + request_token = result['request_token'] + return request_token, request_id + + def wait_for_running(self, timeout=25, raise_on_timeout=True): + started = time.time() + time.sleep(0.5) + elapsed = time.time() - started + + while elapsed < timeout: + status = self.assertIsResult(self.query('status').json()) + if status['status'] == 'running': + return status + time.sleep(0.5) + elapsed = time.time() - started + + status = self.assertIsResult(self.query('status').json()) + if raise_on_timeout: + self.assertEqual( + status['status'], + 'ready', + f'exceeded max time of {timeout}: {elapsed} seconds elapsed' + ) + return status + + def run_command_with_id(self, cmd, id_): + self.assertIsResult(self.async_query(cmd, _test_request_id=id_).json(), id_) + + def make_many_requests(self, num_requests): + stored = [] + for idx in range(num_requests): + response = self.query('run_sql', 'select 1 as id', name='run').json() + result = self.assertIsResult(response) + self.assertIn('request_token', result) + token = result['request_token'] + self.poll_for_result(token) + stored.append(token) + return stored + @mark.flaky(rerun_filter=addr_in_use) -class TestRPCServer(HasRPCServer): +class TestRPCServerCompileRun(HasRPCServer): @use_profile('postgres') - def test_compile_postgres(self): + def test_compile_sql_postgres(self): trivial = self.async_query( 'compile_sql', 'select 1 as id', @@ -364,7 +448,7 @@ def test_compile_postgres(self): ) @use_profile('postgres') - def test_run_postgres(self): + def test_run_sql_postgres(self): # seed + run dbt to make models before using them! self.run_dbt_with_vars(['seed']) self.run_dbt_with_vars(['run']) @@ -486,18 +570,6 @@ def test_run_postgres(self): table={'column_names': ['id'], 'rows': [[1.0]]} ) - def _get_sleep_query(self, duration=15, request_id=90890): - sleep_query = self.query( - 'run_sql', - 'select * from pg_sleep({})'.format(duration), - name='sleeper', - _test_request_id=request_id - ).json() - result = self.assertIsResult(sleep_query, id_=request_id) - self.assertIn('request_token', result) - request_token = result['request_token'] - return request_token, request_id - @mark.flaky(rerun_filter=None) @use_profile('postgres') def test_ps_kill_postgres(self): @@ -512,7 +584,7 @@ def test_ps_kill_postgres(self): self.assertIn('tags', done_result) self.assertEqual(done_result['tags'], task_tags) - request_token, request_id = self._get_sleep_query() + request_token, request_id = self.get_sleep_query() empty_ps_result = self.query('ps', completed=False, active=False).json() result = self.assertIsResult(empty_ps_result) @@ -585,27 +657,10 @@ def test_ps_kill_postgres(self): self.assertGreater(rowdict[1]['elapsed'], 0) self.assertIsNone(rowdict[1]['tags']) - def kill_and_assert(self, request_token, request_id): - kill_response = self.query('kill', task_id=request_token).json() - result = self.assertIsResult(kill_response) - self.assertEqual(result['status'], 'killed') - - poll_id = 90891 - - poll_response = self.poll_for_result(request_token, poll_id).json() - error = self.assertIsErrorWithCode(poll_response, 10009, poll_id) - self.assertEqual(error['message'], 'RPC process killed') - self.assertIn('data', error) - error_data = error['data'] - self.assertEqual(error_data['signum'], 2) - self.assertEqual(error_data['message'], 'RPC process killed by signal 2') - self.assertIn('logs', error_data) - return error_data - @mark.flaky(rerun_filter=lambda *a, **kw: True) @use_profile('postgres') def test_ps_kill_longwait_postgres(self): - request_token, request_id = self._get_sleep_query() + request_token, request_id = self.get_sleep_query() # the test above frequently kills the process during parsing of the # requested node. That's also a useful test, but we should test that @@ -667,14 +722,6 @@ def test_invalid_requests_postgres(self): self.assertIn('logs', error_data) self.assertTrue(len(error_data['logs']) > 0) - def assertHasErrorData(self, error, expected_error_data): - self.assertIn('data', error) - error_data = error['data'] - for key, value in expected_error_data.items(): - self.assertIn(key, error_data) - self.assertEqual(error_data[key], value) - return error_data - @use_profile('postgres') def test_timeout_postgres(self): data = self.async_query( @@ -694,144 +741,201 @@ def test_timeout_postgres(self): self.assertIn('logs', error_data) self.assertTrue(len(error_data['logs']) > 0) - @use_profile('postgres') - def test_seed_project_postgres(self): - # testing "dbt seed" is tricky so we'll just jam some sql in there - self.run_sql_file("seed.sql") - result = self.async_query('seed', show=True).json() + +@mark.flaky(rerun_filter=addr_in_use) +class TestRPCServerProjects(HasRPCServer): + def assertHasResults(self, result, expected, *, missing=None, num_expected=None): dct = self.assertIsResult(result) - self.assertTablesEqual('source', 'seed_expected') self.assertIn('results', dct) results = dct['results'] - self.assertEqual(len(results), 4) - self.assertEqual( - set(r['node']['name'] for r in results), + + if num_expected is None: + num_expected = len(expected) + actual = {r['node']['name'] for r in results} + self.assertEqual(len(actual), num_expected) + self.assertTrue(expected.issubset(actual)) + if missing: + for item in missing: + self.assertNotIn(item, actual) + + def correct_seed_result(self, result): + self.assertTablesEqual('source', 'seed_expected') + self.assertHasResults( + result, {'expected_multi_source', 'other_source_table', 'other_table', 'source'} ) + def assertHasTestResults(self, results, expected, pass_results=None): + self.assertEqual(len(results), expected) + + if pass_results is None: + pass_results = expected + + passes = 0 + for result in results: + # TODO: should this be included even when it's 'none'? Should + # results have all these crazy keys? (no) + self.assertIn('fail', result) + if result['status'] == 0.0: + self.assertIsNone(result['fail']) + passes += 1 + else: + self.assertTrue(result['fail']) + self.assertEqual(passes, pass_results) + + @use_profile('postgres') + def test_seed_project_postgres(self): + # testing "dbt seed" is tricky so we'll just jam some sql in there + self.run_sql_file("seed.sql") + + result = self.async_query('seed', show=True).json() + self.correct_seed_result(result) + + result = self.async_query('seed', show=False).json() + self.correct_seed_result(result) + @use_profile('postgres') def test_seed_project_cli_postgres(self): self.run_sql_file("seed.sql") + result = self.async_query('cli_args', cli='seed --show').json() - dct = self.assertIsResult(result) - self.assertTablesEqual('source', 'seed_expected') - self.assertIn('results', dct) - results = dct['results'] - self.assertEqual(len(results), 4) - self.assertEqual( - set(r['node']['name'] for r in results), - {'expected_multi_source', 'other_source_table', 'other_table', 'source'} - ) + self.correct_seed_result(result) + result = self.async_query('cli_args', cli='seed').json() + self.correct_seed_result(result) @use_profile('postgres') def test_compile_project_postgres(self): self.run_dbt_with_vars(['seed']) + result = self.async_query('compile').json() - dct = self.assertIsResult(result) - self.assertIn('results', dct) - results = dct['results'] - self.assertEqual(len(results), 11) - compiled = set(r['node']['name'] for r in results) - self.assertTrue(compiled.issuperset( - {'descendant_model', 'multi_source_model', 'nonsource_descendant'} - )) - self.assertNotIn('ephemeral_model', compiled) + self.assertHasResults( + result, + {'descendant_model', 'multi_source_model', 'nonsource_descendant'}, + missing=['ephemeral_model'], + num_expected=11, + ) + + result = self.async_query('compile', models=['source:test_source+']).json() + self.assertHasResults( + result, + {'descendant_model', 'multi_source_model'}, + missing=['ephemeral_model', 'nonsource_descendant'], + num_expected=6, + ) @use_profile('postgres') def test_compile_project_cli_postgres(self): self.run_dbt_with_vars(['seed']) + result = self.async_query('cli_args', cli='compile').json() + self.assertHasResults( + result, + {'descendant_model', 'multi_source_model', 'nonsource_descendant'}, + missing=['ephemeral_model'], + num_expected=11, + ) + result = self.async_query('cli_args', cli='compile --models=source:test_source+').json() - dct = self.assertIsResult(result) - self.assertIn('results', dct) - results = dct['results'] - self.assertEqual(len(results), 6) - compiled = set(r['node']['name'] for r in results) - self.assertTrue(compiled.issuperset( - {'descendant_model', 'multi_source_model'} - )) - self.assertNotIn('ephemeral_model', compiled) - self.assertNotIn('nonsource_descendant', compiled) + self.assertHasResults( + result, + {'descendant_model', 'multi_source_model'}, + missing=['ephemeral_model', 'nonsource_descendant'], + num_expected=6, + ) @use_profile('postgres') def test_run_project_postgres(self): self.run_dbt_with_vars(['seed']) result = self.async_query('run').json() - dct = self.assertIsResult(result) - self.assertIn('results', dct) - results = dct['results'] - self.assertEqual(len(results), 3) - self.assertEqual( - set(r['node']['name'] for r in results), - {'descendant_model', 'multi_source_model', 'nonsource_descendant'} - ) + self.assertHasResults(result, {'descendant_model', 'multi_source_model', 'nonsource_descendant'}) self.assertTablesEqual('multi_source_model', 'expected_multi_source') @use_profile('postgres') def test_run_project_cli_postgres(self): self.run_dbt_with_vars(['seed']) result = self.async_query('cli_args', cli='run').json() - dct = self.assertIsResult(result) - self.assertIn('results', dct) - results = dct['results'] - self.assertEqual(len(results), 3) - self.assertEqual( - set(r['node']['name'] for r in results), - {'descendant_model', 'multi_source_model', 'nonsource_descendant'} - ) + self.assertHasResults(result, {'descendant_model', 'multi_source_model', 'nonsource_descendant'}) self.assertTablesEqual('multi_source_model', 'expected_multi_source') @use_profile('postgres') def test_test_project_postgres(self): self.run_dbt_with_vars(['seed']) - result = self.async_query('run').json() - dct = self.assertIsResult(result) - result = self.async_query('test').json() + self.run_dbt_with_vars(['run']) + data = self.async_query('test').json() + result = self.assertIsResult(data) + self.assertIn('results', result) + self.assertHasTestResults(result['results'], 4) + + @use_profile('postgres') + def test_test_project_cli_postgres(self): + self.run_dbt_with_vars(['seed']) + self.run_dbt_with_vars(['run']) + data = self.async_query('cli_args', cli='test').json() + result = self.assertIsResult(data) + self.assertIn('results', result) + self.assertHasTestResults(result['results'], 4) + + def assertManifestExists(self, length): + self.assertTrue(os.path.exists('target/manifest.json')) + with open('target/manifest.json') as fp: + manifest = json.load(fp) + self.assertIn('nodes', manifest) + self.assertEqual(len(manifest['nodes']), length) + + def assertHasDocsGenerated(self, result, expected): dct = self.assertIsResult(result) - self.assertIn('results', dct) - results = dct['results'] - self.assertEqual(len(results), 4) - for result in results: - self.assertEqual(result['status'], 0.0) - # TODO: should this be included even when it's 'none'? Should - # results have all these crazy keys? (no) - self.assertIn('fail', result) - self.assertIsNone(result['fail']) + self.assertIn('status', dct) + self.assertTrue(dct['status']) + self.assertIn('nodes', dct) + nodes = dct['nodes'] + self.assertEqual(set(nodes), expected) - def _wait_for_running(self, timeout=25, raise_on_timeout=True): - started = time.time() - time.sleep(0.5) - elapsed = time.time() - started - while elapsed < timeout: - status = self.assertIsResult(self.query('status').json()) - if status['status'] == 'running': - return status - time.sleep(0.5) - elapsed = time.time() - started + def assertCatalogExists(self): + self.assertTrue(os.path.exists('target/catalog.json')) + with open('target/catalog.json') as fp: + catalog = json.load(fp) - status = self.assertIsResult(self.query('status').json()) - if raise_on_timeout: - self.assertEqual( - status['status'], - 'ready', - f'exceeded max time of {timeout}: {elapsed} seconds elapsed' - ) - return status + def _correct_docs_generate_result(self, result): + expected = { + 'model.test.descendant_model', + 'model.test.multi_source_model', + 'model.test.nonsource_descendant', + 'seed.test.expected_multi_source', + 'seed.test.other_source_table', + 'seed.test.other_table', + 'seed.test.source', + 'source.test.other_source.test_table', + 'source.test.test_source.other_test_table', + 'source.test.test_source.test_table', + } + self.assertHasDocsGenerated(result, expected) + self.assertCatalogExists() + self.assertManifestExists(17) - def assertRunning(self, sleepers): - sleeper_ps_result = self.query('ps', completed=False, active=True).json() - result = self.assertIsResult(sleeper_ps_result) - self.assertEqual(len(result['rows']), len(sleepers)) - result_map = {rd['request_id']: rd for rd in result['rows']} - for _, request_id in sleepers: - found = result_map[request_id] - self.assertEqual(found['request_id'], request_id) - self.assertEqual(found['method'], 'run_sql') - self.assertEqual(found['state'], 'running') - self.assertEqual(found['timeout'], None) - def _add_command(self, cmd, id_): - self.assertIsResult(self.async_query(cmd, _test_request_id=id_).json(), id_) + @use_profile('postgres') + def test_docs_generate_postgres(self): + self.run_dbt_with_vars(['seed']) + self.run_dbt_with_vars(['run']) + self.assertFalse(os.path.exists('target/catalog.json')) + if os.path.exists('target/manifest.json'): + os.remove('target/manifest.json') + result = self.async_query('cli_args', cli='docs generate').json() + self._correct_docs_generate_result(result) + + @use_profile('postgres') + def test_docs_generate_postgres_cli(self): + self.run_dbt_with_vars(['seed']) + self.run_dbt_with_vars(['run']) + self.assertFalse(os.path.exists('target/catalog.json')) + if os.path.exists('target/manifest.json'): + os.remove('target/manifest.json') + result = self.async_query('cli_args', cli='docs generate').json() + self._correct_docs_generate_result(result) + + +@mark.flaky(rerun_filter=addr_in_use) +class TestRPCTaskManagement(HasRPCServer): @mark.flaky(rerun_filter=lambda *a, **kw: True) @use_profile('postgres') @@ -849,32 +953,27 @@ def test_sighup_postgres(self): done_query = self.async_query('compile_sql', 'select 1 as id', name='done').json() self.assertIsResult(done_query) sleepers = [] - command_ids = [] - sleepers.append(self._get_sleep_query(duration=60, request_id=1000)) + sleepers.append(self.get_sleep_query(duration=60, request_id=1000)) self.assertRunning(sleepers) - self._add_command('seed', 20) - command_ids.append(20) - self._add_command('run', 21) - command_ids.append(21) + self.run_command_with_id('seed', 20) + self.run_command_with_id('run', 21) # sighup a few times for _ in range(10): os.kill(status['pid'], signal.SIGHUP) - status = self._wait_for_running() + status = self.wait_for_running() # we should still still see our service: self.assertRunning(sleepers) - self._add_command('seed', 30) - command_ids.append(30) - self._add_command('run', 31) - command_ids.append(31) + self.run_command_with_id('seed', 30) + self.run_command_with_id('run', 31) # start a new one too - sleepers.append(self._get_sleep_query(duration=60, request_id=1001)) + sleepers.append(self.get_sleep_query(duration=60, request_id=1001)) # now we should see both self.assertRunning(sleepers) @@ -885,22 +984,11 @@ def test_sighup_postgres(self): self.assertRunning([alive]) self.kill_and_assert(*alive) - def _make_any_requests(self, num_requests): - stored = [] - for idx in range(num_requests): - response = self.query('run_sql', 'select 1 as id', name='run').json() - result = self.assertIsResult(response) - self.assertIn('request_token', result) - token = result['request_token'] - self.poll_for_result(token) - stored.append(token) - return stored - @use_profile('postgres') def test_gc_by_time_postgres(self): # make a few normal requests num_requests = 10 - self._make_any_requests(num_requests) + self.make_many_requests(num_requests) resp = self.query('ps', completed=True, active=True).json() result = self.assertIsResult(resp) @@ -920,7 +1008,7 @@ def test_gc_by_time_postgres(self): def test_gc_by_id_postgres(self): # make 10 requests, then gc half of them num_requests = 10 - stored = self._make_any_requests(num_requests) + stored = self.make_many_requests(num_requests) resp = self.query('ps', completed=True, active=True).json() result = self.assertIsResult(resp) @@ -951,7 +1039,7 @@ def test_gc_by_id_postgres(self): @use_profile('postgres') def test_postgres_gc_change_interval(self): num_requests = 10 - self._make_any_requests(num_requests) + self.make_many_requests(num_requests) # all present resp = self.query('ps', completed=True, active=True).json() @@ -978,83 +1066,13 @@ def test_postgres_gc_change_interval(self): self.assertEqual(len(result['running']), 0) # make more requests - self._make_any_requests(num_requests) + self.make_many_requests(num_requests) time.sleep(0.5) # there should be 2 left! resp = self.query('ps', completed=True, active=True).json() result = self.assertIsResult(resp) self.assertEqual(len(result['rows']), 2) - @use_profile('postgres') - def test_docs_generate_postgres(self): - self.run_dbt_with_vars(['seed']) - self.run_dbt_with_vars(['run']) - self.assertFalse(os.path.exists('target/catalog.json')) - if os.path.exists('target/manifest.json'): - os.remove('target/manifest.json') - result = self.async_query('docs.generate').json() - dct = self.assertIsResult(result) - self.assertTrue(os.path.exists('target/catalog.json')) - self.assertIn('status', dct) - self.assertTrue(dct['status']) - self.assertIn('nodes', dct) - nodes = dct['nodes'] - self.assertEqual(len(nodes), 10) - expected = { - 'model.test.descendant_model', - 'model.test.multi_source_model', - 'model.test.nonsource_descendant', - 'seed.test.expected_multi_source', - 'seed.test.other_source_table', - 'seed.test.other_table', - 'seed.test.source', - 'source.test.other_source.test_table', - 'source.test.test_source.other_test_table', - 'source.test.test_source.test_table', - } - for uid in expected: - self.assertIn(uid, nodes) - self.assertTrue(os.path.exists('target/manifest.json')) - with open('target/manifest.json') as fp: - manifest = json.load(fp) - self.assertIn('nodes', manifest) - self.assertEqual(len(manifest['nodes']), 17) - - @use_profile('postgres') - def test_docs_generate_postgres(self): - self.run_dbt_with_vars(['seed']) - self.run_dbt_with_vars(['run']) - self.assertFalse(os.path.exists('target/catalog.json')) - if os.path.exists('target/manifest.json'): - os.remove('target/manifest.json') - result = self.async_query('cli_args', cli='docs generate').json() - dct = self.assertIsResult(result) - self.assertTrue(os.path.exists('target/catalog.json')) - self.assertIn('status', dct) - self.assertTrue(dct['status']) - self.assertIn('nodes', dct) - nodes = dct['nodes'] - self.assertEqual(len(nodes), 10) - expected = { - 'model.test.descendant_model', - 'model.test.multi_source_model', - 'model.test.nonsource_descendant', - 'seed.test.expected_multi_source', - 'seed.test.other_source_table', - 'seed.test.other_table', - 'seed.test.source', - 'source.test.other_source.test_table', - 'source.test.test_source.other_test_table', - 'source.test.test_source.test_table', - } - for uid in expected: - self.assertIn(uid, nodes) - self.assertTrue(os.path.exists('target/manifest.json')) - with open('target/manifest.json') as fp: - manifest = json.load(fp) - self.assertIn('nodes', manifest) - self.assertEqual(len(manifest['nodes']), 17) - class FailedServerProcess(ServerProcess): def _compare_result(self, result): @@ -1070,6 +1088,12 @@ class TestRPCServerFailed(HasRPCServer): def models(self): return "malformed_models" + def tearDown(self): + # prevent an OperationalError where the server closes on us in the + # background + self.adapter.cleanup_connections() + super().tearDown() + @use_profile('postgres') def test_postgres_status_error(self): status = self.assertIsResult(self.query('status').json()) @@ -1092,9 +1116,3 @@ def test_postgres_status_error(self): None) self.assertIn('message', data) self.assertIn('Invalid test config', str(data['message'])) - - def tearDown(self): - # prevent an OperationalError where the server closes on us in the - # background - self.adapter.cleanup_connections() - super().tearDown() diff --git a/test/integration/base.py b/test/integration/base.py index d6122029a20..25635314908 100644 --- a/test/integration/base.py +++ b/test/integration/base.py @@ -6,7 +6,6 @@ import time import traceback import unittest -import warnings from contextlib import contextmanager from datetime import datetime from functools import wraps From 75c8feaeb9603fd4c72f6879b5eea3ce572ff8c9 Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Mon, 14 Oct 2019 10:12:36 -0600 Subject: [PATCH 07/13] Make ManifestMetadata a first-class object --- core/dbt/config/project.py | 4 ++ core/dbt/contracts/graph/manifest.py | 70 +++++++++------------------- core/dbt/parser/manifest.py | 24 +++++----- core/dbt/parser/util.py | 2 +- core/dbt/rpc/method.py | 2 +- test/unit/test_manifest.py | 19 ++------ 6 files changed, 44 insertions(+), 77 deletions(-) diff --git a/core/dbt/config/project.py b/core/dbt/config/project.py index 56d5b67e369..e9455726b21 100644 --- a/core/dbt/config/project.py +++ b/core/dbt/config/project.py @@ -19,6 +19,7 @@ from dbt.utils import parse_cli_vars from dbt.source_config import SourceConfig +from dbt.contracts.graph.manifest import ManifestMetadata from dbt.contracts.project import Project as ProjectContract from dbt.contracts.project import PackageConfig @@ -453,3 +454,6 @@ def validate_version(self): ] ) raise DbtProjectError(msg) + + def get_metadata(self) -> ManifestMetadata: + return ManifestMetadata(self.hashed_name()) diff --git a/core/dbt/contracts/graph/manifest.py b/core/dbt/contracts/graph/manifest.py index 374ebb0882f..8287093324f 100644 --- a/core/dbt/contracts/graph/manifest.py +++ b/core/dbt/contracts/graph/manifest.py @@ -2,7 +2,7 @@ import os from dataclasses import dataclass, field from datetime import datetime -from typing import Dict, List, Optional, Union, Mapping +from typing import Dict, List, Optional, Union, Mapping, Any from uuid import UUID from hologram import JsonSchemaMixin @@ -11,7 +11,6 @@ ParsedDocumentation from dbt.contracts.graph.compiled import CompileResultNode from dbt.contracts.util import Writable, Replaceable -from dbt.config import Project from dbt.exceptions import raise_duplicate_resource_name, InternalException from dbt.logger import GLOBAL_LOGGER as logger from dbt.node_types import NodeType @@ -166,9 +165,21 @@ def remote(cls, contents: str) -> 'SourceFile': @dataclass class ManifestMetadata(JsonSchemaMixin, Replaceable): - project_id: Optional[str] - user_id: Optional[UUID] - send_anonymous_usage_stats: Optional[bool] + project_id: Optional[str] = None + user_id: Optional[UUID] = None + send_anonymous_usage_stats: Optional[bool] = None + + def __post_init__(self): + if tracking.active_user is None: + return + + if self.user_id is None: + self.user_id = tracking.active_user.id + + if self.send_anonymous_usage_stats is None: + self.send_anonymous_usage_stats = ( + not tracking.active_user.do_not_track + ) def _sort_values(dct): @@ -197,7 +208,7 @@ def _deepcopy(value): return value.from_dict(value.to_dict()) -@dataclass(init=False) +@dataclass class Manifest: """The manifest for the full graph, after parsing and during compilation. """ @@ -207,27 +218,8 @@ class Manifest: generated_at: datetime disabled: List[ParsedNode] files: Mapping[str, SourceFile] - metadata: ManifestMetadata = field(init=False) - - def __init__( - self, - nodes: Mapping[str, CompileResultNode], - macros: Mapping[str, ParsedMacro], - docs: Mapping[str, ParsedDocumentation], - generated_at: datetime, - disabled: List[ParsedNode], - files: Mapping[str, SourceFile], - config: Optional[Project] = None, - ) -> None: - self.metadata = self.get_metadata(config) - self.nodes = nodes - self.macros = macros - self.docs = docs - self.generated_at = generated_at - self.disabled = disabled - self.files = files - self.flat_graph = None - super(Manifest, self).__init__() + metadata: ManifestMetadata = field(default_factory=ManifestMetadata) + flat_graph: Dict[str, Any] = field(default_factory=dict) @classmethod def from_macros(cls, macros=None, files=None) -> 'Manifest': @@ -242,7 +234,6 @@ def from_macros(cls, macros=None, files=None) -> 'Manifest': generated_at=datetime.utcnow(), disabled=[], files=files, - config=None, ) def update_node(self, new_node): @@ -259,25 +250,6 @@ def update_node(self, new_node): ) self.nodes[unique_id] = new_node - @staticmethod - def get_metadata(config: Optional[Project]) -> ManifestMetadata: - project_id = None - user_id = None - send_anonymous_usage_stats = None - - if config is not None: - project_id = config.hashed_name() - - if tracking.active_user is not None: - user_id = tracking.active_user.id - send_anonymous_usage_stats = not tracking.active_user.do_not_track - - return ManifestMetadata( - project_id=project_id, - user_id=user_id, - send_anonymous_usage_stats=send_anonymous_usage_stats, - ) - def build_flat_graph(self): """This attribute is used in context.common by each node, so we want to only build it once and avoid any concurrency issues around it. @@ -431,14 +403,14 @@ def get_used_schemas(self, resource_types=None): def get_used_databases(self): return frozenset(node.database for node in self.nodes.values()) - def deepcopy(self, config=None): + def deepcopy(self): return Manifest( nodes={k: _deepcopy(v) for k, v in self.nodes.items()}, macros={k: _deepcopy(v) for k, v in self.macros.items()}, docs={k: _deepcopy(v) for k, v in self.docs.items()}, generated_at=self.generated_at, disabled=[_deepcopy(n) for n in self.disabled], - config=config, + metadata=self.metadata, files={k: _deepcopy(v) for k, v in self.files.items()}, ) diff --git a/core/dbt/parser/manifest.py b/core/dbt/parser/manifest.py index 4324a54626f..af97ec2f50f 100644 --- a/core/dbt/parser/manifest.py +++ b/core/dbt/parser/manifest.py @@ -15,18 +15,18 @@ from dbt.contracts.graph.compiled import CompileResultNode from dbt.contracts.graph.manifest import Manifest, FilePath, FileHash from dbt.parser.base import BaseParser -from dbt.parser import AnalysisParser -from dbt.parser import DataTestParser -from dbt.parser import DocumentationParser -from dbt.parser import HookParser -from dbt.parser import MacroParser -from dbt.parser import ModelParser -from dbt.parser import ParseResult -from dbt.parser import SchemaParser -from dbt.parser import SeedParser -from dbt.parser import SnapshotParser -from dbt.parser import ParserUtils +from dbt.parser.analysis import AnalysisParser +from dbt.parser.data_test import DataTestParser +from dbt.parser.docs import DocumentationParser +from dbt.parser.hooks import HookParser +from dbt.parser.macros import MacroParser +from dbt.parser.models import ModelParser +from dbt.parser.results import ParseResult +from dbt.parser.schemas import SchemaParser from dbt.parser.search import FileBlock +from dbt.parser.seeds import SeedParser +from dbt.parser.snapshots import SnapshotParser +from dbt.parser.util import ParserUtils from dbt.version import __version__ @@ -263,7 +263,7 @@ def create_manifest(self) -> Manifest: macros=self.results.macros, docs=self.results.docs, generated_at=datetime.utcnow(), - config=self.root_project, + metadata=self.root_project.get_metadata(), disabled=disabled, files=self.results.files, ) diff --git a/core/dbt/parser/util.py b/core/dbt/parser/util.py index 74801898f96..add93a7cdba 100644 --- a/core/dbt/parser/util.py +++ b/core/dbt/parser/util.py @@ -259,7 +259,7 @@ def add_new_refs(cls, manifest, current_project: Project, node, macros): insert the new node into it as if it were part of regular ref processing """ - manifest = manifest.deepcopy(config=current_project) + manifest = manifest.deepcopy() # it's ok for macros to silently override a local project macro name manifest.macros.update(macros) diff --git a/core/dbt/rpc/method.py b/core/dbt/rpc/method.py index 39f3b24ee95..16159fd4a7a 100644 --- a/core/dbt/rpc/method.py +++ b/core/dbt/rpc/method.py @@ -23,7 +23,7 @@ class RemoteMethod(Generic[Parameters, Result]): def __init__(self, args, config, manifest): self.args = args self.config = config - self.manifest = manifest.deepcopy(config=config) + self.manifest = manifest.deepcopy() @classmethod def get_parameters(cls) -> Type[Parameters]: diff --git a/test/unit/test_manifest.py b/test/unit/test_manifest.py index 218e410eb66..69c28c4dc22 100644 --- a/test/unit/test_manifest.py +++ b/test/unit/test_manifest.py @@ -266,14 +266,11 @@ def test__build_flat_graph(self): self.assertEqual(frozenset(node), REQUIRED_PARSED_NODE_KEYS) @mock.patch.object(tracking, 'active_user') - def test_get_metadata(self, mock_user): + def test_metadata(self, mock_user): mock_user.id = 'cfc9500f-dc7f-4c83-9ea7-2c581c1b38cf' mock_user.do_not_track = True - config = mock.MagicMock() - # md5 of 'test' - config.hashed_name.return_value = '098f6bcd4621d373cade4e832627b4f6' self.assertEqual( - Manifest.get_metadata(config), + ManifestMetadata(project_id='098f6bcd4621d373cade4e832627b4f6'), ManifestMetadata( project_id='098f6bcd4621d373cade4e832627b4f6', user_id='cfc9500f-dc7f-4c83-9ea7-2c581c1b38cf', @@ -286,17 +283,11 @@ def test_get_metadata(self, mock_user): def test_no_nodes_with_metadata(self, mock_user): mock_user.id = 'cfc9500f-dc7f-4c83-9ea7-2c581c1b38cf' mock_user.do_not_track = True - config = mock.MagicMock() - # md5 of 'test' - config.hashed_name.return_value = '098f6bcd4621d373cade4e832627b4f6' + metadata = ManifestMetadata(project_id='098f6bcd4621d373cade4e832627b4f6') manifest = Manifest(nodes={}, macros={}, docs={}, generated_at=datetime.utcnow(), disabled=[], - config=config, files={}) - metadata = { - 'project_id': '098f6bcd4621d373cade4e832627b4f6', - 'user_id': 'cfc9500f-dc7f-4c83-9ea7-2c581c1b38cf', - 'send_anonymous_usage_stats': False, - } + metadata=metadata, files={}) + self.assertEqual( manifest.writable_manifest().to_dict(), { From 4aa4295508cd216d9f1caa430e2a530efb3ec920 Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Mon, 14 Oct 2019 10:03:58 -0600 Subject: [PATCH 08/13] Make mypy totally happy Some circular import cleanups remove is_type function, just compare to resource_type Add type checking for dbt deps --- core/dbt/adapters/base/impl.py | 8 +- core/dbt/adapters/base/plugin.py | 4 +- core/dbt/adapters/factory.py | 4 +- core/dbt/compilation.py | 12 +- core/dbt/config/profile.py | 4 +- core/dbt/context/base.py | 2 +- core/dbt/context/common.py | 19 +- core/dbt/context/operation.py | 2 +- core/dbt/context/parser.py | 2 +- core/dbt/context/runtime.py | 6 +- core/dbt/contracts/connection.py | 11 +- core/dbt/contracts/graph/compiled.py | 14 +- core/dbt/contracts/graph/parsed.py | 44 ++- core/dbt/contracts/graph/unparsed.py | 8 +- core/dbt/contracts/project.py | 8 +- core/dbt/contracts/util.py | 19 +- core/dbt/graph/selector.py | 2 +- core/dbt/linker.py | 4 +- core/dbt/task/deps.py | 341 +++++++++--------- core/dbt/utils.py | 10 +- .../test_concurrent_transaction.py | 4 +- .../038_caching_test/test_caching.py | 4 +- test/integration/base.py | 4 +- test/unit/test_compiler.py | 8 +- test/unit/test_deps.py | 38 +- test/unit/utils.py | 6 +- tox.ini | 45 +-- 27 files changed, 311 insertions(+), 322 deletions(-) diff --git a/core/dbt/adapters/base/impl.py b/core/dbt/adapters/base/impl.py index 0eea68de714..84574874af2 100644 --- a/core/dbt/adapters/base/impl.py +++ b/core/dbt/adapters/base/impl.py @@ -13,10 +13,8 @@ import dbt.flags from dbt.clients.agate_helper import empty_table -from dbt.config import RuntimeConfig from dbt.contracts.graph.manifest import Manifest from dbt.node_types import NodeType -from dbt.parser.manifest import load_internal_manifest from dbt.logger import GLOBAL_LOGGER as logger from dbt.utils import filter_null_values @@ -196,8 +194,8 @@ class BaseAdapter(metaclass=AdapterMeta): # for use in materializations AdapterSpecificConfigs: FrozenSet[str] = frozenset() - def __init__(self, config: RuntimeConfig): - self.config: RuntimeConfig = config + def __init__(self, config): + self.config = config self.cache = RelationsCache() self.connections = self.ConnectionManager(config) self._internal_manifest_lazy: Optional[Manifest] = None @@ -280,6 +278,8 @@ def check_internal_manifest(self) -> Optional[Manifest]: def load_internal_manifest(self) -> Manifest: if self._internal_manifest_lazy is None: + # avoid a circular import + from dbt.parser.manifest import load_internal_manifest manifest = load_internal_manifest(self.config) self._internal_manifest_lazy = manifest return self._internal_manifest_lazy diff --git a/core/dbt/adapters/base/plugin.py b/core/dbt/adapters/base/plugin.py index c307c97d62c..d9f128ad33d 100644 --- a/core/dbt/adapters/base/plugin.py +++ b/core/dbt/adapters/base/plugin.py @@ -1,6 +1,5 @@ from typing import List, Optional, Type -from dbt.config.project import Project from dbt.adapters.base import BaseAdapter, Credentials @@ -18,6 +17,9 @@ def __init__( include_path: str, dependencies: Optional[List[str]] = None ): + # avoid an import cycle + from dbt.config.project import Project + self.adapter: Type[BaseAdapter] = adapter self.credentials: Type[Credentials] = credentials self.include_path: str = include_path diff --git a/core/dbt/adapters/factory.py b/core/dbt/adapters/factory.py index c0157b382fd..dce72369e02 100644 --- a/core/dbt/adapters/factory.py +++ b/core/dbt/adapters/factory.py @@ -74,7 +74,7 @@ def load_plugin(self, name: str) -> Type[Credentials]: return plugin.credentials - def register_adapter(self, config: HasCredentials) -> Adapter: + def register_adapter(self, config: HasCredentials) -> None: adapter_name = config.credentials.type adapter_type = self.get_adapter_class_by_name(adapter_name) @@ -110,7 +110,7 @@ def cleanup_connections(self): def register_adapter(config: HasCredentials) -> None: - return FACTORY.register_adapter(config) + FACTORY.register_adapter(config) def get_adapter(config: HasCredentials): diff --git a/core/dbt/compilation.py b/core/dbt/compilation.py index d603754c2db..b6bbc5442b6 100644 --- a/core/dbt/compilation.py +++ b/core/dbt/compilation.py @@ -6,7 +6,7 @@ import dbt.include import dbt.tracking -from dbt.utils import get_materialization, NodeType, is_type +from dbt.node_types import NodeType from dbt.linker import Linker import dbt.context.runtime @@ -140,7 +140,7 @@ def compile_node(self, node, manifest, extra_context=None): # data tests get wrapped in count(*) # TODO : move this somewhere more reasonable if 'data' in injected_node.tags and \ - is_type(injected_node, NodeType.Test): + injected_node.resource_type == NodeType.Test: injected_node.wrapped_sql = ( "select count(*) as errors " "from (\n{test_sql}\n) sbq").format( @@ -149,14 +149,14 @@ def compile_node(self, node, manifest, extra_context=None): # don't wrap schema tests or analyses. injected_node.wrapped_sql = injected_node.injected_sql - elif is_type(injected_node, NodeType.Snapshot): + elif injected_node.resource_type == NodeType.Snapshot: # unfortunately we do everything automagically for # snapshots. in the future it'd be nice to generate # the SQL at the parser level. pass - elif(is_type(injected_node, NodeType.Model) and - get_materialization(injected_node) == 'ephemeral'): + elif(injected_node.resource_type == NodeType.Model and + injected_node.get_materialization() == 'ephemeral'): pass else: @@ -219,7 +219,7 @@ def _is_writable(node): if not node.injected_sql: return False - if dbt.utils.is_type(node, NodeType.Snapshot): + if node.resource_type == NodeType.Snapshot: return False return True diff --git a/core/dbt/config/profile.py b/core/dbt/config/profile.py index 25c4292f282..ce8fa34e27d 100644 --- a/core/dbt/config/profile.py +++ b/core/dbt/config/profile.py @@ -2,7 +2,6 @@ from hologram import ValidationError -from dbt.adapters.factory import load_plugin from dbt.clients.system import load_file_contents from dbt.clients.yaml_helper import load_yaml_text from dbt.contracts.project import ProfileConfig, UserConfig @@ -121,6 +120,8 @@ def validate(self): @staticmethod def _credentials_from_profile(profile, profile_name, target_name): + # avoid an import cycle + from dbt.adapters.factory import load_plugin # credentials carry their 'type' in their actual type, not their # attributes. We do want this in order to pick our Credentials class. if 'type' not in profile: @@ -129,7 +130,6 @@ def _credentials_from_profile(profile, profile_name, target_name): .format(profile_name, target_name)) typename = profile.pop('type') - try: cls = load_plugin(typename) credentials = cls.from_dict(profile) diff --git a/core/dbt/context/base.py b/core/dbt/context/base.py index 3994694a13e..6956d3aefef 100644 --- a/core/dbt/context/base.py +++ b/core/dbt/context/base.py @@ -39,7 +39,7 @@ def env_var(var, default=None): def debug_here(): import sys - import ipdb + import ipdb # type: ignore frame = sys._getframe(3) ipdb.set_trace(frame) diff --git a/core/dbt/context/common.py b/core/dbt/context/common.py index 171a89c8df0..11be906d239 100644 --- a/core/dbt/context/common.py +++ b/core/dbt/context/common.py @@ -1,7 +1,8 @@ import agate import json import os -from typing import Union, Callable +from typing import Union, Callable, Type +from typing_extensions import Protocol import dbt.clients.agate_helper from dbt.contracts.graph.compiled import CompiledSeedNode @@ -18,7 +19,7 @@ from dbt.logger import GLOBAL_LOGGER as logger from dbt.clients.jinja import get_rendered from dbt.context.base import ( - debug_here, env_var, get_context_modules, add_tracking + debug_here, env_var, get_context_modules, add_tracking, Var ) @@ -82,6 +83,20 @@ def Relation(self): return self.db_wrapper.Relation +class Config(Protocol): + def __init__(self, model, source_config): + ... + + +class Provider(Protocol): + execute: bool + Config: Type[Config] + DatabaseWrapper: Type[BaseDatabaseWrapper] + Var: Type[Var] + ref: Type[BaseResolver] + source: Type[BaseResolver] + + def _add_macro_map(context, package_name, macro_map): """Update an existing context in-place, adding the given macro map to the appropriate package namespace. Adapter packages get inserted into the diff --git a/core/dbt/context/operation.py b/core/dbt/context/operation.py index 50953f4068c..d892e12f013 100644 --- a/core/dbt/context/operation.py +++ b/core/dbt/context/operation.py @@ -3,7 +3,7 @@ from dbt.exceptions import raise_compiler_error -class RefResolver(runtime.BaseRefResolver): +class RefResolver(runtime.RefResolver): def __call__(self, *args): # When you call ref(), this is what happens at operation runtime target_model, name = self.resolve(args) diff --git a/core/dbt/context/parser.py b/core/dbt/context/parser.py index db98f84412c..1d95f18efb3 100644 --- a/core/dbt/context/parser.py +++ b/core/dbt/context/parser.py @@ -121,7 +121,7 @@ def __call__(self, *args): return self.Relation.create_from(self.config, self.model) -class Provider: +class Provider(dbt.context.common.Provider): execute = False Config = Config DatabaseWrapper = DatabaseWrapper diff --git a/core/dbt/context/runtime.py b/core/dbt/context/runtime.py index c97a75b50dc..183f45c2464 100644 --- a/core/dbt/context/runtime.py +++ b/core/dbt/context/runtime.py @@ -9,7 +9,7 @@ from dbt.logger import GLOBAL_LOGGER as logger # noqa -class BaseRefResolver(dbt.context.common.BaseResolver): +class RefResolver(dbt.context.common.BaseResolver): def resolve(self, args): name = None package = None @@ -48,8 +48,6 @@ def create_relation(self, target_model, name): else: return self.Relation.create_from(self.config, target_model) - -class RefResolver(BaseRefResolver): def validate(self, resolved, args): if resolved.unique_id not in self.model.depends_on.nodes: dbt.exceptions.ref_bad_context(self.model, args) @@ -149,7 +147,7 @@ class Var(dbt.context.base.Var): pass -class Provider: +class Provider(dbt.context.common.Provider): execute = True Config = Config DatabaseWrapper = DatabaseWrapper diff --git a/core/dbt/contracts/connection.py b/core/dbt/contracts/connection.py index e0e1fe112aa..d8a83f4f9ae 100644 --- a/core/dbt/contracts/connection.py +++ b/core/dbt/contracts/connection.py @@ -29,10 +29,11 @@ class ConnectionState(StrEnum): class Connection(ExtensibleJsonSchemaMixin, Replaceable): type: Identifier name: Optional[str] - _credentials: JsonSchemaMixin = None # underscore to prevent serialization state: ConnectionState = ConnectionState.INIT transaction_open: bool = False - _handle: Optional[Any] = None # underscore to prevent serialization + # prevent serialization + _handle: Optional[Any] = None + _credentials: JsonSchemaMixin = field(init=False) def __init__( self, @@ -45,8 +46,8 @@ def __init__( ) -> None: self.type = type self.name = name - self.credentials = credentials self.state = state + self.credentials = credentials self.transaction_open = transaction_open self.handle = handle @@ -71,8 +72,8 @@ def handle(self, value): # and https://github.com/python/mypy/issues/5374 # for why we have type: ignore. Maybe someday dataclasses + abstract classes # will work. -@dataclass -class Credentials( # type: ignore +@dataclass # type: ignore +class Credentials( ExtensibleJsonSchemaMixin, Replaceable, metaclass=abc.ABCMeta diff --git a/core/dbt/contracts/graph/compiled.py b/core/dbt/contracts/graph/compiled.py index cd47718ea15..aed62f88846 100644 --- a/core/dbt/contracts/graph/compiled.py +++ b/core/dbt/contracts/graph/compiled.py @@ -14,18 +14,18 @@ ) from dbt.node_types import NodeType from dbt.contracts.util import Replaceable -from dbt.exceptions import InternalException +from dbt.exceptions import InternalException, RuntimeException from hologram import JsonSchemaMixin from dataclasses import dataclass, field -import sqlparse -from typing import Optional, List, Union +import sqlparse # type: ignore +from typing import Optional, List, Union, Dict, Type @dataclass class InjectedCTE(JsonSchemaMixin, Replaceable): id: str - sql: Optional[str] = None + sql: str # for some frustrating reason, we can't subclass from ParsedNode directly, # or typing.Union will flatten CompiledNode+ParsedNode into just ParsedNode. @@ -45,6 +45,10 @@ class CompiledNode(ParsedNode): def prepend_ctes(self, prepended_ctes: List[InjectedCTE]): self.extra_ctes_injected = True self.extra_ctes = prepended_ctes + if self.compiled_sql is None: + raise RuntimeException( + 'Cannot prepend ctes to an unparsed node', self + ) self.injected_sql = _inject_ctes_into_sql( self.compiled_sql, prepended_ctes, @@ -176,7 +180,7 @@ def _inject_ctes_into_sql(sql: str, ctes: List[InjectedCTE]) -> str: return str(parsed) -COMPILED_TYPES = { +COMPILED_TYPES: Dict[NodeType, Type[CompiledNode]] = { NodeType.Analysis: CompiledAnalysisNode, NodeType.Model: CompiledModelNode, NodeType.Operation: CompiledHookNode, diff --git a/core/dbt/contracts/graph/parsed.py b/core/dbt/contracts/graph/parsed.py index ef80e49fa21..cdd4062962e 100644 --- a/core/dbt/contracts/graph/parsed.py +++ b/core/dbt/contracts/graph/parsed.py @@ -15,7 +15,7 @@ UnparsedBaseNode, FreshnessThreshold, ExternalTable, AdditionalPropertiesAllowed ) -from dbt.contracts.util import Replaceable +from dbt.contracts.util import Replaceable, list_str from dbt.logger import GLOBAL_LOGGER as logger # noqa from dbt.node_types import NodeType @@ -61,7 +61,7 @@ class NodeConfig( vars: Dict[str, Any] = field(default_factory=dict) quoting: Dict[str, Any] = field(default_factory=dict) column_types: Dict[str, Any] = field(default_factory=dict) - tags: Union[List[str], str] = field(default_factory=list) + tags: Union[List[str], str] = field(default_factory=list_str) @classmethod def field_mapping(cls): @@ -257,7 +257,7 @@ def empty(self): @dataclass class TestConfig(NodeConfig): - severity: Severity = 'error' + severity: Severity = Severity('error') @dataclass @@ -277,8 +277,8 @@ class ParsedTestNode(ParsedNode): @dataclass(init=False) class _SnapshotConfig(NodeConfig): - unique_key: str - target_schema: str + unique_key: str = field(init=False) + target_schema: str = field(init=False) target_database: Optional[str] = None def __init__( @@ -288,15 +288,15 @@ def __init__( target_database: Optional[str] = None, **kwargs ) -> None: - self.target_database = target_database - self.target_schema = target_schema self.unique_key = unique_key + self.target_schema = target_schema + self.target_database = target_database super().__init__(**kwargs) @dataclass(init=False) class GenericSnapshotConfig(_SnapshotConfig): - strategy: str + strategy: str = field(init=False) def __init__(self, strategy: str, **kwargs) -> None: self.strategy = strategy @@ -305,10 +305,11 @@ def __init__(self, strategy: str, **kwargs) -> None: @dataclass(init=False) class TimestampSnapshotConfig(_SnapshotConfig): - strategy: str = field(metadata={ - 'restrict': [str(SnapshotStrategy.Timestamp)] - }) - updated_at: str + strategy: str = field( + init=False, + metadata={'restrict': [str(SnapshotStrategy.Timestamp)]}, + ) + updated_at: str = field(init=False) def __init__( self, strategy: str, updated_at: str, **kwargs @@ -320,9 +321,10 @@ def __init__( @dataclass(init=False) class CheckSnapshotConfig(_SnapshotConfig): - strategy: str = field(metadata={ - 'restrict': [str(SnapshotStrategy.Check)] - }) + strategy: str = field( + init=False, + metadata={'restrict': [str(SnapshotStrategy.Check)]}, + ) # TODO: is there a way to get this to accept tuples of strings? Adding # `Tuple[str, ...]` to the list of types results in this: # ['email'] is valid under each of {'type': 'array', 'items': @@ -330,7 +332,7 @@ class CheckSnapshotConfig(_SnapshotConfig): # but without it, parsing gets upset about values like `('email',)` # maybe hologram itself should support this behavior? It's not like tuples # are meaningful in json - check_cols: Union[All, List[str]] + check_cols: Union[All, List[str]] = field(init=False) def __init__( self, strategy: str, check_cols: Union[All, List[str]], @@ -361,7 +363,8 @@ def _create_if_else_chain( 'if-then-else' chain. This results is much better/more consistent errors from jsonschema. """ - result = schema = {} + schema: Dict[str, Any] = {} + result: Dict[str, Any] = {} criteria = criteria[:] while criteria: if_clause, then_clause = criteria.pop() @@ -492,7 +495,12 @@ def has_freshness(self): return bool(self.freshness) and self.loaded_at_field is not None -PARSED_TYPES = { +ParsedResource = Union[ + ParsedMacro, ParsedNode, ParsedDocumentation, ParsedSourceDefinition +] + + +PARSED_TYPES: Dict[NodeType, Type[ParsedResource]] = { NodeType.Analysis: ParsedAnalysisNode, NodeType.Documentation: ParsedDocumentation, NodeType.Macro: ParsedMacro, diff --git a/core/dbt/contracts/graph/unparsed.py b/core/dbt/contracts/graph/unparsed.py index 5ecde24c083..9ab03aabd1f 100644 --- a/core/dbt/contracts/graph/unparsed.py +++ b/core/dbt/contracts/graph/unparsed.py @@ -68,11 +68,7 @@ def __post_init__(self): @dataclass class ColumnDescription(JsonSchemaMixin, Replaceable): - columns: Optional[List[NamedTested]] = field(default_factory=list) - - def __post_init__(self): - if self.columns is None: - self.columns = [] + columns: List[NamedTested] = field(default_factory=list) @dataclass @@ -84,7 +80,6 @@ class NodeDescription(NamedTested): class UnparsedNodeUpdate(ColumnDescription, NodeDescription): def __post_init__(self): NodeDescription.__post_init__(self) - ColumnDescription.__post_init__(self) class TimePeriod(StrEnum): @@ -205,7 +200,6 @@ class UnparsedSourceTableDefinition(ColumnDescription, NodeDescription): def __post_init__(self): NodeDescription.__post_init__(self) - ColumnDescription.__post_init__(self) @dataclass diff --git a/core/dbt/contracts/project.py b/core/dbt/contracts/project.py index 9cb3c32c951..b3119e86424 100644 --- a/core/dbt/contracts/project.py +++ b/core/dbt/contracts/project.py @@ -1,4 +1,4 @@ -from dbt.contracts.util import Replaceable, Mergeable +from dbt.contracts.util import Replaceable, Mergeable, list_str from dbt.logger import GLOBAL_LOGGER as logger # noqa from dbt import tracking from dbt.ui import printer @@ -151,8 +151,8 @@ class Project(HyphenatedJsonSchemaMixin, Replaceable): log_path: Optional[str] = None modules_path: Optional[str] = None quoting: Optional[Quoting] = None - on_run_start: Optional[List[str]] = field(default_factory=list) - on_run_end: Optional[List[str]] = field(default_factory=list) + on_run_start: Optional[List[str]] = field(default_factory=list_str) + on_run_end: Optional[List[str]] = field(default_factory=list_str) require_dbt_version: Optional[Union[List[str], str]] = None models: Dict[str, Any] = field(default_factory=dict) seeds: Dict[str, Any] = field(default_factory=dict) @@ -200,7 +200,7 @@ class ProfileConfig(HyphenatedJsonSchemaMixin, Replaceable): @dataclass -class ConfiguredQuoting(JsonSchemaMixin, Replaceable): +class ConfiguredQuoting(Quoting, Replaceable): identifier: bool schema: bool database: Optional[bool] diff --git a/core/dbt/contracts/util.py b/core/dbt/contracts/util.py index 842d9bbc87f..12e6005b80a 100644 --- a/core/dbt/contracts/util.py +++ b/core/dbt/contracts/util.py @@ -1,8 +1,25 @@ import dataclasses +from typing import List from dbt.clients.system import write_json +def list_str() -> List[str]: + """Mypy gets upset about¸stuff like: + + from dataclasses import dataclass, field + from typing import Optional, List + + @dataclass + class Foo: + x: Optional[List[str]] = field(default_factory=list) + + + Because `list` could be any kind of list, I guess + """ + return [] + + class Replaceable: def replace(self, **kwargs): return dataclasses.replace(self, **kwargs) @@ -26,4 +43,4 @@ def merged(self, *args): class Writable: def write(self, path: str, omit_none: bool = False): - write_json(path, self.to_dict(omit_none=omit_none)) + write_json(path, self.to_dict(omit_none=omit_none)) # type: ignore diff --git a/core/dbt/graph/selector.py b/core/dbt/graph/selector.py index 07daf947a71..cfe2d931177 100644 --- a/core/dbt/graph/selector.py +++ b/core/dbt/graph/selector.py @@ -1,6 +1,6 @@ from enum import Enum -import networkx as nx +import networkx as nx # type: ignore from dbt.logger import GLOBAL_LOGGER as logger from dbt.utils import is_enabled, coalesce diff --git a/core/dbt/linker.py b/core/dbt/linker.py index 479f33a1853..e0e2d5d2776 100644 --- a/core/dbt/linker.py +++ b/core/dbt/linker.py @@ -1,6 +1,6 @@ from queue import PriorityQueue from typing import Iterable, Set, Optional -import networkx as nx +import networkx as nx # type: ignore import threading @@ -254,7 +254,7 @@ def sorted_ephemeral_ancestors( if node.resource_type != NodeType.Model: continue - if node.get_materialization() != 'ephemeral': + if node.get_materialization() != 'ephemeral': # type: ignore continue # this is an ephemeral model! We have to find everything it # refs and do it all over again until we exhaust them all diff --git a/core/dbt/task/deps.py b/core/dbt/task/deps.py index 4e5a499a567..d5af4efb099 100644 --- a/core/dbt/task/deps.py +++ b/core/dbt/task/deps.py @@ -4,7 +4,9 @@ import shutil import tempfile from dataclasses import dataclass, field -from typing import Union, Dict, Optional, List +from typing import ( + Union, Dict, Optional, List, Type, Iterator, NoReturn, Generic, TypeVar, +) import dbt.utils import dbt.deprecations @@ -50,54 +52,78 @@ def _initialize_downloads(): RegistryPackageContract] -def _parse_package(dict_: dict) -> PackageContract: - only_1_keys = ['package', 'git', 'local'] - specified = [k for k in only_1_keys if dict_.get(k)] - if len(specified) > 1: - dbt.exceptions.raise_dependency_error( - 'Packages should not contain more than one of {}; ' - 'yours has {} of them - {}' - .format(only_1_keys, len(specified), specified)) - if dict_.get('package'): - return RegistryPackageContract.from_dict(dict_) - if dict_.get('git'): - if dict_.get('version'): - msg = ("Keyword 'version' specified for git package {}.\nDid " - "you mean 'revision'?".format(dict_.get('git'))) - dbt.exceptions.raise_dependency_error(msg) - return GitPackageContract.from_dict(dict_) - if dict_.get('local'): - return LocalPackageContract.from_dict(dict_) - dbt.exceptions.raise_dependency_error( - 'Malformed package definition. Must contain package, git, or local.') - - def md5sum(s: str): return hashlib.md5(s.encode('latin-1')).hexdigest() -@dataclass -class Pinned(metaclass=abc.ABCMeta): - _cached_metadata: Optional[ProjectPackageMetadata] = field(init=False) - - def __post_init__(self): - self._cached_metadata = None - - def __str__(self): - version = self.get_version() - if not version: - return self.name +PackageContractType = TypeVar('PackageContractType', bound=PackageContract) - return '{}@{}'.format(self.name, version) +class BasePackage(metaclass=abc.ABCMeta): @abc.abstractproperty - def name(self): + def name(self) -> str: raise NotImplementedError + def all_names(self) -> List[str]: + return [self.name] + @abc.abstractmethod - def source_type(self): + def source_type(self) -> str: raise NotImplementedError + +class LocalPackageMixin: + def __init__(self, local: str) -> None: + super().__init__() + self.local = local + + @property + def name(self): + return self.local + + def source_type(self): + return 'local' + + +class GitPackageMixin: + def __init__(self, git: str) -> None: + super().__init__() + self.git = git + + @property + def name(self): + return self.git + + def source_type(self) -> str: + return 'git' + + +class RegistryPackageMixin: + def __init__(self, package: str) -> None: + super().__init__() + self.package = package + + @property + def name(self): + return self.package + + def source_type(self) -> str: + return 'hub' + + +class PinnedPackage(BasePackage): + def __init__(self) -> None: + if hasattr(self, '_cached_metadata'): + raise ValueError('already here') + self._cached_metadata: Optional[ProjectPackageMetadata] = None + + def __str__(self) -> str: + version = self.get_version() + if not version: + return self.name + + return '{}@{}'.format(self.name, version) + @abc.abstractmethod def get_version(self) -> Optional[str]: raise NotImplementedError @@ -128,16 +154,9 @@ def get_installation_path(self, project): return os.path.join(project.modules_path, dest_dirname) -@dataclass -class LocalPinned(Pinned): - local: str - - @property - def name(self): - return self.local - - def source_type(self): - return 'local' +class LocalPinnedPackage(LocalPackageMixin, PinnedPackage): + def __init__(self, local: str) -> None: + super().__init__(local) def get_version(self): return None @@ -177,24 +196,15 @@ def install(self, project): shutil.copytree(src_path, dest_path) -@dataclass -class GitPinned(Pinned): - git: str - revision: str - warn_unpinned: bool = True - _checkout_name: str = field(init=False) - - def __post_init__(self): - super().__post_init__() +class GitPinnedPackage(GitPackageMixin, PinnedPackage): + def __init__( + self, git: str, revision: str, warn_unpinned: bool = True + ) -> None: + super().__init__(git) + self.revision = revision + self.warn_unpinned = warn_unpinned self._checkout_name = md5sum(self.git) - @property - def name(self): - return self.git - - def source_type(self): - return 'git' - def get_version(self): return self.revision @@ -244,10 +254,10 @@ def install(self, project): system.move(self._checkout(), dest_path) -@dataclass -class RegistryPinned(Pinned): - package: str - version: str +class RegistryPinnedPackage(RegistryPackageMixin, PinnedPackage): + def __init__(self, package: str, version: str) -> None: + super().__init__(package) + self.version = version @property def name(self): @@ -280,61 +290,54 @@ def install(self, project): system.untar_package(tar_path, deps_path, package_name) -class Package(metaclass=abc.ABCMeta): +SomePinned = TypeVar('SomePinned', bound=PinnedPackage) +SomeUnpinned = TypeVar('SomeUnpinned', bound='UnpinnedPackage') + + +class UnpinnedPackage(Generic[SomePinned], BasePackage): @abc.abstractclassmethod def from_contract(cls, contract): raise NotImplementedError - @abc.abstractproperty - def name(self): + @abc.abstractmethod + def incorporate(self: SomeUnpinned, other: SomeUnpinned) -> SomeUnpinned: raise NotImplementedError - def all_names(self): - return [self.name] - - def _typecheck(self, other): - if not isinstance(other, self.__class__): - raise_dependency_error( - 'Cannot incorporate {0} ({0.__class__.__name__}) into ' - '{1} ({1.__class__.__name__}): mismatched types' - .format(other, self)) - - -@dataclass -class LocalPackage(Package): - local: str + @abc.abstractmethod + def resolved(self) -> SomePinned: + raise NotImplementedError - def source_type(self): - return 'local' - - @property - def name(self): - return self.local +class LocalUnpinnedPackage( + LocalPackageMixin, UnpinnedPackage[LocalPinnedPackage] +): @classmethod - def from_contract(cls, contract: LocalPackageContract) -> 'LocalPackage': + def from_contract( + cls, contract: LocalPackageContract + ) -> 'LocalUnpinnedPackage': return cls(local=contract.local) def incorporate( - self, other: Union['LocalPackage', LocalPinned] - ) -> 'LocalPackage': - if isinstance(other, LocalPinned): - other = LocalPackage(local=other.local) - self._typecheck(other) - return LocalPackage(local=self.local) + self, other: 'LocalUnpinnedPackage' + ) -> 'LocalUnpinnedPackage': + return LocalUnpinnedPackage(local=self.local) - def resolved(self) -> LocalPinned: - return LocalPinned(local=self.local) + def resolved(self) -> LocalPinnedPackage: + return LocalPinnedPackage(local=self.local) -@dataclass -class GitPackage(Package): - git: str - revisions: List[str] - warn_unpinned: bool = True +class GitUnpinnedPackage(GitPackageMixin, UnpinnedPackage[GitPinnedPackage]): + def __init__( + self, git: str, revisions: List[str], warn_unpinned: bool = True + ) -> None: + super().__init__(git) + self.revisions = revisions + self.warn_unpinned = warn_unpinned @classmethod - def from_contract(cls, contract: GitPackageContract) -> 'GitPackage': + def from_contract( + cls, contract: GitPackageContract + ) -> 'GitUnpinnedPackage': revisions = [contract.revision] if contract.revision else [] # we want to map None -> True @@ -342,14 +345,7 @@ def from_contract(cls, contract: GitPackageContract) -> 'GitPackage': return cls(git=contract.git, revisions=revisions, warn_unpinned=warn_unpinned) - @property - def name(self): - return self.git - - def source_type(self): - return 'git' - - def all_names(self): + def all_names(self) -> List[str]: if self.git.endswith('.git'): other = self.git[:-4] else: @@ -357,22 +353,17 @@ def all_names(self): return [self.git, other] def incorporate( - self, other: Union['GitPackage', GitPinned] - ) -> 'GitPackage': - - if isinstance(other, GitPinned): - other = GitPackage(git=other.git, revisions=[other.revision], - warn_unpinned=other.warn_unpinned) - - self._typecheck(other) - + self, other: 'GitUnpinnedPackage' + ) -> 'GitUnpinnedPackage': warn_unpinned = self.warn_unpinned and other.warn_unpinned - return GitPackage(git=self.git, - revisions=self.revisions + other.revisions, - warn_unpinned=warn_unpinned) + return GitUnpinnedPackage( + git=self.git, + revisions=self.revisions + other.revisions, + warn_unpinned=warn_unpinned, + ) - def resolved(self) -> GitPinned: + def resolved(self) -> GitPinnedPackage: requested = set(self.revisions) if len(requested) == 0: requested = {'master'} @@ -381,20 +372,20 @@ def resolved(self) -> GitPinned: 'git dependencies should contain exactly one version. ' '{} contains: {}'.format(self.git, requested)) - return GitPinned( + return GitPinnedPackage( git=self.git, revision=requested.pop(), warn_unpinned=self.warn_unpinned ) -@dataclass -class RegistryPackage(Package): - package: str - versions: List[semver.VersionSpecifier] - - @property - def name(self): - return self.package +class RegistryUnpinnedPackage( + RegistryPackageMixin, UnpinnedPackage[RegistryPinnedPackage] +): + def __init__( + self, package: str, versions: List[semver.VersionSpecifier] + ) -> None: + super().__init__(package) + self.versions = versions def _check_in_index(self): index = registry.index_cached() @@ -404,7 +395,7 @@ def _check_in_index(self): @classmethod def from_contract( cls, contract: RegistryPackageContract - ) -> 'RegistryPackage': + ) -> 'RegistryUnpinnedPackage': raw_version = contract.version if isinstance(raw_version, str): raw_version = [raw_version] @@ -416,18 +407,14 @@ def from_contract( return cls(package=contract.package, versions=versions) def incorporate( - self, other: ['RegistryPackage', RegistryPinned] - ) -> 'RegistryPackage': - if isinstance(other, RegistryPinned): - versions = [ - semver.VersionSpecifier.from_version_string(other.version) - ] - other = RegistryPackage(package=other.package, versions=versions) - self._typecheck(other) - return RegistryPackage(package=self.package, - versions=self.versions + other.versions) - - def resolved(self) -> RegistryPinned: + self, other: 'RegistryUnpinnedPackage' + ) -> 'RegistryUnpinnedPackage': + return RegistryUnpinnedPackage( + package=self.package, + versions=self.versions + other.versions, + ) + + def resolved(self) -> RegistryPinnedPackage: self._check_in_index() try: range_ = semver.reduce_versions(*self.versions) @@ -445,16 +432,12 @@ def resolved(self) -> RegistryPinned: target = semver.resolve_to_specific_version(range_, available) if not target: package_version_not_found(self.package, range_, available) - return RegistryPinned(package=self.package, version=target) - - -PackageResolver = Union[LocalPackage, GitPackage, RegistryPackage] -PinnedPackages = Union[LocalPinned, GitPinned, RegistryPinned] + return RegistryPinnedPackage(package=self.package, version=target) @dataclass class PackageListing: - packages: Dict[str, PackageResolver] = field(default_factory=dict) + packages: Dict[str, UnpinnedPackage] = field(default_factory=dict) def __len__(self): return len(self.packages) @@ -462,40 +445,52 @@ def __len__(self): def __bool__(self): return bool(self.packages) - def _pick_key(self, key: Package): + def _pick_key(self, key: BasePackage) -> str: for name in key.all_names(): if name in self.packages: return name return key.name - def __contains__(self, key: Package): + def __contains__(self, key: BasePackage): for name in key.all_names(): if name in self.packages: return True - def __getitem__(self, key: Package): - key = self._pick_key(key) - return self.packages[key] + def __getitem__(self, key: BasePackage): + key_str: str = self._pick_key(key) + return self.packages[key_str] + + def __setitem__(self, key: BasePackage, value): + key_str: str = self._pick_key(key) + self.packages[key_str] = value - def __setitem__(self, key: Package, value): - key = self._pick_key(key) - self.packages[key] = value + def _mismatched_types( + self, old: UnpinnedPackage, new: UnpinnedPackage + ) -> NoReturn: + raise_dependency_error( + f'Cannot incorporate {new} ({new.__class__.__name__}) in {old} ' + f'({old.__class__.__name__}): mismatched types' + ) - def incorporate(self, package: Package): - key = self._pick_key(package) + def incorporate(self, package: UnpinnedPackage): + key: str = self._pick_key(package) if key in self.packages: - self.packages[key] = self.packages[key].incorporate(package) + existing: UnpinnedPackage = self.packages[key] + if not isinstance(existing, type(package)): + self._mismatched_types(existing, package) + self.packages[key] = existing.incorporate(package) else: self.packages[key] = package - def update_from(self, src: List[PackageContract]) -> 'PackageListing': + def update_from(self, src: List[PackageContract]) -> None: + pkg: UnpinnedPackage for contract in src: if isinstance(contract, LocalPackageContract): - pkg = LocalPackage.from_contract(contract) + pkg = LocalUnpinnedPackage.from_contract(contract) elif isinstance(contract, GitPackageContract): - pkg = GitPackage.from_contract(contract) + pkg = GitUnpinnedPackage.from_contract(contract) elif isinstance(contract, RegistryPackageContract): - pkg = RegistryPackage.from_contract(contract) + pkg = RegistryUnpinnedPackage.from_contract(contract) else: raise dbt.exceptions.InternalException( 'Invalid package type {}'.format(type(contract)) @@ -504,22 +499,22 @@ def update_from(self, src: List[PackageContract]) -> 'PackageListing': @classmethod def from_contracts( - cls: 'PackageListing', src: List[PackageContract] + cls: Type['PackageListing'], src: List[PackageContract] ) -> 'PackageListing': self = cls({}) self.update_from(src) return self - def resolved(self) -> List[PinnedPackages]: + def resolved(self) -> List[PinnedPackage]: return [p.resolved() for p in self.packages.values()] - def __iter__(self): + def __iter__(self) -> Iterator[UnpinnedPackage]: return iter(self.packages.values()) def resolve_packages( packages: List[PackageContract], config -) -> List[PinnedPackages]: +) -> List[PinnedPackage]: pending = PackageListing.from_contracts(packages) final = PackageListing() diff --git a/core/dbt/utils.py b/core/dbt/utils.py index 652556cbff9..4299cfd2c70 100644 --- a/core/dbt/utils.py +++ b/core/dbt/utils.py @@ -300,12 +300,6 @@ def is_enabled(node): return node.config.enabled -def is_type(node, _type): - if hasattr(_type, 'value'): - _type = _type.value - return node.resource_type == _type - - def get_pseudo_test_path(node_name, source_path, test_type): "schema tests all come from schema.yml files. fake a source sql file" source_path_parts = split_path(source_path) @@ -390,7 +384,7 @@ def invalid_ref_test_message(node, target_model_name, target_model_package, def invalid_ref_fail_unless_test(node, target_model_name, target_model_package, disabled): - if is_type(node, NodeType.Test): + if node.resource_type == NodeType.Test: msg = invalid_ref_test_message(node, target_model_name, target_model_package, disabled) if disabled: @@ -406,7 +400,7 @@ def invalid_ref_fail_unless_test(node, target_model_name, def invalid_source_fail_unless_test(node, target_name, target_table_name): - if is_type(node, NodeType.Test): + if node.resource_type == NodeType.Test: msg = dbt.exceptions.source_disabled_message(node, target_name, target_table_name) dbt.exceptions.warn_or_error(msg, log_fmt='WARNING: {}') diff --git a/test/integration/032_concurrent_transaction_test/test_concurrent_transaction.py b/test/integration/032_concurrent_transaction_test/test_concurrent_transaction.py index 471cda24d04..7d185cb0458 100644 --- a/test/integration/032_concurrent_transaction_test/test_concurrent_transaction.py +++ b/test/integration/032_concurrent_transaction_test/test_concurrent_transaction.py @@ -1,10 +1,10 @@ from test.integration.base import DBTIntegrationTest, use_profile import threading -from dbt.adapters.factory import ADAPTER_TYPES +from dbt.adapters.factory import FACTORY def get_adapter_standalone(config): - cls = ADAPTER_TYPES[config.credentials.type] + cls = FACTORY.adapter_types[config.credentials.type] return cls(config) diff --git a/test/integration/038_caching_test/test_caching.py b/test/integration/038_caching_test/test_caching.py index 203ccf30742..2953b21bf4f 100644 --- a/test/integration/038_caching_test/test_caching.py +++ b/test/integration/038_caching_test/test_caching.py @@ -1,5 +1,5 @@ from test.integration.base import DBTIntegrationTest, use_profile -from dbt.adapters import factory +from dbt.adapters.factory import FACTORY class TestBaseCaching(DBTIntegrationTest): @property @@ -19,7 +19,7 @@ def run_and_get_adapter(self): # we want to inspect the adapter that dbt used for the run, which is # not self.adapter. You can't do this until after you've run dbt once. self.run_dbt(['run']) - return factory._ADAPTERS[self.adapter_type] + return FACTORY.adapters[self.adapter_type] def cache_run(self): adapter = self.run_and_get_adapter() diff --git a/test/integration/base.py b/test/integration/base.py index 25635314908..28eb4f18ef4 100644 --- a/test/integration/base.py +++ b/test/integration/base.py @@ -16,7 +16,7 @@ import dbt.main as dbt import dbt.flags as flags -from dbt.adapters.factory import get_adapter, reset_adapters +from dbt.adapters.factory import get_adapter, reset_adapters, register_adapter from dbt.clients.jinja import template_cache from dbt.config import RuntimeConfig from dbt.context import common @@ -395,6 +395,7 @@ def load_config(self): config = RuntimeConfig.from_args(TestArgs(kwargs)) + register_adapter(config) adapter = get_adapter(config) adapter.cleanup_connections() self.adapter_type = adapter.type() @@ -411,6 +412,7 @@ def tearDown(self): # get any current run adapter and clean up its connections before we # reset them. It'll probably be different from ours because # handle_and_check() calls reset_adapters(). + register_adapter(self.config) adapter = get_adapter(self.config) if adapter is not self.adapter: adapter.cleanup_connections() diff --git a/test/unit/test_compiler.py b/test/unit/test_compiler.py index 3a9dc92b77f..9fc0519e646 100644 --- a/test/unit/test_compiler.py +++ b/test/unit/test_compiler.py @@ -74,7 +74,7 @@ def test__prepend_ctes__already_has_cte(self): raw_sql='select * from {{ref("ephemeral")}}', compiled=True, extra_ctes_injected=False, - extra_ctes=[InjectedCTE(id='model.root.ephemeral')], + extra_ctes=[InjectedCTE(id='model.root.ephemeral', sql='select * from source_table')], injected_sql='', compiled_sql=( 'with cte as (select * from something_else) ' @@ -240,7 +240,7 @@ def test__prepend_ctes(self): raw_sql='select * from {{ref("ephemeral")}}', compiled=True, extra_ctes_injected=False, - extra_ctes=[InjectedCTE(id='model.root.ephemeral')], + extra_ctes=[InjectedCTE(id='model.root.ephemeral', sql='select * from source_table')], injected_sql='', compiled_sql='select * from __dbt__CTE__ephemeral' ), @@ -318,7 +318,7 @@ def test__prepend_ctes__multiple_levels(self): raw_sql='select * from {{ref("ephemeral")}}', compiled=True, extra_ctes_injected=False, - extra_ctes=[InjectedCTE(id='model.root.ephemeral')], + extra_ctes=[InjectedCTE(id='model.root.ephemeral', sql='select * from source_table')], injected_sql='', compiled_sql='select * from __dbt__CTE__ephemeral' ), @@ -342,7 +342,7 @@ def test__prepend_ctes__multiple_levels(self): raw_sql='select * from {{ref("ephemeral_level_two")}}', compiled=True, extra_ctes_injected=False, - extra_ctes=[InjectedCTE(id='model.root.ephemeral_level_two')], + extra_ctes=[InjectedCTE(id='model.root.ephemeral_level_two', sql='select * from source_table')], injected_sql='', compiled_sql='select * from __dbt__CTE__ephemeral_level_two' # noqa ), diff --git a/test/unit/test_deps.py b/test/unit/test_deps.py index e235feafbf2..14c7f392ef4 100644 --- a/test/unit/test_deps.py +++ b/test/unit/test_deps.py @@ -2,9 +2,11 @@ from unittest import mock import dbt.exceptions -from dbt.task.deps import GitPackage, LocalPackage, RegistryPackage, \ - LocalPackageContract, GitPackageContract, RegistryPackageContract, \ +from dbt.task.deps import ( + GitUnpinnedPackage, LocalUnpinnedPackage, RegistryUnpinnedPackage, + LocalPackageContract, GitPackageContract, RegistryPackageContract, resolve_packages +) from dbt.contracts.project import PackageConfig from dbt.semver import VersionSpecifier @@ -15,7 +17,7 @@ class TestLocalPackage(unittest.TestCase): def test_init(self): a_contract = LocalPackageContract.from_dict({'local': '/path/to/package'}) self.assertEqual(a_contract.local, '/path/to/package') - a = LocalPackage.from_contract(a_contract) + a = LocalUnpinnedPackage.from_contract(a_contract) self.assertEqual(a.local, '/path/to/package') a_pinned = a.resolved() self.assertEqual(a_pinned.local, '/path/to/package') @@ -31,7 +33,7 @@ def test_init(self): self.assertEqual(a_contract.revision, '0.0.1') self.assertIs(a_contract.warn_unpinned, None) - a = GitPackage.from_contract(a_contract) + a = GitUnpinnedPackage.from_contract(a_contract) self.assertEqual(a.git, 'http://example.com') self.assertEqual(a.revisions, ['0.0.1']) self.assertIs(a.warn_unpinned, True) @@ -56,8 +58,8 @@ def test_resolve_ok(self): {'git': 'http://example.com', 'revision': '0.0.1', 'warn-unpinned': False} ) - a = GitPackage.from_contract(a_contract) - b = GitPackage.from_contract(b_contract) + a = GitUnpinnedPackage.from_contract(a_contract) + b = GitUnpinnedPackage.from_contract(b_contract) self.assertTrue(a.warn_unpinned) self.assertFalse(b.warn_unpinned) c = a.incorporate(b) @@ -75,8 +77,8 @@ def test_resolve_fail(self): b_contract = GitPackageContract.from_dict( {'git': 'http://example.com', 'revision': '0.0.2'} ) - a = GitPackage.from_contract(a_contract) - b = GitPackage.from_contract(b_contract) + a = GitUnpinnedPackage.from_contract(a_contract) + b = GitUnpinnedPackage.from_contract(b_contract) c = a.incorporate(b) self.assertEqual(c.git, 'http://example.com') self.assertEqual(c.revisions, ['0.0.1', '0.0.2']) @@ -89,7 +91,7 @@ def test_default_revision(self): self.assertEqual(a_contract.revision, None) self.assertIs(a_contract.warn_unpinned, None) - a = GitPackage.from_contract(a_contract) + a = GitUnpinnedPackage.from_contract(a_contract) self.assertEqual(a.git, 'http://example.com') self.assertEqual(a.revisions, []) self.assertIs(a.warn_unpinned, True) @@ -141,7 +143,7 @@ def test_init(self): self.assertEqual(a_contract.package, 'fishtown-analytics-test/a') self.assertEqual(a_contract.version, '0.1.2') - a = RegistryPackage.from_contract(a_contract) + a = RegistryUnpinnedPackage.from_contract(a_contract) self.assertEqual(a.package, 'fishtown-analytics-test/a') self.assertEqual( a.versions, @@ -175,8 +177,8 @@ def test_resolve_ok(self): package='fishtown-analytics-test/a', version='0.1.2' ) - a = RegistryPackage.from_contract(a_contract) - b = RegistryPackage.from_contract(b_contract) + a = RegistryUnpinnedPackage.from_contract(a_contract) + b = RegistryUnpinnedPackage.from_contract(b_contract) c = a.incorporate(b) self.assertEqual(c.package, 'fishtown-analytics-test/a') @@ -208,7 +210,7 @@ def test_resolve_ok(self): self.assertEqual(c_pinned.source_type(), 'hub') def test_resolve_missing_package(self): - a = RegistryPackage.from_contract(RegistryPackageContract( + a = RegistryUnpinnedPackage.from_contract(RegistryPackageContract( package='fishtown-analytics-test/b', version='0.1.2' )) @@ -219,7 +221,7 @@ def test_resolve_missing_package(self): self.assertEqual(msg, str(exc.exception)) def test_resolve_missing_version(self): - a = RegistryPackage.from_contract(RegistryPackageContract( + a = RegistryUnpinnedPackage.from_contract(RegistryPackageContract( package='fishtown-analytics-test/a', version='0.1.4' )) @@ -242,8 +244,8 @@ def test_resolve_conflict(self): package='fishtown-analytics-test/a', version='0.1.3' ) - a = RegistryPackage.from_contract(a_contract) - b = RegistryPackage.from_contract(b_contract) + a = RegistryUnpinnedPackage.from_contract(a_contract) + b = RegistryUnpinnedPackage.from_contract(b_contract) c = a.incorporate(b) with self.assertRaises(dbt.exceptions.DependencyException) as exc: @@ -263,8 +265,8 @@ def test_resolve_ranges(self): package='fishtown-analytics-test/a', version='<0.1.4' ) - a = RegistryPackage.from_contract(a_contract) - b = RegistryPackage.from_contract(b_contract) + a = RegistryUnpinnedPackage.from_contract(a_contract) + b = RegistryUnpinnedPackage.from_contract(b_contract) c = a.incorporate(b) self.assertEqual(c.package, 'fishtown-analytics-test/a') diff --git a/test/unit/utils.py b/test/unit/utils.py index b0675f7efa5..1be0078313e 100644 --- a/test/unit/utils.py +++ b/test/unit/utils.py @@ -59,10 +59,10 @@ def inject_adapter(value): """Inject the given adapter into the adapter factory, so your hand-crafted artisanal adapter will be available from get_adapter() as if dbt loaded it. """ - from dbt.adapters import factory + from dbt.adapters.factory import FACTORY key = value.type() - factory._ADAPTERS[key] = value - factory.ADAPTER_TYPES[key] = type(value) + FACTORY.adapters[key] = value + FACTORY.adapter_types[key] = type(value) class ContractTestCase(TestCase): diff --git a/tox.ini b/tox.ini index 94dbb13fa20..83eb19cb69b 100644 --- a/tox.ini +++ b/tox.ini @@ -11,50 +11,7 @@ deps = [testenv:mypy] basepython = python3.6 -commands = /bin/bash -c '$(which mypy) \ - core/dbt/adapters/base \ - core/dbt/adapters/sql \ - core/dbt/adapters/cache.py \ - core/dbt/clients \ - core/dbt/config \ - core/dbt/contracts/rpc.py \ - core/dbt/deprecations.py \ - core/dbt/exceptions.py \ - core/dbt/flags.py \ - core/dbt/helper_types.py \ - core/dbt/hooks.py \ - core/dbt/include \ - core/dbt/links.py \ - core/dbt/logger.py \ - core/dbt/main.py \ - core/dbt/node_runners.py \ - core/dbt/node_types.py \ - core/dbt/parser \ - core/dbt/perf_utils.py \ - core/dbt/profiler.py \ - core/dbt/py.typed \ - core/dbt/rpc \ - core/dbt/semver.py \ - core/dbt/source_config.py \ - core/dbt/task/base.py \ - core/dbt/task/clean.py \ - core/dbt/task/debug.py \ - core/dbt/task/freshness.py \ - core/dbt/task/generate.py \ - core/dbt/task/init.py \ - core/dbt/task/list.py \ - core/dbt/task/remote.py \ - core/dbt/task/run_operation.py \ - core/dbt/task/runnable.py \ - core/dbt/task/seed.py \ - core/dbt/task/serve.py \ - core/dbt/task/snapshot.py \ - core/dbt/task/test.py \ - core/dbt/tracking.py \ - core/dbt/ui \ - core/dbt/utils.py \ - core/dbt/version.py \ - core/dbt/writer.py' +commands = /bin/bash -c '$(which mypy) core/dbt' deps = -r{toxinidir}/requirements.txt -r{toxinidir}/dev_requirements.txt From b658f879f921634ca94ed31e2e2977ccb6ab94c9 Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Sat, 12 Oct 2019 00:14:07 -0600 Subject: [PATCH 09/13] fix field serialization for hologram+mypy --- core/dbt/contracts/graph/parsed.py | 40 ++++++++++++++++++++++++------ 1 file changed, 32 insertions(+), 8 deletions(-) diff --git a/core/dbt/contracts/graph/parsed.py b/core/dbt/contracts/graph/parsed.py index cdd4062962e..34b51cd1882 100644 --- a/core/dbt/contracts/graph/parsed.py +++ b/core/dbt/contracts/graph/parsed.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass, field +from dataclasses import dataclass, field, Field from typing import ( Optional, Union, List, Dict, Any, Type, Tuple, NewType, MutableMapping ) @@ -277,8 +277,8 @@ class ParsedTestNode(ParsedNode): @dataclass(init=False) class _SnapshotConfig(NodeConfig): - unique_key: str = field(init=False) - target_schema: str = field(init=False) + unique_key: str = field(init=False, metadata=dict(init_required=True)) + target_schema: str = field(init=False, metadata=dict(init_required=True)) target_database: Optional[str] = None def __init__( @@ -293,10 +293,25 @@ def __init__( self.target_database = target_database super().__init__(**kwargs) + # type hacks... + @classmethod + def _get_fields(cls) -> List[Tuple[Field, str]]: # type: ignore + fields: List[Tuple[Field, str]] = [] + for old_field, name in super()._get_fields(): + new_field = old_field + # tell hologram we're really an initvar + if old_field.metadata and old_field.metadata.get('init_required'): + new_field = field(init=True, metadata=old_field.metadata) + new_field.name = old_field.name + new_field.type = old_field.type + new_field._field_type = old_field._field_type # type: ignore + fields.append((new_field, name)) + return fields + @dataclass(init=False) class GenericSnapshotConfig(_SnapshotConfig): - strategy: str = field(init=False) + strategy: str = field(init=False, metadata=dict(init_required=True)) def __init__(self, strategy: str, **kwargs) -> None: self.strategy = strategy @@ -307,9 +322,12 @@ def __init__(self, strategy: str, **kwargs) -> None: class TimestampSnapshotConfig(_SnapshotConfig): strategy: str = field( init=False, - metadata={'restrict': [str(SnapshotStrategy.Timestamp)]}, + metadata=dict( + restrict=[str(SnapshotStrategy.Timestamp)], + init_required=True, + ), ) - updated_at: str = field(init=False) + updated_at: str = field(init=False, metadata=dict(init_required=True)) def __init__( self, strategy: str, updated_at: str, **kwargs @@ -323,7 +341,10 @@ def __init__( class CheckSnapshotConfig(_SnapshotConfig): strategy: str = field( init=False, - metadata={'restrict': [str(SnapshotStrategy.Check)]}, + metadata=dict( + restrict=[str(SnapshotStrategy.Check)], + init_required=True, + ), ) # TODO: is there a way to get this to accept tuples of strings? Adding # `Tuple[str, ...]` to the list of types results in this: @@ -332,7 +353,10 @@ class CheckSnapshotConfig(_SnapshotConfig): # but without it, parsing gets upset about values like `('email',)` # maybe hologram itself should support this behavior? It's not like tuples # are meaningful in json - check_cols: Union[All, List[str]] = field(init=False) + check_cols: Union[All, List[str]] = field( + init=False, + metadata=dict(init_required=True), + ) def __init__( self, strategy: str, check_cols: Union[All, List[str]], From 61f8e6d4a1b0d7a081ef0a571921530bca42f971 Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Mon, 14 Oct 2019 11:36:46 -0600 Subject: [PATCH 10/13] fix unit tests --- test/unit/test_graph.py | 8 +++++++- test/unit/test_linker.py | 5 ++++- test/unit/test_parser.py | 3 +++ 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/test/unit/test_graph.py b/test/unit/test_graph.py index 157030d225f..c9ac80ae495 100644 --- a/test/unit/test_graph.py +++ b/test/unit/test_graph.py @@ -4,6 +4,7 @@ import dbt.clients.system import dbt.compilation +import dbt.context.parser import dbt.exceptions import dbt.flags import dbt.linker @@ -14,7 +15,6 @@ from dbt.contracts.graph.manifest import FilePath, SourceFile, FileHash from dbt.parser.results import ParseResult from dbt.parser.base import BaseParser -from dbt.parser.search import FileBlock try: from queue import Empty @@ -33,6 +33,7 @@ def tearDown(self): self.load_projects_patcher.stop() self.file_system_patcher.stop() self.get_adapter_patcher.stop() + self.get_adapter_patcher_cmn.stop() self.mock_filesystem_constructor.stop() self.mock_hook_constructor.stop() self.load_patch.stop() @@ -52,6 +53,10 @@ def setUp(self): ) self.get_adapter_patcher = patch('dbt.context.parser.get_adapter') self.factory = self.get_adapter_patcher.start() + # also patch this one + + self.get_adapter_patcher_cmn = patch('dbt.context.common.get_adapter') + self.factory_cmn = self.get_adapter_patcher_cmn.start() def mock_write_gpickle(graph, outfile): self.graph_result = graph @@ -275,6 +280,7 @@ def test__dependency_list(self): n: MagicMock(unique_id=n) for n in model_ids }) + manifest.expect.side_effect = lambda n: MagicMock(unique_id=n) queue = linker.as_graph_queue(manifest) for model_id in model_ids: diff --git a/test/unit/test_linker.py b/test/unit/test_linker.py index aa6b32430e3..d07d9b89e0a 100644 --- a/test/unit/test_linker.py +++ b/test/unit/test_linker.py @@ -11,9 +11,12 @@ def _mock_manifest(nodes): - return mock.MagicMock(nodes={ + manifest = mock.MagicMock(nodes={ n: mock.MagicMock(unique_id=n) for n in nodes }) + manifest.expect.side_effect = lambda n: mock.MagicMock(unique_id=n) + return manifest + class LinkerTest(unittest.TestCase): diff --git a/test/unit/test_parser.py b/test/unit/test_parser.py index fb5d658f3ca..eec07eef7f6 100644 --- a/test/unit/test_parser.py +++ b/test/unit/test_parser.py @@ -88,10 +88,13 @@ def setUp(self): } self.patcher = mock.patch('dbt.context.parser.get_adapter') self.factory = self.patcher.start() + self.patcher_cmn = mock.patch('dbt.context.common.get_adapter') + self.factory_cmn = self.patcher_cmn.start() self.macro_manifest = Manifest.from_macros() def tearDown(self): + self.patcher_cmn.stop() self.patcher.stop() def file_block_for(self, data: str, filename: str, searched: str): From 14d868313569b977819aec70fe13eec7ac7783d6 Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Mon, 14 Oct 2019 13:02:13 -0600 Subject: [PATCH 11/13] Avoid deepcopying the manifest 1x per task per sighup --- core/dbt/rpc/method.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/dbt/rpc/method.py b/core/dbt/rpc/method.py index 16159fd4a7a..29521c7246f 100644 --- a/core/dbt/rpc/method.py +++ b/core/dbt/rpc/method.py @@ -23,7 +23,7 @@ class RemoteMethod(Generic[Parameters, Result]): def __init__(self, args, config, manifest): self.args = args self.config = config - self.manifest = manifest.deepcopy() + self.manifest = manifest @classmethod def get_parameters(cls) -> Type[Parameters]: From 75916754a616f3926c3a705e4bca5ac4b563a3bc Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Mon, 14 Oct 2019 13:12:28 -0600 Subject: [PATCH 12/13] add register_adapter to dbt debug --- core/dbt/task/debug.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/dbt/task/debug.py b/core/dbt/task/debug.py index 5935c6048c1..57990389e8a 100644 --- a/core/dbt/task/debug.py +++ b/core/dbt/task/debug.py @@ -9,7 +9,7 @@ import dbt.utils import dbt.exceptions from dbt.links import ProfileConfigDocs -from dbt.adapters.factory import get_adapter +from dbt.adapters.factory import get_adapter, register_adapter from dbt.version import get_installed_version from dbt.config import Project, Profile from dbt.clients.yaml_helper import load_yaml_text @@ -270,6 +270,7 @@ def attempt_connection(profile): """Return a string containing the error message, or None if there was no error. """ + register_adapter(profile) adapter = get_adapter(profile) try: with adapter.connection_named('debug'): From 43daea05c1f526980ad6b31de4af1a605f8651c0 Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Mon, 14 Oct 2019 16:28:28 -0600 Subject: [PATCH 13/13] remove inscrutable __getattr__ override --- core/dbt/contracts/graph/manifest.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/core/dbt/contracts/graph/manifest.py b/core/dbt/contracts/graph/manifest.py index 8287093324f..c74f550219d 100644 --- a/core/dbt/contracts/graph/manifest.py +++ b/core/dbt/contracts/graph/manifest.py @@ -387,12 +387,6 @@ def patch_nodes(self, patches): 'not found or is disabled').format(patch.name) ) - # TODO: why is this here? - def __getattr__(self, name): - raise AttributeError("'{}' object has no attribute '{}'".format( - type(self).__name__, name) - ) - def get_used_schemas(self, resource_types=None): return frozenset({ (node.database, node.schema)