Skip to content

Commit

Permalink
refactor: factor out jinja interactions (#309)
Browse files Browse the repository at this point in the history
  • Loading branch information
cmcarthur authored and Connor McArthur committed Mar 3, 2017
1 parent 6c0f59d commit 1a101ad
Show file tree
Hide file tree
Showing 8 changed files with 171 additions and 165 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
### Changes

- Graph refactor: fix common issues with load order ([#292](https://github.com/fishtown-analytics/dbt/pull/292))
- Refactor: factor out jinja interactions ([#309](https://github.com/fishtown-analytics/dbt/pull/309))
- Speedup: detect cycles at the end of compilation ([#307](https://github.com/fishtown-analytics/dbt/pull/307))
- Speedup: write graph file with gpickle instead of yaml ([#306](https://github.com/fishtown-analytics/dbt/pull/306))

Expand Down
52 changes: 52 additions & 0 deletions dbt/clients/jinja.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import dbt.compat
import dbt.exceptions

import jinja2
import jinja2.sandbox


class SilentUndefined(jinja2.Undefined):
"""
This class sets up the parser to just ignore undefined jinja2 calls. So,
for example, `env` is not defined here, but will not make the parser fail
with a fatal error.
"""
def _fail_with_undefined_error(self, *args, **kwargs):
return None

__add__ = __radd__ = __mul__ = __rmul__ = __div__ = __rdiv__ = \
__truediv__ = __rtruediv__ = __floordiv__ = __rfloordiv__ = \
__mod__ = __rmod__ = __pos__ = __neg__ = __call__ = \
__getitem__ = __lt__ = __le__ = __gt__ = __ge__ = __int__ = \
__float__ = __complex__ = __pow__ = __rpow__ = \
_fail_with_undefined_error


env = jinja2.sandbox.SandboxedEnvironment()

silent_on_undefined_env = jinja2.sandbox.SandboxedEnvironment(
undefined=SilentUndefined)


def get_template(string, ctx, node=None, silent_on_undefined=False):
try:
local_env = env

if silent_on_undefined:
local_env = silent_on_undefined_env

return local_env.from_string(dbt.compat.to_string(string), globals=ctx)

except (jinja2.exceptions.TemplateSyntaxError,
jinja2.exceptions.UndefinedError) as e:
dbt.exceptions.raise_compiler_error(node, str(e))


def get_rendered(string, ctx, node=None, silent_on_undefined=False):
try:
template = get_template(string, ctx, node, silent_on_undefined)
return template.render(ctx)

except (jinja2.exceptions.TemplateSyntaxError,
jinja2.exceptions.UndefinedError) as e:
dbt.exceptions.raise_compiler_error(node, str(e))
132 changes: 34 additions & 98 deletions dbt/compilation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import os
import fnmatch
import jinja2
from collections import defaultdict
import time
import sqlparse
Expand All @@ -11,15 +9,15 @@
from dbt.model import Model, NodeType
from dbt.source import Source
from dbt.utils import find_model_by_fqn, find_model_by_name, \
split_path, This, Var, compiler_error
split_path, This, Var, is_enabled, get_materialization

from dbt.linker import Linker
from dbt.runtime import RuntimeContext

import dbt.compat
import dbt.contracts.graph.compiled
import dbt.contracts.graph.parsed
import dbt.contracts.project
import dbt.exceptions
import dbt.flags
import dbt.parser
import dbt.templates
Expand Down Expand Up @@ -59,17 +57,6 @@ def compile_and_print_status(project, args):
logger.info("Compiled {}".format(stat_line))


def compile_string(string, ctx):
try:
env = jinja2.Environment()
template = env.from_string(dbt.compat.to_string(string), globals=ctx)
return template.render(ctx)
except jinja2.exceptions.TemplateSyntaxError as e:
compiler_error(None, str(e))
except jinja2.exceptions.UndefinedError as e:
compiler_error(None, str(e))


def prepend_ctes(model, all_models):
model, _, all_models = recursively_prepend_ctes(model, all_models)

Expand Down Expand Up @@ -204,39 +191,24 @@ def do_ref(*args):
elif len(args) == 2:
target_model_package, target_model_name = args
else:
compiler_error(
model,
"ref() takes at most two arguments ({} given)".format(
len(args)
)
)
dbt.exceptions.ref_invalid_args(model, args)

target_model = dbt.utils.find_model_by_name(
all_models,
target_model_name,
target_model_package)

if target_model is None:
compiler_error(
model,
"Model '{}' depends on model '{}' which was not found."
.format(model.get('unique_id'), target_model_name))
dbt.exceptions.ref_target_not_found(model, target_model_name)

target_model_id = target_model.get('unique_id')
if is_enabled(model) and not is_enabled(target_model):
dbt.exceptions.ref_disabled_dependency(model, target_model)

if target_model.get('config', {}).get('enabled') is False and \
model.get('config', {}).get('enabled') is True:
compiler_error(
model,
"Model '{}' depends on model '{}' which is disabled in "
"the project config".format(model.get('unique_id'),
target_model.get('unique_id')))
target_model_id = target_model.get('unique_id')

model['depends_on'].append(target_model_id)

if target_model.get('config', {}) \
.get('materialized') == 'ephemeral':

if get_materialization(target_model) == 'ephemeral':
model['extra_cte_ids'].append(target_model_id)
return '__dbt__CTE__{}'.format(target_model.get('name'))
else:
Expand Down Expand Up @@ -295,6 +267,8 @@ def get_compiler_context(self, linker, model, models,
return context

def get_context(self, linker, model, models):
# THIS IS STILL USED FOR WRAPPING, BUT SHOULD GO AWAY
# - Connor
runtime = RuntimeContext(model=model)

context = self.project.context()
Expand All @@ -321,30 +295,26 @@ def get_context(self, linker, model, models):

def compile_node(self, linker, node, nodes, macro_generator):
logger.debug("Compiling {}".format(node.get('unique_id')))
try:
compiled_node = node.copy()
compiled_node.update({
'compiled': False,
'compiled_sql': None,
'extra_ctes_injected': False,
'extra_cte_ids': [],
'extra_cte_sql': [],
'injected_sql': None,
})

context = self.get_compiler_context(linker, compiled_node, nodes,
macro_generator)

env = jinja2.sandbox.SandboxedEnvironment()

compiled_node['compiled_sql'] = env.from_string(
node.get('raw_sql')).render(context)

compiled_node['compiled'] = True
except jinja2.exceptions.TemplateSyntaxError as e:
compiler_error(node, str(e))
except jinja2.exceptions.UndefinedError as e:
compiler_error(node, str(e))

compiled_node = node.copy()
compiled_node.update({
'compiled': False,
'compiled_sql': None,
'extra_ctes_injected': False,
'extra_cte_ids': [],
'extra_cte_sql': [],
'injected_sql': None,
})

context = self.get_compiler_context(linker, compiled_node, nodes,
macro_generator)

compiled_node['compiled_sql'] = dbt.clients.jinja.get_rendered(
node.get('raw_sql'),
context,
node)

compiled_node['compiled'] = True

return compiled_node

Expand All @@ -353,34 +323,6 @@ def write_graph_file(self, linker):
graph_path = os.path.join(self.project['target-path'], filename)
linker.write_graph(graph_path)

def new_add_cte_to_rendered_query(self, linker, primary_model,
compiled_models):

fqn_to_model = {tuple(model.fqn): model for model in compiled_models}
sorted_nodes = linker.as_topological_ordering()

models_to_add = self.__recursive_add_ctes(linker, primary_model)

required_ctes = []
for node in sorted_nodes:

if node not in fqn_to_model:
continue

model = fqn_to_model[node]
# add these in topological sort order -- significant for CTEs
if model.is_ephemeral and model in models_to_add:
required_ctes.append(model)

query = compiled_models[primary_model]
if len(required_ctes) == 0:
return query
else:
compiled_query = self.combine_query_with_ctes(
primary_model, query, required_ctes, compiled_models
)
return compiled_query

def compile_nodes(self, linker, nodes, macro_generator):
all_projects = self.get_all_projects()

Expand Down Expand Up @@ -440,8 +382,7 @@ def compile_nodes(self, linker, nodes, macro_generator):

if injected_node.get('resource_type') in (NodeType.Model,
NodeType.Analysis) and \
injected_node.get('config', {}) \
.get('materialized') != 'ephemeral':
get_materialization(injected_node) != 'ephemeral':
self.__write(build_path, injected_node.get('wrapped_sql'))
written_nodes.append(injected_node)
injected_node['build_path'] = build_path
Expand All @@ -459,10 +400,7 @@ def compile_nodes(self, linker, nodes, macro_generator):
injected_node.get('unique_id'),
compiled_nodes.get(dependency).get('unique_id'))
else:
compiler_error(
model,
"dependency {} not found in graph!".format(
dependency))
dbt.exceptions.dependency_not_found(model, dependency)

cycle = linker.find_cycles()

Expand Down Expand Up @@ -586,8 +524,8 @@ def compile(self):

for project in dbt.utils.dependency_projects(self.project):
all_macros.extend(
self.get_macros(this_project=self.project, own_project=project)
)
self.get_macros(this_project=self.project,
own_project=project))

macro_generator = self.generate_macros(all_macros)

Expand All @@ -597,8 +535,6 @@ def compile(self):
compiled_nodes, written_nodes = self.compile_nodes(linker, all_nodes,
macro_generator)

# TODO re-add archives

self.write_graph_file(linker)

stats = {}
Expand Down
53 changes: 53 additions & 0 deletions dbt/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from dbt.compat import basestring


class Exception(BaseException):
pass

Expand All @@ -10,6 +13,10 @@ class ValidationException(RuntimeException):
pass


class CompilationException(RuntimeException):
pass


class NotImplementedException(Exception):
pass

Expand All @@ -20,3 +27,49 @@ class ProgrammingException(Exception):

class FailedToConnectException(Exception):
pass


def raise_compiler_error(node, msg):
name = '<Unknown>'

if node is None:
name = '<None>'
elif isinstance(node, basestring):
name = node
elif isinstance(node, dict):
name = node.get('name')
else:
name = node.nice_name

raise CompilationException(
"! Compilation error while compiling model {}:\n! {}\n"
.format(name, msg))


def ref_invalid_args(model, args):
raise_compiler_error(
model,
"ref() takes at most two arguments ({} given)".format(
len(args)))


def ref_target_not_found(model, target_model_name):
raise_compiler_error(
model,
"Model '{}' depends on model '{}' which was not found."
.format(model.get('unique_id'), target_model_name))


def ref_disabled_dependency(model, target_model):
raise_compiler_error(
model,
"Model '{}' depends on model '{}' which is disabled in "
"the project config".format(model.get('unique_id'),
target_model.get('unique_id')))


def dependency_not_found(model, target_model_name):
raise_compiler_error(
model,
"'{}' depends on '{}' which is not in the graph!"
.format(model.get('unique_id'), target_model_name))
Loading

0 comments on commit 1a101ad

Please sign in to comment.