Skip to content

Commit

Permalink
Add annotations to the code
Browse files Browse the repository at this point in the history
  • Loading branch information
Fokko committed Dec 5, 2019
1 parent ace777e commit 9bb3c69
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 36 deletions.
10 changes: 5 additions & 5 deletions core/dbt/adapters/base/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ class SchemaSearchMap(dict):
"""A utility class to keep track of what information_schema tables to
search for what schemas
"""
def add(self, relation):
def add(self, relation: BaseRelation):
key = relation.information_schema_only()
if key not in self:
self[key] = set()
Expand Down Expand Up @@ -225,7 +225,7 @@ def clear_transaction(self) -> None:
def commit_if_has_connection(self) -> None:
self.connections.commit_if_has_connection()

def nice_connection_name(self):
def nice_connection_name(self) -> str:
conn = self.connections.get_if_exists()
if conn is None or conn.name is None:
return '<None>'
Expand All @@ -234,7 +234,7 @@ def nice_connection_name(self):
@contextmanager
def connection_named(
self, name: str, node: Optional[CompileResultNode] = None
):
) -> Connection:
try:
self.connections.query_header.set(name, node)
conn = self.acquire_connection(name)
Expand Down Expand Up @@ -306,7 +306,7 @@ def load_internal_manifest(self) -> Manifest:
###
# Caching methods
###
def _schema_is_cached(self, database: str, schema: str):
def _schema_is_cached(self, database: str, schema: str) -> bool:
"""Check if the schema is cached, and by default logs if it is not."""

if dbt.flags.USE_CACHE is False:
Expand Down Expand Up @@ -345,7 +345,7 @@ def _get_cache_schemas(

def _relations_cache_for_schemas(self, manifest: Manifest) -> None:
"""Populate the relations cache for the given schemas. Returns an
iteratble of the schemas populated, as strings.
iterable of the schemas populated, as strings.
"""
if not dbt.flags.USE_CACHE:
return
Expand Down
24 changes: 14 additions & 10 deletions core/dbt/adapters/cache.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
from collections import namedtuple
import threading
from copy import deepcopy
from typing import Optional

from dbt.adapters.cache import _CachedRelation
from dbt.logger import CACHE_LOGGER as logger
import dbt.exceptions

from core.dbt.adapters.factory import BaseRelation

_ReferenceKey = namedtuple('_ReferenceKey', 'database schema identifier')


def _lower(value):
def _lower(value: Optional[str]) -> Optional[str]:
"""Postgres schemas can be None so we can't just call lower()."""
if value is None:
return None
return value.lower()


def _make_key(relation):
def _make_key(relation: BaseRelation) -> _ReferenceKey:
"""Make _ReferenceKeys with lowercase values for the cache so we don't have
to keep track of quoting
"""
Expand All @@ -24,10 +28,10 @@ def _make_key(relation):
_lower(relation.identifier))


def dot_separated(key):
def dot_separated(key: _ReferenceKey) -> str:
"""Return the key in dot-separated string form.
:param key _ReferenceKey: The key to stringify.
:param _ReferenceKey key: The key to stringify.
"""
return '.'.join(map(str, key))

Expand All @@ -41,25 +45,25 @@ class _CachedRelation:
that refer to this relation.
:attr BaseRelation inner: The underlying dbt relation.
"""
def __init__(self, inner):
def __init__(self, inner: BaseRelation):
self.referenced_by = {}
self.inner = inner

def __str__(self):
def __str__(self) -> str:
return (
'_CachedRelation(database={}, schema={}, identifier={}, inner={})'
).format(self.database, self.schema, self.identifier, self.inner)

@property
def database(self):
def database(self) -> Optional[str]:
return _lower(self.inner.database)

@property
def schema(self):
def schema(self) -> Optional[str]:
return _lower(self.inner.schema)

@property
def identifier(self):
def identifier(self) -> Optional[str]:
return _lower(self.inner.identifier)

def __copy__(self):
Expand All @@ -82,7 +86,7 @@ def key(self):
"""
return _make_key(self)

def add_reference(self, referrer):
def add_reference(self, referrer: _CachedRelation):
"""Add a reference from referrer to self, indicating that if this node
were drop...cascaded, the referrer would be dropped as well.
Expand Down
42 changes: 22 additions & 20 deletions core/dbt/adapters/sql/impl.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import agate
from typing import Any, Optional, Tuple, Type
from typing import Any, Optional, Tuple, Type, List

import dbt.clients.agate_helper
from dbt.contracts.connection import Connection
Expand All @@ -9,6 +9,7 @@
from dbt.adapters.sql import SQLConnectionManager
from dbt.logger import GLOBAL_LOGGER as logger

from core.dbt.adapters.factory import BaseRelation

LIST_RELATIONS_MACRO_NAME = 'list_relations_without_caching'
GET_COLUMNS_IN_RELATION_MACRO_NAME = 'get_columns_in_relation'
Expand Down Expand Up @@ -38,6 +39,7 @@ class SQLAdapter(BaseAdapter):
- list_relations_without_caching
- get_columns_in_relation
"""

ConnectionManager: Type[SQLConnectionManager]
connections: SQLConnectionManager

Expand All @@ -63,35 +65,35 @@ def add_query(
abridge_sql_log)

@classmethod
def convert_text_type(cls, agate_table, col_idx):
def convert_text_type(cls, agate_table: agate.Table, col_idx: int) -> str:
return "text"

@classmethod
def convert_number_type(cls, agate_table, col_idx):
def convert_number_type(cls, agate_table: agate.Table, col_idx: int) -> str:
decimals = agate_table.aggregate(agate.MaxPrecision(col_idx))
return "float8" if decimals else "integer"

@classmethod
def convert_boolean_type(cls, agate_table, col_idx):
def convert_boolean_type(cls, agate_table: agate.Table, col_idx: int) -> str:
return "boolean"

@classmethod
def convert_datetime_type(cls, agate_table, col_idx):
def convert_datetime_type(cls, agate_table: agate.Table, col_idx: int) -> str:
return "timestamp without time zone"

@classmethod
def convert_date_type(cls, agate_table, col_idx):
def convert_date_type(cls, agate_table: agate.Table, col_idx: int) -> str:
return "date"

@classmethod
def convert_time_type(cls, agate_table, col_idx):
def convert_time_type(cls, agate_table: agate.Table, col_idx: int) -> str:
return "time"

@classmethod
def is_cancelable(cls):
def is_cancelable(cls) -> bool:
return True

def expand_column_types(self, goal, current):
def expand_column_types(self, goal, current, model_name: Optional[str] = None):
reference_columns = {
c.name: c for c in
self.get_columns_in_relation(goal)
Expand All @@ -114,7 +116,7 @@ def expand_column_types(self, goal, current):

self.alter_column_type(current, column_name, new_type)

def alter_column_type(self, relation, column_name, new_column_type):
def alter_column_type(self, relation, column_name, new_column_type) -> None:
"""
1. Create a new column (w/ temp name and correct type)
2. Copy data over to it
Expand All @@ -131,7 +133,7 @@ def alter_column_type(self, relation, column_name, new_column_type):
kwargs=kwargs
)

def drop_relation(self, relation):
def drop_relation(self, relation, model_name: Optional[str] = None):
if relation.type is None:
dbt.exceptions.raise_compiler_error(
'Tried to drop relation {}, but its type is null.'
Expand All @@ -143,13 +145,13 @@ def drop_relation(self, relation):
kwargs={'relation': relation}
)

def truncate_relation(self, relation):
def truncate_relation(self, relation, model_name: Optional[str] = None):
self.execute_macro(
TRUNCATE_RELATION_MACRO_NAME,
kwargs={'relation': relation}
)

def rename_relation(self, from_relation, to_relation):
def rename_relation(self, from_relation, to_relation, model_name: Optional[str] = None):
self.cache_renamed(from_relation, to_relation)

kwargs = {'from_relation': from_relation, 'to_relation': to_relation}
Expand All @@ -158,13 +160,13 @@ def rename_relation(self, from_relation, to_relation):
kwargs=kwargs
)

def get_columns_in_relation(self, relation):
def get_columns_in_relation(self, relation: str, model_name: Optional[str] = None):
return self.execute_macro(
GET_COLUMNS_IN_RELATION_MACRO_NAME,
kwargs={'relation': relation}
)

def create_schema(self, database, schema):
def create_schema(self, database: str, schema: str, model_name: Optional[str] = None) -> None:
logger.debug('Creating schema "{}"."{}".', database, schema)
kwargs = {
'database_name': self.quote_as_configured(database, 'database'),
Expand All @@ -173,7 +175,7 @@ def create_schema(self, database, schema):
self.execute_macro(CREATE_SCHEMA_MACRO_NAME, kwargs=kwargs)
self.commit_if_has_connection()

def drop_schema(self, database, schema):
def drop_schema(self, database: str, schema: str, model_name: Optional[str] = None) -> None:
logger.debug('Dropping schema "{}"."{}".', database, schema)
kwargs = {
'database_name': self.quote_as_configured(database, 'database'),
Expand All @@ -182,7 +184,7 @@ def drop_schema(self, database, schema):
self.execute_macro(DROP_SCHEMA_MACRO_NAME,
kwargs=kwargs)

def list_relations_without_caching(self, information_schema, schema):
def list_relations_without_caching(self, information_schema, schema, model_name: Optional[str] = None) -> List[BaseRelation]:
kwargs = {'information_schema': information_schema, 'schema': schema}
results = self.execute_macro(
LIST_RELATIONS_MACRO_NAME,
Expand All @@ -209,18 +211,18 @@ def list_relations_without_caching(self, information_schema, schema):
))
return relations

def quote(cls, identifier):
def quote(self, identifier):
return '"{}"'.format(identifier)

def list_schemas(self, database):
def list_schemas(self, database: str, model_name: Optional[str] = None) -> List[str]:
results = self.execute_macro(
LIST_SCHEMAS_MACRO_NAME,
kwargs={'database': database}
)

return [row[0] for row in results]

def check_schema_exists(self, database, schema):
def check_schema_exists(self, database: str, schema: str, model_name: Optional[str] = None) -> bool:
information_schema = self.Relation.create(
database=database,
schema=schema,
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/contracts/graph/parsed.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ def _create_if_else_chain(
key: str,
criteria: List[Tuple[str, Type[JsonSchemaMixin]]],
default: Type[JsonSchemaMixin]
) -> dict:
) -> Dict[str, Any]:
"""Mutate a given schema key that contains a 'oneOf' to instead be an
'if-then-else' chain. This results is much better/more consistent errors
from jsonschema.
Expand Down

0 comments on commit 9bb3c69

Please sign in to comment.