Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support the basic dbt 1.6 metric #870

Merged
merged 9 commits into from
Sep 1, 2023
253 changes: 253 additions & 0 deletions piperider_cli/dbtutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
DbtProfileInvalidError, \
DbtProfileBigQueryAuthWithTokenUnsupportedError, DbtRunTimeError
from piperider_cli.metrics_engine import Metric
from piperider_cli.metrics_engine.metrics import SemanticModel
from piperider_cli.statistics import Statistics

console = Console()
Expand Down Expand Up @@ -350,6 +351,258 @@
return metrics


def is_dbt_schema_version_16(manifest: Dict):
# dbt_schema_version: 'https://schemas.getdbt.com/dbt/manifest/v10.json'
schema_version = manifest['metadata'].get('dbt_schema_version').split('/')[-1]
version = schema_version.split('.')[0]
return int(version[1:]) >= 10


def load_metric_jinja_string_template(value: str):
from jinja2 import Environment, BaseLoader
env = Environment(loader=BaseLoader())

def dimension(var):
return var

env.globals['Dimension'] = dimension
template = env.from_string(value)

return template


def get_support_time_grains(grain: str):
all_time_grains = ['day', 'week', 'month', 'quarter', 'year']
available_time_grains = all_time_grains[all_time_grains.index(grain):]
support_time_grains = ['day', 'month', 'year']

return [x for x in support_time_grains if x in available_time_grains]


def find_derived_time_grains(manifest: Dict, metric: Dict):
nodes = metric.get('depends_on').get('nodes', [])
depends_on = nodes[0]
if depends_on.startswith('semantic_model.'):
semantic_model = manifest.get('semantic_models').get(depends_on)
measure = None
for obj in semantic_model.get('measures'):
if obj.get('name') == metric.get('type_params').get('measure').get('name'):
measure = obj
break

agg_time_dimension = measure.get('agg_time_dimension') if measure.get(
'agg_time_dimension') else semantic_model.get('defaults').get('agg_time_dimension')

time_grains = None
if agg_time_dimension is not None:
# find dimension definition - time
for obj in semantic_model.get('dimensions'):
if obj.get('name') == agg_time_dimension:
grain = obj.get('type_params').get('time_granularity')
time_grains = get_support_time_grains(grain)
break

return time_grains


def get_dbt_state_metrics_16(dbt_state_dir: str, dbt_tag: Optional[str] = None, dbt_resources: Optional[dict] = None):
manifest = _get_state_manifest(dbt_state_dir)

if not is_dbt_schema_version_16(manifest):
console.print("[[bold yellow]Skip[/bold yellow]] Metric query is not supported for dbt < 1.6")
return []

def is_chosen(key, metric):
statistics = Statistics()
if dbt_resources:
chosen = key in dbt_resources['metrics']
if not chosen:
statistics.add_field_one('filter')
return chosen

Check warning on line 421 in piperider_cli/dbtutil.py

View check run for this annotation

Codecov / codecov/patch

piperider_cli/dbtutil.py#L418-L421

Added lines #L418 - L421 were not covered by tests
if dbt_tag is not None:
chosen = dbt_tag in metric.get('tags')
if not chosen:
statistics.add_field_one('notag')
return chosen

Check warning on line 426 in piperider_cli/dbtutil.py

View check run for this annotation

Codecov / codecov/patch

piperider_cli/dbtutil.py#L423-L426

Added lines #L423 - L426 were not covered by tests
return True

metrics = []
metric_map = {}

for key, metric in manifest.get('metrics').items():
metric_map[metric.get('name')] = metric

def _create_metric(name, filter=None, alias=None):
statistics = Statistics()
metric = metric_map.get(name)

if metric.get('type') == 'simple':
primary_entity = None
metric_filter = []
if metric.get('filter') is not None:
sql_filter = load_metric_jinja_string_template(metric.get('filter').get('where_sql_template')).render()
metric_filter.append({'field': sql_filter.split(' ')[0],
'operator': sql_filter.split(' ')[1],
'value': sql_filter.split(' ')[2]})
if filter is not None:
sql_filter = load_metric_jinja_string_template(filter.get('where_sql_template')).render()
metric_filter.append({'field': sql_filter.split(' ')[0],
'operator': sql_filter.split(' ')[1],
'value': sql_filter.split(' ')[2]})

nodes = metric.get('depends_on').get('nodes', [])
depends_on = nodes[0]
if depends_on.startswith('semantic_model.'):
semantic_model = manifest.get('semantic_models').get(depends_on)
table = semantic_model.get('node_relation').get('alias')
schema = semantic_model.get('node_relation').get('schema_name')
database = semantic_model.get('node_relation').get('database')
# find measure definition
measure = None
for obj in semantic_model.get('measures'):
if obj.get('name') == metric.get('type_params').get('measure').get('name'):
measure = obj
break
# TODO: remove assertion
assert measure is not None, 'Measure not found'
expression = measure.get('expr') if measure.get('expr') is not None else measure.get('name')
calculation_method = measure.get('agg')

for entity in semantic_model.get('entities'):
if entity.get('type') == 'primary':
primary_entity = entity.get('name')
break

agg_time_dimension = measure.get('agg_time_dimension') if measure.get(
'agg_time_dimension') else semantic_model.get('defaults').get('agg_time_dimension')
timestamp = None
time_grains = None
if agg_time_dimension is not None:
# find dimension definition - time
for obj in semantic_model.get('dimensions'):
if obj.get('name') == agg_time_dimension:
timestamp = obj.get('expr')
grain = obj.get('type_params').get('time_granularity')
time_grains = get_support_time_grains(grain)
break

if metric.get('filter') is not None:
# find dimension definition - categorical
for obj in semantic_model.get('dimensions'):
if obj.get('name') in metric_filter[0]['field']:
expr = obj.get('expr') or obj.get('name')
metric_filter[0]['field'] = metric_filter[0]['field'].replace(obj.get('name'), expr)
break
else:
# TODO: remove assertion
assert False, 'Simple type metric should depend on semantic model.'

Check warning on line 498 in piperider_cli/dbtutil.py

View check run for this annotation

Codecov / codecov/patch

piperider_cli/dbtutil.py#L498

Added line #L498 was not covered by tests

model = SemanticModel(metric.get('name'), table, schema, database, expression, timestamp,
filters=metric_filter)

m = Metric(metric.get('name'), model=model, calculation_method=calculation_method, time_grains=time_grains,
label=metric.get('label'), description=metric.get('description'),
ref_id=metric.get('unique_id'))

for f in m.model.filters:
if primary_entity in f['field']:
f['field'] = f['field'].replace(f'{primary_entity}__', '')
else:
console.print(

Check warning on line 511 in piperider_cli/dbtutil.py

View check run for this annotation

Codecov / codecov/patch

piperider_cli/dbtutil.py#L511

Added line #L511 was not covered by tests
f"[[bold yellow]Skip[/bold yellow]] Metric '{metric.get('name')}'. "
f"Dimension of foreign entities is not supported.")
statistics.add_field_one('nosupport')
return None

Check warning on line 515 in piperider_cli/dbtutil.py

View check run for this annotation

Codecov / codecov/patch

piperider_cli/dbtutil.py#L514-L515

Added lines #L514 - L515 were not covered by tests
if m.calculation_method == 'median':
console.print(

Check warning on line 517 in piperider_cli/dbtutil.py

View check run for this annotation

Codecov / codecov/patch

piperider_cli/dbtutil.py#L517

Added line #L517 was not covered by tests
f"[[bold yellow]Skip[/bold yellow]] Metric '{metric.get('name')}'. "
f"Aggregation type 'median' is not supported.")
statistics.add_field_one('nosupport')
return None

Check warning on line 521 in piperider_cli/dbtutil.py

View check run for this annotation

Codecov / codecov/patch

piperider_cli/dbtutil.py#L520-L521

Added lines #L520 - L521 were not covered by tests

return m
elif metric.get('type') == 'derived':
ref_metrics = []
time_grains = ['day', 'month', 'year']
expr = metric.get('type_params').get('expr')
for ref_metric in metric.get('type_params').get('metrics'):
if ref_metric.get('offset_window') is not None:
console.print(

Check warning on line 530 in piperider_cli/dbtutil.py

View check run for this annotation

Codecov / codecov/patch

piperider_cli/dbtutil.py#L530

Added line #L530 was not covered by tests
f"[[bold yellow]Skip[/bold yellow]] Metric '{metric.get('name')}'. "
f"Derived metric property 'offset_window' is not supported.")
statistics.add_field_one('nosupport')
return None

Check warning on line 534 in piperider_cli/dbtutil.py

View check run for this annotation

Codecov / codecov/patch

piperider_cli/dbtutil.py#L533-L534

Added lines #L533 - L534 were not covered by tests
m2 = _create_metric(
ref_metric.get('name'),
filter=ref_metric.get('filter'),
alias=ref_metric.get('alias'))
ref_metrics.append(m2)

derived_time_grains = find_derived_time_grains(manifest, metric_map[ref_metric.get('name')])
if len(time_grains) < len(derived_time_grains):
time_grains = derived_time_grains

Check warning on line 543 in piperider_cli/dbtutil.py

View check run for this annotation

Codecov / codecov/patch

piperider_cli/dbtutil.py#L543

Added line #L543 was not covered by tests

if ref_metric.get('alias') is not None:
expr = expr.replace(ref_metric.get('alias'), ref_metric.get('name'))

Check warning on line 546 in piperider_cli/dbtutil.py

View check run for this annotation

Codecov / codecov/patch

piperider_cli/dbtutil.py#L546

Added line #L546 was not covered by tests

m = Metric(metric.get('name'),
calculation_method='derived',
expression=expr,
time_grains=time_grains,
label=metric.get('label'), description=metric.get('description'), ref_metrics=ref_metrics,
ref_id=metric.get('unique_id'))
return m
elif metric.get('type') == 'ratio':
ref_metrics = []
time_grains = ['day', 'month', 'year']
numerator = metric.get('type_params').get('numerator')
m2 = _create_metric(

Check warning on line 559 in piperider_cli/dbtutil.py

View check run for this annotation

Codecov / codecov/patch

piperider_cli/dbtutil.py#L555-L559

Added lines #L555 - L559 were not covered by tests
numerator.get('name'),
filter=numerator.get('filter'),
alias=numerator.get('alias'))
ref_metrics.append(m2)
derived_time_grains = find_derived_time_grains(manifest, metric_map[numerator.get('name')])
if len(time_grains) < len(derived_time_grains):
time_grains = derived_time_grains

Check warning on line 566 in piperider_cli/dbtutil.py

View check run for this annotation

Codecov / codecov/patch

piperider_cli/dbtutil.py#L563-L566

Added lines #L563 - L566 were not covered by tests

denominator = metric.get('type_params').get('denominator')
m2 = _create_metric(

Check warning on line 569 in piperider_cli/dbtutil.py

View check run for this annotation

Codecov / codecov/patch

piperider_cli/dbtutil.py#L568-L569

Added lines #L568 - L569 were not covered by tests
denominator.get('name'),
filter=denominator.get('filter'),
alias=denominator.get('alias'))
ref_metrics.append(m2)
derived_time_grains = find_derived_time_grains(manifest, metric_map[denominator.get('name')])
if len(time_grains) < len(derived_time_grains):
time_grains = derived_time_grains

Check warning on line 576 in piperider_cli/dbtutil.py

View check run for this annotation

Codecov / codecov/patch

piperider_cli/dbtutil.py#L573-L576

Added lines #L573 - L576 were not covered by tests

m = Metric(metric.get('name'),

Check warning on line 578 in piperider_cli/dbtutil.py

View check run for this annotation

Codecov / codecov/patch

piperider_cli/dbtutil.py#L578

Added line #L578 was not covered by tests
calculation_method='derived',
expression=f"{numerator.get('name')} / {denominator.get('name')}",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: Would be better to pass the calculation_method & numerator & denominator to Metric object. Then generate the SQL accordingly. But now, it is fine.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will put it into another story to enhance

time_grains=time_grains,
label=metric.get('label'), description=metric.get('description'), ref_metrics=ref_metrics,
ref_id=metric.get('unique_id'))
return m

Check warning on line 584 in piperider_cli/dbtutil.py

View check run for this annotation

Codecov / codecov/patch

piperider_cli/dbtutil.py#L584

Added line #L584 was not covered by tests
else:
console.print(

Check warning on line 586 in piperider_cli/dbtutil.py

View check run for this annotation

Codecov / codecov/patch

piperider_cli/dbtutil.py#L586

Added line #L586 was not covered by tests
f"[[bold yellow]Skip[/bold yellow]] Metric '{metric.get('name')}'. "
f"Metric type 'Cumulative' is not supported.")
statistics.add_field_one('nosupport')
return None

Check warning on line 590 in piperider_cli/dbtutil.py

View check run for this annotation

Codecov / codecov/patch

piperider_cli/dbtutil.py#L589-L590

Added lines #L589 - L590 were not covered by tests

for key, metric in manifest.get('metrics').items():
statistics = Statistics()
statistics.add_field_one('total')

if not is_chosen(key, metric):
continue

Check warning on line 597 in piperider_cli/dbtutil.py

View check run for this annotation

Codecov / codecov/patch

piperider_cli/dbtutil.py#L597

Added line #L597 was not covered by tests

m = _create_metric(metric.get('name'))
if m is not None:
metrics.append(m)

return metrics


def check_dbt_manifest(dbt_state_dir: str) -> bool:
path = os.path.join(dbt_state_dir, 'manifest.json')
if os.path.isabs(path) is False:
Expand Down
59 changes: 34 additions & 25 deletions piperider_cli/metrics_engine/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
return value


class Metric:
class SemanticModel:
def __init__(
self,
name,
Expand All @@ -34,27 +34,38 @@
database,
expression,
timestamp,
calculation_method,
time_grains=None,
dimensions=None,
filters=None,
label=None,
description=None,
ref_id=None,
):
self.name = name
self.table = table
self.database = database
self.schema = schema.lower() if schema is not None else None
self.expression = expression
self.timestamp = timestamp
self.filters = filters


class Metric:
def __init__(
self,
name: str,
model: SemanticModel = None,
calculation_method=None,
time_grains=None,
expression: str = None,
label=None,
description=None,
ref_metrics=None,
ref_id=None,
):
self.name = name
self.model = model
self.calculation_method = calculation_method
self.time_grains = time_grains
self.dimensions = dimensions
self.filters = filters
self.expression = expression
self.label = label
self.description = description
self.ref_metrics: List[Metric] = []
self.ref_metrics: List[Metric] = ref_metrics or []
self.ref_id = ref_id


Expand Down Expand Up @@ -82,12 +93,8 @@
for grain in metric.time_grains:
if grain not in ['day', 'week', 'month', 'quarter', 'year']:
continue
if not metric.dimensions:
yield grain, []
else:
for r in range(1, len(metric.dimensions) + 1):
for dims in itertools.combinations(metric.dimensions, r):
yield grain, list(dims)

yield grain, []

Check warning on line 97 in piperider_cli/metrics_engine/metrics.py

View check run for this annotation

Codecov / codecov/patch

piperider_cli/metrics_engine/metrics.py#L97

Added line #L97 was not covered by tests

@staticmethod
def _compose_query_name(grain: str, dimensions: List[str], label=False) -> str:
Expand Down Expand Up @@ -134,12 +141,13 @@
)
else:
# Source model
model = metric.model

Check warning on line 144 in piperider_cli/metrics_engine/metrics.py

View check run for this annotation

Codecov / codecov/patch

piperider_cli/metrics_engine/metrics.py#L144

Added line #L144 was not covered by tests
if self.data_source.type_name == 'bigquery':
source_model = text(f"`{metric.database}.{metric.schema}.{metric.table}`")
source_model = text(f"`{model.database}.{model.schema}.{model.table}`")

Check warning on line 146 in piperider_cli/metrics_engine/metrics.py

View check run for this annotation

Codecov / codecov/patch

piperider_cli/metrics_engine/metrics.py#L146

Added line #L146 was not covered by tests
elif self.data_source.type_name == 'databricks':
source_model = text(f"{metric.schema}.{metric.table}")
source_model = text(f"{model.schema}.{model.table}")

Check warning on line 148 in piperider_cli/metrics_engine/metrics.py

View check run for this annotation

Codecov / codecov/patch

piperider_cli/metrics_engine/metrics.py#L148

Added line #L148 was not covered by tests
else:
source_model = text(f"{metric.database}.{metric.schema}.{metric.table}")
source_model = text(f"{model.database}.{model.schema}.{model.table}")

Check warning on line 150 in piperider_cli/metrics_engine/metrics.py

View check run for this annotation

Codecov / codecov/patch

piperider_cli/metrics_engine/metrics.py#L150

Added line #L150 was not covered by tests

# Base model
# 1. map expression to 'c'
Expand All @@ -149,16 +157,16 @@
start_date = self.date_trunc(grain, func.current_date()) - self._interval(grain,
self._slot_count_by_grain(grain))
stmt = select(
literal_column(metric.expression).label('c'),
self.date_trunc(grain, func.cast(literal_column(metric.timestamp), Date)).label('d'),
literal_column(model.expression).label('c'),
self.date_trunc(grain, func.cast(literal_column(model.timestamp), Date)).label('d'),
).select_from(
source_model
).where(
func.cast(literal_column(metric.timestamp), Date) >= start_date
func.cast(literal_column(model.timestamp), Date) >= start_date
)
for f in metric.filters:
for f in model.filters:

Check warning on line 167 in piperider_cli/metrics_engine/metrics.py

View check run for this annotation

Codecov / codecov/patch

piperider_cli/metrics_engine/metrics.py#L167

Added line #L167 was not covered by tests
stmt = stmt.where(text(f"{f.get('field')} {f.get('operator')} {f.get('value')}"))
base_model = stmt.cte(f"{metric.name}_base_model")
base_model = stmt.cte(f"{model.name}_base_model")

Check warning on line 169 in piperider_cli/metrics_engine/metrics.py

View check run for this annotation

Codecov / codecov/patch

piperider_cli/metrics_engine/metrics.py#L169

Added line #L169 was not covered by tests

# Aggregated model
# 1. select 'd'
Expand Down Expand Up @@ -304,7 +312,8 @@

total_param = len(list(self._get_query_param(metric)))
completed_param = 0
engine = self.data_source.get_engine_by_database(metric.database)
engine = self.data_source.get_engine_by_database(

Check warning on line 315 in piperider_cli/metrics_engine/metrics.py

View check run for this annotation

Codecov / codecov/patch

piperider_cli/metrics_engine/metrics.py#L315

Added line #L315 was not covered by tests
metric.model.database if metric.model is not None else None)

query_results = {}
futures = []
Expand Down
Loading
Loading