Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add macros to flat graph #332

Merged
merged 12 commits into from
Mar 17, 2017
68 changes: 46 additions & 22 deletions dbt/clients/jinja.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,36 +4,56 @@
import jinja2
import jinja2.sandbox

from dbt.utils import NodeType

class SilentUndefined(jinja2.Undefined):
"""
This class sets up the parser to just ignore undefined jinja2 calls. So,
for example, `env` is not defined here, but will not make the parser fail
with a fatal error.
"""
def _fail_with_undefined_error(self, *args, **kwargs):
return None

__add__ = __radd__ = __mul__ = __rmul__ = __div__ = __rdiv__ = \
__truediv__ = __rtruediv__ = __floordiv__ = __rfloordiv__ = \
__mod__ = __rmod__ = __pos__ = __neg__ = __call__ = \
__getitem__ = __lt__ = __le__ = __gt__ = __ge__ = __int__ = \
__float__ = __complex__ = __pow__ = __rpow__ = \
_fail_with_undefined_error
def create_macro_capture_env(node):

class ParserMacroCapture(jinja2.Undefined):
"""
This class sets up the parser to capture macros.
"""
def __init__(self, hint=None, obj=None, name=None,
exc=None):
super(jinja2.Undefined, self).__init__()

env = jinja2.sandbox.SandboxedEnvironment()
self.node = node
self.name = name
self.package_name = node.get('package_name')

def __getattr__(self, name):

# jinja uses these for safety, so we have to override them.
# see https://github.com/pallets/jinja/blob/master/jinja2/sandbox.py#L332-L339 # noqa
if name in ['unsafe_callable', 'alters_data']:
return False

self.package_name = self.name
self.name = name

return self

def __call__(self, *args, **kwargs):
path = '{}.{}.{}'.format(NodeType.Macro,
self.package_name,
self.name)

silent_on_undefined_env = jinja2.sandbox.SandboxedEnvironment(
undefined=SilentUndefined)
if path not in self.node['depends_on']['macros']:
self.node['depends_on']['macros'].append(path)

return jinja2.sandbox.SandboxedEnvironment(
undefined=ParserMacroCapture)

def get_template(string, ctx, node=None, silent_on_undefined=False):

env = jinja2.sandbox.SandboxedEnvironment()


def get_template(string, ctx, node=None, capture_macros=False):
try:
local_env = env

if silent_on_undefined:
local_env = silent_on_undefined_env
if capture_macros is True:
local_env = create_macro_capture_env(node)

return local_env.from_string(dbt.compat.to_string(string), globals=ctx)

Expand All @@ -42,11 +62,15 @@ def get_template(string, ctx, node=None, silent_on_undefined=False):
dbt.exceptions.raise_compiler_error(node, str(e))


def get_rendered(string, ctx, node=None, silent_on_undefined=False):
def render_template(template, ctx, node=None):
try:
template = get_template(string, ctx, node, silent_on_undefined)
return template.render(ctx)

except (jinja2.exceptions.TemplateSyntaxError,
jinja2.exceptions.UndefinedError) as e:
dbt.exceptions.raise_compiler_error(node, str(e))


def get_rendered(string, ctx, node=None, capture_macros=False):
template = get_template(string, ctx, node, capture_macros)
return render_template(template, ctx, node=None)
82 changes: 47 additions & 35 deletions dbt/compilation.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import os
from collections import OrderedDict
from collections import OrderedDict, defaultdict
import sqlparse

import dbt.project
import dbt.utils

from dbt.model import Model, NodeType
from dbt.source import Source
from dbt.utils import This, Var, is_enabled, get_materialization
from dbt.model import Model
from dbt.utils import This, Var, is_enabled, get_materialization, NodeType

from dbt.linker import Linker
from dbt.runtime import RuntimeContext
Expand All @@ -30,6 +29,45 @@
graph_file_name = 'graph.gpickle'


def recursively_parse_macros_for_node(node, flat_graph, context):
# this once worked, but is now long dead
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hah

# for unique_id in node.get('depends_on', {}).get('macros'):
# TODO: make it so that we only parse the necessary macros for any node.

for unique_id, macro in flat_graph.get('macros').items():
if macro is None:
dbt.exceptions.macro_not_found(node, unique_id)

name = macro.get('name')
package_name = macro.get('package_name')

if context.get(package_name, {}).get(name) is not None:
# we've already re-parsed this macro and added it to
# the context.
continue

reparsed = dbt.parser.parse_macro_file(
macro_file_path=macro.get('path'),
macro_file_contents=macro.get('raw_sql'),
root_path=macro.get('root_path'),
package_name=package_name,
context=context)

for unique_id, macro in reparsed.items():
macro_map = {macro.get('name'): macro.get('parsed_macro')}

if context.get(package_name) is None:
context[package_name] = {}

context.get(package_name, {}) \
.update(macro_map)

if package_name == node.get('package_name'):
context.update(macro_map)

return context


def compile_and_print_status(project, args):
compiler = Compiler(project, args)
compiler.initialize()
Expand Down Expand Up @@ -245,32 +283,8 @@ def get_compiler_context(self, linker, model, flat_graph):
context['invocation_id'] = '{{ invocation_id }}'
context['sql_now'] = adapter.date_function

for unique_id, macro in flat_graph.get('macros').items():
name = macro.get('name')
package_name = macro.get('package_name')

if context.get(package_name, {}).get(name) is not None:
# we've already re-parsed this macro and added it to
# the context.
continue

reparsed = dbt.parser.parse_macro_file(
macro_file_path=macro.get('path'),
macro_file_contents=macro.get('raw_sql'),
root_path=macro.get('root_path'),
package_name=package_name)

for unique_id, macro in reparsed.items():
macro_map = {macro.get('name'): macro.get('parsed_macro')}

if context.get(package_name) is None:
context[package_name] = {}

context.get(package_name, {}) \
.update(macro_map)

if package_name == model.get('package_name'):
context.update(macro_map)
context = recursively_parse_macros_for_node(
model, flat_graph, context)

return context

Expand Down Expand Up @@ -566,14 +580,12 @@ def compile(self):

self.write_graph_file(linker)

stats = {}
stats = defaultdict(int)

for node_name, node in compiled_graph.get('nodes').items():
stats[node.get('resource_type')] = stats.get(
node.get('resource_type'), 0) + 1
stats[node.get('resource_type')] += 1

for node_name, node in compiled_graph.get('macros').items():
stats[node.get('resource_type')] = stats.get(
node.get('resource_type'), 0) + 1
stats[node.get('resource_type')] += 1

return stats
7 changes: 6 additions & 1 deletion dbt/contracts/graph/parsed.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import jinja2.runtime

from dbt.compat import basestring
from dbt.model import NodeType
from dbt.utils import NodeType

from dbt.contracts.common import validate_with
from dbt.contracts.graph.unparsed import unparsed_node_contract, \
Expand Down Expand Up @@ -53,6 +53,11 @@
Required('unique_id'): All(basestring, Length(min=1, max=255)),
Required('tags'): All(set),

# parsed fields
Required('depends_on'): {
Required('macros'): [All(basestring, Length(min=1, max=255))],
},

# contents
Required('parsed_macro'): jinja2.runtime.Macro

Expand Down
2 changes: 1 addition & 1 deletion dbt/contracts/graph/unparsed.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dbt.compat import basestring
from dbt.contracts.common import validate_with

from dbt.model import NodeType
from dbt.utils import NodeType

unparsed_base_contract = Schema({
# identifiers
Expand Down
7 changes: 7 additions & 0 deletions dbt/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,10 @@ def dependency_not_found(model, target_model_name):
model,
"'{}' depends on '{}' which is not in the graph!"
.format(model.get('unique_id'), target_model_name))


def macro_not_found(model, target_macro_id):
raise_compiler_error(
model,
"'{}' references macro '{}' which is not defined!"
.format(model.get('unique_id'), target_macro_id))
6 changes: 2 additions & 4 deletions dbt/graph/selector.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@

# import dbt.utils.compiler_error
import networkx as nx
from dbt.logger import GLOBAL_LOGGER as logger

import dbt.model

from dbt.utils import NodeType

SELECTOR_PARENTS = '+'
SELECTOR_CHILDREN = '+'
Expand Down Expand Up @@ -130,7 +128,7 @@ def get_nodes_from_spec(project, graph, spec):
# they'll be filtered out later.
child_tests = [n for n in graph.successors(node)
if graph.node.get(n).get('resource_type') ==
dbt.model.NodeType.Test]
NodeType.Test]
test_nodes.update(child_tests)

return model_nodes | test_nodes
Expand Down
4 changes: 2 additions & 2 deletions dbt/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections import defaultdict

import dbt.compilation
import dbt.model
from dbt.utils import NodeType


def from_file(graph_file):
Expand Down Expand Up @@ -65,7 +65,7 @@ def is_blocking_dependency(self, node_data):
if 'dbt_run_type' not in node_data or 'materialized' not in node_data:
return False

return node_data['dbt_run_type'] == dbt.model.NodeType.Model \
return node_data['dbt_run_type'] == NodeType.Model \
and node_data['materialized'] != 'ephemeral'

def as_dependency_list(self, limit_to=None):
Expand Down
17 changes: 4 additions & 13 deletions dbt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,14 @@
from dbt.adapters.factory import get_adapter
from dbt.compat import basestring

import dbt.clients.jinja
import dbt.flags

from dbt.templates import BaseCreateTemplate, ArchiveInsertTemplate
from dbt.utils import split_path
from dbt.templates import BaseCreateTemplate
from dbt.utils import split_path, NodeType
import dbt.project
from dbt.utils import deep_merge, DBTConfigKeys, compiler_error, \
compiler_warning

from dbt.utils import deep_merge, DBTConfigKeys, compiler_error

class NodeType(object):
Base = 'base'
Model = 'model'
Analysis = 'analysis'
Test = 'test'
Archive = 'archive'
Macro = 'macro'
import dbt.clients.jinja


class SourceConfig(object):
Expand Down
Loading