From 28b8067f9cb9eedf849d7a9bdf47322f7ccf15ab Mon Sep 17 00:00:00 2001 From: Connor McArthur Date: Thu, 2 Mar 2017 16:44:22 -0500 Subject: [PATCH] graph refactor (#292) --- CHANGELOG.md | 6 + dbt/adapters/postgres.py | 14 +- dbt/adapters/snowflake.py | 4 +- dbt/archival.py | 70 -- dbt/compat.py | 45 + dbt/compilation.py | 810 +++++++-------- dbt/compiled_model.py | 196 ---- dbt/contracts/common.py | 17 + dbt/contracts/connection.py | 40 +- dbt/contracts/graph/__init__.py | 0 dbt/contracts/graph/compiled.py | 36 + dbt/contracts/graph/parsed.py | 68 ++ dbt/contracts/graph/unparsed.py | 27 + dbt/contracts/project.py | 19 + dbt/exceptions.py | 6 +- dbt/flags.py | 1 + dbt/graph/selector.py | 29 +- dbt/linker.py | 14 + dbt/main.py | 8 +- dbt/model.py | 368 +------ dbt/parser.py | 409 ++++++++ dbt/runner.py | 951 +++++++++-------- dbt/schema_tester.py | 118 --- dbt/source.py | 56 +- dbt/task/archive.py | 16 +- dbt/task/compile.py | 13 +- dbt/task/run.py | 16 +- dbt/task/test.py | 28 +- dbt/templates.py | 16 +- dbt/utils.py | 69 +- .../test_simple_dependency_with_configs.py | 6 +- .../test_schema_test_graph_selection.py | 4 +- .../test_schema_tests.py | 5 +- .../integration/010_permission_tests/seed.sql | 24 +- .../010_permission_tests/tearDown.sql | 2 +- .../010_permission_tests/test_permissions.py | 5 +- .../test_invalid_models.py | 11 +- .../test_context_vars.py | 2 + .../models/hooks.sql | 0 test/integration/014_hook_tests/seed.sql | 39 + .../integration/014_hook_tests/seed_model.sql | 19 + .../seed.sql => 014_hook_tests/seed_run.sql} | 4 +- .../014_hook_tests/test_model_hooks.py | 140 +++ .../test_run_hooks.py} | 13 +- .../test_cli_invocation.py | 3 +- test/integration/base.py | 2 + test/unit/test_compiler.py | 342 +++++++ test/unit/test_graph.py | 101 +- test/unit/test_graph_selection.py | 47 +- test/unit/test_parser.py | 957 ++++++++++++++++++ test/unit/test_runner.py | 278 +++++ tox.ini | 2 +- 52 files changed, 3538 insertions(+), 1938 deletions(-) delete mode 100644 dbt/archival.py create mode 100644 dbt/compat.py delete mode 100644 dbt/compiled_model.py create mode 100644 dbt/contracts/common.py create mode 100644 dbt/contracts/graph/__init__.py create mode 100644 dbt/contracts/graph/compiled.py create mode 100644 dbt/contracts/graph/parsed.py create mode 100644 dbt/contracts/graph/unparsed.py create mode 100644 dbt/contracts/project.py create mode 100644 dbt/parser.py delete mode 100644 dbt/schema_tester.py rename test/integration/{014_pre_post_run_hook_tests => 014_hook_tests}/models/hooks.sql (100%) create mode 100644 test/integration/014_hook_tests/seed.sql create mode 100644 test/integration/014_hook_tests/seed_model.sql rename test/integration/{014_pre_post_run_hook_tests/seed.sql => 014_hook_tests/seed_run.sql} (80%) create mode 100644 test/integration/014_hook_tests/test_model_hooks.py rename test/integration/{014_pre_post_run_hook_tests/test_pre_post_run_hooks.py => 014_hook_tests/test_run_hooks.py} (90%) create mode 100644 test/unit/test_compiler.py create mode 100644 test/unit/test_parser.py create mode 100644 test/unit/test_runner.py diff --git a/CHANGELOG.md b/CHANGELOG.md index d061621aaef..bf6e6a3912e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,9 @@ +## dbt 0.7.2 (unreleased) + +### Changes + +- Graph refactor: fix common issues with load order ([#292](https://github.com/fishtown-analytics/dbt/pull/292)) + ## dbt 0.7.1 (February 28, 2017) ### Overview diff --git a/dbt/adapters/postgres.py b/dbt/adapters/postgres.py index 06a74bef52a..f886ed3c073 100644 --- a/dbt/adapters/postgres.py +++ b/dbt/adapters/postgres.py @@ -291,7 +291,7 @@ def rename(cls, profile, from_name, to_name, model_name=None): @classmethod def execute_model(cls, profile, model): - parts = re.split(r'-- (DBT_OPERATION .*)', model.compiled_contents) + parts = re.split(r'-- (DBT_OPERATION .*)', model.get('wrapped_sql')) connection = cls.get_connection(profile) if flags.STRICT_MODE: @@ -317,7 +317,7 @@ def call_expand_target_column_types(kwargs): func_map[function](kwargs) else: handle, cursor = cls.add_query_to_transaction( - part, connection, model.name) + part, connection, model.get('name')) handle.commit() @@ -504,6 +504,16 @@ def commit(cls, profile): handle = connection.get('handle') handle.commit() + @classmethod + def rollback(cls, profile): + connection = cls.get_connection(profile) + + if flags.STRICT_MODE: + validate_connection(connection) + + handle = connection.get('handle') + handle.rollback() + @classmethod def get_status(cls, cursor): return cursor.statusmessage diff --git a/dbt/adapters/snowflake.py b/dbt/adapters/snowflake.py index c89c64fefbb..231d26ec4f1 100644 --- a/dbt/adapters/snowflake.py +++ b/dbt/adapters/snowflake.py @@ -182,7 +182,7 @@ def rename(cls, profile, from_name, to_name, model_name=None): @classmethod def execute_model(cls, profile, model): - parts = re.split(r'-- (DBT_OPERATION .*)', model.compiled_contents) + parts = re.split(r'-- (DBT_OPERATION .*)', model.get('wrapped_sql')) connection = cls.get_connection(profile) if flags.STRICT_MODE: @@ -216,7 +216,7 @@ def call_expand_target_column_types(kwargs): func_map[function](kwargs) else: handle, cursor = cls.add_query_to_transaction( - part, connection, model.name) + part, connection, model.get('name')) handle.commit() diff --git a/dbt/archival.py b/dbt/archival.py deleted file mode 100644 index 74245922cf1..00000000000 --- a/dbt/archival.py +++ /dev/null @@ -1,70 +0,0 @@ -from __future__ import print_function -import dbt.schema -import dbt.templates -import jinja2 - -from dbt.adapters.factory import get_adapter - - -class Archival(object): - - def __init__(self, project, archive_model): - self.archive_model = archive_model - self.project = project - - def compile(self): - source_schema = self.archive_model.source_schema - target_schema = self.archive_model.target_schema - source_table = self.archive_model.source_table - target_table = self.archive_model.target_table - unique_key = self.archive_model.unique_key - updated_at = self.archive_model.updated_at - - profile = self.project.run_environment() - adapter = get_adapter(profile) - - adapter.create_schema(profile, target_schema) - - source_columns = adapter.get_columns_in_table( - profile, source_schema, source_table) - - if len(source_columns) == 0: - raise RuntimeError( - 'Source table "{}"."{}" does not ' - 'exist'.format(source_schema, source_table)) - - extra_cols = [ - dbt.schema.Column("valid_from", "timestamp", None), - dbt.schema.Column("valid_to", "timestamp", None), - dbt.schema.Column("scd_id", "text", None), - dbt.schema.Column("dbt_updated_at", "timestamp", None) - ] - - dest_columns = source_columns + extra_cols - - adapter.create_table( - profile, - target_schema, - target_table, - dest_columns, - sort=updated_at, - dist=unique_key - ) - - env = jinja2.Environment() - - ctx = { - "columns": source_columns, - "updated_at": updated_at, - "unique_key": unique_key, - "source_schema": source_schema, - "source_table": source_table, - "target_schema": target_schema, - "target_table": target_table - } - - base_query = dbt.templates.SCDArchiveTemplate - template = env.from_string(base_query, globals=ctx) - rendered = template.render(ctx) - - return rendered diff --git a/dbt/compat.py b/dbt/compat.py new file mode 100644 index 00000000000..26cba2577d3 --- /dev/null +++ b/dbt/compat.py @@ -0,0 +1,45 @@ +import codecs + +WHICH_PYTHON = None + +try: + basestring + WHICH_PYTHON = 2 +except NameError: + WHICH_PYTHON = 3 + +if WHICH_PYTHON == 2: + basestring = basestring +else: + basestring = str + + +def to_unicode(s): + if WHICH_PYTHON == 2: + return unicode(s) + else: + return str(s) + + +def to_string(s): + if WHICH_PYTHON == 2: + if isinstance(s, unicode): + return s + elif isinstance(s, basestring): + return to_unicode(s) + else: + return to_unicode(str(s)) + else: + if isinstance(s, basestring): + return s + else: + return str(s) + + +def write_file(path, s): + if WHICH_PYTHON == 2: + with codecs.open(path, 'w', encoding='utf-8') as f: + return f.write(to_string(s)) + else: + with open(path, 'w') as f: + return f.write(to_string(s)) diff --git a/dbt/compilation.py b/dbt/compilation.py index efcf5aa8572..c2e0eaa103e 100644 --- a/dbt/compilation.py +++ b/dbt/compilation.py @@ -8,12 +8,20 @@ import dbt.project import dbt.utils +from dbt.model import Model, NodeType from dbt.source import Source from dbt.utils import find_model_by_fqn, find_model_by_name, \ - split_path, This, Var, compiler_error, to_string + split_path, This, Var, compiler_error from dbt.linker import Linker from dbt.runtime import RuntimeContext + +import dbt.compat +import dbt.contracts.graph.compiled +import dbt.contracts.graph.parsed +import dbt.contracts.project +import dbt.flags +import dbt.parser import dbt.templates from dbt.adapters.factory import get_adapter @@ -26,10 +34,35 @@ graph_file_name = 'graph.yml' +def compile_and_print_status(project, args): + compiler = Compiler(project, args) + compiler.initialize() + names = { + NodeType.Model: 'models', + NodeType.Test: 'tests', + NodeType.Archive: 'archives', + NodeType.Analysis: 'analyses', + } + + results = { + NodeType.Model: 0, + NodeType.Test: 0, + NodeType.Archive: 0, + NodeType.Analysis: 0, + } + + results.update(compiler.compile()) + + stat_line = ", ".join( + ["{} {}".format(ct, names.get(t)) for t, ct in results.items()]) + + logger.info("Compiled {}".format(stat_line)) + + def compile_string(string, ctx): try: env = jinja2.Environment() - template = env.from_string(str(string), globals=ctx) + template = env.from_string(dbt.compat.to_string(string), globals=ctx) return template.render(ctx) except jinja2.exceptions.TemplateSyntaxError as e: compiler_error(None, str(e)) @@ -37,12 +70,100 @@ def compile_string(string, ctx): compiler_error(None, str(e)) +def prepend_ctes(model, all_models): + model, _, all_models = recursively_prepend_ctes(model, all_models) + + return (model, all_models) + + +def recursively_prepend_ctes(model, all_models): + if dbt.flags.STRICT_MODE: + dbt.contracts.graph.compiled.validate_one(model) + dbt.contracts.graph.compiled.validate(all_models) + + model = model.copy() + prepend_ctes = [] + + if model.get('all_ctes_injected') is True: + return (model, model.get('extra_cte_ids'), all_models) + + for cte_id in model.get('extra_cte_ids'): + cte_to_add = all_models.get(cte_id) + cte_to_add, new_prepend_ctes, all_models = recursively_prepend_ctes( + cte_to_add, all_models) + + prepend_ctes = new_prepend_ctes + prepend_ctes + new_cte_name = '__dbt__CTE__{}'.format(cte_to_add.get('name')) + prepend_ctes.append(' {} as (\n{}\n)'.format( + new_cte_name, + cte_to_add.get('compiled_sql'))) + + model['extra_ctes_injected'] = True + model['extra_cte_sql'] = prepend_ctes + model['injected_sql'] = inject_ctes_into_sql( + model.get('compiled_sql'), + model.get('extra_cte_sql')) + + all_models[model.get('unique_id')] = model + + return (model, prepend_ctes, all_models) + + +def inject_ctes_into_sql(sql, ctes): + """ + `ctes` is a list of CTEs in the form: + + [ "__dbt__CTE__ephemeral as (select * from table)", + "__dbt__CTE__events as (select id, type from events)" ] + + Given `sql` like: + + "with internal_cte as (select * from sessions) + select * from internal_cte" + + This will spit out: + + "with __dbt__CTE__ephemeral as (select * from table), + __dbt__CTE__events as (select id, type from events), + with internal_cte as (select * from sessions) + select * from internal_cte" + + (Whitespace enhanced for readability.) + """ + if len(ctes) == 0: + return sql + + parsed_stmts = sqlparse.parse(sql) + parsed = parsed_stmts[0] + + with_stmt = None + for token in parsed.tokens: + if token.is_keyword and token.normalized == 'WITH': + with_stmt = token + break + + if with_stmt is None: + # no with stmt, add one, and inject CTEs right at the beginning + first_token = parsed.token_first() + with_stmt = sqlparse.sql.Token(sqlparse.tokens.Keyword, 'with') + parsed.insert_before(first_token, with_stmt) + else: + # stmt exists, add a comma (which will come after injected CTEs) + trailing_comma = sqlparse.sql.Token(sqlparse.tokens.Punctuation, ',') + parsed.insert_after(with_stmt, trailing_comma) + + parsed.insert_after( + with_stmt, + sqlparse.sql.Token(sqlparse.tokens.Keyword, ", ".join(ctes))) + + return dbt.compat.to_string(parsed) + + class Compiler(object): def __init__(self, project, args): self.project = project self.args = args - - self.macro_generator = None + self.parsed_models = None def initialize(self): if not os.path.exists(self.project['target-path']): @@ -51,118 +172,37 @@ def initialize(self): if not os.path.exists(self.project['modules-path']): os.makedirs(self.project['modules-path']) - def model_sources(self, this_project, own_project=None): - if own_project is None: - own_project = this_project - - paths = own_project.get('source-paths', []) - return Source( - this_project, - own_project=own_project - ).get_models(paths) - def get_macros(self, this_project, own_project=None): if own_project is None: own_project = this_project paths = own_project.get('macro-paths', []) return Source(this_project, own_project=own_project).get_macros(paths) - def get_archives(self, project): - return Source( - project, - own_project=project - ).get_archives() - - def project_schemas(self, project): - source_paths = project.get('source-paths', []) - return Source(project).get_schemas(source_paths) - - def project_tests(self, project): - source_paths = project.get('test-paths', []) - return Source(project).get_tests(source_paths) - - def analysis_sources(self, project): - paths = project.get('analysis-paths', []) - return Source(project).get_analyses(paths) - - def validate_models_unique(self, models, error_type): - found_models = defaultdict(list) - for model in models: - found_models[model.name].append(model) - for model_name, model_list in found_models.items(): - if len(model_list) > 1: - models_str = "\n - ".join( - [str(model) for model in model_list]) - - error_msg = "Found {} models with the same name.\n" \ - " Name='{}'\n" \ - " - {}".format( - len(model_list), model_name, models_str - ) - - error_type(model_list[0], error_msg) - def __write(self, build_filepath, payload): target_path = os.path.join(self.project['target-path'], build_filepath) if not os.path.exists(os.path.dirname(target_path)): os.makedirs(os.path.dirname(target_path)) - with open(target_path, 'w') as f: - f.write(to_string(payload)) + dbt.compat.write_file(target_path, payload) def __model_config(self, model, linker): def do_config(*args, **kwargs): - if len(args) == 1 and len(kwargs) == 0: - opts = args[0] - elif len(args) == 0 and len(kwargs) > 0: - opts = kwargs - else: - raise RuntimeError( - "Invalid model config given inline in {}".format(model) - ) + return '' - if type(opts) != dict: - raise RuntimeError( - "Invalid model config given inline in {}".format(model) - ) - - model.update_in_model_config(opts) - model.add_to_prologue("Config specified in model: {}".format(opts)) - return "" return do_config - def model_can_reference(self, src_model, other_model): - """ - returns True if the src_model can reference the other_model. Models - can access other models in their package and dependency models, but - a dependency model cannot access models "up" the dependency chain. - """ - - # hack for now b/c we don't support recursive dependencies - return ( - other_model.own_project['name'] == src_model.own_project['name'] or - src_model.own_project['name'] == src_model.project['name'] - ) - - def __ref(self, linker, ctx, model, all_models): - schema = ctx['env']['schema'] - - source_model = tuple(model.fqn) - linker.add_node(source_model) + def __ref(self, ctx, model, all_models): + schema = ctx.get('env', {}).get('schema') def do_ref(*args): + target_model_name = None + target_model_package = None + if len(args) == 1: - other_model_name = args[0] - other_model = find_model_by_name(all_models, other_model_name) + target_model_name = args[0] elif len(args) == 2: - other_model_package, other_model_name = args - - other_model = find_model_by_name( - all_models, - other_model_name, - package_namespace=other_model_package - ) + target_model_package, target_model_name = args else: compiler_error( model, @@ -171,54 +211,97 @@ def do_ref(*args): ) ) - other_model_fqn = tuple(other_model.fqn[:-1] + [other_model_name]) - src_fqn = ".".join(source_model) - ref_fqn = ".".join(other_model_fqn) + target_model = dbt.utils.find_model_by_name( + all_models, + target_model_name, + target_model_package) - if not other_model.is_enabled: - raise RuntimeError( + if target_model is None: + compiler_error( + model, + "Model '{}' depends on model '{}' which was not found." + .format(model.get('unique_id'), target_model_name)) + + target_model_id = target_model.get('unique_id') + + if target_model.get('config', {}) \ + .get('enabled') is False: + compiler_error( + model, "Model '{}' depends on model '{}' which is disabled in " - "the project config".format(src_fqn, ref_fqn) - ) + "the project config".format(model.get('unique_id'), + target_model.get('unique_id'))) - # this creates a trivial cycle -- should this be a compiler error? - # we can still interpolate the name w/o making a self-cycle - if source_model == other_model_fqn: - pass - else: - linker.dependency(source_model, other_model_fqn) + model['depends_on'].append(target_model_id) - if other_model.is_ephemeral: - linker.inject_cte(model, other_model) - return other_model.cte_name + if target_model.get('config', {}) \ + .get('materialized') == 'ephemeral': + + model['extra_cte_ids'].append(target_model_id) + return '__dbt__CTE__{}'.format(target_model.get('name')) else: - return '"{}"."{}"'.format(schema, other_model_name) + return '"{}"."{}"'.format(schema, target_model.get('name')) def wrapped_do_ref(*args): try: return do_ref(*args) except RuntimeError as e: - root = os.path.relpath( - model.root_dir, - model.project['project-root'] - ) - - filepath = os.path.join(root, model.rel_filepath) - logger.info("Compiler error in {}".format(filepath)) + logger.info("Compiler error in {}".format(model.get('path'))) logger.info("Enabled models:") - for m in all_models: - logger.info(" - {}".format(".".join(m.fqn))) + for n, m in all_models.items(): + if m.get('resource_type') == NodeType.Model: + logger.info(" - {}".format(m.get('unique_id'))) raise e return wrapped_do_ref + def get_compiler_context(self, linker, model, models, + macro_generator=None): + context = self.project.context() + + if macro_generator is not None: + for macro_data in macro_generator(context): + macro = macro_data["macro"] + macro_name = macro_data["name"] + project = macro_data["project"] + + if context.get(project.get('name')) is None: + context[project.get('name')] = {} + + context.get(project.get('name'), {}) \ + .update({macro_name: macro}) + + if model.get('package_name') == project.get('name'): + context.update({macro_name: macro}) + + adapter = get_adapter(self.project.run_environment()) + + # built-ins + context['ref'] = self.__ref(context, model, models) + context['config'] = self.__model_config(model, linker) + context['this'] = This( + context['env']['schema'], + (model.get('name') if dbt.flags.NON_DESTRUCTIVE + else '{}__dbt_tmp'.format(model.get('name'))), + model.get('name') + ) + context['var'] = Var(model, context=context) + context['target'] = self.project.get_target() + + # these get re-interpolated at runtime! + context['run_started_at'] = '{{ run_started_at }}' + context['invocation_id'] = '{{ invocation_id }}' + context['sql_now'] = adapter.date_function + + return context + def get_context(self, linker, model, models): runtime = RuntimeContext(model=model) context = self.project.context() # built-ins - context['ref'] = self.__ref(linker, context, model, models) + context['ref'] = self.__ref(context, model, models) context['config'] = self.__model_config(model, linker) context['this'] = This( context['env']['schema'], model.immediate_name, model.name @@ -235,107 +318,44 @@ def get_context(self, linker, model, models): runtime.update_global(context) - # add in macros (can we cache these somehow?) - for macro_data in self.macro_generator(context): - macro = macro_data["macro"] - macro_name = macro_data["name"] - project = macro_data["project"] - - runtime.update_package(project['name'], {macro_name: macro}) - - if project['name'] == self.project['name']: - runtime.update_global({macro_name: macro}) - return runtime - def compile_model(self, linker, model, models): + def compile_node(self, linker, node, nodes, macro_generator): try: - fs_loader = jinja2.FileSystemLoader(searchpath=model.root_dir) - jinja = jinja2.Environment(loader=fs_loader) + compiled_node = node.copy() + compiled_node.update({ + 'compiled': False, + 'compiled_sql': None, + 'extra_ctes_injected': False, + 'extra_cte_ids': [], + 'extra_cte_sql': [], + 'injected_sql': None, + }) + + context = self.get_compiler_context(linker, compiled_node, nodes, + macro_generator) - template_contents = dbt.clients.system.load_file_contents( - model.absolute_path) + env = jinja2.sandbox.SandboxedEnvironment() - template = jinja.from_string(template_contents) - context = self.get_context(linker, model, models) + compiled_node['compiled_sql'] = env.from_string( + node.get('raw_sql')).render(context) - rendered = template.render(context) + compiled_node['compiled'] = True except jinja2.exceptions.TemplateSyntaxError as e: - compiler_error(model, str(e)) + compiler_error(node, str(e)) except jinja2.exceptions.UndefinedError as e: - compiler_error(model, str(e)) + compiler_error(node, str(e)) - return rendered + return compiled_node def write_graph_file(self, linker): filename = graph_file_name graph_path = os.path.join(self.project['target-path'], filename) linker.write_graph(graph_path) - def combine_query_with_ctes(self, model, query, ctes, compiled_models): - parsed_stmts = sqlparse.parse(query) - if len(parsed_stmts) != 1: - raise RuntimeError( - "unexpectedly parsed {} queries from model " - "{}".format(len(parsed_stmts), model) - ) - - parsed = parsed_stmts[0] - - with_stmt = None - for token in parsed.tokens: - if token.is_keyword and token.normalized == 'WITH': - with_stmt = token - break - - if with_stmt is None: - # no with stmt, add one! - first_token = parsed.token_first() - with_stmt = sqlparse.sql.Token(sqlparse.tokens.Keyword, 'with') - parsed.insert_before(first_token, with_stmt) - else: - # stmt exists, add a comma (which will come after our injected - # CTE(s) ) - trailing_comma = sqlparse.sql.Token( - sqlparse.tokens.Punctuation, ',' - ) - parsed.insert_after(with_stmt, trailing_comma) - - cte_mapping = [ - (model.cte_name, compiled_models[model]) for model in ctes - ] - - # these newlines are important -- comments could otherwise interfere - # w/ query - cte_stmts = [ - " {} as (\n{}\n)".format(name, contents) - for (name, contents) in cte_mapping - ] - - cte_text = sqlparse.sql.Token( - sqlparse.tokens.Keyword, ", ".join(cte_stmts) - ) - parsed.insert_after(with_stmt, cte_text) - - return str(parsed) - - def __recursive_add_ctes(self, linker, model): - if model not in linker.cte_map: - return set() - - models_to_add = linker.cte_map[model] - recursive_models = [ - self.__recursive_add_ctes(linker, m) for m in models_to_add - ] - - for recursive_model_set in recursive_models: - models_to_add = models_to_add | recursive_model_set + def new_add_cte_to_rendered_query(self, linker, primary_model, + compiled_models): - return models_to_add - - def add_cte_to_rendered_query( - self, linker, primary_model, compiled_models - ): fqn_to_model = {tuple(model.fqn): model for model in compiled_models} sorted_nodes = linker.as_topological_ordering() @@ -361,154 +381,90 @@ def add_cte_to_rendered_query( ) return compiled_query - def remove_node_from_graph(self, linker, model, models): - # remove the node - children = linker.remove_node(tuple(model.fqn)) - - # check if we bricked the graph. if so: throw compilation error - for child in children: - other_model = find_model_by_fqn(models, child) - - if other_model.is_enabled: - this_fqn = ".".join(model.fqn) - that_fqn = ".".join(other_model.fqn) - compiler_error( - model, - "Model '{}' depends on model '{}' which is " - "disabled".format(that_fqn, this_fqn) - ) - - def compile_models(self, linker, models): - compiled_models = {model: self.compile_model(linker, model, models) - for model in models} - sorted_models = [find_model_by_fqn(models, fqn) - for fqn in linker.as_topological_ordering()] - - written_models = [] - for model in sorted_models: - # in-model configs were just evaluated. Evict anything that is - # newly-disabled - if not model.is_enabled: - self.remove_node_from_graph(linker, model, models) - continue - - injected_stmt = self.add_cte_to_rendered_query( - linker, model, compiled_models - ) - - context = self.get_context(linker, model, models) - wrapped_stmt = model.compile(injected_stmt, self.project, context) - - serialized = model.serialize() - linker.update_node_data(tuple(model.fqn), serialized) - - if model.is_ephemeral: - continue - - self.__write(model.build_path(), wrapped_stmt) - written_models.append(model) - - return compiled_models, written_models - - def compile_analyses(self, linker, compiled_models): - analyses = self.analysis_sources(self.project) - compiled_analyses = { - analysis: self.compile_model( - linker, analysis, compiled_models - ) for analysis in analyses - } - - written_analyses = [] - referenceable_models = {} - referenceable_models.update(compiled_models) - referenceable_models.update(compiled_analyses) - for analysis in analyses: - injected_stmt = self.add_cte_to_rendered_query( - linker, - analysis, - referenceable_models - ) - - serialized = analysis.serialize() - linker.update_node_data(tuple(analysis.fqn), serialized) - - build_path = analysis.build_path() - self.__write(build_path, injected_stmt) - written_analyses.append(analysis) - - return written_analyses - - def get_local_and_package_sources(self, project, source_getter): - all_sources = [] - - all_sources.extend(source_getter(project)) - - for package in dbt.utils.dependency_projects(project): - all_sources.extend(source_getter(package)) - - return all_sources + def compile_nodes(self, linker, nodes, macro_generator): + all_projects = self.get_all_projects() + + compiled_nodes = {} + injected_nodes = {} + wrapped_nodes = {} + written_nodes = [] + + for name, node in nodes.items(): + compiled_nodes[name] = self.compile_node(linker, node, nodes, + macro_generator) + + if dbt.flags.STRICT_MODE: + dbt.contracts.graph.compiled.validate(compiled_nodes) + + for name, node in compiled_nodes.items(): + node, compiled_nodes = prepend_ctes(node, compiled_nodes) + injected_nodes[name] = node + + if dbt.flags.STRICT_MODE: + dbt.contracts.graph.compiled.validate(injected_nodes) + + for name, injected_node in injected_nodes.items(): + # now turn model nodes back into the old-style model object for + # wrapping + if injected_node.get('resource_type') in [NodeType.Test, + NodeType.Analysis]: + # don't wrap tests or analyses. + injected_node['wrapped_sql'] = injected_node['injected_sql'] + wrapped_nodes[name] = injected_node + + elif injected_node.get('resource_type') == NodeType.Archive: + # unfortunately we do everything automagically for + # archives. in the future it'd be nice to generate + # the SQL at the parser level. + pass - def compile_schema_tests(self, linker, models): - all_schema_specs = self.get_local_and_package_sources( - self.project, - self.project_schemas - ) + else: + model = Model( + self.project, + injected_node.get('root_path'), + injected_node.get('path'), + all_projects.get(injected_node.get('package_name'))) - schema_tests = [] + cfg = injected_node.get('config', {}) + model._config = cfg - for schema in all_schema_specs: - # compiling a SchemaFile returns >= 0 SchemaTest models - try: - schema_tests.extend(schema.compile()) - except RuntimeError as e: - logger.info("\n" + str(e)) - schema_test_path = schema.filepath - logger.info("Skipping compilation for {}...\n" - .format(schema_test_path)) - - written_tests = [] - for schema_test in schema_tests: - # show a warning if the model being tested doesn't exist - try: - source_model = find_model_by_name(models, - schema_test.model_name) - except RuntimeError as e: - dbt.utils.compiler_warning(schema_test, str(e)) - continue + context = self.get_context(linker, model, injected_nodes) - if not source_model.is_enabled: - continue + wrapped_stmt = model.compile( + injected_node.get('injected_sql'), self.project, context) - serialized = schema_test.serialize() + injected_node['wrapped_sql'] = wrapped_stmt + wrapped_nodes[name] = injected_node - model_node = tuple(source_model.fqn) - test_node = tuple(schema_test.fqn) + build_path = os.path.join('build', injected_node.get('path')) - linker.dependency(test_node, model_node) - linker.update_node_data(test_node, serialized) + if injected_node.get('resource_type') in (NodeType.Model, + NodeType.Analysis) and \ + injected_node.get('config', {}) \ + .get('materialized') != 'ephemeral': + self.__write(build_path, injected_node.get('wrapped_sql')) + written_nodes.append(injected_node) + injected_node['build_path'] = build_path - query = schema_test.render() - self.__write(schema_test.build_path(), query) - written_tests.append(schema_test) + linker.add_node(injected_node.get('unique_id')) + project = all_projects[injected_node.get('package_name')] - return written_tests + linker.update_node_data( + injected_node.get('unique_id'), + injected_node) - def compile_data_tests(self, linker, models): - tests = self.get_local_and_package_sources( - self.project, - self.project_tests - ) + for dependency in injected_node.get('depends_on'): + if compiled_nodes.get(dependency): + linker.dependency( + injected_node.get('unique_id'), + compiled_nodes.get(dependency).get('unique_id')) + else: + compiler_error( + model, + "dependency {} not found in graph!".format( + dependency)) - written_tests = [] - for data_test in tests: - serialized = data_test.serialize() - linker.update_node_data(tuple(data_test.fqn), serialized) - query = self.compile_model(linker, data_test, models) - wrapped = data_test.render(query) - self.__write(data_test.build_path(), wrapped) - written_tests.append(data_test) - - return written_tests + return wrapped_nodes, written_nodes def generate_macros(self, all_macros): def do_gen(ctx): @@ -519,32 +475,108 @@ def do_gen(ctx): return macros return do_gen - def compile_archives(self, linker, compiled_models): - all_archives = self.get_archives(self.project) - - for archive in all_archives: - sql = archive.compile() - fqn = tuple(archive.fqn) - linker.update_node_data(fqn, archive.serialize()) - self.__write(archive.build_path(), sql) - - return all_archives - - def get_models(self): - all_models = self.model_sources(this_project=self.project) - for project in dbt.utils.dependency_projects(self.project): - all_models.extend( - self.model_sources( - this_project=self.project, own_project=project - ) - ) - - return all_models + def get_all_projects(self): + root_project = self.project.cfg + all_projects = {root_project.get('name'): root_project} + dependency_projects = dbt.utils.dependency_projects(self.project) + + for project in dependency_projects: + name = project.cfg.get('name', 'unknown') + all_projects[name] = project.cfg + + if dbt.flags.STRICT_MODE: + dbt.contracts.project.validate_list(all_projects) + + return all_projects + + def get_parsed_models(self, root_project, all_projects, macro_generator): + parsed_models = {} + + for name, project in all_projects.items(): + parsed_models.update( + dbt.parser.load_and_parse_sql( + package_name=name, + root_project=root_project, + all_projects=all_projects, + root_dir=project.get('project-root'), + relative_dirs=project.get('source-paths', []), + resource_type=NodeType.Model, + macro_generator=macro_generator)) + + return parsed_models + + def get_parsed_analyses(self, root_project, all_projects, macro_generator): + parsed_models = {} + + for name, project in all_projects.items(): + parsed_models.update( + dbt.parser.load_and_parse_sql( + package_name=name, + root_project=root_project, + all_projects=all_projects, + root_dir=project.get('project-root'), + relative_dirs=project.get('analysis-paths', []), + resource_type=NodeType.Analysis, + macro_generator=macro_generator)) + + return parsed_models + + def get_parsed_data_tests(self, root_project, all_projects, + macro_generator): + parsed_tests = {} + + for name, project in all_projects.items(): + parsed_tests.update( + dbt.parser.load_and_parse_sql( + package_name=name, + root_project=root_project, + all_projects=all_projects, + root_dir=project.get('project-root'), + relative_dirs=project.get('test-paths', []), + resource_type=NodeType.Test, + macro_generator=macro_generator, + tags=['data'])) + + return parsed_tests + + def get_parsed_schema_tests(self, root_project, all_projects): + parsed_tests = {} + + for name, project in all_projects.items(): + parsed_tests.update( + dbt.parser.load_and_parse_yml( + package_name=name, + root_project=root_project, + all_projects=all_projects, + root_dir=project.get('project-root'), + relative_dirs=project.get('source-paths', []))) + + return parsed_tests + + def load_all_nodes(self, root_project, all_projects, macro_generator): + all_nodes = {} + + all_nodes.update(self.get_parsed_models(root_project, all_projects, + macro_generator)) + all_nodes.update(self.get_parsed_analyses(root_project, all_projects, + macro_generator)) + all_nodes.update( + self.get_parsed_data_tests(root_project, all_projects, + macro_generator)) + all_nodes.update( + self.get_parsed_schema_tests(root_project, all_projects)) + all_nodes.update( + dbt.parser.parse_archives_from_projects(root_project, + all_projects)) + + return all_nodes def compile(self): linker = Linker() - all_models = self.get_models() + root_project = self.project.cfg + all_projects = self.get_all_projects() + all_macros = self.get_macros(this_project=self.project) for project in dbt.utils.dependency_projects(self.project): @@ -552,48 +584,22 @@ def compile(self): self.get_macros(this_project=self.project, own_project=project) ) - self.macro_generator = self.generate_macros(all_macros) + macro_generator = self.generate_macros(all_macros) - enabled_models = [ - model for model in all_models - if model.is_enabled and not model.is_empty - ] + all_nodes = self.load_all_nodes(root_project, all_projects, + macro_generator) - compiled_models, written_models = self.compile_models( - linker, enabled_models - ) - - compilers = { - 'schema tests': self.compile_schema_tests, - 'data tests': self.compile_data_tests, - 'archives': self.compile_archives, - 'analyses': self.compile_analyses - } - - compiled = { - 'models': written_models - } + compiled_nodes, written_nodes = self.compile_nodes(linker, all_nodes, + macro_generator) - for (compile_type, compiler_f) in compilers.items(): - newly_compiled = compiler_f(linker, compiled_models) - compiled[compile_type] = newly_compiled - - self.validate_models_unique( - compiled['models'], - dbt.utils.compiler_error - ) + # TODO re-add archives - self.validate_models_unique( - compiled['data tests'], - dbt.utils.compiler_warning - ) + self.write_graph_file(linker) - self.validate_models_unique( - compiled['schema tests'], - dbt.utils.compiler_warning - ) + stats = {} - self.write_graph_file(linker) + for node_name, node in compiled_nodes.items(): + stats[node.get('resource_type')] = stats.get( + node.get('resource_type'), 0) + 1 - stats = {ttype: len(m) for (ttype, m) in compiled.items()} return stats diff --git a/dbt/compiled_model.py b/dbt/compiled_model.py deleted file mode 100644 index 139566ec5fd..00000000000 --- a/dbt/compiled_model.py +++ /dev/null @@ -1,196 +0,0 @@ -import hashlib -import jinja2 -from dbt.utils import compiler_error, to_unicode -from dbt.adapters.factory import get_adapter -import dbt.model - - -class CompiledModel(object): - def __init__(self, fqn, data): - self.fqn = fqn - self.data = data - self.nice_name = ".".join(fqn) - - # these are set just before the models are executed - self.tmp_drop_type = None - self.final_drop_type = None - self.profile = None - - self.skip = False - self._contents = None - self.compiled_contents = None - - def __getitem__(self, key): - return self.data[key] - - def hashed_name(self): - fqn_string = ".".join(self.fqn) - return hashlib.md5(fqn_string.encode('utf-8')).hexdigest() - - def context(self): - return self.data - - def hashed_contents(self): - return hashlib.md5(self.contents.encode('utf-8')).hexdigest() - - def do_skip(self): - self.skip = True - - def should_skip(self): - return self.skip - - def is_type(self, run_type): - return self.data['dbt_run_type'] == run_type - - def is_test_type(self, test_type): - return self.data.get('dbt_test_type') == test_type - - def is_test(self): - return self.data['dbt_run_type'] == dbt.model.NodeType.Test - - @property - def contents(self): - if self._contents is None: - with open(self.data['build_path']) as fh: - self._contents = to_unicode(fh.read(), 'utf-8') - return self._contents - - def compile(self, context, profile, existing): - self.prepare(existing, profile) - - contents = self.contents - try: - env = jinja2.Environment() - self.compiled_contents = env.from_string(contents).render(context) - return self.compiled_contents - except jinja2.exceptions.TemplateSyntaxError as e: - compiler_error(self, str(e)) - - @property - def materialization(self): - return self.data['materialized'] - - @property - def name(self): - return self.data['name'] - - @property - def tmp_name(self): - return self.data['tmp_name'] - - def project(self): - return {'name': self.data['project_name']} - - @property - def schema(self): - if self.profile is None: - raise RuntimeError( - "`profile` not set in compiled model {}".format(self) - ) - else: - return get_adapter(self.profile).get_default_schema(self.profile) - - def should_execute(self, args, existing): - if args.non_destructive and \ - self.materialization == 'view' and \ - self.name in existing: - - return False - else: - return self.data['enabled'] and self.materialization != 'ephemeral' - - def should_rename(self, args): - if args.non_destructive and self.materialization == 'table': - return False - else: - return self.materialization in ['table', 'view'] - - def prepare(self, existing, profile): - if self.materialization == 'incremental': - tmp_drop_type = None - final_drop_type = None - else: - tmp_drop_type = existing.get(self.tmp_name, None) - final_drop_type = existing.get(self.name, None) - - self.tmp_drop_type = tmp_drop_type - self.final_drop_type = final_drop_type - self.profile = profile - - def __repr__(self): - return "".format( - self.data['project_name'], self.name, self.data['build_path'] - ) - - -class CompiledTest(CompiledModel): - def __init__(self, fqn, data): - super(CompiledTest, self).__init__(fqn, data) - - def should_rename(self): - return False - - def should_execute(self, args, existing): - return True - - def prepare(self, existing, profile): - self.profile = profile - - def __repr__(self): - return "".format( - self.data['project_name'], self.name, self.data['build_path'] - ) - - -class CompiledArchive(CompiledModel): - def __init__(self, fqn, data): - super(CompiledArchive, self).__init__(fqn, data) - - def should_rename(self): - return False - - def should_execute(self, args, existing): - return True - - def prepare(self, existing, profile): - self.profile = profile - - def __repr__(self): - return "".format( - self.data['project_name'], self.name, self.data['build_path'] - ) - - -class CompiledAnalysis(CompiledModel): - def __init__(self, fqn, data): - super(CompiledAnalysis, self).__init__(fqn, data) - - def should_rename(self): - return False - - def should_execute(self, args, existing): - return False - - def __repr__(self): - return "".format( - self.data['project_name'], self.name, self.data['build_path'] - ) - - -def make_compiled_model(fqn, data): - run_type = data['dbt_run_type'] - - if run_type == dbt.model.NodeType.Model: - return CompiledModel(fqn, data) - - elif run_type == dbt.model.NodeType.Test: - return CompiledTest(fqn, data) - - elif run_type == dbt.model.NodeType.Archive: - return CompiledArchive(fqn, data) - - elif run_type == dbt.model.NodeType.Analysis: - return CompiledAnalysis(fqn, data) - - else: - raise RuntimeError("invalid run_type given: {}".format(run_type)) diff --git a/dbt/contracts/common.py b/dbt/contracts/common.py new file mode 100644 index 00000000000..4f68581b294 --- /dev/null +++ b/dbt/contracts/common.py @@ -0,0 +1,17 @@ +from voluptuous.error import Invalid, MultipleInvalid + +from dbt.exceptions import ValidationException +from dbt.logger import GLOBAL_LOGGER as logger + + +def validate_with(schema, data): + try: + schema(data) + + except MultipleInvalid as e: + logger.error(str(e)) + raise ValidationException(str(e)) + + except Invalid as e: + logger.error(str(e)) + raise ValidationException(str(e)) diff --git a/dbt/contracts/connection.py b/dbt/contracts/connection.py index 46aeb90527d..6cd2a5ad0b0 100644 --- a/dbt/contracts/connection.py +++ b/dbt/contracts/connection.py @@ -1,7 +1,7 @@ from voluptuous import Schema, Required, All, Any, Extra, Range, Optional -from voluptuous.error import MultipleInvalid -from dbt.exceptions import ValidationException +from dbt.compat import basestring +from dbt.contracts.common import validate_with from dbt.logger import GLOBAL_LOGGER as logger @@ -13,22 +13,22 @@ }) postgres_credentials_contract = Schema({ - Required('dbname'): str, - Required('host'): str, - Required('user'): str, - Required('pass'): str, + Required('dbname'): basestring, + Required('host'): basestring, + Required('user'): basestring, + Required('pass'): basestring, Required('port'): All(int, Range(min=0, max=65535)), - Required('schema'): str, + Required('schema'): basestring, }) snowflake_credentials_contract = Schema({ - Required('account'): str, - Required('user'): str, - Required('password'): str, - Required('database'): str, - Required('schema'): str, - Required('warehouse'): str, - Optional('role'): str, + Required('account'): basestring, + Required('user'): basestring, + Required('password'): basestring, + Required('database'): basestring, + Required('schema'): basestring, + Required('warehouse'): basestring, + Optional('role'): basestring, }) credentials_mapping = { @@ -39,11 +39,7 @@ def validate_connection(connection): - try: - connection_contract(connection) - - credentials_contract = credentials_mapping.get(connection.get('type')) - credentials_contract(connection.get('credentials')) - except MultipleInvalid as e: - logger.info(e) - raise ValidationException(str(e)) + validate_with(connection_contract, connection) + + credentials_contract = credentials_mapping.get(connection.get('type')) + validate_with(credentials_contract, connection.get('credentials')) diff --git a/dbt/contracts/graph/__init__.py b/dbt/contracts/graph/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/dbt/contracts/graph/compiled.py b/dbt/contracts/graph/compiled.py new file mode 100644 index 00000000000..5277e7889f7 --- /dev/null +++ b/dbt/contracts/graph/compiled.py @@ -0,0 +1,36 @@ +from voluptuous import Schema, Required, All, Any, Extra, Range, Optional, \ + Length + +from dbt.compat import basestring +from dbt.exceptions import ValidationException +from dbt.logger import GLOBAL_LOGGER as logger + +from dbt.contracts.common import validate_with +from dbt.contracts.graph.parsed import parsed_graph_item_contract + + +compiled_graph_item_contract = parsed_graph_item_contract.extend({ + # compiled fields + Required('compiled'): bool, + Required('compiled_sql'): Any(basestring, None), + + # injected fields + Required('extra_ctes_injected'): bool, + Required('extra_cte_ids'): All(list, [basestring]), + Required('extra_cte_sql'): All(list, [basestring]), + Required('injected_sql'): Any(basestring, None), +}) + + +def validate_one(compiled_graph_item): + validate_with(compiled_graph_item_contract, compiled_graph_item) + + +def validate(compiled_graph): + for k, v in compiled_graph.items(): + validate_with(compiled_graph_item_contract, v) + + if v.get('unique_id') != k: + error_msg = 'unique_id must match key name in compiled graph!' + logger.info(error_msg) + raise ValidationException(error_msg) diff --git a/dbt/contracts/graph/parsed.py b/dbt/contracts/graph/parsed.py new file mode 100644 index 00000000000..320a18e47cd --- /dev/null +++ b/dbt/contracts/graph/parsed.py @@ -0,0 +1,68 @@ +from voluptuous import Schema, Required, All, Any, Extra, Range, Optional, \ + Length + +from dbt.compat import basestring +from dbt.exceptions import ValidationException +from dbt.logger import GLOBAL_LOGGER as logger + +from dbt.contracts.common import validate_with +from dbt.contracts.graph.unparsed import unparsed_graph_item_contract + + +config_contract = { + Required('enabled'): bool, + Required('materialized'): Any('table', 'view', 'ephemeral', 'incremental'), + Required('post-hook'): list, + Required('pre-hook'): list, + Required('vars'): dict, + + # incremental optional fields + Optional('sql_where'): basestring, + Optional('unique_key'): basestring, + + # adapter optional fields + Optional('sort'): basestring, + Optional('dist'): basestring, +} + + +parsed_graph_item_contract = unparsed_graph_item_contract.extend({ + # identifiers + Required('unique_id'): All(basestring, Length(min=1, max=255)), + Required('fqn'): All(list, [All(basestring)]), + + # parsed fields + Required('depends_on'): All(list, + [All(basestring, Length(min=1, max=255))]), + Required('empty'): bool, + Required('config'): config_contract, + Required('tags'): All(list, [basestring]), +}) + + +def validate_one(parsed_graph_item): + validate_with(parsed_graph_item_contract, parsed_graph_item) + + materialization = parsed_graph_item.get('config', {}) \ + .get('materialized') + + if materialization == 'incremental' and \ + parsed_graph_item.get('config', {}).get('sql_where') is None: + raise ValidationException( + 'missing `sql_where` for an incremental model') + elif (materialization != 'incremental' and + parsed_graph_item.get('config', {}).get('sql_where') is not None): + raise ValidationException( + 'invalid field `sql_where` for a non-incremental model') + + +def validate(parsed_graph): + for k, v in parsed_graph.items(): + validate_one(v) + + if v.get('unique_id') != k: + error_msg = ('unique_id must match key name in parsed graph!' + 'key: {}, model: {}' + .format(k, v)) + logger.info(error_msg) + raise ValidationException(error_msg) diff --git a/dbt/contracts/graph/unparsed.py b/dbt/contracts/graph/unparsed.py new file mode 100644 index 00000000000..b1b91c166d3 --- /dev/null +++ b/dbt/contracts/graph/unparsed.py @@ -0,0 +1,27 @@ +from voluptuous import Schema, Required, All, Any, Extra, Range, Optional, \ + Length + +from dbt.compat import basestring +from dbt.contracts.common import validate_with +from dbt.logger import GLOBAL_LOGGER as logger + +from dbt.model import NodeType + +unparsed_graph_item_contract = Schema({ + # identifiers + Required('name'): All(basestring, Length(min=1, max=63)), + Required('package_name'): basestring, + Required('resource_type'): Any(NodeType.Model, + NodeType.Test, + NodeType.Analysis), + + # filesystem + Required('root_path'): basestring, + Required('path'): basestring, + Required('raw_sql'): basestring, +}) + + +def validate(unparsed_graph): + for item in unparsed_graph: + validate_with(unparsed_graph_item_contract, item) diff --git a/dbt/contracts/project.py b/dbt/contracts/project.py new file mode 100644 index 00000000000..5b5668eea47 --- /dev/null +++ b/dbt/contracts/project.py @@ -0,0 +1,19 @@ +from voluptuous import Schema, Required, All, Any, Extra, Range, Optional, \ + Length, ALLOW_EXTRA + +from dbt.contracts.common import validate_with +from dbt.logger import GLOBAL_LOGGER as logger + +project_contract = Schema({ + Required('name'): str +}, extra=ALLOW_EXTRA) + +projects_list_contract = Schema({str: project_contract}) + + +def validate(project): + validate_with(project_contract, project) + + +def validate_list(projects): + validate_with(projects_list_contract, projects) diff --git a/dbt/exceptions.py b/dbt/exceptions.py index 991bfced3cd..b336e610e06 100644 --- a/dbt/exceptions.py +++ b/dbt/exceptions.py @@ -2,7 +2,11 @@ class Exception(BaseException): pass -class ValidationException(Exception): +class RuntimeException(RuntimeError, Exception): + pass + + +class ValidationException(RuntimeException): pass diff --git a/dbt/flags.py b/dbt/flags.py index 928c20aaeb6..3048445bf21 100644 --- a/dbt/flags.py +++ b/dbt/flags.py @@ -1 +1,2 @@ STRICT_MODE = False +NON_DESTRUCTIVE = False diff --git a/dbt/graph/selector.py b/dbt/graph/selector.py index dc6810a2ace..1fe227925bb 100644 --- a/dbt/graph/selector.py +++ b/dbt/graph/selector.py @@ -3,6 +3,9 @@ import networkx as nx from dbt.logger import GLOBAL_LOGGER as logger +import dbt.model + + SELECTOR_PARENTS = '+' SELECTOR_CHILDREN = '+' SELECTOR_GLOB = '*' @@ -43,7 +46,7 @@ def parse_spec(node_spec): def get_package_names(graph): - return set([node[0] for node in graph.nodes()]) + return set([node.split(".")[1] for node in graph.nodes()]) def is_selected_node(real_node, node_selector): @@ -79,17 +82,21 @@ def get_nodes_by_qualified_name(project, graph, qualified_name): package_names = get_package_names(graph) for node in graph.nodes(): - if len(qualified_name) == 1 and node[-1] == qualified_name[0]: + # node naming has changed to dot notation. split to tuple for + # compatibility with this code. + fqn_ish = node.split('.')[1:] + + if len(qualified_name) == 1 and fqn_ish == qualified_name[0]: yield node elif qualified_name[0] in package_names: - if is_selected_node(node, qualified_name): + if is_selected_node(fqn_ish, qualified_name): yield node else: for package_name in package_names: local_qualified_node_name = (package_name,) + qualified_name - if is_selected_node(node, local_qualified_node_name): + if is_selected_node(fqn_ish, local_qualified_node_name): yield node break @@ -104,6 +111,8 @@ def get_nodes_from_spec(project, graph, spec): qualified_node_name)) additional_nodes = set() + test_nodes = set() + if select_parents: for node in selected_nodes: parent_nodes = nx.ancestors(graph, node) @@ -114,7 +123,17 @@ def get_nodes_from_spec(project, graph, spec): child_nodes = nx.descendants(graph, node) additional_nodes.update(child_nodes) - return selected_nodes | additional_nodes + model_nodes = selected_nodes | additional_nodes + + for node in model_nodes: + # include tests that depend on this node. if we aren't running tests, + # they'll be filtered out later. + child_tests = [n for n in graph.successors(node) + if graph.node.get(n).get('resource_type') == + dbt.model.NodeType.Test] + test_nodes.update(child_tests) + + return model_nodes | test_nodes def warn_if_useless_spec(spec, nodes): diff --git a/dbt/linker.py b/dbt/linker.py index b144892ae1d..a56dc29777b 100644 --- a/dbt/linker.py +++ b/dbt/linker.py @@ -1,8 +1,17 @@ import networkx as nx from collections import defaultdict + +import dbt.compilation import dbt.model +def from_file(graph_file): + linker = Linker() + linker.read_graph(graph_file) + + return linker + + class Linker(object): def __init__(self, data=None): if data is None: @@ -94,6 +103,11 @@ def dependency(self, node1, node2): self.graph.add_node(node2) self.graph.add_edge(node2, node1) + if len(list(nx.simple_cycles(self.graph))) > 0: + raise ValidationException( + "Detected a cycle when adding dependency from {} to {}" + .format(node1, node2)) + def add_node(self, node): self.graph.add_node(node) diff --git a/dbt/main.py b/dbt/main.py index cdf3180c818..ec2e6bad1da 100644 --- a/dbt/main.py +++ b/dbt/main.py @@ -28,7 +28,7 @@ def main(args=None): args = sys.argv[1:] try: - return handle(args) + handle(args) except RuntimeError as e: logger.info("Encountered an error:") @@ -186,6 +186,12 @@ def invoke_dbt(parsed): log_dir = proj.get('log-path', 'logs') + if hasattr(proj.args, 'non_destructive') and \ + proj.args.non_destructive is True: + flags.NON_DESTRUCTIVE = True + else: + flags.NON_DESTRUCTIVE = False + logger.debug("running dbt with arguments %s", parsed) task = parsed.cls(args=parsed, project=proj) diff --git a/dbt/model.py b/dbt/model.py index 2feccfa9260..7da379afb0c 100644 --- a/dbt/model.py +++ b/dbt/model.py @@ -2,27 +2,23 @@ import yaml import jinja2 import re + +from dbt.compat import basestring from dbt.templates import BaseCreateTemplate, ArchiveInsertTemplate from dbt.utils import split_path -import dbt.schema_tester import dbt.project -import dbt.archival from dbt.adapters.factory import get_adapter from dbt.utils import deep_merge, DBTConfigKeys, compiler_error, \ compiler_warning +import dbt.flags class NodeType(object): Base = 'base' Model = 'model' + Analysis = 'analysis' Test = 'test' Archive = 'archive' - Analysis = 'analysis' - - -class TestNodeType(object): - SchemaTest = 'schema' - DataTest = 'data' class SourceConfig(object): @@ -42,6 +38,7 @@ class SourceConfig(object): ] def __init__(self, active_project, own_project, fqn): + self._config = None self.active_project = active_project self.own_project = own_project self.fqn = fqn @@ -89,6 +86,7 @@ def config(self): cfg = self._merge(defaults, active_config, self.in_model_config) else: own_config = self.load_config_from_own_project() + cfg = self._merge( defaults, own_config, self.in_model_config, active_config ) @@ -101,7 +99,8 @@ def config(self): return cfg def is_full_refresh(self): - if hasattr(self.active_project.args, 'full_refresh'): + if hasattr(self.active_project, 'args') and \ + hasattr(self.active_project.args, 'full_refresh'): return self.active_project.args.full_refresh else: return False @@ -169,7 +168,7 @@ def get_project_config(self, project): for k in SourceConfig.ExtendDictFields: config[k] = {} - model_configs = project['models'] + model_configs = project.get('models') if model_configs is None: return config @@ -208,6 +207,7 @@ class DBTSource(object): dbt_run_type = NodeType.Base def __init__(self, project, top_dir, rel_filepath, own_project): + self._config = None self.project = project self.own_project = own_project @@ -256,6 +256,9 @@ def contents(self): @property def config(self): + if self._config is not None: + return self._config + return self.source_config.config def update_in_model_config(self, config): @@ -313,10 +316,7 @@ def tmp_name(self): return "{}__dbt_tmp".format(self.name) def is_non_destructive(self): - if hasattr(self.project.args, 'non_destructive'): - return self.project.args.non_destructive - else: - return False + return dbt.flags.NON_DESTRUCTIVE def rename_query(self, schema): opts = { @@ -399,12 +399,6 @@ def build_path(self): return os.path.join(*path_parts) def compile_string(self, ctx, string): - # python 2+3 check for stringiness - try: - basestring - except NameError: - basestring = str - # if bool/int/float/etc are passed in, don't compile anything if not isinstance(string, basestring): return string @@ -498,224 +492,6 @@ def __repr__(self): ) -class Analysis(Model): - dbt_run_type = NodeType.Analysis - - def __init__(self, project, target_dir, rel_filepath, own_project): - return super(Analysis, self).__init__( - project, - target_dir, - rel_filepath, - own_project - ) - - def build_path(self): - build_dir = 'build-analysis' - filename = "{}.sql".format(self.name) - path_parts = [build_dir] + self.fqn[:-1] + [filename] - return os.path.join(*path_parts) - - def __repr__(self): - return "".format(self.name, self.filepath) - - -class SchemaTest(DBTSource): - test_type = "base" - dbt_run_type = NodeType.Test - dbt_test_type = TestNodeType.SchemaTest - - def __init__(self, project, target_dir, rel_filepath, model_name, options): - self.schema = project.context()['env']['schema'] - self.model_name = model_name - self.options = options - self.params = self.get_params(options) - - super(SchemaTest, self).__init__( - project, target_dir, rel_filepath, project - ) - - @property - def fqn(self): - parts = split_path(self.filepath) - name, _ = os.path.splitext(parts[-1]) - return [self.project['name']] + parts[1:-1] + \ - ['schema', self.get_filename()] - - def serialize(self): - serialized = DBTSource.serialize(self).copy() - serialized['dbt_test_type'] = self.dbt_test_type - - return serialized - - def get_params(self, options): - return { - "schema": self.schema, - "table": self.model_name, - "field": options - } - - def unique_option_key(self): - return self.params - - def get_filename(self): - key = re.sub('[^0-9a-zA-Z]+', '_', self.unique_option_key()) - filename = "{test_type}_{model_name}_{key}".format( - test_type=self.test_type, model_name=self.model_name, key=key - ) - return filename - - def build_path(self): - build_dir = "test" - filename = "{}.sql".format(self.get_filename()) - path_parts = [build_dir] + self.fqn[:-1] + [filename] - return os.path.join(*path_parts) - - @property - def template(self): - raise NotImplementedError("not implemented") - - def render(self): - return self.template.format(**self.params) - - def __repr__(self): - class_name = self.__class__.__name__ - return "<{} {}.{}: {}>".format( - class_name, self.project['name'], self.name, self.filepath - ) - - -class NotNullSchemaTest(SchemaTest): - template = dbt.schema_tester.QUERY_VALIDATE_NOT_NULL - test_type = "not_null" - - def unique_option_key(self): - return self.params['field'] - - def describe(self): - return 'VALIDATE NOT NULL {schema}.{table}.{field}' \ - .format(**self.params) - - -class UniqueSchemaTest(SchemaTest): - template = dbt.schema_tester.QUERY_VALIDATE_UNIQUE - test_type = "unique" - - def unique_option_key(self): - return self.params['field'] - - def describe(self): - return 'VALIDATE UNIQUE {schema}.{table}.{field}'.format(**self.params) - - -class ReferentialIntegritySchemaTest(SchemaTest): - template = dbt.schema_tester.QUERY_VALIDATE_REFERENTIAL_INTEGRITY - test_type = "relationships" - - def get_params(self, options): - return { - "schema": self.schema, - "child_table": self.model_name, - "child_field": options['from'], - "parent_table": options['to'], - "parent_field": options['field'] - } - - def unique_option_key(self): - return "{child_field}_to_{parent_table}_{parent_field}" \ - .format(**self.params) - - def describe(self): - return """VALIDATE REFERENTIAL INTEGRITY - {schema}.{child_table}.{child_field} to - {schema}.{parent_table}.{parent_field}""".format(**self.params) - - -class AcceptedValuesSchemaTest(SchemaTest): - template = dbt.schema_tester.QUERY_VALIDATE_ACCEPTED_VALUES - test_type = "accepted_values" - - def get_params(self, options): - quoted_values = ["'{}'".format(v) for v in options['values']] - quoted_values_csv = ",".join(quoted_values) - return { - "schema": self.schema, - "table": self.model_name, - "field": options['field'], - "values_csv": quoted_values_csv - } - - def unique_option_key(self): - return "{field}".format(**self.params) - - def describe(self): - return """VALIDATE ACCEPTED VALUES - {schema}.{table}.{field} VALUES - ({values_csv})""".format(**self.params) - - -class SchemaFile(DBTSource): - SchemaTestMap = { - 'not_null': NotNullSchemaTest, - 'unique': UniqueSchemaTest, - 'relationships': ReferentialIntegritySchemaTest, - 'accepted_values': AcceptedValuesSchemaTest - } - - def __init__(self, project, target_dir, rel_filepath, own_project): - super(SchemaFile, self).__init__( - project, target_dir, rel_filepath, own_project - ) - self.og_target_dir = target_dir - self.schema = yaml.safe_load(self.contents) - - def get_test(self, test_type): - if test_type in SchemaFile.SchemaTestMap: - return SchemaFile.SchemaTestMap[test_type] - else: - possible_types = ", ".join(SchemaFile.SchemaTestMap.keys()) - compiler_error( - self, - "Invalid validation type given in {}: '{}'. Possible: {}" - .format(self.filepath, test_type, possible_types) - ) - - def do_compile(self): - schema_tests = [] - for model_name, constraint_blob in self.schema.items(): - constraints = constraint_blob.get('constraints', {}) - for constraint_type, constraint_data in constraints.items(): - if constraint_data is None: - compiler_error( - self, - "no constraints given to test: '{}.{}'" - .format(model_name, constraint_type) - ) - for params in constraint_data: - schema_test_klass = self.get_test(constraint_type) - schema_test = schema_test_klass( - self.project, - self.og_target_dir, - self.rel_filepath, - model_name, - params - ) - schema_tests.append(schema_test) - return schema_tests - - def compile(self): - try: - return self.do_compile() - except TypeError as e: - compiler_error(self, str(e)) - except AttributeError as e: - compiler_error(self, str(e)) - - def __repr__(self): - return "".format( - self.project['name'], self.model_name, self.filepath - ) - - class Csv(DBTSource): def __init__(self, project, target_dir, rel_filepath, own_project): super(Csv, self).__init__( @@ -750,119 +526,3 @@ def __repr__(self): return "".format( self.project['name'], self.name, self.filepath ) - - -class ArchiveModel(DBTSource): - dbt_run_type = NodeType.Archive - build_dir = 'archive' - template = ArchiveInsertTemplate() - - def __init__(self, project, archive_data): - - self.validate(archive_data) - - self.source_schema = archive_data['source_schema'] - self.target_schema = archive_data['target_schema'] - self.source_table = archive_data['source_table'] - self.target_table = archive_data['target_table'] - self.unique_key = archive_data['unique_key'] - self.updated_at = archive_data['updated_at'] - - rel_filepath = os.path.join(self.target_schema, self.target_table) - - super(ArchiveModel, self).__init__( - project, self.build_dir, rel_filepath, project - ) - - def validate(self, data): - required = [ - 'source_schema', - 'target_schema', - 'source_table', - 'target_table', - 'unique_key', - 'updated_at', - ] - - for key in required: - if data.get(key, None) is None: - compiler_error( - "Invalid archive config: missing required field '{}'" - .format(key) - ) - - def serialize(self): - data = DBTSource.serialize(self).copy() - - serialized = { - "source_schema": self.source_schema, - "target_schema": self.target_schema, - "source_table": self.source_table, - "target_table": self.target_table, - "unique_key": self.unique_key, - "updated_at": self.updated_at - } - - data.update(serialized) - return data - - def compile(self): - archival = dbt.archival.Archival(self.project, self) - query = archival.compile() - - sql = self.template.wrap( - self.target_schema, self.target_table, query, self.unique_key - ) - - return sql - - def build_path(self): - filename = "{}.sql".format(self.name) - path_parts = [self.build_dir] + self.fqn[:-1] + [filename] - return os.path.join(*path_parts) - - def __repr__(self): - return " {} unique:{} updated_at:{}>".format( - self.source_table, - self.target_table, - self.unique_key, - self.updated_at - ) - - -class DataTest(DBTSource): - dbt_run_type = NodeType.Test - dbt_test_type = TestNodeType.DataTest - - def __init__(self, project, target_dir, rel_filepath, own_project): - super(DataTest, self).__init__( - project, - target_dir, - rel_filepath, - own_project - ) - - def build_path(self): - build_dir = "test" - filename = "{}.sql".format(self.name) - fqn_parts = self.fqn[0:1] + ['data'] + self.fqn[1:-1] - path_parts = [build_dir] + fqn_parts + [filename] - return os.path.join(*path_parts) - - def serialize(self): - serialized = DBTSource.serialize(self).copy() - serialized['dbt_test_type'] = self.dbt_test_type - - return serialized - - def render(self, query): - return "select count(*) from (\n{}\n) sbq".format(query) - - @property - def immediate_name(self): - return self.name - - def __repr__(self): - return "".format( - self.project['name'], self.name, self.filepath - ) diff --git a/dbt/parser.py b/dbt/parser.py new file mode 100644 index 00000000000..ad3ce1f66e8 --- /dev/null +++ b/dbt/parser.py @@ -0,0 +1,409 @@ +import copy +import jinja2 +import jinja2.sandbox +import os +import yaml + +import dbt.flags +import dbt.model +import dbt.utils + +import dbt.contracts.graph.parsed +import dbt.contracts.graph.unparsed +import dbt.contracts.project + +from dbt.model import NodeType + +QUERY_VALIDATE_NOT_NULL = """ +with validation as ( + select {field} as f + from {ref} +) +select count(*) from validation where f is null +""" + + +QUERY_VALIDATE_UNIQUE = """ +with validation as ( + select {field} as f + from {ref} + where {field} is not null +), +validation_errors as ( + select f from validation group by f having count(*) > 1 +) +select count(*) from validation_errors +""" + + +QUERY_VALIDATE_ACCEPTED_VALUES = """ +with all_values as ( + select distinct {field} as f + from {ref} +), +validation_errors as ( + select f from all_values where f not in ({values_csv}) +) +select count(*) from validation_errors +""" + + +QUERY_VALIDATE_REFERENTIAL_INTEGRITY = """ +with parent as ( + select {parent_field} as id + from {parent_ref} +), child as ( + select {child_field} as id + from {child_ref} +) +select count(*) from child +where id not in (select id from parent) and id is not null +""" + + +class SilentUndefined(jinja2.Undefined): + """ + This class sets up the parser to just ignore undefined jinja2 calls. So, + for example, `env` is not defined here, but will not make the parser fail + with a fatal error. + """ + def _fail_with_undefined_error(self, *args, **kwargs): + return None + + __add__ = __radd__ = __mul__ = __rmul__ = __div__ = __rdiv__ = \ + __truediv__ = __rtruediv__ = __floordiv__ = __rfloordiv__ = \ + __mod__ = __rmod__ = __pos__ = __neg__ = __call__ = \ + __getitem__ = __lt__ = __le__ = __gt__ = __ge__ = __int__ = \ + __float__ = __complex__ = __pow__ = __rpow__ = \ + _fail_with_undefined_error + + +def get_path(resource_type, package_name, resource_name): + return "{}.{}.{}".format(resource_type, package_name, resource_name) + + +def get_model_path(package_name, resource_name): + return get_path(NodeType.Model, package_name, resource_name) + + +def get_test_path(package_name, resource_name): + return get_path(NodeType.Test, package_name, resource_name) + + +def get_macro_path(package_name, resource_name): + return get_path('macros', package_name, resource_name) + + +def __ref(model): + + def ref(*args): + pass + + return ref + + +def __config(model, cfg): + + def config(*args, **kwargs): + if len(args) == 1 and len(kwargs) == 0: + opts = args[0] + elif len(args) == 0 and len(kwargs) > 0: + opts = kwargs + else: + dbt.utils.compiler_error( + model.get('name'), + "Invalid model config given inline in {}".format(model)) + + cfg.update_in_model_config(opts) + + return config + + +def get_fqn(path, package_project_config, extra=[]): + parts = dbt.utils.split_path(path) + name, _ = os.path.splitext(parts[-1]) + fqn = ([package_project_config.get('name')] + + parts[:-1] + + extra + + [name]) + + return fqn + + +def parse_node(node, node_path, root_project_config, package_project_config, + macro_generator=None, tags=[], fqn_extra=[]): + parsed_node = copy.deepcopy(node) + + parsed_node.update({ + 'depends_on': [], + }) + + fqn = get_fqn(node.get('path'), package_project_config, fqn_extra) + + config = dbt.model.SourceConfig( + root_project_config, package_project_config, fqn) + + context = {} + + if macro_generator is not None: + for macro_data in macro_generator(context): + macro = macro_data["macro"] + macro_name = macro_data["name"] + project = macro_data["project"] + + if context.get(project.get('name')) is None: + context[project.get('name')] = {} + + context.get(project.get('name'), {}) \ + .update({macro_name: macro}) + + if node.get('package_name') == project.get('name'): + context.update({macro_name: macro}) + + context['ref'] = __ref(parsed_node) + context['config'] = __config(parsed_node, config) + + env = jinja2.sandbox.SandboxedEnvironment( + undefined=SilentUndefined) + + env.from_string(node.get('raw_sql')).render(context) + + config_dict = node.get('config', {}) + config_dict.update(config.config) + + parsed_node['unique_id'] = node_path + parsed_node['config'] = config_dict + parsed_node['empty'] = (len(node.get('raw_sql').strip()) == 0) + parsed_node['fqn'] = fqn + parsed_node['tags'] = tags + + return parsed_node + + +def parse_sql_nodes(nodes, root_project, projects, macro_generator=None, + tags=[]): + to_return = {} + + dbt.contracts.graph.unparsed.validate(nodes) + + for node in nodes: + package_name = node.get('package_name') + + node_path = get_path(node.get('resource_type'), + package_name, + node.get('name')) + + # TODO if this is set, raise a compiler error + to_return[node_path] = parse_node(node, + node_path, + root_project, + projects.get(package_name), + macro_generator, + tags=tags) + + dbt.contracts.graph.parsed.validate(to_return) + + return to_return + + +def load_and_parse_sql(package_name, root_project, all_projects, root_dir, + relative_dirs, resource_type, macro_generator, tags=[]): + extension = "[!.#~]*.sql" + + if dbt.flags.STRICT_MODE: + dbt.contracts.project.validate_list(all_projects) + + file_matches = dbt.clients.system.find_matching( + root_dir, + relative_dirs, + extension) + + result = [] + + for file_match in file_matches: + file_contents = dbt.clients.system.load_file_contents( + file_match.get('absolute_path')) + + parts = dbt.utils.split_path(file_match.get('relative_path', '')) + name, _ = os.path.splitext(parts[-1]) + + result.append({ + 'name': name, + 'root_path': root_dir, + 'resource_type': resource_type, + 'path': file_match.get('relative_path'), + 'package_name': package_name, + 'raw_sql': file_contents + }) + + return parse_sql_nodes(result, root_project, all_projects, macro_generator, + tags) + + +def parse_schema_tests(tests, root_project, projects): + to_return = {} + + for test in tests: + test_yml = yaml.safe_load(test.get('raw_yml')) + + # validate schema test yml structure + + for model_name, test_spec in test_yml.items(): + for test_type, configs in test_spec.get('constraints', {}).items(): + for config in configs: + to_add = parse_schema_test( + test, model_name, config, test_type, + root_project, + projects.get(test.get('package_name'))) + + if to_add is not None: + to_return[to_add.get('unique_id')] = to_add + + return to_return + + +def parse_schema_test(test_base, model_name, test_config, test_type, + root_project_config, package_project_config): + if test_type == 'not_null': + raw_sql = QUERY_VALIDATE_NOT_NULL.format( + ref="{{ref('"+model_name+"')}}", field=test_config) + name_key = test_config + + elif test_type == 'unique': + raw_sql = QUERY_VALIDATE_UNIQUE.format( + ref="{{ref('"+model_name+"')}}", field=test_config) + name_key = test_config + + elif test_type == 'relationships': + if not isinstance(test_config, dict): + return None + + child_field = test_config.get('from') + parent_field = test_config.get('field') + parent_model = test_config.get('to') + + raw_sql = QUERY_VALIDATE_REFERENTIAL_INTEGRITY.format( + child_field=child_field, + child_ref="{{ref('"+model_name+"')}}", + parent_field=parent_field, + parent_ref=("{{ref('"+parent_model+"')}}")) + + name_key = '{}_to_{}_{}'.format(child_field, parent_model, + parent_field) + + elif test_type == 'accepted_values': + if not isinstance(test_config, dict): + return None + + raw_sql = QUERY_VALIDATE_ACCEPTED_VALUES.format( + ref="{{ref('"+model_name+"')}}", + field=test_config.get('field', ''), + values_csv="'{}'".format( + "','".join([str(v) for v in test_config.get('values', [])]))) + + name_key = test_config.get('field') + + else: + raise dbt.exceptions.ValidationException( + 'Unknown schema test type {}'.format(test_type)) + + name = '{}_{}_{}'.format(test_type, model_name, name_key) + + to_return = { + 'name': name, + 'resource_type': test_base.get('resource_type'), + 'package_name': test_base.get('package_name'), + 'root_path': test_base.get('root_path'), + 'path': test_base.get('path'), + 'raw_sql': raw_sql + } + + return parse_node(to_return, + get_test_path(test_base.get('package_name'), + name), + root_project_config, + package_project_config, + tags=['schema'], + fqn_extra=['schema']) + + +def load_and_parse_yml(package_name, root_project, all_projects, root_dir, + relative_dirs): + extension = "[!.#~]*.yml" + + if dbt.flags.STRICT_MODE: + dbt.contracts.project.validate_list(all_projects) + + file_matches = dbt.clients.system.find_matching( + root_dir, + relative_dirs, + extension) + + result = [] + + for file_match in file_matches: + file_contents = dbt.clients.system.load_file_contents( + file_match.get('absolute_path')) + + parts = dbt.utils.split_path(file_match.get('relative_path', '')) + name, _ = os.path.splitext(parts[-1]) + + result.append({ + 'name': name, + 'root_path': root_dir, + 'resource_type': NodeType.Test, + 'path': file_match.get('relative_path'), + 'package_name': package_name, + 'raw_yml': file_contents + }) + + return parse_schema_tests(result, root_project, all_projects) + + +def parse_archives_from_projects(root_project, all_projects): + archives = [] + to_return = {} + + for name, project in all_projects.items(): + archives = archives + parse_archives_from_project(project) + + for archive in archives: + node_path = get_path(archive.get('resource_type'), + archive.get('package_name'), + archive.get('name')) + + to_return[node_path] = parse_node( + archive, + node_path, + root_project, + all_projects.get(archive.get('package_name'))) + + return to_return + + +def parse_archives_from_project(project): + archives = [] + archive_configs = project.get('archive', []) + + for archive_config in archive_configs: + tables = archive_config.get('tables') + + if tables is None: + continue + + for table in tables: + config = table.copy() + config['source_schema'] = archive_config.get('source_schema') + config['target_schema'] = archive_config.get('target_schema') + + archives.append({ + 'name': table.get('target_table'), + 'root_path': project.get('project-root'), + 'resource_type': NodeType.Archive, + 'path': project.get('project-root'), + 'package_name': project.get('name'), + 'config': config, + 'raw_sql': '-- noop' + }) + + return archives diff --git a/dbt/runner.py b/dbt/runner.py index 469b4780e13..b86ed6cae66 100644 --- a/dbt/runner.py +++ b/dbt/runner.py @@ -1,6 +1,7 @@ - from __future__ import print_function +import jinja2 +import hashlib import psycopg2 import os import sys @@ -13,14 +14,15 @@ from dbt.adapters.factory import get_adapter from dbt.logger import GLOBAL_LOGGER as logger -import dbt.compilation -from dbt.linker import Linker + from dbt.source import Source from dbt.utils import find_model_by_fqn, find_model_by_name, \ dependency_projects -from dbt.compiled_model import make_compiled_model +from dbt.model import NodeType +import dbt.compilation import dbt.exceptions +import dbt.linker import dbt.tracking import dbt.schema import dbt.graph.selector @@ -33,312 +35,355 @@ def get_timestamp(): - return "{} |".format(time.strftime("%H:%M:%S")) + return time.strftime("%H:%M:%S") -class RunModelResult(object): - def __init__(self, model, error=None, skip=False, status=None, - execution_time=0): - self.model = model - self.error = error - self.skip = skip - self.status = status - self.execution_time = execution_time +def get_materialization(model): + return model.get('config', {}).get('materialized') - @property - def errored(self): - return self.error is not None - @property - def skipped(self): - return self.skip +def get_hash(model): + return hashlib.md5(model.get('unique_id').encode('utf-8')).hexdigest() -class BaseRunner(object): - def __init__(self, project): - self.project = project +def get_hashed_contents(model): + return hashlib.md5(model.get('raw_sql').encode('utf-8')).hexdigest() - self.profile = project.run_environment() - self.adapter = get_adapter(self.profile) - def pre_run_msg(self, model): - raise NotImplementedError("not implemented") +def is_enabled(model): + return model.get('config', {}).get('enabled') is True - def skip_msg(self, model): - return "SKIP relation {}.{}".format( - self.adapter.get_default_schema(self.profile), model.name) - def post_run_msg(self, result): - raise NotImplementedError("not implemented") +def print_timestamped_line(msg): + logger.info("{} | {}".format(get_timestamp(), msg)) - def pre_run_all_msg(self, models): - raise NotImplementedError("not implemented") - def post_run_all_msg(self, results): - raise NotImplementedError("not implemented") +def print_fancy_output_line(msg, status, index, total, execution_time=None): + prefix = "{timestamp} | {index} of {total} {message}".format( + timestamp=get_timestamp(), + index=index, + total=total, + message=msg) + justified = prefix.ljust(80, ".") - def post_run_all(self, models, results, context): - pass + if execution_time is None: + status_time = "" + else: + status_time = " in {execution_time:0.2f}s".format( + execution_time=execution_time) - def pre_run_all(self, models, context): - pass + output = "{justified} [{status}{status_time}]".format( + justified=justified, status=status, status_time=status_time) - def status(self, result): - raise NotImplementedError("not implemented") + logger.info(output) -class ModelRunner(BaseRunner): - run_type = dbt.model.NodeType.Model +def print_skip_line(model, schema, relation, index, num_models): + msg = 'SKIP relation {}.{}'.format(schema, relation) + print_fancy_output_line(msg, 'SKIP', index, num_models) - def pre_run_msg(self, model): - print_vars = { - "schema": self.adapter.get_default_schema(self.profile), - "model_name": model.name, - "model_type": model.materialization, - "info": "START" - } - output = ("START {model_type} model {schema}.{model_name} " - .format(**print_vars)) - return output - - def post_run_msg(self, result): - model = result.model - print_vars = { - "schema": self.adapter.get_default_schema(self.profile), - "model_name": model.name, - "model_type": model.materialization, - "info": "ERROR creating" if result.errored else "OK created" - } +def print_counts(flat_nodes): + counts = {} + for node in flat_nodes: + t = node.get('resource_type') + counts[t] = counts.get(t, 0) + 1 - output = ("{info} {model_type} model {schema}.{model_name} " - .format(**print_vars)) - return output + for k, v in counts.items(): + logger.info("") + print_timestamped_line("Running {} {}s".format(v, k)) + print_timestamped_line("") - def pre_run_all_msg(self, models): - return "{} Running {} models".format(get_timestamp(), len(models)) - def post_run_all_msg(self, results): - return ("{} Finished running {} models" - .format(get_timestamp(), len(results))) +def print_start_line(node, schema_name, index, total): + if node.get('resource_type') == NodeType.Model: + print_model_start_line(node, schema_name, index, total) + if node.get('resource_type') == NodeType.Test: + print_test_start_line(node, schema_name, index, total) - def status(self, result): - return result.status - def is_non_destructive(self): - if hasattr(self.project.args, 'non_destructive'): - return self.project.args.non_destructive - else: - return False +def print_test_start_line(model, schema_name, index, total): + msg = "START test {name}".format( + name=model.get('name')) - def execute(self, model): - profile = self.project.run_environment() - adapter = get_adapter(profile) + print_fancy_output_line(msg, 'RUN', index, total) - if model.tmp_drop_type is not None: - if model.materialization == 'table' and self.is_non_destructive(): - adapter.truncate( - profile=profile, - table=model.tmp_name, - model_name=model.name) - else: - adapter.drop( - profile=profile, - relation=model.tmp_name, - relation_type=model.tmp_drop_type, - model_name=model.name) - - status = adapter.execute_model( - profile=profile, - model=model) - - if model.final_drop_type is not None: - if model.materialization == 'table' and self.is_non_destructive(): - # we just inserted into this recently truncated table... - # do nothing here - pass - else: - adapter.drop( - profile=profile, - relation=model.name, - relation_type=model.final_drop_type, - model_name=model.name) - - if model.should_rename(self.project.args): - adapter.rename( - profile=profile, - from_name=model.tmp_name, - to_name=model.name, - model_name=model.name) - adapter.commit( - profile=profile) +def print_model_start_line(model, schema_name, index, total): + msg = "START {model_type} model {schema}.{relation}".format( + model_type=get_materialization(model), + schema=schema_name, + relation=model.get('name')) - return status + print_fancy_output_line(msg, 'RUN', index, total) - def __run_hooks(self, hooks, context, source): - if type(hooks) not in (list, tuple): - hooks = [hooks] - target = self.project.get_target() +def print_result_line(result, schema_name, index, total): + node = result.node - ctx = { - "target": target, - "state": "start", - "invocation_id": context['invocation_id'], - "run_started_at": context['run_started_at'] - } + if node.get('resource_type') == NodeType.Model: + print_model_result_line(result, schema_name, index, total) + elif node.get('resource_type') == NodeType.Test: + print_test_result_line(result, schema_name, index, total) - compiled_hooks = [ - dbt.compilation.compile_string(hook, ctx) for hook in hooks - ] - profile = self.project.run_environment() - adapter = get_adapter(profile) +def print_test_result_line(result, schema_name, index, total): + model = result.node + info = 'PASS' + + if result.errored: + info = "ERROR" + elif result.status > 0: + info = 'FAIL {}'.format(result.status) + elif result.status == 0: + info = 'PASS' + else: + raise RuntimeError("unexpected status: {}".format(result.status)) + + print_fancy_output_line( + "{info} {name}".format( + info=info, + name=model.get('name')), + info, + index, + total, + result.execution_time) + + +def execute_test(profile, test): + adapter = get_adapter(profile) + _, cursor = adapter.execute_one( + profile, + test.get('wrapped_sql'), + test.get('name')) + + rows = cursor.fetchall() + + adapter.commit(profile) + + cursor.close() + + if len(rows) > 1: + raise RuntimeError( + "Bad test {name}: Returned {num_rows} rows instead of 1" + .format(name=model.name, num_rows=len(rows))) + + row = rows[0] + if len(row) > 1: + raise RuntimeError( + "Bad test {name}: Returned {num_cols} cols instead of 1" + .format(name=model.name, num_cols=len(row))) - adapter.execute_all( + return row[0] + + +def print_model_result_line(result, schema_name, index, total): + model = result.node + info = 'OK created' + + if result.errored: + info = 'ERROR creating' + + print_fancy_output_line( + "{info} {model_type} model {schema}.{relation}".format( + info=info, + model_type=get_materialization(model), + schema=schema_name, + relation=model.get('name')), + result.status, + index, + total, + result.execution_time) + + +def print_results_line(results, execution_time): + stats = {} + + for result in results: + stats[result.node.get('resource_type')] = stats.get( + result.node.get('resource_type'), 0) + 1 + + stat_line = ", ".join( + ["{} {}s".format(ct, t) for t, ct in stats.items()]) + + print_timestamped_line("") + print_timestamped_line( + "Finished running {stat_line} in {execution_time:0.2f}s." + .format(stat_line=stat_line, execution_time=execution_time)) + + +def execute_model(profile, model, existing): + adapter = get_adapter(profile) + schema = adapter.get_default_schema(profile) + + tmp_name = '{}__dbt_tmp'.format(model.get('name')) + + if dbt.flags.NON_DESTRUCTIVE: + # for non destructive mode, we only look at the already existing table. + tmp_name = model.get('name') + + result = None + + # TRUNCATE / DROP + if get_materialization(model) == 'table' and \ + dbt.flags.NON_DESTRUCTIVE and \ + existing.get(tmp_name) == 'table': + # tables get truncated instead of dropped in non-destructive mode. + adapter.truncate( profile=profile, - queries=compiled_hooks, - model_name=source) + table=tmp_name, + model_name=model.get('name')) - adapter.commit(profile) + elif dbt.flags.NON_DESTRUCTIVE: + # never drop existing relations in non destructive mode. + pass - def pre_run_all(self, models, context): - hooks = self.project.cfg.get('on-run-start', []) - self.__run_hooks(hooks, context, 'on-run-start hooks') + elif (get_materialization(model) != 'incremental' and + existing.get(tmp_name) is not None): + # otherwise, for non-incremental things, drop them with IF EXISTS + adapter.drop( + profile=profile, + relation=tmp_name, + relation_type=existing.get(tmp_name), + model_name=model.get('name')) + + # and update the list of what exists + existing = adapter.query_for_existing(profile, schema) + + # EXECUTE + if get_materialization(model) == 'view' and dbt.flags.NON_DESTRUCTIVE and \ + model.get('name') in existing: + # views don't need to be recreated in non destructive mode since they + # will repopulate automatically. note that we won't run DDL for these + # views either. + pass + elif is_enabled(model) and get_materialization(model) != 'ephemeral': + result = adapter.execute_model(profile, model) - def post_run_all(self, models, results, context): - hooks = self.project.cfg.get('on-run-end', []) - self.__run_hooks(hooks, context, 'on-run-end hooks') + # DROP OLD RELATION AND RENAME + if dbt.flags.NON_DESTRUCTIVE: + # in non-destructive mode, we truncate and repopulate tables, and + # don't modify views. + pass + elif get_materialization(model) in ['table', 'view']: + # otherwise, drop tables and views, and rename tmp tables/views to + # their new names + if existing.get(model.get('name')) is not None: + adapter.drop( + profile=profile, + relation=model.get('name'), + relation_type=existing.get(model.get('name')), + model_name=model.get('name')) + adapter.rename(profile=profile, + from_name=tmp_name, + to_name=model.get('name'), + model_name=model.get('name')) -class TestRunner(ModelRunner): - run_type = dbt.model.NodeType.Test + return result - test_data_type = dbt.model.TestNodeType.DataTest - test_schema_type = dbt.model.TestNodeType.SchemaTest - def pre_run_msg(self, model): - if model.is_test_type(self.test_data_type): - return "DATA TEST {name} ".format(name=model.name) - else: - return "SCHEMA TEST {name} ".format(name=model.name) - - def post_run_msg(self, result): - model = result.model - info = self.status(result) - - return "{info} {name} ".format(info=info, name=model.name) - - def pre_run_all_msg(self, models): - return "{} Running {} tests".format(get_timestamp(), len(models)) - - def post_run_all_msg(self, results): - total = len(results) - passed = len([result for result in results if not - result.errored and not result.skipped and - result.status == 0]) - failed = len([result for result in results if not - result.errored and not result.skipped and - result.status > 0]) - errored = len([result for result in results if result.errored]) - skipped = len([result for result in results if result.skipped]) - - total_errors = failed + errored - - overview = ("PASS={passed} FAIL={total_errors} SKIP={skipped} " - "TOTAL={total}".format( - total=total, - passed=passed, - total_errors=total_errors, - skipped=skipped)) - - if total_errors > 0: - final = "Tests completed with errors" - else: - final = "All tests passed" +def execute_archive(profile, node, context): + adapter = get_adapter(profile) - return "\n{overview}\n{final}".format(overview=overview, final=final) + node_cfg = node.get('config', {}) - def status(self, result): - if result.errored: - info = "ERROR" - elif result.status > 0: - info = 'FAIL {}'.format(result.status) - elif result.status == 0: - info = 'PASS' - else: - raise RuntimeError("unexpected status: {}".format(result.status)) + source_columns = adapter.get_columns_in_table( + profile, node_cfg.get('source_schema'), node_cfg.get('source_table')) - return info + if len(source_columns) == 0: + raise RuntimeError( + 'Source table "{}"."{}" does not ' + 'exist'.format(source_schema, source_table)) - def execute(self, model): - profile = self.project.run_environment() - adapter = get_adapter(profile) + dest_columns = source_columns + [ + dbt.schema.Column("valid_from", "timestamp", None), + dbt.schema.Column("valid_to", "timestamp", None), + dbt.schema.Column("scd_id", "text", None), + dbt.schema.Column("dbt_updated_at", "timestamp", None) + ] - _, cursor = adapter.execute_one( - profile, model.compiled_contents, model.name) - rows = cursor.fetchall() + adapter.create_table( + profile, + schema=node_cfg.get('target_schema'), + table=node_cfg.get('target_table'), + columns=dest_columns, + sort=node_cfg.get('updated_at'), + dist=node_cfg.get('unique_key')) - cursor.close() + # TODO move this to inject_runtime_config, generate archive SQL + # in wrap step. can't do this right now because we actually need + # to inspect status of the schema at runtime and archive requires + # a lot of information about the schema to generate queries. + template_ctx = context.copy() + template_ctx.update(node_cfg) - if len(rows) > 1: - raise RuntimeError( - "Bad test {name}: Returned {num_rows} rows instead of 1" - .format(name=model.name, num_rows=len(rows))) + env = jinja2.Environment() + select = env.from_string( + dbt.templates.SCDArchiveTemplate, + template_ctx + ).render(node_cfg) - row = rows[0] - if len(row) > 1: - raise RuntimeError( - "Bad test {name}: Returned {num_cols} cols instead of 1" - .format(name=model.name, num_cols=len(row))) + insert_stmt = dbt.templates.ArchiveInsertTemplate().wrap( + schema=node_cfg.get('target_schema'), + table=node_cfg.get('target_table'), + query=select, + unique_key=node_cfg.get('unique_key')) - return row[0] + env = jinja2.Environment() + node['wrapped_sql'] = env.from_string( + insert_stmt, + template_ctx + ).render(node_cfg) + result = adapter.execute_model( + profile=profile, + model=node) -class ArchiveRunner(BaseRunner): - run_type = dbt.model.NodeType.Archive + return result - def pre_run_msg(self, model): - print_vars = { - "schema": self.adapter.get_default_schema(self.profile), - "model_name": model.name, - } - output = ("START archive table {schema}.{model_name} " - .format(**print_vars)) - return output +def run_hooks(profile, hooks, context, source): + if type(hooks) not in (list, tuple): + hooks = [hooks] - def post_run_msg(self, result): - model = result.model - print_vars = { - "schema": self.adapter.get_default_schema(self.profile), - "model_name": model.name, - "info": "ERROR archiving" if result.errored else "OK created" - } + ctx = { + "target": profile, + "state": "start", + "invocation_id": context['invocation_id'], + "run_started_at": context['run_started_at'] + } - output = "{info} table {schema}.{model_name} ".format(**print_vars) - return output + compiled_hooks = [ + dbt.compilation.compile_string(hook, ctx) for hook in hooks + ] - def pre_run_all_msg(self, models): - return "Archiving {} tables".format(len(models)) + adapter = get_adapter(profile) - def post_run_all_msg(self, results): - return ("{} Finished archiving {} tables" - .format(get_timestamp(), len(results))) + adapter.execute_all( + profile=profile, + queries=compiled_hooks, + model_name=source) - def status(self, result): - return result.status + adapter.commit(profile) - def execute(self, model): - profile = self.project.run_environment() - adapter = get_adapter(profile) - status = adapter.execute_model( - profile=profile, - model=model) +class RunModelResult(object): + def __init__(self, node, error=None, skip=False, status=None, + execution_time=0): + self.node = node + self.error = error + self.skip = skip + self.status = status + self.execution_time = execution_time - return status + @property + def errored(self): + return self.error is not None + + @property + def skipped(self): + return self.skip class RunManager(object): @@ -358,11 +403,6 @@ def __init__(self, project, target_path, args): adapter = get_adapter(profile) schema_name = adapter.get_default_schema(profile) - self.existing_models = adapter.query_for_existing( - profile, - schema_name - ) - def call_get_columns_in_table(schema_name, table_name): return adapter.get_columns_in_table( profile, schema_name, table_name) @@ -385,109 +425,115 @@ def call_table_exists(schema, table): "already_exists": call_table_exists, } + def inject_runtime_config(self, node): + sql = dbt.compilation.compile_string(node.get('wrapped_sql'), + self.context) + + node['wrapped_sql'] = sql + + return node + def deserialize_graph(self): - logger.info("Loading dependency graph file") + logger.info("Loading dependency graph file.") - linker = Linker() base_target_path = self.project['target-path'] graph_file = os.path.join( base_target_path, dbt.compilation.graph_file_name ) - linker.read_graph(graph_file) - return linker + return dbt.linker.from_file(graph_file) + + def execute_node(self, node, existing): + profile = self.project.run_environment() + + logger.debug("executing node %s", node.get('unique_id')) + + if node.get('skip') is True: + return RunModelResult(node, skip=True) + + node = self.inject_runtime_config(node) - def execute_model(self, runner, model): - logger.debug("executing model %s", model) + if node.get('resource_type') == NodeType.Model: + result = execute_model(profile, node, existing) + elif node.get('resource_type') == NodeType.Test: + result = execute_test(profile, node) + elif node.get('resource_type') == NodeType.Archive: + result = execute_archive(profile, node, self.context) - result = runner.execute(model) return result - def safe_execute_model(self, data): - runner, model = data['runner'], data['model'] + def safe_execute_node(self, data): + node, existing = data start_time = time.time() error = None + try: - status = self.execute_model(runner, model) + status = self.execute_node(node, existing) except (RuntimeError, dbt.exceptions.ProgrammingException, psycopg2.ProgrammingError, psycopg2.InternalError) as e: error = "Error executing {filepath}\n{error}".format( - filepath=model['build_path'], error=str(e).strip()) + filepath=node.get('build_path'), error=str(e).strip()) status = "ERROR" logger.debug(error) if type(e) == psycopg2.InternalError and \ ABORTED_TRANSACTION_STRING == e.diag.message_primary: return RunModelResult( - model, error=ABORTED_TRANSACTION_STRING, status="SKIP") + node, + error='{}\n'.format(ABORTED_TRANSACTION_STRING), + status="SKIP") except Exception as e: error = ("Unhandled error while executing {filepath}\n{error}" .format( - filepath=model['build_path'], error=str(e).strip())) + filepath=node.get('build_path'), + error=str(e).strip())) logger.debug(error) raise e execution_time = time.time() - start_time - return RunModelResult(model, + return RunModelResult(node, error=error, status=status, execution_time=execution_time) - def as_concurrent_dep_list(self, linker, models_to_run): - # linker.as_dependency_list operates on nodes, but this method operates - # on compiled models. Use a dict to translate between the two - node_model_map = {m.fqn: m for m in models_to_run} - dependency_list = linker.as_dependency_list(node_model_map.keys()) + def as_concurrent_dep_list(self, linker, nodes_to_run): + dependency_list = linker.as_dependency_list(nodes_to_run) - model_dependency_list = [] - for node_level in dependency_list: - model_level = [node_model_map[n] for n in node_level] - model_dependency_list.append(model_level) + concurrent_dependency_list = [] + for level in dependency_list: + node_level = [linker.get_node(node) for node in level] + concurrent_dependency_list.append(node_level) - return model_dependency_list + return concurrent_dependency_list - def on_model_failure(self, linker, models, selected_nodes): - def skip_dependent(model): - dependent_nodes = linker.get_dependent_nodes(model.fqn) + def on_model_failure(self, linker, selected_nodes): + def skip_dependent(node): + dependent_nodes = linker.get_dependent_nodes(node.get('unique_id')) for node in dependent_nodes: if node in selected_nodes: - try: - model_to_skip = find_model_by_fqn(models, node) - model_to_skip.do_skip() - except RuntimeError as e: - pass + node_data = linker.get_node(node) + node_data['skip'] = True + linker.update_node_data(node, node_data) + return skip_dependent - def print_fancy_output_line(self, message, status, index, total, - execution_time=None): - prefix = "{timestamp} {index} of {total} {message}".format( - timestamp=get_timestamp(), - index=index, - total=total, - message=message) - justified = prefix.ljust(80, ".") - - if execution_time is None: - status_time = "" - else: - status_time = " in {execution_time:0.2f}s".format( - execution_time=execution_time) + def execute_nodes(self, node_dependency_list, on_failure, + should_run_hooks=False): + profile = self.project.run_environment() + adapter = get_adapter(profile) + schema_name = adapter.get_default_schema(profile) - output = "{justified} [{status}{status_time}]".format( - justified=justified, status=status, status_time=status_time) - logger.info(output) + flat_nodes = list(itertools.chain.from_iterable( + node_dependency_list)) - def execute_models(self, runner, model_dependency_list, on_failure): - flat_models = list(itertools.chain.from_iterable( - model_dependency_list)) + num_nodes = len(flat_nodes) - num_models = len(flat_models) - if num_models == 0: + if num_nodes == 0: logger.info("WARNING: Nothing to do. Try checking your model " "configs and running `dbt compile`".format( self.target_path)) @@ -497,144 +543,142 @@ def execute_models(self, runner, model_dependency_list, on_failure): logger.info("Concurrency: {} threads (target='{}')".format( num_threads, self.project.get_target().get('name')) ) - logger.info("Running!") + + existing = adapter.query_for_existing(profile, schema_name) pool = ThreadPool(num_threads) - logger.info("") - logger.info(runner.pre_run_all_msg(flat_models)) - runner.pre_run_all(flat_models, self.context) + print_counts(flat_nodes) + + start_time = time.time() + + if should_run_hooks: + run_hooks(self.project.get_target(), + self.project.cfg.get('on-run-start', []), + self.context, + 'on-run-start hooks') + + node_id_to_index_map = {node.get('unique_id'): i + 1 for (i, node) + in enumerate(flat_nodes)} - fqn_to_id_map = {model.fqn: i + 1 for (i, model) - in enumerate(flat_models)} + def get_idx(node): + return node_id_to_index_map[node.get('unique_id')] - def get_idx(model): - return fqn_to_id_map[model.fqn] + node_results = [] + for node_list in node_dependency_list: + for i, node in enumerate([node for node in node_list + if node.get('skip')]): + print_skip_line(node, schema_name, node.get('name'), + get_idx(node), num_nodes) - model_results = [] - for model_list in model_dependency_list: - for i, model in enumerate([model for model in model_list - if model.should_skip()]): - msg = runner.skip_msg(model) - self.print_fancy_output_line( - msg, 'SKIP', get_idx(model), num_models) - model_result = RunModelResult(model, skip=True) - model_results.append(model_result) + node_result = RunModelResult(node, skip=True) + node_results.append(node_result) - models_to_execute = [model for model in model_list - if not model.should_skip()] + nodes_to_execute = [node for node in node_list + if not node.get('skip')] threads = self.threads - num_models_this_batch = len(models_to_execute) - model_index = 0 + num_nodes_this_batch = len(nodes_to_execute) + node_index = 0 def on_complete(run_model_results): for run_model_result in run_model_results: - model_results.append(run_model_result) - - msg = runner.post_run_msg(run_model_result) - status = runner.status(run_model_result) - index = get_idx(run_model_result.model) - self.print_fancy_output_line( - msg, - status, - index, - num_models, - run_model_result.execution_time - ) + node_results.append(run_model_result) + + index = get_idx(run_model_result.node) + + print_result_line(run_model_result, + schema_name, + index, + num_nodes) invocation_id = dbt.tracking.active_user.invocation_id dbt.tracking.track_model_run({ "invocation_id": invocation_id, "index": index, - "total": num_models, + "total": num_nodes, "execution_time": run_model_result.execution_time, "run_status": run_model_result.status, "run_skipped": run_model_result.skip, "run_error": run_model_result.error, - "model_materialization": run_model_result.model['materialized'], # noqa - "model_id": run_model_result.model.hashed_name(), - "hashed_contents": run_model_result.model.hashed_contents(), # noqa + "model_materialization": get_materialization(run_model_result.node), # noqa + "model_id": get_hash(run_model_result.node), + "hashed_contents": get_hashed_contents(run_model_result.node), # noqa }) if run_model_result.errored: - on_failure(run_model_result.model) + on_failure(run_model_result.node) logger.info(run_model_result.error) - while model_index < num_models_this_batch: - local_models = [] + while node_index < num_nodes_this_batch: + local_nodes = [] for i in range( - model_index, - min(model_index + threads, num_models_this_batch)): - model = models_to_execute[i] - local_models.append(model) - msg = runner.pre_run_msg(model) - self.print_fancy_output_line( - msg, 'RUN', get_idx(model), num_models - ) - - wrapped_models_to_execute = [ - {"runner": runner, "model": model} - for model in local_models - ] + node_index, + min(node_index + threads, num_nodes_this_batch)): + node = nodes_to_execute[i] + local_nodes.append(node) + + print_start_line(node, + schema_name, + get_idx(node), + num_nodes) + map_result = pool.map_async( - self.safe_execute_model, - wrapped_models_to_execute, + self.safe_execute_node, + [(node, existing,) for node in local_nodes], callback=on_complete ) map_result.wait() run_model_results = map_result.get() - model_index += threads + node_index += threads pool.close() pool.join() - logger.info("") - logger.info(runner.post_run_all_msg(model_results)) - runner.post_run_all(flat_models, model_results, self.context) + if should_run_hooks: + run_hooks(self.project.get_target(), + self.project.cfg.get('on-run-end', []), + self.context, + 'on-run-end hooks') - return model_results + execution_time = time.time() - start_time + + print_results_line(node_results, execution_time) + + return node_results + + def get_nodes_to_run(self, graph, include_spec, exclude_spec, + resource_types, tags): - def get_nodes_to_run(self, graph, include_spec, exclude_spec, model_type): if include_spec is None: include_spec = ['*'] if exclude_spec is None: exclude_spec = [] - model_nodes = [ + to_run = [ n for n in graph.nodes() - if graph.node[n]['dbt_run_type'] == model_type + if (graph.node.get(n).get('empty') is False and + is_enabled(graph.node.get(n))) ] - model_only_graph = graph.subgraph(model_nodes) + filtered_graph = graph.subgraph(to_run) selected_nodes = dbt.graph.selector.select_nodes(self.project, - model_only_graph, + filtered_graph, include_spec, exclude_spec) - return selected_nodes - - def get_compiled_models(self, linker, nodes, node_type): - compiled_models = [] - for fqn in nodes: - compiled_model = make_compiled_model(fqn, linker.get_node(fqn)) - if not compiled_model.is_type(node_type): - continue - - if not compiled_model.should_execute(self.args, - self.existing_models): - continue - - context = self.context.copy() - context.update(compiled_model.context()) - - profile = self.project.run_environment() - compiled_model.compile(context, profile, self.existing_models) - compiled_models.append(compiled_model) + post_filter = [ + n for n in selected_nodes + if ((graph.node.get(n).get('resource_type') in resource_types) and + get_materialization(graph.node.get(n)) != 'ephemeral' and + (len(tags) == 0 or + # does the node share any tags with the run? + bool(set(graph.node.get(n).get('tags')) & set(tags)))) + ] - return compiled_models + return set(post_filter) def try_create_schema(self): profile = self.project.run_environment() @@ -651,120 +695,47 @@ def try_create_schema(self): logger.info(str(e)) raise - def run_models_from_graph(self, include_spec, exclude_spec): - runner = ModelRunner(self.project) + def run_types_from_graph(self, include_spec, exclude_spec, + resource_types, tags, should_run_hooks=False): linker = self.deserialize_graph() selected_nodes = self.get_nodes_to_run( linker.graph, include_spec, exclude_spec, - dbt.model.NodeType.Model) + resource_types, + tags) - compiled_models = self.get_compiled_models( + dependency_list = self.as_concurrent_dep_list( linker, - selected_nodes, - runner.run_type) + selected_nodes) self.try_create_schema() - model_dependency_list = self.as_concurrent_dep_list( - linker, - compiled_models - ) + on_failure = self.on_model_failure(linker, selected_nodes) - on_failure = self.on_model_failure(linker, compiled_models, - selected_nodes) - results = self.execute_models( - runner, model_dependency_list, on_failure - ) - - return results - - def run_tests_from_graph(self, include_spec, exclude_spec, - test_schemas, test_data): - - runner = TestRunner(self.project) - linker = self.deserialize_graph() - - selected_model_nodes = self.get_nodes_to_run( - linker.graph, - include_spec, - exclude_spec, - dbt.model.NodeType.Model) - - # just throw everything in this set, then pick out tests later - nodes_and_neighbors = set() - for model_node in selected_model_nodes: - nodes_and_neighbors.add(model_node) - neighbors = linker.graph.neighbors(model_node) - for neighbor in neighbors: - nodes_and_neighbors.add(neighbor) - - compiled_models = self.get_compiled_models( - linker, - nodes_and_neighbors, - runner.run_type) - - selected_nodes = set(cm.fqn for cm in compiled_models) - - self.try_create_schema() - - all_tests = [] - if test_schemas: - all_tests.extend([cm for cm in compiled_models - if cm.is_test_type(runner.test_schema_type)]) - - if test_data: - all_tests.extend([cm for cm in compiled_models - if cm.is_test_type(runner.test_data_type)]) - - dep_list = [all_tests] - - on_failure = self.on_model_failure(linker, all_tests, selected_nodes) - results = self.execute_models(runner, dep_list, on_failure) - - return results - - def run_archives_from_graph(self): - runner = ArchiveRunner(self.project) - linker = self.deserialize_graph() - - selected_nodes = self.get_nodes_to_run( - linker.graph, - None, - None, - dbt.model.NodeType.Archive) - - compiled_models = self.get_compiled_models( - linker, - selected_nodes, - runner.run_type) - - self.try_create_schema() - - model_dependency_list = self.as_concurrent_dep_list( - linker, - compiled_models - ) - - on_failure = self.on_model_failure(linker, compiled_models, - selected_nodes) - results = self.execute_models( - runner, model_dependency_list, on_failure - ) + results = self.execute_nodes(dependency_list, on_failure, + should_run_hooks) return results # ------------------------------------ - def run_tests(self, include_spec, exclude_spec, - test_schemas=False, test_data=False): - return self.run_tests_from_graph(include_spec, exclude_spec, - test_schemas, test_data) - def run_models(self, include_spec, exclude_spec): - return self.run_models_from_graph(include_spec, exclude_spec) - - def run_archives(self): - return self.run_archives_from_graph() + return self.run_types_from_graph(include_spec, + exclude_spec, + resource_types=[NodeType.Model], + tags=[], + should_run_hooks=True) + + def run_tests(self, include_spec, exclude_spec, tags): + return self.run_types_from_graph(include_spec, + exclude_spec, + [NodeType.Test], + tags) + + def run_archives(self, include_spec, exclude_spec): + return self.run_types_from_graph(include_spec, + exclude_spec, + [NodeType.Archive], + []) diff --git a/dbt/schema_tester.py b/dbt/schema_tester.py deleted file mode 100644 index 42da30a5a6d..00000000000 --- a/dbt/schema_tester.py +++ /dev/null @@ -1,118 +0,0 @@ -import os - -from dbt.logger import GLOBAL_LOGGER as logger -import dbt.targets - -import psycopg2 -import logging -import time -import datetime - - -QUERY_VALIDATE_NOT_NULL = """ -with validation as ( - select {field} as f - from "{schema}"."{table}" -) -select count(*) from validation where f is null -""" - -QUERY_VALIDATE_UNIQUE = """ -with validation as ( - select {field} as f - from "{schema}"."{table}" - where {field} is not null -), -validation_errors as ( - select f from validation group by f having count(*) > 1 -) -select count(*) from validation_errors -""" - -QUERY_VALIDATE_ACCEPTED_VALUES = """ -with all_values as ( - select distinct {field} as f - from "{schema}"."{table}" -), -validation_errors as ( - select f from all_values where f not in ({values_csv}) -) -select count(*) from validation_errors -""" - -QUERY_VALIDATE_REFERENTIAL_INTEGRITY = """ -with parent as ( - select {parent_field} as id - from "{schema}"."{parent_table}" -), child as ( - select {child_field} as id - from "{schema}"."{child_table}" -) -select count(*) from child -where id not in (select id from parent) and id is not null -""" - -DDL_TEST_RESULT_CREATE = """ -create table if not exists {schema}.dbt_test_results ( - tested_at timestamp without time zone, - model_name text, - errored bool, - skipped bool, - failed bool, - count_failures integer, - execution_time double precision -); -""" - - -class SchemaTester(object): - def __init__(self, project): - self.project = project - - self.test_started_at = datetime.datetime.now() - - def get_target(self): - target_cfg = self.project.run_environment() - return dbt.targets.get_target(target_cfg) - - def execute_query(self, model, sql): - target = self.get_target() - - with target.get_handle() as handle: - with handle.cursor() as cursor: - try: - logger.debug("SQL: %s", sql) - pre = time.time() - cursor.execute(sql) - post = time.time() - logger.debug( - "SQL status: %s in %d seconds", - cursor.statusmessage, post-pre) - except psycopg2.ProgrammingError as e: - logger.debug('programming error: %s', sql) - return e.diag.message_primary - except Exception as e: - logger.debug( - 'encountered exception while running: %s', sql) - e.model = model - raise e - - result = cursor.fetchone() - if len(result) != 1: - logger.debug("SQL: %s", sql) - logger.debug("RESULT: %s", result) - raise RuntimeError( - "Unexpected validation result. Expected 1 record, " - "got {}".format(len(result))) - else: - return result[0] - - def validate_schema(self, schema_test): - sql = schema_test.render() - num_rows = self.execute_query(model, sql) - if num_rows == 0: - logger.info(" OK") - yield True - else: - logger.info(" FAILED ({})".format(num_rows)) - yield False diff --git a/dbt/source.py b/dbt/source.py index afb7792fd5f..ac8747e9f78 100644 --- a/dbt/source.py +++ b/dbt/source.py @@ -1,7 +1,6 @@ import os.path import fnmatch -from dbt.model import Model, Analysis, SchemaFile, Csv, Macro, \ - ArchiveModel, DataTest +from dbt.model import Model, Csv, Macro import dbt.clients.system @@ -41,36 +40,6 @@ def get_models(self, model_dirs): Model, file_matches) - def get_analyses(self, analysis_dirs): - file_matches = dbt.clients.system.find_matching( - self.own_project_root, - analysis_dirs, - "[!.#~]*.sql") - - return self.build_models_from_file_matches( - Analysis, - file_matches) - - def get_schemas(self, schema_dirs): - file_matches = dbt.clients.system.find_matching( - self.own_project_root, - schema_dirs, - "[!.#~]*.yml") - - return self.build_models_from_file_matches( - SchemaFile, - file_matches) - - def get_tests(self, test_dirs): - file_matches = dbt.clients.system.find_matching( - self.own_project_root, - test_dirs, - "[!.#~]*.sql") - - return self.build_models_from_file_matches( - DataTest, - file_matches) - def get_csvs(self, csv_dirs): file_matches = dbt.clients.system.find_matching( self.own_project_root, @@ -90,26 +59,3 @@ def get_macros(self, macro_dirs): return self.build_models_from_file_matches( Macro, file_matches) - - def get_archives(self): - "Get Archive models defined in project config" - - if 'archive' not in self.project: - return [] - - raw_source_schemas = self.project['archive'] - - archives = [] - for schema in raw_source_schemas: - schema = schema.copy() - if 'tables' not in schema: - continue - - tables = schema.pop('tables') - for table in tables: - fields = table.copy() - fields.update(schema) - archives.append(ArchiveModel( - self.project, fields - )) - return archives diff --git a/dbt/task/archive.py b/dbt/task/archive.py index df8b5e5d048..8077adb2793 100644 --- a/dbt/task/archive.py +++ b/dbt/task/archive.py @@ -1,5 +1,6 @@ +import dbt.compilation + from dbt.runner import RunManager -from dbt.compilation import Compiler from dbt.logger import GLOBAL_LOGGER as logger @@ -8,16 +9,9 @@ def __init__(self, args, project): self.args = args self.project = project - def compile(self): - compiler = Compiler(self.project, self.args) - compiler.initialize() - compiled = compiler.compile() - - count_compiled_archives = compiled['archives'] - logger.info("Compiled {} archives".format(count_compiled_archives)) - def run(self): - self.compile() + dbt.compilation.compile_and_print_status( + self.project, self.args) runner = RunManager( self.project, @@ -25,4 +19,4 @@ def run(self): self.args ) - runner.run_archives() + runner.run_archives(['*'], []) diff --git a/dbt/task/compile.py b/dbt/task/compile.py index 04eb627d6b5..bfa207a7384 100644 --- a/dbt/task/compile.py +++ b/dbt/task/compile.py @@ -1,4 +1,5 @@ -from dbt.compilation import Compiler, CompilableEntities +import dbt.compilation + from dbt.logger import GLOBAL_LOGGER as logger @@ -8,11 +9,5 @@ def __init__(self, args, project): self.project = project def run(self): - compiler = Compiler(self.project, self.args) - compiler.initialize() - results = compiler.compile() - - stat_line = ", ".join( - ["{} {}".format(results[k], k) for k in CompilableEntities] - ) - logger.info("Compiled {}".format(stat_line)) + dbt.compilation.compile_and_print_status( + self.project, self.args) diff --git a/dbt/task/run.py b/dbt/task/run.py index a06ec84c532..c6f5eee9f9c 100644 --- a/dbt/task/run.py +++ b/dbt/task/run.py @@ -1,6 +1,7 @@ from __future__ import print_function -from dbt.compilation import Compiler, CompilableEntities +import dbt.compilation + from dbt.logger import GLOBAL_LOGGER as logger from dbt.runner import RunManager @@ -12,18 +13,9 @@ def __init__(self, args, project): self.args = args self.project = project - def compile(self): - compiler = Compiler(self.project, self.args) - compiler.initialize() - results = compiler.compile() - - stat_line = ", ".join([ - "{} {}".format(results[k], k) for k in CompilableEntities - ]) - logger.info("Compiled {}".format(stat_line)) - def run(self): - self.compile() + dbt.compilation.compile_and_print_status( + self.project, self.args) runner = RunManager( self.project, self.project['target-path'], self.args diff --git a/dbt/task/test.py b/dbt/task/test.py index 8875231f59e..e96737b0a33 100644 --- a/dbt/task/test.py +++ b/dbt/task/test.py @@ -1,5 +1,5 @@ +import dbt.compilation -from dbt.compilation import Compiler, CompilableEntities from dbt.runner import RunManager from dbt.logger import GLOBAL_LOGGER as logger @@ -19,38 +19,24 @@ def __init__(self, args, project): self.args = args self.project = project - def compile(self): - compiler = Compiler(self.project, self.args) - compiler.initialize() - results = compiler.compile() - - stat_line = ", ".join( - ["{} {}".format(results[k], k) for k in CompilableEntities] - ) - logger.info("Compiled {}".format(stat_line)) - def run(self): - self.compile() + dbt.compilation.compile_and_print_status( + self.project, self.args) runner = RunManager( - self.project, self.project['target-path'], self.args - ) + self.project, self.project['target-path'], self.args) include = self.args.models exclude = self.args.exclude if (self.args.data and self.args.schema) or \ (not self.args.data and not self.args.schema): - res = runner.run_tests(include, exclude, - test_schemas=True, test_data=True) + res = runner.run_tests(include, exclude, []) elif self.args.data: - res = runner.run_tests(include, exclude, - test_schemas=False, test_data=True) + res = runner.run_tests(include, exclude, ['data']) elif self.args.schema: - res = runner.run_tests(include, exclude, - test_schemas=True, test_data=False) + res = runner.run_tests(include, exclude, ['schema']) else: raise RuntimeError("unexpected") - logger.info("Done!") return res diff --git a/dbt/templates.py b/dbt/templates.py index 6c5ed2e1334..a1aef839ec7 100644 --- a/dbt/templates.py +++ b/dbt/templates.py @@ -117,11 +117,9 @@ def wrap(self, opts): with "current_data" as ( select - {% raw %} - {% for col in get_columns_in_table(source_schema, source_table) %} - "{{ col.name }}" {% if not loop.last %},{% endif %} - {% endfor %}, - {% endraw %} + {% for col in get_columns_in_table(source_schema, source_table) %} + "{{ col.name }}" {% if not loop.last %},{% endif %} + {% endfor %}, "{{ updated_at }}" as "dbt_updated_at", "{{ unique_key }}" as "dbt_pk", "{{ updated_at }}" as "valid_from", @@ -133,11 +131,9 @@ def wrap(self, opts): "archived_data" as ( select - {% raw %} - {% for col in get_columns_in_table(source_schema, source_table) %} - "{{ col.name }}" {% if not loop.last %},{% endif %} - {% endfor %}, - {% endraw %} + {% for col in get_columns_in_table(source_schema, source_table) %} + "{{ col.name }}" {% if not loop.last %},{% endif %} + {% endfor %}, "{{ updated_at }}" as "dbt_updated_at", "{{ unique_key }}" as "dbt_pk", "valid_from", diff --git a/dbt/utils.py b/dbt/utils.py index 3be15b85e0a..801f613eecb 100644 --- a/dbt/utils.py +++ b/dbt/utils.py @@ -1,7 +1,10 @@ import os +import jinja2 import json import dbt.project + +from dbt.compat import basestring from dbt.logger import GLOBAL_LOGGER as logger DBTConfigKeys = [ @@ -34,6 +37,10 @@ def __repr__(self): def compiler_error(model, msg): if model is None: name = '' + elif isinstance(model, str): + name = model + elif isinstance(model, dict): + name = model.get('name') else: name = model.nice_name @@ -59,7 +66,14 @@ class Var(object): def __init__(self, model, context): self.model = model self.context = context - self.local_vars = model.config.get('vars', {}) + + if isinstance(model, dict) and model.get('unique_id'): + self.local_vars = model.get('config', {}).get('vars') + self.model_name = model.get('name') + else: + # still used for wrapping + self.model_name = model.nice_name + self.local_vars = model.config.get('vars', {}) def pretty_dict(self, data): return json.dumps(data, sort_keys=True, indent=4) @@ -70,7 +84,7 @@ def __call__(self, var_name, default=None): compiler_error( self.model, self.UndefinedVarError.format( - var_name, self.model.nice_name, pretty_vars + var_name, self.model_name, pretty_vars ) ) elif var_name in self.local_vars: @@ -82,37 +96,36 @@ def __call__(self, var_name, default=None): var_name, self.model.nice_name, pretty_vars ) ) - compiled = self.model.compile_string(self.context, raw) + + # if bool/int/float/etc are passed in, don't compile anything + if not isinstance(raw, basestring): + return raw + + env = jinja2.Environment() + compiled = env.from_string(raw, self.context).render(self.context) + return compiled else: return default -def find_model_by_name(models, name, package_namespace=None): - found = [] - for model in models: - if model.name == name: - if package_namespace is None: - found.append(model) - elif (package_namespace is not None and - package_namespace == model.project['name']): - found.append(model) - - nice_package_name = 'ANY' if package_namespace is None \ - else package_namespace - if len(found) == 0: - raise RuntimeError( - "Can't find a model named '{}' in package '{}' -- does it exist?" - .format(name, nice_package_name) - ) - elif len(found) == 1: - return found[0] - else: - raise RuntimeError( - "Model specification is ambiguous: model='{}' package='{}' -- " - "{} models match criteria: {}" - .format(name, nice_package_name, len(found), found) - ) +def model_cte_name(model): + return '__dbt__CTE__{}'.format(model.get('name')) + + +def find_model_by_name(all_models, target_model_name, + target_model_package): + + for name, model in all_models.items(): + resource_type, package_name, model_name = name.split('.') + + if (resource_type == 'model' and + ((target_model_name == model_name) and + (target_model_package is None or + target_model_package == package_name))): + return model + + return None def find_model_by_fqn(models, fqn): diff --git a/test/integration/006_simple_dependency_test/test_simple_dependency_with_configs.py b/test/integration/006_simple_dependency_test/test_simple_dependency_with_configs.py index 971c06bbefd..ac0744ed74f 100644 --- a/test/integration/006_simple_dependency_test/test_simple_dependency_with_configs.py +++ b/test/integration/006_simple_dependency_test/test_simple_dependency_with_configs.py @@ -105,6 +105,8 @@ def project_config(self): @attr(type='postgres') def test_simple_dependency(self): + self.use_default_project() + self.run_dbt(["deps"]) self.run_dbt(["run"]) @@ -114,7 +116,7 @@ def test_simple_dependency(self): self.assertTablesEqual("seed","incremental") -class TestSimpleDependencyWithModelSpecificOverriddenConfigs(BaseTestSimpleDependencyWithConfigs): +class TestSimpleDependencyWithModelSpecificOverriddenConfigsAndMaterializations(BaseTestSimpleDependencyWithConfigs): @property def project_config(self): @@ -127,7 +129,7 @@ def project_config(self): "vars": { "config_1": "ghi", "config_2": "jkl", - #"bool_config": True + "bool_config": True } }, diff --git a/test/integration/007_graph_selection_tests/test_schema_test_graph_selection.py b/test/integration/007_graph_selection_tests/test_schema_test_graph_selection.py index 3acd4a62a01..c36a146a431 100644 --- a/test/integration/007_graph_selection_tests/test_schema_test_graph_selection.py +++ b/test/integration/007_graph_selection_tests/test_schema_test_graph_selection.py @@ -43,7 +43,9 @@ def run_schema_and_assert(self, include, exclude, expected_tests): test_task = TestTask(args, self.project) test_results = test_task.run() - ran_tests = sorted([test.model.name for test in test_results]) + print(test_results) + + ran_tests = sorted([test.node.get('name') for test in test_results]) expected_sorted = sorted(expected_tests) self.assertEqual(ran_tests, expected_sorted) diff --git a/test/integration/008_schema_tests_test/test_schema_tests.py b/test/integration/008_schema_tests_test/test_schema_tests.py index 323423a857d..ba679796e8c 100644 --- a/test/integration/008_schema_tests_test/test_schema_tests.py +++ b/test/integration/008_schema_tests_test/test_schema_tests.py @@ -25,6 +25,7 @@ def run_schema_validations(self): args = FakeArgs() test_task = TestTask(args, project) + print(project) return test_task.run() @attr(type='postgres') @@ -34,7 +35,7 @@ def test_schema_tests(self): for result in test_results: # assert that all deliberately failing tests actually fail - if 'failure' in result.model.name: + if 'failure' in result.node.get('name'): self.assertFalse(result.errored) self.assertFalse(result.skipped) self.assertTrue(result.status > 0) @@ -74,4 +75,4 @@ def test_malformed_schema_test_wont_brick_run(self): self.run_dbt() ran_tests = self.run_schema_validations() - self.assertEqual(ran_tests, []) + self.assertEqual(len(ran_tests), 2) diff --git a/test/integration/010_permission_tests/seed.sql b/test/integration/010_permission_tests/seed.sql index 50ae457a701..6a3e0e6cf46 100644 --- a/test/integration/010_permission_tests/seed.sql +++ b/test/integration/010_permission_tests/seed.sql @@ -1,6 +1,6 @@ -create schema private; +create schema private_010; -create table private.seed ( +create table private_010.seed ( id BIGSERIAL PRIMARY KEY, first_name VARCHAR(50), last_name VARCHAR(50), @@ -10,13 +10,13 @@ create table private.seed ( ); -insert into private.seed (first_name, last_name, email, gender, ip_address) values ('Kathryn', 'Walker', 'kwalker1@ezinearticles.com', 'Female', '194.121.179.35'); -insert into private.seed (first_name, last_name, email, gender, ip_address) values ('Gerald', 'Ryan', 'gryan2@com.com', 'Male', '11.3.212.243'); -insert into private.seed (first_name, last_name, email, gender, ip_address) values ('Bonnie', 'Spencer', 'bspencer3@ameblo.jp', 'Female', '216.32.196.175'); -insert into private.seed (first_name, last_name, email, gender, ip_address) values ('Harold', 'Taylor', 'htaylor4@people.com.cn', 'Male', '253.10.246.136'); -insert into private.seed (first_name, last_name, email, gender, ip_address) values ('Jacqueline', 'Griffin', 'jgriffin5@t.co', 'Female', '16.13.192.220'); -insert into private.seed (first_name, last_name, email, gender, ip_address) values ('Wanda', 'Arnold', 'warnold6@google.nl', 'Female', '232.116.150.64'); -insert into private.seed (first_name, last_name, email, gender, ip_address) values ('Craig', 'Ortiz', 'cortiz7@sciencedaily.com', 'Male', '199.126.106.13'); -insert into private.seed (first_name, last_name, email, gender, ip_address) values ('Gary', 'Day', 'gday8@nih.gov', 'Male', '35.81.68.186'); -insert into private.seed (first_name, last_name, email, gender, ip_address) values ('Rose', 'Wright', 'rwright9@yahoo.co.jp', 'Female', '236.82.178.100'); -insert into private.seed (first_name, last_name, email, gender, ip_address) values ('Raymond', 'Kelley', 'rkelleya@fc2.com', 'Male', '213.65.166.67'); +insert into private_010.seed (first_name, last_name, email, gender, ip_address) values ('Kathryn', 'Walker', 'kwalker1@ezinearticles.com', 'Female', '194.121.179.35'); +insert into private_010.seed (first_name, last_name, email, gender, ip_address) values ('Gerald', 'Ryan', 'gryan2@com.com', 'Male', '11.3.212.243'); +insert into private_010.seed (first_name, last_name, email, gender, ip_address) values ('Bonnie', 'Spencer', 'bspencer3@ameblo.jp', 'Female', '216.32.196.175'); +insert into private_010.seed (first_name, last_name, email, gender, ip_address) values ('Harold', 'Taylor', 'htaylor4@people.com.cn', 'Male', '253.10.246.136'); +insert into private_010.seed (first_name, last_name, email, gender, ip_address) values ('Jacqueline', 'Griffin', 'jgriffin5@t.co', 'Female', '16.13.192.220'); +insert into private_010.seed (first_name, last_name, email, gender, ip_address) values ('Wanda', 'Arnold', 'warnold6@google.nl', 'Female', '232.116.150.64'); +insert into private_010.seed (first_name, last_name, email, gender, ip_address) values ('Craig', 'Ortiz', 'cortiz7@sciencedaily.com', 'Male', '199.126.106.13'); +insert into private_010.seed (first_name, last_name, email, gender, ip_address) values ('Gary', 'Day', 'gday8@nih.gov', 'Male', '35.81.68.186'); +insert into private_010.seed (first_name, last_name, email, gender, ip_address) values ('Rose', 'Wright', 'rwright9@yahoo.co.jp', 'Female', '236.82.178.100'); +insert into private_010.seed (first_name, last_name, email, gender, ip_address) values ('Raymond', 'Kelley', 'rkelleya@fc2.com', 'Male', '213.65.166.67'); diff --git a/test/integration/010_permission_tests/tearDown.sql b/test/integration/010_permission_tests/tearDown.sql index 4da20cda0bf..f7125ff7824 100644 --- a/test/integration/010_permission_tests/tearDown.sql +++ b/test/integration/010_permission_tests/tearDown.sql @@ -1,2 +1,2 @@ -drop schema if exists private cascade; +drop schema if exists private_010 cascade; diff --git a/test/integration/010_permission_tests/test_permissions.py b/test/integration/010_permission_tests/test_permissions.py index a322457e596..1a01464e205 100644 --- a/test/integration/010_permission_tests/test_permissions.py +++ b/test/integration/010_permission_tests/test_permissions.py @@ -6,13 +6,14 @@ class TestPermissions(DBTIntegrationTest): def setUp(self): DBTIntegrationTest.setUp(self) + self.run_sql_file("test/integration/010_permission_tests/tearDown.sql") self.run_sql_file("test/integration/010_permission_tests/seed.sql") def tearDown(self): - DBTIntegrationTest.tearDown(self) - self.run_sql_file("test/integration/010_permission_tests/tearDown.sql") + DBTIntegrationTest.tearDown(self) + @property def schema(self): return "permission_tests_010" diff --git a/test/integration/011_invalid_model_tests/test_invalid_models.py b/test/integration/011_invalid_model_tests/test_invalid_models.py index 0f8338a4ad9..4f526a7e289 100644 --- a/test/integration/011_invalid_model_tests/test_invalid_models.py +++ b/test/integration/011_invalid_model_tests/test_invalid_models.py @@ -1,6 +1,8 @@ from nose.plugins.attrib import attr from test.integration.base import DBTIntegrationTest +from dbt.exceptions import ValidationException + class TestInvalidViewModels(DBTIntegrationTest): def setUp(self): @@ -18,7 +20,12 @@ def models(self): @attr(type='postgres') def test_view_with_incremental_attributes(self): - self.run_dbt() + try: + self.run_dbt() + # should throw + self.assertTrue(False) + except RuntimeError as e: + pass class TestInvalidDisabledModels(DBTIntegrationTest): @@ -43,7 +50,7 @@ def test_view_with_incremental_attributes(self): # should throw self.assertTrue(False) except RuntimeError as e: - self.assertTrue("config must be either True or False" in str(e)) + self.assertTrue("enabled" in str(e)) class TestInvalidModelReference(DBTIntegrationTest): diff --git a/test/integration/013_context_var_tests/test_context_vars.py b/test/integration/013_context_var_tests/test_context_vars.py index 514a0b54540..ff102dc9be7 100644 --- a/test/integration/013_context_var_tests/test_context_vars.py +++ b/test/integration/013_context_var_tests/test_context_vars.py @@ -1,6 +1,8 @@ from nose.plugins.attrib import attr from test.integration.base import DBTIntegrationTest +import dbt.flags + class TestContextVars(DBTIntegrationTest): def setUp(self): diff --git a/test/integration/014_pre_post_run_hook_tests/models/hooks.sql b/test/integration/014_hook_tests/models/hooks.sql similarity index 100% rename from test/integration/014_pre_post_run_hook_tests/models/hooks.sql rename to test/integration/014_hook_tests/models/hooks.sql diff --git a/test/integration/014_hook_tests/seed.sql b/test/integration/014_hook_tests/seed.sql new file mode 100644 index 00000000000..b889daa446f --- /dev/null +++ b/test/integration/014_hook_tests/seed.sql @@ -0,0 +1,39 @@ + +drop table run_hooks_014.on_run_hook; + +create table run_hooks_014.on_run_hook ( + "state" TEXT, -- start|end + + "target.dbname" TEXT, + "target.host" TEXT, + "target.name" TEXT, + "target.schema" TEXT, + "target.type" TEXT, + "target.user" TEXT, + "target.pass" TEXT, + "target.port" INTEGER, + "target.threads" INTEGER, + + "run_started_at" TEXT, + "invocation_id" TEXT +); + + +drop table model_hooks_014.on_model_hook; + +create table model_hooks_014.on_model_hook ( + "state" TEXT, -- start|end + + "target.dbname" TEXT, + "target.host" TEXT, + "target.name" TEXT, + "target.schema" TEXT, + "target.type" TEXT, + "target.user" TEXT, + "target.pass" TEXT, + "target.port" INTEGER, + "target.threads" INTEGER, + + "run_started_at" TEXT, + "invocation_id" TEXT +); diff --git a/test/integration/014_hook_tests/seed_model.sql b/test/integration/014_hook_tests/seed_model.sql new file mode 100644 index 00000000000..0ac8eb49b62 --- /dev/null +++ b/test/integration/014_hook_tests/seed_model.sql @@ -0,0 +1,19 @@ + +drop table if exists model_hooks_014.on_model_hook; + +create table model_hooks_014.on_model_hook ( + "state" TEXT, -- start|end + + "target.dbname" TEXT, + "target.host" TEXT, + "target.name" TEXT, + "target.schema" TEXT, + "target.type" TEXT, + "target.user" TEXT, + "target.pass" TEXT, + "target.port" INTEGER, + "target.threads" INTEGER, + + "run_started_at" TEXT, + "invocation_id" TEXT +); diff --git a/test/integration/014_pre_post_run_hook_tests/seed.sql b/test/integration/014_hook_tests/seed_run.sql similarity index 80% rename from test/integration/014_pre_post_run_hook_tests/seed.sql rename to test/integration/014_hook_tests/seed_run.sql index 49c34acd0fc..918c993699d 100644 --- a/test/integration/014_pre_post_run_hook_tests/seed.sql +++ b/test/integration/014_hook_tests/seed_run.sql @@ -1,5 +1,7 @@ -create table pre_post_run_hooks_014.on_run_hook ( +drop table if exists run_hooks_014.on_run_hook; + +create table run_hooks_014.on_run_hook ( "state" TEXT, -- start|end "target.dbname" TEXT, diff --git a/test/integration/014_hook_tests/test_model_hooks.py b/test/integration/014_hook_tests/test_model_hooks.py new file mode 100644 index 00000000000..2f510e4771c --- /dev/null +++ b/test/integration/014_hook_tests/test_model_hooks.py @@ -0,0 +1,140 @@ +from nose.plugins.attrib import attr +from test.integration.base import DBTIntegrationTest + + +MODEL_PRE_HOOK = """ + insert into model_hooks_014.on_model_hook ( + "state", + "target.dbname", + "target.host", + "target.name", + "target.schema", + "target.type", + "target.user", + "target.pass", + "target.port", + "target.threads", + "run_started_at", + "invocation_id" + ) VALUES ( + 'start', + '{{ target.dbname }}', + '{{ target.host }}', + '{{ target.name }}', + '{{ target.schema }}', + '{{ target.type }}', + '{{ target.user }}', + '{{ target.pass }}', + {{ target.port }}, + {{ target.threads }}, + '{{ run_started_at }}', + '{{ invocation_id }}' + ) +""" + +MODEL_POST_HOOK = """ + insert into model_hooks_014.on_model_hook ( + "state", + "target.dbname", + "target.host", + "target.name", + "target.schema", + "target.type", + "target.user", + "target.pass", + "target.port", + "target.threads", + "run_started_at", + "invocation_id" + ) VALUES ( + 'end', + '{{ target.dbname }}', + '{{ target.host }}', + '{{ target.name }}', + '{{ target.schema }}', + '{{ target.type }}', + '{{ target.user }}', + '{{ target.pass }}', + {{ target.port }}, + {{ target.threads }}, + '{{ run_started_at }}', + '{{ invocation_id }}' + ) +""" + + +class TestPrePostModelHooks(DBTIntegrationTest): + + def setUp(self): + DBTIntegrationTest.setUp(self) + + self.run_sql_file("test/integration/014_hook_tests/seed_model.sql") + + self.fields = [ + 'state', + 'target.dbname', + 'target.host', + 'target.name', + 'target.port', + 'target.schema', + 'target.threads', + 'target.type', + 'target.user', + 'target.pass', + 'run_started_at', + 'invocation_id' + ] + + @property + def schema(self): + return "model_hooks_014" + + @property + def project_config(self): + return { + 'models': { + 'test': { + 'pre-hook': MODEL_PRE_HOOK, + 'post-hook': MODEL_POST_HOOK, + } + } + } + + @property + def models(self): + return "test/integration/014_hook_tests/models" + + def get_ctx_vars(self, state): + field_list = ", ".join(['"{}"'.format(f) for f in self.fields]) + query = "select {field_list} from {schema}.on_model_hook where state = '{state}'".format(field_list=field_list, schema=self.schema, state=state) + + vals = self.run_sql(query, fetch='all') + self.assertFalse(len(vals) == 0, 'nothing inserted into hooks table') + self.assertFalse(len(vals) > 1, 'too many rows in hooks table') + ctx = dict([(k,v) for (k,v) in zip(self.fields, vals[0])]) + + return ctx + + def check_hooks(self, state): + ctx = self.get_ctx_vars(state) + + self.assertEqual(ctx['state'], state) + self.assertEqual(ctx['target.dbname'], 'dbt') + self.assertEqual(ctx['target.host'], 'database') + self.assertEqual(ctx['target.name'], 'default2') + self.assertEqual(ctx['target.port'], 5432) + self.assertEqual(ctx['target.schema'], self.schema) + self.assertEqual(ctx['target.threads'], 1) + self.assertEqual(ctx['target.type'], 'postgres') + self.assertEqual(ctx['target.user'], 'root') + self.assertEqual(ctx['target.pass'], '') + + self.assertTrue(ctx['run_started_at'] is not None and len(ctx['run_started_at']) > 0, 'run_started_at was not set') + self.assertTrue(ctx['invocation_id'] is not None and len(ctx['invocation_id']) > 0, 'invocation_id was not set') + + @attr(type='postgres') + def test_pre_and_post_model_hooks(self): + self.run_dbt(['run']) + + self.check_hooks('start') + self.check_hooks('end') diff --git a/test/integration/014_pre_post_run_hook_tests/test_pre_post_run_hooks.py b/test/integration/014_hook_tests/test_run_hooks.py similarity index 90% rename from test/integration/014_pre_post_run_hook_tests/test_pre_post_run_hooks.py rename to test/integration/014_hook_tests/test_run_hooks.py index 226dc9ef3df..451ab129df0 100644 --- a/test/integration/014_pre_post_run_hook_tests/test_pre_post_run_hooks.py +++ b/test/integration/014_hook_tests/test_run_hooks.py @@ -3,7 +3,7 @@ RUN_START_HOOK = """ - insert into pre_post_run_hooks_014.on_run_hook ( + insert into run_hooks_014.on_run_hook ( "state", "target.dbname", "target.host", @@ -33,7 +33,7 @@ """ RUN_END_HOOK = """ - insert into pre_post_run_hooks_014.on_run_hook ( + insert into run_hooks_014.on_run_hook ( "state", "target.dbname", "target.host", @@ -67,7 +67,7 @@ class TestPrePostRunHooks(DBTIntegrationTest): def setUp(self): DBTIntegrationTest.setUp(self) - self.run_sql_file("test/integration/014_pre_post_run_hook_tests/seed.sql") + self.run_sql_file("test/integration/014_hook_tests/seed_run.sql") self.fields = [ 'state', @@ -86,7 +86,7 @@ def setUp(self): @property def schema(self): - return "pre_post_run_hooks_014" + return "run_hooks_014" @property def project_config(self): @@ -97,7 +97,7 @@ def project_config(self): @property def models(self): - return "test/integration/014_pre_post_run_hook_tests/models" + return "test/integration/014_hook_tests/models" def get_ctx_vars(self, state): field_list = ", ".join(['"{}"'.format(f) for f in self.fields]) @@ -105,6 +105,7 @@ def get_ctx_vars(self, state): vals = self.run_sql(query, fetch='all') self.assertFalse(len(vals) == 0, 'nothing inserted into on_run_hook table') + self.assertFalse(len(vals) > 1, 'too many rows in hooks table') ctx = dict([(k,v) for (k,v) in zip(self.fields, vals[0])]) return ctx @@ -117,7 +118,7 @@ def check_hooks(self, state): self.assertEqual(ctx['target.host'], 'database') self.assertEqual(ctx['target.name'], 'default2') self.assertEqual(ctx['target.port'], 5432) - self.assertEqual(ctx['target.schema'], 'pre_post_run_hooks_014') + self.assertEqual(ctx['target.schema'], self.schema) self.assertEqual(ctx['target.threads'], 1) self.assertEqual(ctx['target.type'], 'postgres') self.assertEqual(ctx['target.user'], 'root') diff --git a/test/integration/015_cli_invocation_tests/test_cli_invocation.py b/test/integration/015_cli_invocation_tests/test_cli_invocation.py index e7782cbb795..d7d4ecbbc69 100644 --- a/test/integration/015_cli_invocation_tests/test_cli_invocation.py +++ b/test/integration/015_cli_invocation_tests/test_cli_invocation.py @@ -97,4 +97,5 @@ def test_toplevel_dbt_run_with_profile_dir_arg(self): # make sure the test runs against `custom_schema` for test_result in res: - self.assertTrue(self.custom_schema, test_result.model.compiled_contents) + self.assertTrue(self.custom_schema, + test_result.node.get('wrapped_sql')) diff --git a/test/integration/base.py b/test/integration/base.py index bce6ec87b22..5d697984a52 100644 --- a/test/integration/base.py +++ b/test/integration/base.py @@ -194,6 +194,8 @@ def tearDown(self): except: os.rename("dbt_modules", "dbt_modules-{}".format(time.time())) + self.handle.close() + @property def project_config(self): return {} diff --git a/test/unit/test_compiler.py b/test/unit/test_compiler.py new file mode 100644 index 00000000000..5d1e3053645 --- /dev/null +++ b/test/unit/test_compiler.py @@ -0,0 +1,342 @@ +from mock import MagicMock +import unittest + +import os + +import dbt.flags +import dbt.compilation + + +class CompilerTest(unittest.TestCase): + + def assertEqualIgnoreWhitespace(self, a, b): + self.assertEqual( + "".join(a.split()), + "".join(b.split())) + + def setUp(self): + dbt.flags.STRICT_MODE = True + + self.maxDiff = None + + self.root_project_config = { + 'name': 'root_project', + 'version': '0.1', + 'profile': 'test', + 'project-root': os.path.abspath('.'), + } + + self.snowplow_project_config = { + 'name': 'snowplow', + 'version': '0.1', + 'project-root': os.path.abspath('./dbt_modules/snowplow'), + } + + self.model_config = { + 'enabled': True, + 'materialized': 'view', + 'post-hook': [], + 'pre-hook': [], + 'vars': {}, + } + + def test__prepend_ctes__already_has_cte(self): + ephemeral_config = self.model_config.copy() + ephemeral_config['materialized'] = 'ephemeral' + + compiled_models = { + 'model.root.view': { + 'name': 'view', + 'resource_type': 'model', + 'unique_id': 'model.root.view', + 'fqn': ['root_project', 'view'], + 'empty': False, + 'package_name': 'root', + 'root_path': '/usr/src/app', + 'depends_on': [ + 'model.root.ephemeral' + ], + 'config': self.model_config, + 'tags': [], + 'path': 'view.sql', + 'raw_sql': 'select * from {{ref("ephemeral")}}', + 'compiled': True, + 'extra_ctes_injected': False, + 'extra_cte_ids': [ + 'model.root.ephemeral' + ], + 'extra_cte_sql': [], + 'injected_sql': '', + 'compiled_sql': ('with cte as (select * from something_else) ' + 'select * from __dbt__CTE__ephemeral') + }, + 'model.root.ephemeral': { + 'name': 'ephemeral', + 'resource_type': 'model', + 'unique_id': 'model.root.ephemeral', + 'fqn': ['root_project', 'ephemeral'], + 'empty': False, + 'package_name': 'root', + 'root_path': '/usr/src/app', + 'depends_on': [], + 'config': ephemeral_config, + 'tags': [], + 'path': 'ephemeral.sql', + 'raw_sql': 'select * from source_table', + 'compiled': True, + 'compiled_sql': 'select * from source_table', + 'extra_ctes_injected': False, + 'extra_cte_ids': [], + 'extra_cte_sql': [], + 'injected_sql': '' + } + } + + result, all_models = dbt.compilation.prepend_ctes( + compiled_models['model.root.view'], + compiled_models) + + self.assertEqual(result, all_models.get('model.root.view')) + self.assertEqual(result.get('extra_ctes_injected'), True) + self.assertEqualIgnoreWhitespace( + result.get('injected_sql'), + ('with __dbt__CTE__ephemeral as (' + 'select * from source_table' + '), cte as (select * from something_else) ' + 'select * from __dbt__CTE__ephemeral')) + + self.assertEqual( + all_models.get('model.root.ephemeral').get('extra_ctes_injected'), + True) + + def test__prepend_ctes__no_ctes(self): + compiled_models = { + 'model.root.view': { + 'name': 'view', + 'resource_type': 'model', + 'unique_id': 'model.root.view', + 'fqn': ['root_project', 'view'], + 'empty': False, + 'package_name': 'root', + 'root_path': '/usr/src/app', + 'depends_on': [], + 'config': self.model_config, + 'tags': [], + 'path': 'view.sql', + 'raw_sql': ('with cte as (select * from something_else) ' + 'select * from source_table'), + 'compiled': True, + 'extra_ctes_injected': False, + 'extra_cte_ids': [], + 'extra_cte_sql': [], + 'injected_sql': '', + 'compiled_sql': ('with cte as (select * from something_else) ' + 'select * from source_table') + }, + 'model.root.view_no_cte': { + 'name': 'view_no_cte', + 'resource_type': 'model', + 'unique_id': 'model.root.view_no_cte', + 'fqn': ['root_project', 'view_no_cte'], + 'empty': False, + 'package_name': 'root', + 'root_path': '/usr/src/app', + 'depends_on': [], + 'config': self.model_config, + 'tags': [], + 'path': 'view.sql', + 'raw_sql': 'select * from source_table', + 'compiled': True, + 'extra_ctes_injected': False, + 'extra_cte_ids': [], + 'extra_cte_sql': [], + 'injected_sql': '', + 'compiled_sql': ('select * from source_table') + } + } + + result, all_models = dbt.compilation.prepend_ctes( + compiled_models.get('model.root.view'), + compiled_models) + + self.assertEqual(result, all_models.get('model.root.view')) + self.assertEqual(result.get('extra_ctes_injected'), True) + self.assertEqualIgnoreWhitespace( + result.get('injected_sql'), + compiled_models.get('model.root.view').get('compiled_sql')) + + result, all_models = dbt.compilation.prepend_ctes( + compiled_models.get('model.root.view_no_cte'), + compiled_models) + + self.assertEqual(result, all_models.get('model.root.view_no_cte')) + self.assertEqual(result.get('extra_ctes_injected'), True) + self.assertEqualIgnoreWhitespace( + result.get('injected_sql'), + compiled_models.get('model.root.view_no_cte').get('compiled_sql')) + + + def test__prepend_ctes(self): + ephemeral_config = self.model_config.copy() + ephemeral_config['materialized'] = 'ephemeral' + + compiled_models = { + 'model.root.view': { + 'name': 'view', + 'resource_type': 'model', + 'unique_id': 'model.root.view', + 'fqn': ['root_project', 'view'], + 'empty': False, + 'package_name': 'root', + 'root_path': '/usr/src/app', + 'depends_on': [ + 'model.root.ephemeral' + ], + 'config': self.model_config, + 'tags': [], + 'path': 'view.sql', + 'raw_sql': 'select * from {{ref("ephemeral")}}', + 'compiled': True, + 'extra_ctes_injected': False, + 'extra_cte_ids': [ + 'model.root.ephemeral' + ], + 'extra_cte_sql': [], + 'injected_sql': '', + 'compiled_sql': 'select * from __dbt__CTE__ephemeral' + }, + 'model.root.ephemeral': { + 'name': 'ephemeral', + 'resource_type': 'model', + 'unique_id': 'model.root.ephemeral', + 'fqn': ['root_project', 'ephemeral'], + 'empty': False, + 'package_name': 'root', + 'root_path': '/usr/src/app', + 'depends_on': [], + 'config': ephemeral_config, + 'tags': [], + 'path': 'ephemeral.sql', + 'raw_sql': 'select * from source_table', + 'compiled': True, + 'extra_ctes_injected': False, + 'extra_cte_ids': [], + 'extra_cte_sql': [], + 'injected_sql': '', + 'compiled_sql': 'select * from source_table' + } + } + + result, all_models = dbt.compilation.prepend_ctes( + compiled_models['model.root.view'], + compiled_models) + + self.assertEqual(result, all_models.get('model.root.view')) + self.assertEqual(result.get('extra_ctes_injected'), True) + self.assertEqualIgnoreWhitespace( + result.get('injected_sql'), + ('with __dbt__CTE__ephemeral as (' + 'select * from source_table' + ') ' + 'select * from __dbt__CTE__ephemeral')) + + self.assertEqual( + all_models.get('model.root.ephemeral').get('extra_ctes_injected'), + True) + + + def test__prepend_ctes__multiple_levels(self): + ephemeral_config = self.model_config.copy() + ephemeral_config['materialized'] = 'ephemeral' + + compiled_models = { + 'model.root.view': { + 'name': 'view', + 'resource_type': 'model', + 'unique_id': 'model.root.view', + 'fqn': ['root_project', 'view'], + 'empty': False, + 'package_name': 'root', + 'root_path': '/usr/src/app', + 'depends_on': [ + 'model.root.ephemeral' + ], + 'config': self.model_config, + 'tags': [], + 'path': 'view.sql', + 'raw_sql': 'select * from {{ref("ephemeral")}}', + 'compiled': True, + 'extra_ctes_injected': False, + 'extra_cte_ids': [ + 'model.root.ephemeral' + ], + 'extra_cte_sql': [], + 'injected_sql': '', + 'compiled_sql': 'select * from __dbt__CTE__ephemeral' + }, + 'model.root.ephemeral': { + 'name': 'ephemeral', + 'resource_type': 'model', + 'unique_id': 'model.root.ephemeral', + 'fqn': ['root_project', 'ephemeral'], + 'empty': False, + 'package_name': 'root', + 'root_path': '/usr/src/app', + 'depends_on': [], + 'config': ephemeral_config, + 'tags': [], + 'path': 'ephemeral.sql', + 'raw_sql': 'select * from {{ref("ephemeral_level_two")}}', + 'compiled': True, + 'extra_ctes_injected': False, + 'extra_cte_ids': [ + 'model.root.ephemeral_level_two' + ], + 'extra_cte_sql': [], + 'injected_sql': '', + 'compiled_sql': 'select * from __dbt__CTE__ephemeral_level_two' + }, + 'model.root.ephemeral_level_two': { + 'name': 'ephemeral_level_two', + 'resource_type': 'model', + 'unique_id': 'model.root.ephemeral_level_two', + 'fqn': ['root_project', 'ephemeral_level_two'], + 'empty': False, + 'package_name': 'root', + 'root_path': '/usr/src/app', + 'depends_on': [], + 'config': ephemeral_config, + 'tags': [], + 'path': 'ephemeral_level_two.sql', + 'raw_sql': 'select * from source_table', + 'compiled': True, + 'extra_ctes_injected': False, + 'extra_cte_ids': [], + 'extra_cte_sql': [], + 'injected_sql': '', + 'compiled_sql': 'select * from source_table' + } + + } + + result, all_models = dbt.compilation.prepend_ctes( + compiled_models['model.root.view'], + compiled_models) + + self.assertEqual(result, all_models.get('model.root.view')) + self.assertEqual(result.get('extra_ctes_injected'), True) + self.assertEqualIgnoreWhitespace( + result.get('injected_sql'), + ('with __dbt__CTE__ephemeral_level_two as (' + 'select * from source_table' + '), __dbt__CTE__ephemeral as (' + 'select * from __dbt__CTE__ephemeral_level_two' + ') ' + 'select * from __dbt__CTE__ephemeral')) + + self.assertEqual( + all_models.get('model.root.ephemeral').get('extra_ctes_injected'), + True) + self.assertEqual( + all_models.get('model.root.ephemeral_level_two').get('extra_ctes_injected'), + True) diff --git a/test/unit/test_graph.py b/test/unit/test_graph.py index a29526dbf91..989ca3cb808 100644 --- a/test/unit/test_graph.py +++ b/test/unit/test_graph.py @@ -4,11 +4,13 @@ import unittest import dbt.compilation +import dbt.exceptions +import dbt.flags +import dbt.linker import dbt.model import dbt.project import dbt.templates import dbt.utils -import dbt.linker import networkx as nx from test.integration.base import FakeArgs @@ -25,6 +27,8 @@ def tearDown(self): dbt.clients.system.load_file_contents = self.real_load_file_contents def setUp(self): + dbt.flags.STRICT_MODE = True + def mock_write_yaml(graph, outfile): self.graph_result = graph @@ -127,7 +131,7 @@ def test__single_model(self): self.assertEquals( self.graph_result.nodes(), - [('test_models_compile', 'model_one')]) + ['model.test_models_compile.model_one']) self.assertEquals( self.graph_result.edges(), @@ -145,18 +149,14 @@ def test__two_models_simple_ref(self): six.assertCountEqual(self, self.graph_result.nodes(), [ - ('test_models_compile', 'model_one'), - ('test_models_compile', 'model_two') + 'model.test_models_compile.model_one', + 'model.test_models_compile.model_two', ]) - six.assertCountEqual(self, - self.graph_result.edges(), - [ - ( - ('test_models_compile', 'model_one'), - ('test_models_compile', 'model_two') - ) - ]) + six.assertCountEqual( + self, + self.graph_result.edges(), + [ ('model.test_models_compile.model_one','model.test_models_compile.model_two',) ]) def test__model_materializations(self): self.use_models({ @@ -190,7 +190,9 @@ def test__model_materializations(self): nodes = self.graph_result.node for model, expected in expected_materialization.items(): - actual = nodes[("test_models_compile", model)]["materialized"] + key = 'model.test_models_compile.{}'.format(model) + actual = nodes[key].get('config', {}) \ + .get('materialized') self.assertEquals(actual, expected) def test__model_enabled(self): @@ -212,11 +214,15 @@ def test__model_enabled(self): compiler = self.get_compiler(self.get_project(cfg)) compiler.compile() - six.assertCountEqual(self, - self.graph_result.nodes(), - [('test_models_compile', 'model_one')]) + six.assertCountEqual( + self, self.graph_result.nodes(), + ['model.test_models_compile.model_one', + 'model.test_models_compile.model_two']) - six.assertCountEqual(self, self.graph_result.edges(), []) + six.assertCountEqual( + self, self.graph_result.edges(), + [('model.test_models_compile.model_one', + 'model.test_models_compile.model_two',)]) def test__model_incremental_without_sql_where_fails(self): self.use_models({ @@ -257,13 +263,14 @@ def test__model_incremental(self): compiler = self.get_compiler(self.get_project(cfg)) compiler.compile() - node = ('test_models_compile', 'model_one') + node = 'model.test_models_compile.model_one' self.assertEqual(self.graph_result.nodes(), [node]) self.assertEqual(self.graph_result.edges(), []) self.assertEqual( - self.graph_result.node[node]['materialized'], + self.graph_result.node[node].get('config', {}) \ + .get('materialized'), 'incremental') def test__topological_ordering(self): @@ -284,31 +291,23 @@ def test__topological_ordering(self): six.assertCountEqual(self, self.graph_result.nodes(), [ - ('test_models_compile', 'model_1'), - ('test_models_compile', 'model_2'), - ('test_models_compile', 'model_3'), - ('test_models_compile', 'model_4') + 'model.test_models_compile.model_1', + 'model.test_models_compile.model_2', + 'model.test_models_compile.model_3', + 'model.test_models_compile.model_4', ]) six.assertCountEqual(self, self.graph_result.edges(), [ - ( - ('test_models_compile', 'model_1'), - ('test_models_compile', 'model_2') - ), - ( - ('test_models_compile', 'model_1'), - ('test_models_compile', 'model_3') - ), - ( - ('test_models_compile', 'model_2'), - ('test_models_compile', 'model_3') - ), - ( - ('test_models_compile', 'model_3'), - ('test_models_compile', 'model_4') - ) + ('model.test_models_compile.model_1', + 'model.test_models_compile.model_2',), + ('model.test_models_compile.model_1', + 'model.test_models_compile.model_3',), + ('model.test_models_compile.model_2', + 'model.test_models_compile.model_3',), + ('model.test_models_compile.model_3', + 'model.test_models_compile.model_4',), ]) linker = dbt.linker.Linker() @@ -316,10 +315,10 @@ def test__topological_ordering(self): actual_ordering = linker.as_topological_ordering() expected_ordering = [ - ('test_models_compile', 'model_1'), - ('test_models_compile', 'model_2'), - ('test_models_compile', 'model_3'), - ('test_models_compile', 'model_4') + 'model.test_models_compile.model_1', + 'model.test_models_compile.model_2', + 'model.test_models_compile.model_3', + 'model.test_models_compile.model_4', ] self.assertEqual(actual_ordering, expected_ordering) @@ -345,18 +344,10 @@ def test__dependency_list(self): actual_dep_list = linker.as_dependency_list() expected_dep_list = [ - [ - ('test_models_compile', 'model_1') - ], - [ - ('test_models_compile', 'model_2') - ], - [ - ('test_models_compile', 'model_3') - ], - [ - ('test_models_compile', 'model_4'), - ] + ['model.test_models_compile.model_1'], + ['model.test_models_compile.model_2'], + ['model.test_models_compile.model_3'], + ['model.test_models_compile.model_4'], ] self.assertEqual(actual_dep_list, expected_dep_list) diff --git a/test/unit/test_graph_selection.py b/test/unit/test_graph_selection.py index 8c68a9da43f..fa37f94a071 100644 --- a/test/unit/test_graph_selection.py +++ b/test/unit/test_graph_selection.py @@ -12,18 +12,12 @@ class GraphSelectionTest(unittest.TestCase): def setUp(self): integer_graph = nx.balanced_tree(2, 2, nx.DiGraph()) - simple_mapping = { - i: letter for (i, letter) in enumerate(string.ascii_lowercase) - } package_mapping = { - i: ('X' if i % 2 == 0 else 'Y', letter) + i: 'm.' + ('X' if i % 2 == 0 else 'Y') + '.' + letter for (i, letter) in enumerate(string.ascii_lowercase) } - # Edges: [(a, b), (a, c), (b, d), (b, e), (c, f), (c, g)] - self.simple_graph = nx.relabel_nodes(integer_graph, simple_mapping) - # Edges: [(X.a, Y.b), (X.a, X.c), (Y.b, Y.d), (Y.b, X.e), (X.c, Y.f), (X.c, X.g)] self.package_graph = nx.relabel_nodes(integer_graph, package_mapping) @@ -78,37 +72,13 @@ def run_specs_and_assert(self, graph, include, exclude, expected): self.assertEquals(selected, expected) - # Test the select_nodes() interface - def test__single_node_selection(self): - self.run_specs_and_assert(self.simple_graph, ['a'], [], set('a')) - - def test__node_and_children(self): - self.run_specs_and_assert(self.simple_graph, ['a+'], [], set('abcdefg')) - - def test__node_and_parents(self): - self.run_specs_and_assert(self.simple_graph, ['+g'], [], set('acg')) - - def test__node_and_children_and_parents(self): - self.run_specs_and_assert(self.simple_graph, ['+c+'], [], set('acfg')) - - def test__node_and_children_and_parents_except_one(self): - self.run_specs_and_assert(self.simple_graph, ['+c+'], ['c'], set('afg')) - - def test__node_and_children_and_parents_except_many(self): - self.run_specs_and_assert(self.simple_graph, ['+c+'], ['+f'], set('g')) - - def test__multiple_node_selection(self): - self.run_specs_and_assert(self.simple_graph, ['a', 'b'], [], set('ab')) - - def test__multiple_node_selection_mixed(self): - self.run_specs_and_assert(self.simple_graph, ['a+', 'b+'], ['b', '+c'], set('defg')) def test__single_node_selection_in_package(self): self.run_specs_and_assert( self.package_graph, ['X.a'], [], - set([('X', 'a')]) + set(['m.X.a']) ) def test__multiple_node_selection_in_package(self): @@ -116,7 +86,7 @@ def test__multiple_node_selection_in_package(self): self.package_graph, ['X.a', 'b'], [], - set([('X', 'a'), ('Y', 'b')]) + set(['m.X.a', 'm.Y.b']) ) def test__select_children_except_in_package(self): @@ -124,16 +94,7 @@ def test__select_children_except_in_package(self): self.package_graph, ['X.a+'], ['b'], - set([ - ('X', 'a'), - # ('Y', 'b'), - ('X', 'c'), - ('Y', 'd'), - ('X', 'e'), - ('Y', 'f'), - ('X', 'g') - ]) - ) + set(['m.X.a','m.X.c', 'm.Y.d','m.X.e','m.Y.f','m.X.g'])) def parse_spec_and_assert(self, spec, parents, children, qualified_node_name): parsed = graph_selector.parse_spec(spec) diff --git a/test/unit/test_parser.py b/test/unit/test_parser.py new file mode 100644 index 00000000000..6e55030f382 --- /dev/null +++ b/test/unit/test_parser.py @@ -0,0 +1,957 @@ +from mock import MagicMock +import unittest + +import os + +import dbt.flags +import dbt.parser + + +class ParserTest(unittest.TestCase): + + def find_input_by_name(self, models, name): + return next( + (model for model in models if model.get('name') == name), + {}) + + def setUp(self): + dbt.flags.STRICT_MODE = True + + self.maxDiff = None + + self.root_project_config = { + 'name': 'root', + 'version': '0.1', + 'profile': 'test', + 'project-root': os.path.abspath('.'), + } + + self.snowplow_project_config = { + 'name': 'snowplow', + 'version': '0.1', + 'project-root': os.path.abspath('./dbt_modules/snowplow'), + } + + self.model_config = { + 'enabled': True, + 'materialized': 'view', + 'post-hook': [], + 'pre-hook': [], + 'vars': {}, + } + + def test__single_model(self): + models = [{ + 'name': 'model_one', + 'resource_type': 'model', + 'package_name': 'root', + 'root_path': '/usr/src/app', + 'path': 'model_one.sql', + 'raw_sql': ("select * from events"), + }] + + self.assertEquals( + dbt.parser.parse_sql_nodes( + models, + self.root_project_config, + {'root': self.root_project_config, + 'snowplow': self.snowplow_project_config}), + { + 'model.root.model_one': { + 'name': 'model_one', + 'resource_type': 'model', + 'unique_id': 'model.root.model_one', + 'fqn': ['root', 'model_one'], + 'empty': False, + 'package_name': 'root', + 'root_path': '/usr/src/app', + 'depends_on': [], + 'config': self.model_config, + 'tags': [], + 'path': 'model_one.sql', + 'raw_sql': self.find_input_by_name( + models, 'model_one').get('raw_sql') + } + } + ) + + def test__single_model__nested_configuration(self): + models = [{ + 'name': 'model_one', + 'resource_type': 'model', + 'package_name': 'root', + 'root_path': '/usr/src/app', + 'path': 'nested/path/model_one.sql', + 'raw_sql': ("select * from events"), + }] + + self.root_project_config = { + 'name': 'root', + 'version': '0.1', + 'profile': 'test', + 'project-root': os.path.abspath('.'), + 'models': { + 'materialized': 'ephemeral', + 'root': { + 'nested': { + 'path': { + 'materialized': 'ephemeral' + } + } + } + } + } + + ephemeral_config = self.model_config.copy() + ephemeral_config.update({ + 'materialized': 'ephemeral' + }) + + self.assertEquals( + dbt.parser.parse_sql_nodes( + models, + self.root_project_config, + {'root': self.root_project_config, + 'snowplow': self.snowplow_project_config}), + { + 'model.root.model_one': { + 'name': 'model_one', + 'resource_type': 'model', + 'unique_id': 'model.root.model_one', + 'fqn': ['root', 'nested', 'path', 'model_one'], + 'empty': False, + 'package_name': 'root', + 'root_path': '/usr/src/app', + 'depends_on': [], + 'config': ephemeral_config, + 'tags': [], + 'path': 'nested/path/model_one.sql', + 'raw_sql': self.find_input_by_name( + models, 'model_one').get('raw_sql') + } + } + ) + + def test__empty_model(self): + models = [{ + 'name': 'model_one', + 'resource_type': 'model', + 'package_name': 'root', + 'path': 'model_one.sql', + 'root_path': '/usr/src/app', + 'raw_sql': (" "), + }] + + self.assertEquals( + dbt.parser.parse_sql_nodes( + models, + self.root_project_config, + {'root': self.root_project_config}), + { + 'model.root.model_one': { + 'name': 'model_one', + 'resource_type': 'model', + 'unique_id': 'model.root.model_one', + 'fqn': ['root', 'model_one'], + 'empty': True, + 'package_name': 'root', + 'depends_on': [], + 'config': self.model_config, + 'tags': [], + 'path': 'model_one.sql', + 'root_path': '/usr/src/app', + 'raw_sql': self.find_input_by_name( + models, 'model_one').get('raw_sql') + } + } + ) + + def test__simple_dependency(self): + models = [{ + 'name': 'base', + 'resource_type': 'model', + 'package_name': 'root', + 'path': 'base.sql', + 'root_path': '/usr/src/app', + 'raw_sql': 'select * from events' + }, { + 'name': 'events_tx', + 'resource_type': 'model', + 'package_name': 'root', + 'path': 'events_tx.sql', + 'root_path': '/usr/src/app', + 'raw_sql': "select * from {{ref('base')}}" + }] + + self.assertEquals( + dbt.parser.parse_sql_nodes( + models, + self.root_project_config, + {'root': self.root_project_config, + 'snowplow': self.snowplow_project_config}), + { + 'model.root.base': { + 'name': 'base', + 'resource_type': 'model', + 'unique_id': 'model.root.base', + 'fqn': ['root', 'base'], + 'empty': False, + 'package_name': 'root', + 'depends_on': [], + 'config': self.model_config, + 'tags': [], + 'path': 'base.sql', + 'root_path': '/usr/src/app', + 'raw_sql': self.find_input_by_name( + models, 'base').get('raw_sql') + }, + 'model.root.events_tx': { + 'name': 'events_tx', + 'resource_type': 'model', + 'unique_id': 'model.root.events_tx', + 'fqn': ['root', 'events_tx'], + 'empty': False, + 'package_name': 'root', + 'depends_on': [], + 'config': self.model_config, + 'tags': [], + 'path': 'events_tx.sql', + 'root_path': '/usr/src/app', + 'raw_sql': self.find_input_by_name( + models, 'events_tx').get('raw_sql') + } + } + ) + + def test__multiple_dependencies(self): + models = [{ + 'name': 'events', + 'resource_type': 'model', + 'package_name': 'root', + 'path': 'events.sql', + 'root_path': '/usr/src/app', + 'raw_sql': 'select * from base.events', + }, { + 'name': 'sessions', + 'resource_type': 'model', + 'package_name': 'root', + 'path': 'sessions.sql', + 'root_path': '/usr/src/app', + 'raw_sql': 'select * from base.sessions', + }, { + 'name': 'events_tx', + 'resource_type': 'model', + 'package_name': 'root', + 'path': 'events_tx.sql', + 'root_path': '/usr/src/app', + 'raw_sql': ("with events as (select * from {{ref('events')}}) " + "select * from events"), + }, { + 'name': 'sessions_tx', + 'resource_type': 'model', + 'package_name': 'root', + 'path': 'sessions_tx.sql', + 'root_path': '/usr/src/app', + 'raw_sql': ("with sessions as (select * from {{ref('sessions')}}) " + "select * from sessions"), + }, { + 'name': 'multi', + 'resource_type': 'model', + 'package_name': 'root', + 'path': 'multi.sql', + 'root_path': '/usr/src/app', + 'raw_sql': ("with s as (select * from {{ref('sessions_tx')}}), " + "e as (select * from {{ref('events_tx')}}) " + "select * from e left join s on s.id = e.sid"), + }] + + self.assertEquals( + dbt.parser.parse_sql_nodes( + models, + self.root_project_config, + {'root': self.root_project_config, + 'snowplow': self.snowplow_project_config}), + { + 'model.root.events': { + 'name': 'events', + 'resource_type': 'model', + 'unique_id': 'model.root.events', + 'fqn': ['root', 'events'], + 'empty': False, + 'package_name': 'root', + 'depends_on': [], + 'config': self.model_config, + 'tags': [], + 'path': 'events.sql', + 'root_path': '/usr/src/app', + 'raw_sql': self.find_input_by_name( + models, 'events').get('raw_sql') + }, + 'model.root.sessions': { + 'name': 'sessions', + 'resource_type': 'model', + 'unique_id': 'model.root.sessions', + 'fqn': ['root', 'sessions'], + 'empty': False, + 'package_name': 'root', + 'depends_on': [], + 'config': self.model_config, + 'tags': [], + 'path': 'sessions.sql', + 'root_path': '/usr/src/app', + 'raw_sql': self.find_input_by_name( + models, 'sessions').get('raw_sql') + }, + 'model.root.events_tx': { + 'name': 'events_tx', + 'resource_type': 'model', + 'unique_id': 'model.root.events_tx', + 'fqn': ['root', 'events_tx'], + 'empty': False, + 'package_name': 'root', + 'depends_on': [], + 'config': self.model_config, + 'tags': [], + 'path': 'events_tx.sql', + 'root_path': '/usr/src/app', + 'raw_sql': self.find_input_by_name( + models, 'events_tx').get('raw_sql') + }, + 'model.root.sessions_tx': { + 'name': 'sessions_tx', + 'resource_type': 'model', + 'unique_id': 'model.root.sessions_tx', + 'fqn': ['root', 'sessions_tx'], + 'empty': False, + 'package_name': 'root', + 'depends_on': [], + 'config': self.model_config, + 'tags': [], + 'path': 'sessions_tx.sql', + 'root_path': '/usr/src/app', + 'raw_sql': self.find_input_by_name( + models, 'sessions_tx').get('raw_sql') + }, + 'model.root.multi': { + 'name': 'multi', + 'resource_type': 'model', + 'unique_id': 'model.root.multi', + 'fqn': ['root', 'multi'], + 'empty': False, + 'package_name': 'root', + 'depends_on': [], + 'config': self.model_config, + 'tags': [], + 'path': 'multi.sql', + 'root_path': '/usr/src/app', + 'raw_sql': self.find_input_by_name( + models, 'multi').get('raw_sql') + } + } + ) + + def test__multiple_dependencies__packages(self): + models = [{ + 'name': 'events', + 'resource_type': 'model', + 'package_name': 'snowplow', + 'path': 'events.sql', + 'root_path': '/usr/src/app', + 'raw_sql': 'select * from base.events', + }, { + 'name': 'sessions', + 'resource_type': 'model', + 'package_name': 'snowplow', + 'path': 'sessions.sql', + 'root_path': '/usr/src/app', + 'raw_sql': 'select * from base.sessions', + }, { + 'name': 'events_tx', + 'resource_type': 'model', + 'package_name': 'snowplow', + 'path': 'events_tx.sql', + 'root_path': '/usr/src/app', + 'raw_sql': ("with events as (select * from {{ref('events')}}) " + "select * from events"), + }, { + 'name': 'sessions_tx', + 'resource_type': 'model', + 'package_name': 'snowplow', + 'path': 'sessions_tx.sql', + 'root_path': '/usr/src/app', + 'raw_sql': ("with sessions as (select * from {{ref('sessions')}}) " + "select * from sessions"), + }, { + 'name': 'multi', + 'resource_type': 'model', + 'package_name': 'root', + 'path': 'multi.sql', + 'root_path': '/usr/src/app', + 'raw_sql': ("with s as (select * from {{ref('snowplow', 'sessions_tx')}}), " + "e as (select * from {{ref('snowplow', 'events_tx')}}) " + "select * from e left join s on s.id = e.sid"), + }] + + self.assertEquals( + dbt.parser.parse_sql_nodes( + models, + self.root_project_config, + {'root': self.root_project_config, + 'snowplow': self.snowplow_project_config}), + { + 'model.snowplow.events': { + 'name': 'events', + 'resource_type': 'model', + 'unique_id': 'model.snowplow.events', + 'fqn': ['snowplow', 'events'], + 'empty': False, + 'package_name': 'snowplow', + 'depends_on': [], + 'config': self.model_config, + 'tags': [], + 'path': 'events.sql', + 'root_path': '/usr/src/app', + 'raw_sql': self.find_input_by_name( + models, 'events').get('raw_sql') + }, + 'model.snowplow.sessions': { + 'name': 'sessions', + 'resource_type': 'model', + 'unique_id': 'model.snowplow.sessions', + 'fqn': ['snowplow', 'sessions'], + 'empty': False, + 'package_name': 'snowplow', + 'depends_on': [], + 'config': self.model_config, + 'tags': [], + 'path': 'sessions.sql', + 'root_path': '/usr/src/app', + 'raw_sql': self.find_input_by_name( + models, 'sessions').get('raw_sql') + }, + 'model.snowplow.events_tx': { + 'name': 'events_tx', + 'resource_type': 'model', + 'unique_id': 'model.snowplow.events_tx', + 'fqn': ['snowplow', 'events_tx'], + 'empty': False, + 'package_name': 'snowplow', + 'depends_on': [], + 'config': self.model_config, + 'tags': [], + 'path': 'events_tx.sql', + 'root_path': '/usr/src/app', + 'raw_sql': self.find_input_by_name( + models, 'events_tx').get('raw_sql') + }, + 'model.snowplow.sessions_tx': { + 'name': 'sessions_tx', + 'resource_type': 'model', + 'unique_id': 'model.snowplow.sessions_tx', + 'fqn': ['snowplow', 'sessions_tx'], + 'empty': False, + 'package_name': 'snowplow', + 'depends_on': [], + 'config': self.model_config, + 'tags': [], + 'path': 'sessions_tx.sql', + 'root_path': '/usr/src/app', + 'raw_sql': self.find_input_by_name( + models, 'sessions_tx').get('raw_sql') + }, + 'model.root.multi': { + 'name': 'multi', + 'resource_type': 'model', + 'unique_id': 'model.root.multi', + 'fqn': ['root', 'multi'], + 'empty': False, + 'package_name': 'root', + 'depends_on': [], + 'config': self.model_config, + 'tags': [], + 'path': 'multi.sql', + 'root_path': '/usr/src/app', + 'raw_sql': self.find_input_by_name( + models, 'multi').get('raw_sql') + } + } + ) + + def test__in_model_config(self): + models = [{ + 'name': 'model_one', + 'resource_type': 'model', + 'package_name': 'root', + 'path': 'model_one.sql', + 'root_path': '/usr/src/app', + 'raw_sql': ("{{config({'materialized':'table'})}}" + "select * from events"), + }] + + self.model_config.update({ + 'materialized': 'table' + }) + + self.assertEquals( + dbt.parser.parse_sql_nodes( + models, + self.root_project_config, + {'root': self.root_project_config, + 'snowplow': self.snowplow_project_config}), + { + 'model.root.model_one': { + 'name': 'model_one', + 'resource_type': 'model', + 'unique_id': 'model.root.model_one', + 'fqn': ['root', 'model_one'], + 'empty': False, + 'package_name': 'root', + 'depends_on': [], + 'config': self.model_config, + 'tags': [], + 'root_path': '/usr/src/app', + 'path': 'model_one.sql', + 'raw_sql': self.find_input_by_name( + models, 'model_one').get('raw_sql') + } + } + ) + + def test__root_project_config(self): + self.root_project_config = { + 'name': 'root', + 'version': '0.1', + 'profile': 'test', + 'project-root': os.path.abspath('.'), + 'models': { + 'materialized': 'ephemeral', + 'root': { + 'view': { + 'materialized': 'view' + } + } + } + } + + models = [{ + 'name': 'table', + 'resource_type': 'model', + 'package_name': 'root', + 'path': 'table.sql', + 'root_path': '/usr/src/app', + 'raw_sql': ("{{config({'materialized':'table'})}}" + "select * from events"), + }, { + 'name': 'ephemeral', + 'resource_type': 'model', + 'package_name': 'root', + 'path': 'ephemeral.sql', + 'root_path': '/usr/src/app', + 'raw_sql': ("select * from events"), + }, { + 'name': 'view', + 'resource_type': 'model', + 'package_name': 'root', + 'path': 'view.sql', + 'root_path': '/usr/src/app', + 'raw_sql': ("select * from events"), + }] + + self.model_config.update({ + 'materialized': 'table' + }) + + ephemeral_config = self.model_config.copy() + ephemeral_config.update({ + 'materialized': 'ephemeral' + }) + + view_config = self.model_config.copy() + view_config.update({ + 'materialized': 'view' + }) + + self.assertEquals( + dbt.parser.parse_sql_nodes( + models, + self.root_project_config, + {'root': self.root_project_config, + 'snowplow': self.snowplow_project_config}), + { + 'model.root.table': { + 'name': 'table', + 'resource_type': 'model', + 'unique_id': 'model.root.table', + 'fqn': ['root', 'table'], + 'empty': False, + 'package_name': 'root', + 'depends_on': [], + 'path': 'table.sql', + 'config': self.model_config, + 'tags': [], + 'root_path': '/usr/src/app', + 'raw_sql': self.find_input_by_name( + models, 'table').get('raw_sql') + }, + 'model.root.ephemeral': { + 'name': 'ephemeral', + 'resource_type': 'model', + 'unique_id': 'model.root.ephemeral', + 'fqn': ['root', 'ephemeral'], + 'empty': False, + 'package_name': 'root', + 'depends_on': [], + 'path': 'ephemeral.sql', + 'config': ephemeral_config, + 'tags': [], + 'root_path': '/usr/src/app', + 'raw_sql': self.find_input_by_name( + models, 'ephemeral').get('raw_sql') + }, + 'model.root.view': { + 'name': 'view', + 'resource_type': 'model', + 'unique_id': 'model.root.view', + 'fqn': ['root', 'view'], + 'empty': False, + 'package_name': 'root', + 'depends_on': [], + 'path': 'view.sql', + 'root_path': '/usr/src/app', + 'config': view_config, + 'tags': [], + 'raw_sql': self.find_input_by_name( + models, 'ephemeral').get('raw_sql') + } + } + + ) + + def test__other_project_config(self): + self.root_project_config = { + 'name': 'root', + 'version': '0.1', + 'profile': 'test', + 'project-root': os.path.abspath('.'), + 'models': { + 'materialized': 'ephemeral', + 'root': { + 'view': { + 'materialized': 'view' + } + }, + 'snowplow': { + 'enabled': False, + 'views': { + 'materialized': 'view', + } + } + } + } + + self.snowplow_project_config = { + 'name': 'snowplow', + 'version': '0.1', + 'project-root': os.path.abspath('./dbt_modules/snowplow'), + 'models': { + 'snowplow': { + 'enabled': False, + 'views': { + 'materialized': 'table', + 'sort': 'timestamp' + } + } + } + } + + models = [{ + 'name': 'table', + 'resource_type': 'model', + 'package_name': 'root', + 'path': 'table.sql', + 'root_path': '/usr/src/app', + 'raw_sql': ("{{config({'materialized':'table'})}}" + "select * from events"), + }, { + 'name': 'ephemeral', + 'resource_type': 'model', + 'package_name': 'root', + 'path': 'ephemeral.sql', + 'root_path': '/usr/src/app', + 'raw_sql': ("select * from events"), + }, { + 'name': 'view', + 'resource_type': 'model', + 'package_name': 'root', + 'path': 'view.sql', + 'root_path': '/usr/src/app', + 'raw_sql': ("select * from events"), + }, { + 'name': 'disabled', + 'resource_type': 'model', + 'package_name': 'snowplow', + 'path': 'disabled.sql', + 'root_path': '/usr/src/app', + 'raw_sql': ("select * from events"), + }, { + 'name': 'package', + 'resource_type': 'model', + 'package_name': 'snowplow', + 'path': 'views/package.sql', + 'root_path': '/usr/src/app', + 'raw_sql': ("select * from events"), + }] + + self.model_config.update({ + 'materialized': 'table' + }) + + ephemeral_config = self.model_config.copy() + ephemeral_config.update({ + 'materialized': 'ephemeral' + }) + + view_config = self.model_config.copy() + view_config.update({ + 'materialized': 'view' + }) + + disabled_config = self.model_config.copy() + disabled_config.update({ + 'enabled': False, + 'materialized': 'ephemeral' + }) + + sort_config = self.model_config.copy() + sort_config.update({ + 'enabled': False, + 'materialized': 'view', + 'sort': 'timestamp', + }) + + self.assertEquals( + dbt.parser.parse_sql_nodes( + models, + self.root_project_config, + {'root': self.root_project_config, + 'snowplow': self.snowplow_project_config}), + { + 'model.root.table': { + 'name': 'table', + 'resource_type': 'model', + 'unique_id': 'model.root.table', + 'fqn': ['root', 'table'], + 'empty': False, + 'package_name': 'root', + 'depends_on': [], + 'path': 'table.sql', + 'root_path': '/usr/src/app', + 'config': self.model_config, + 'tags': [], + 'raw_sql': self.find_input_by_name( + models, 'table').get('raw_sql') + }, + 'model.root.ephemeral': { + 'name': 'ephemeral', + 'resource_type': 'model', + 'unique_id': 'model.root.ephemeral', + 'fqn': ['root', 'ephemeral'], + 'empty': False, + 'package_name': 'root', + 'depends_on': [], + 'path': 'ephemeral.sql', + 'root_path': '/usr/src/app', + 'config': ephemeral_config, + 'tags': [], + 'raw_sql': self.find_input_by_name( + models, 'ephemeral').get('raw_sql') + }, + 'model.root.view': { + 'name': 'view', + 'resource_type': 'model', + 'unique_id': 'model.root.view', + 'fqn': ['root', 'view'], + 'empty': False, + 'package_name': 'root', + 'depends_on': [], + 'path': 'view.sql', + 'root_path': '/usr/src/app', + 'config': view_config, + 'tags': [], + 'raw_sql': self.find_input_by_name( + models, 'view').get('raw_sql') + }, + 'model.snowplow.disabled': { + 'name': 'disabled', + 'resource_type': 'model', + 'unique_id': 'model.snowplow.disabled', + 'fqn': ['snowplow', 'disabled'], + 'empty': False, + 'package_name': 'snowplow', + 'depends_on': [], + 'path': 'disabled.sql', + 'root_path': '/usr/src/app', + 'config': disabled_config, + 'tags': [], + 'raw_sql': self.find_input_by_name( + models, 'disabled').get('raw_sql') + }, + 'model.snowplow.package': { + 'name': 'package', + 'resource_type': 'model', + 'unique_id': 'model.snowplow.package', + 'fqn': ['snowplow', 'views', 'package'], + 'empty': False, + 'package_name': 'snowplow', + 'depends_on': [], + 'path': 'views/package.sql', + 'root_path': '/usr/src/app', + 'config': sort_config, + 'tags': [], + 'raw_sql': self.find_input_by_name( + models, 'package').get('raw_sql') + } + } + ) + + def test__simple_schema_test(self): + tests = [{ + 'name': 'test_one', + 'resource_type': 'test', + 'package_name': 'root', + 'root_path': '/usr/src/app', + 'path': 'test_one.yml', + 'raw_sql': None, + 'raw_yml': ('{model_one: {constraints: {not_null: [id],' + 'unique: [id],' + 'accepted_values: [{field: id, values: ["a","b"]}],' + 'relationships: [{from: id, to: model_two, field: id}]' + '}}}') + }] + + not_null_sql = dbt.parser.QUERY_VALIDATE_NOT_NULL \ + .format( + field='id', + ref="{{ref('model_one')}}") + + unique_sql = dbt.parser.QUERY_VALIDATE_UNIQUE \ + .format( + field='id', + ref="{{ref('model_one')}}") + + accepted_values_sql = dbt.parser.QUERY_VALIDATE_ACCEPTED_VALUES \ + .format( + field='id', + ref="{{ref('model_one')}}", + values_csv="'a','b'") + + relationships_sql = dbt.parser.QUERY_VALIDATE_REFERENTIAL_INTEGRITY \ + .format( + parent_field='id', + parent_ref="{{ref('model_two')}}", + child_field='id', + child_ref="{{ref('model_one')}}") + + self.assertEquals( + dbt.parser.parse_schema_tests( + tests, + self.root_project_config, + {'root': self.root_project_config, + 'snowplow': self.snowplow_project_config}), + { + 'test.root.not_null_model_one_id': { + 'name': 'not_null_model_one_id', + 'resource_type': 'test', + 'unique_id': 'test.root.not_null_model_one_id', + 'fqn': ['root', 'schema', 'test_one'], + 'empty': False, + 'package_name': 'root', + 'root_path': '/usr/src/app', + 'depends_on': [], + 'config': self.model_config, + 'path': 'test_one.yml', + 'tags': ['schema'], + 'raw_sql': not_null_sql, + }, + 'test.root.unique_model_one_id': { + 'name': 'unique_model_one_id', + 'resource_type': 'test', + 'unique_id': 'test.root.unique_model_one_id', + 'fqn': ['root', 'schema', 'test_one'], + 'empty': False, + 'package_name': 'root', + 'root_path': '/usr/src/app', + 'depends_on': [], + 'config': self.model_config, + 'path': 'test_one.yml', + 'tags': ['schema'], + 'raw_sql': unique_sql, + }, + 'test.root.accepted_values_model_one_id': { + 'name': 'accepted_values_model_one_id', + 'resource_type': 'test', + 'unique_id': 'test.root.accepted_values_model_one_id', + 'fqn': ['root', 'schema', 'test_one'], + 'empty': False, + 'package_name': 'root', + 'root_path': '/usr/src/app', + 'depends_on': [], + 'config': self.model_config, + 'path': 'test_one.yml', + 'tags': ['schema'], + 'raw_sql': accepted_values_sql, + }, + 'test.root.relationships_model_one_id_to_model_two_id': { + 'name': 'relationships_model_one_id_to_model_two_id', + 'resource_type': 'test', + 'unique_id': 'test.root.relationships_model_one_id_to_model_two_id', + 'fqn': ['root', 'schema', 'test_one'], + 'empty': False, + 'package_name': 'root', + 'root_path': '/usr/src/app', + 'depends_on': [], + 'config': self.model_config, + 'path': 'test_one.yml', + 'tags': ['schema'], + 'raw_sql': relationships_sql, + } + + + } + ) + + + def test__simple_data_test(self): + tests = [{ + 'name': 'no_events', + 'resource_type': 'test', + 'package_name': 'root', + 'path': 'no_events.sql', + 'root_path': '/usr/src/app', + 'raw_sql': "select * from {{ref('base')}}" + }] + + self.assertEquals( + dbt.parser.parse_sql_nodes( + tests, + self.root_project_config, + {'root': self.root_project_config, + 'snowplow': self.snowplow_project_config}), + { + 'test.root.no_events': { + 'name': 'no_events', + 'resource_type': 'test', + 'unique_id': 'test.root.no_events', + 'fqn': ['root', 'no_events'], + 'empty': False, + 'package_name': 'root', + 'depends_on': [], + 'config': self.model_config, + 'path': 'no_events.sql', + 'root_path': '/usr/src/app', + 'tags': [], + 'raw_sql': self.find_input_by_name( + tests, 'no_events').get('raw_sql') + } + } + ) diff --git a/test/unit/test_runner.py b/test/unit/test_runner.py new file mode 100644 index 00000000000..e0797b85205 --- /dev/null +++ b/test/unit/test_runner.py @@ -0,0 +1,278 @@ +from mock import MagicMock, patch +import unittest + +import os + +import dbt.flags +import dbt.parser +import dbt.runner + + +class TestRunner(unittest.TestCase): + + def setUp(self): + dbt.flags.STRICT_MODE = True + dbt.flags.NON_DESTRUCTIVE = True + + self.profile = { + 'type': 'postgres', + 'dbname': 'postgres', + 'user': 'root', + 'host': 'database', + 'pass': 'password123', + 'port': 5432, + 'schema': 'public' + } + + self.model_config = { + 'enabled': True, + 'materialized': 'view', + 'post-hook': [], + 'pre-hook': [], + 'vars': {}, + } + + self.model = { + 'name': 'view', + 'resource_type': 'model', + 'unique_id': 'model.root.view', + 'fqn': ['root_project', 'view'], + 'empty': False, + 'package_name': 'root', + 'root_path': '/usr/src/app', + 'depends_on': [ + 'model.root.ephemeral' + ], + 'config': self.model_config, + 'path': 'view.sql', + 'raw_sql': 'select * from {{ref("ephemeral")}}', + 'compiled': True, + 'extra_ctes_injected': False, + 'extra_cte_ids': [ + 'model.root.ephemeral' + ], + 'extra_cte_sql': [], + 'compiled_sql': 'select * from __dbt__CTE__ephemeral', + 'injected_sql': ('with __dbt__CTE__ephemeral as (' + 'select * from "public"."ephemeral"', + ')' + 'select * from __dbt__CTE__ephemeral'), + 'wrapped_sql': ('create view "public"."view" as (' + 'with __dbt__CTE__ephemeral as (' + 'select * from "public"."ephemeral"' + ')' + 'select * from __dbt__CTE__ephemeral' + '))') + } + + self.existing = {} + + def fake_drop(profile, relation, relation_type, model_name): + del self.existing[relation] + + def fake_query_for_existing(profile, schema): + return self.existing + + self._drop = dbt.adapters.postgres.PostgresAdapter.drop + self._query_for_existing = \ + dbt.adapters.postgres.PostgresAdapter.query_for_existing + + dbt.adapters.postgres.PostgresAdapter.drop = MagicMock( + side_effect=fake_drop) + + dbt.adapters.postgres.PostgresAdapter.query_for_existing = MagicMock( + side_effect=fake_query_for_existing) + + def tearDown(self): + dbt.adapters.postgres.PostgresAdapter.drop = self._drop + dbt.adapters.postgres.PostgresAdapter.query_for_existing = \ + self._query_for_existing + + + + @patch('dbt.adapters.postgres.PostgresAdapter.execute_model', return_value=1) + @patch('dbt.adapters.postgres.PostgresAdapter.rename', return_value=None) + @patch('dbt.adapters.postgres.PostgresAdapter.truncate', return_value=None) + def test__execute_model__view(self, + mock_adapter_truncate, + mock_adapter_rename, + mock_adapter_execute_model): + model = self.model.copy() + + dbt.runner.execute_model( + self.profile, + model, + existing=self.existing) + + dbt.adapters.postgres.PostgresAdapter.drop.assert_not_called() + + mock_adapter_truncate.assert_not_called() + mock_adapter_rename.assert_not_called() + mock_adapter_execute_model.assert_called_once() + + + @patch('dbt.adapters.postgres.PostgresAdapter.execute_model', return_value=1) + @patch('dbt.adapters.postgres.PostgresAdapter.rename', return_value=1) + @patch('dbt.adapters.postgres.PostgresAdapter.truncate', return_value=None) + def test__execute_model__view__existing(self, + mock_adapter_truncate, + mock_adapter_rename, + mock_adapter_execute_model): + self.existing = {'view': 'table'} + + model = self.model.copy() + + dbt.runner.execute_model( + self.profile, + model, + existing=self.existing) + + dbt.adapters.postgres.PostgresAdapter.drop.assert_not_called() + + mock_adapter_truncate.assert_not_called() + mock_adapter_rename.assert_not_called() + mock_adapter_execute_model.assert_not_called() + + + @patch('dbt.adapters.postgres.PostgresAdapter.execute_model', return_value=1) + @patch('dbt.adapters.postgres.PostgresAdapter.rename', return_value=1) + @patch('dbt.adapters.postgres.PostgresAdapter.truncate', return_value=None) + def test__execute_model__table(self, + mock_adapter_truncate, + mock_adapter_rename, + mock_adapter_execute_model): + model = self.model.copy() + model['config']['materialized'] = 'table' + + dbt.runner.execute_model( + self.profile, + model, + existing=self.existing) + + dbt.adapters.postgres.PostgresAdapter.drop.assert_not_called() + + mock_adapter_truncate.assert_not_called() + mock_adapter_rename.assert_not_called() + mock_adapter_execute_model.assert_called_once() + + + @patch('dbt.adapters.postgres.PostgresAdapter.execute_model', return_value=1) + @patch('dbt.adapters.postgres.PostgresAdapter.rename', return_value=1) + @patch('dbt.adapters.postgres.PostgresAdapter.truncate', return_value=None) + def test__execute_model__table__existing(self, + mock_adapter_truncate, + mock_adapter_rename, + mock_adapter_execute_model): + self.existing = {'view': 'table'} + + model = self.model.copy() + model['config']['materialized'] = 'table' + + dbt.runner.execute_model( + self.profile, + self.model, + existing=self.existing) + + dbt.adapters.postgres.PostgresAdapter.drop.assert_not_called() + + mock_adapter_truncate.assert_called_once() + mock_adapter_rename.assert_not_called() + mock_adapter_execute_model.assert_called_once() + + + + @patch('dbt.adapters.postgres.PostgresAdapter.execute_model', return_value=1) + @patch('dbt.adapters.postgres.PostgresAdapter.rename', return_value=None) + @patch('dbt.adapters.postgres.PostgresAdapter.truncate', return_value=None) + def test__execute_model__view__destructive(self, + mock_adapter_truncate, + mock_adapter_rename, + mock_adapter_execute_model): + dbt.flags.NON_DESTRUCTIVE = False + + model = self.model.copy() + + dbt.runner.execute_model( + self.profile, + model, + existing=self.existing) + + dbt.adapters.postgres.PostgresAdapter.drop.assert_not_called() + + mock_adapter_truncate.assert_not_called() + mock_adapter_rename.assert_called_once() + mock_adapter_execute_model.assert_called_once() + + + @patch('dbt.adapters.postgres.PostgresAdapter.execute_model', return_value=1) + @patch('dbt.adapters.postgres.PostgresAdapter.rename', return_value=1) + @patch('dbt.adapters.postgres.PostgresAdapter.truncate', return_value=None) + def test__execute_model__view__existing__destructive(self, + mock_adapter_truncate, + mock_adapter_rename, + mock_adapter_execute_model): + dbt.flags.NON_DESTRUCTIVE = False + + self.existing = {'view': 'view'} + model = self.model.copy() + + dbt.runner.execute_model( + self.profile, + model, + existing=self.existing) + + dbt.adapters.postgres.PostgresAdapter.drop.assert_called_once() + + mock_adapter_truncate.assert_not_called() + mock_adapter_rename.assert_called_once() + mock_adapter_execute_model.assert_called_once() + + + @patch('dbt.adapters.postgres.PostgresAdapter.execute_model', return_value=1) + @patch('dbt.adapters.postgres.PostgresAdapter.rename', return_value=1) + @patch('dbt.adapters.postgres.PostgresAdapter.truncate', return_value=None) + def test__execute_model__table__destructive(self, + mock_adapter_truncate, + mock_adapter_rename, + mock_adapter_execute_model): + dbt.flags.NON_DESTRUCTIVE = False + + model = self.model.copy() + model['config']['materialized'] = 'table' + + dbt.runner.execute_model( + self.profile, + model, + existing=self.existing) + + dbt.adapters.postgres.PostgresAdapter.drop.assert_not_called() + + mock_adapter_truncate.assert_not_called() + mock_adapter_rename.assert_called_once() + mock_adapter_execute_model.assert_called_once() + + + @patch('dbt.adapters.postgres.PostgresAdapter.execute_model', return_value=1) + @patch('dbt.adapters.postgres.PostgresAdapter.rename', return_value=1) + @patch('dbt.adapters.postgres.PostgresAdapter.truncate', return_value=None) + def test__execute_model__table__existing__destructive(self, + mock_adapter_truncate, + mock_adapter_rename, + mock_adapter_execute_model): + dbt.flags.NON_DESTRUCTIVE = False + + self.existing = {'view': 'table'} + + model = self.model.copy() + model['config']['materialized'] = 'table' + + dbt.runner.execute_model( + self.profile, + self.model, + existing=self.existing) + + dbt.adapters.postgres.PostgresAdapter.drop.assert_called_once() + + mock_adapter_truncate.assert_not_called() + mock_adapter_rename.assert_called_once() + mock_adapter_execute_model.assert_called_once() diff --git a/tox.ini b/tox.ini index 92df45426a0..32957b36fb3 100644 --- a/tox.ini +++ b/tox.ini @@ -48,7 +48,7 @@ basepython = python3.6 passenv = * setenv = HOME=/root/ -commands = /bin/bash -c '{envpython} $(which nosetests) -v -a type=postgres {posargs} --with-coverage --cover-branches --cover-html --cover-html-dir=htmlcov test/integration/*' +commands = /bin/bash -c '{envpython} $(which nosetests) -v -a type=postgres {posargs} test/integration/*' deps = -r{toxinidir}/requirements.txt -r{toxinidir}/dev_requirements.txt