From 9bb3c69bdb92414cf13693e5ce4bbe7b8e224d36 Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Thu, 5 Dec 2019 19:15:07 +0100 Subject: [PATCH] Add annotations to the code --- core/dbt/adapters/base/impl.py | 10 +++---- core/dbt/adapters/cache.py | 24 ++++++++++------- core/dbt/adapters/sql/impl.py | 42 ++++++++++++++++-------------- core/dbt/contracts/graph/parsed.py | 2 +- 4 files changed, 42 insertions(+), 36 deletions(-) diff --git a/core/dbt/adapters/base/impl.py b/core/dbt/adapters/base/impl.py index 15b534cb5aa..d10aecd6441 100644 --- a/core/dbt/adapters/base/impl.py +++ b/core/dbt/adapters/base/impl.py @@ -115,7 +115,7 @@ class SchemaSearchMap(dict): """A utility class to keep track of what information_schema tables to search for what schemas """ - def add(self, relation): + def add(self, relation: BaseRelation): key = relation.information_schema_only() if key not in self: self[key] = set() @@ -225,7 +225,7 @@ def clear_transaction(self) -> None: def commit_if_has_connection(self) -> None: self.connections.commit_if_has_connection() - def nice_connection_name(self): + def nice_connection_name(self) -> str: conn = self.connections.get_if_exists() if conn is None or conn.name is None: return '' @@ -234,7 +234,7 @@ def nice_connection_name(self): @contextmanager def connection_named( self, name: str, node: Optional[CompileResultNode] = None - ): + ) -> Connection: try: self.connections.query_header.set(name, node) conn = self.acquire_connection(name) @@ -306,7 +306,7 @@ def load_internal_manifest(self) -> Manifest: ### # Caching methods ### - def _schema_is_cached(self, database: str, schema: str): + def _schema_is_cached(self, database: str, schema: str) -> bool: """Check if the schema is cached, and by default logs if it is not.""" if dbt.flags.USE_CACHE is False: @@ -345,7 +345,7 @@ def _get_cache_schemas( def _relations_cache_for_schemas(self, manifest: Manifest) -> None: """Populate the relations cache for the given schemas. Returns an - iteratble of the schemas populated, as strings. + iterable of the schemas populated, as strings. """ if not dbt.flags.USE_CACHE: return diff --git a/core/dbt/adapters/cache.py b/core/dbt/adapters/cache.py index ca858cb95d3..1b0bec6fc4f 100644 --- a/core/dbt/adapters/cache.py +++ b/core/dbt/adapters/cache.py @@ -1,21 +1,25 @@ from collections import namedtuple import threading from copy import deepcopy +from typing import Optional + +from dbt.adapters.cache import _CachedRelation from dbt.logger import CACHE_LOGGER as logger import dbt.exceptions +from core.dbt.adapters.factory import BaseRelation _ReferenceKey = namedtuple('_ReferenceKey', 'database schema identifier') -def _lower(value): +def _lower(value: Optional[str]) -> Optional[str]: """Postgres schemas can be None so we can't just call lower().""" if value is None: return None return value.lower() -def _make_key(relation): +def _make_key(relation: BaseRelation) -> _ReferenceKey: """Make _ReferenceKeys with lowercase values for the cache so we don't have to keep track of quoting """ @@ -24,10 +28,10 @@ def _make_key(relation): _lower(relation.identifier)) -def dot_separated(key): +def dot_separated(key: _ReferenceKey) -> str: """Return the key in dot-separated string form. - :param key _ReferenceKey: The key to stringify. + :param _ReferenceKey key: The key to stringify. """ return '.'.join(map(str, key)) @@ -41,25 +45,25 @@ class _CachedRelation: that refer to this relation. :attr BaseRelation inner: The underlying dbt relation. """ - def __init__(self, inner): + def __init__(self, inner: BaseRelation): self.referenced_by = {} self.inner = inner - def __str__(self): + def __str__(self) -> str: return ( '_CachedRelation(database={}, schema={}, identifier={}, inner={})' ).format(self.database, self.schema, self.identifier, self.inner) @property - def database(self): + def database(self) -> Optional[str]: return _lower(self.inner.database) @property - def schema(self): + def schema(self) -> Optional[str]: return _lower(self.inner.schema) @property - def identifier(self): + def identifier(self) -> Optional[str]: return _lower(self.inner.identifier) def __copy__(self): @@ -82,7 +86,7 @@ def key(self): """ return _make_key(self) - def add_reference(self, referrer): + def add_reference(self, referrer: _CachedRelation): """Add a reference from referrer to self, indicating that if this node were drop...cascaded, the referrer would be dropped as well. diff --git a/core/dbt/adapters/sql/impl.py b/core/dbt/adapters/sql/impl.py index e15bc143a2c..80a9db08456 100644 --- a/core/dbt/adapters/sql/impl.py +++ b/core/dbt/adapters/sql/impl.py @@ -1,5 +1,5 @@ import agate -from typing import Any, Optional, Tuple, Type +from typing import Any, Optional, Tuple, Type, List import dbt.clients.agate_helper from dbt.contracts.connection import Connection @@ -9,6 +9,7 @@ from dbt.adapters.sql import SQLConnectionManager from dbt.logger import GLOBAL_LOGGER as logger +from core.dbt.adapters.factory import BaseRelation LIST_RELATIONS_MACRO_NAME = 'list_relations_without_caching' GET_COLUMNS_IN_RELATION_MACRO_NAME = 'get_columns_in_relation' @@ -38,6 +39,7 @@ class SQLAdapter(BaseAdapter): - list_relations_without_caching - get_columns_in_relation """ + ConnectionManager: Type[SQLConnectionManager] connections: SQLConnectionManager @@ -63,35 +65,35 @@ def add_query( abridge_sql_log) @classmethod - def convert_text_type(cls, agate_table, col_idx): + def convert_text_type(cls, agate_table: agate.Table, col_idx: int) -> str: return "text" @classmethod - def convert_number_type(cls, agate_table, col_idx): + def convert_number_type(cls, agate_table: agate.Table, col_idx: int) -> str: decimals = agate_table.aggregate(agate.MaxPrecision(col_idx)) return "float8" if decimals else "integer" @classmethod - def convert_boolean_type(cls, agate_table, col_idx): + def convert_boolean_type(cls, agate_table: agate.Table, col_idx: int) -> str: return "boolean" @classmethod - def convert_datetime_type(cls, agate_table, col_idx): + def convert_datetime_type(cls, agate_table: agate.Table, col_idx: int) -> str: return "timestamp without time zone" @classmethod - def convert_date_type(cls, agate_table, col_idx): + def convert_date_type(cls, agate_table: agate.Table, col_idx: int) -> str: return "date" @classmethod - def convert_time_type(cls, agate_table, col_idx): + def convert_time_type(cls, agate_table: agate.Table, col_idx: int) -> str: return "time" @classmethod - def is_cancelable(cls): + def is_cancelable(cls) -> bool: return True - def expand_column_types(self, goal, current): + def expand_column_types(self, goal, current, model_name: Optional[str] = None): reference_columns = { c.name: c for c in self.get_columns_in_relation(goal) @@ -114,7 +116,7 @@ def expand_column_types(self, goal, current): self.alter_column_type(current, column_name, new_type) - def alter_column_type(self, relation, column_name, new_column_type): + def alter_column_type(self, relation, column_name, new_column_type) -> None: """ 1. Create a new column (w/ temp name and correct type) 2. Copy data over to it @@ -131,7 +133,7 @@ def alter_column_type(self, relation, column_name, new_column_type): kwargs=kwargs ) - def drop_relation(self, relation): + def drop_relation(self, relation, model_name: Optional[str] = None): if relation.type is None: dbt.exceptions.raise_compiler_error( 'Tried to drop relation {}, but its type is null.' @@ -143,13 +145,13 @@ def drop_relation(self, relation): kwargs={'relation': relation} ) - def truncate_relation(self, relation): + def truncate_relation(self, relation, model_name: Optional[str] = None): self.execute_macro( TRUNCATE_RELATION_MACRO_NAME, kwargs={'relation': relation} ) - def rename_relation(self, from_relation, to_relation): + def rename_relation(self, from_relation, to_relation, model_name: Optional[str] = None): self.cache_renamed(from_relation, to_relation) kwargs = {'from_relation': from_relation, 'to_relation': to_relation} @@ -158,13 +160,13 @@ def rename_relation(self, from_relation, to_relation): kwargs=kwargs ) - def get_columns_in_relation(self, relation): + def get_columns_in_relation(self, relation: str, model_name: Optional[str] = None): return self.execute_macro( GET_COLUMNS_IN_RELATION_MACRO_NAME, kwargs={'relation': relation} ) - def create_schema(self, database, schema): + def create_schema(self, database: str, schema: str, model_name: Optional[str] = None) -> None: logger.debug('Creating schema "{}"."{}".', database, schema) kwargs = { 'database_name': self.quote_as_configured(database, 'database'), @@ -173,7 +175,7 @@ def create_schema(self, database, schema): self.execute_macro(CREATE_SCHEMA_MACRO_NAME, kwargs=kwargs) self.commit_if_has_connection() - def drop_schema(self, database, schema): + def drop_schema(self, database: str, schema: str, model_name: Optional[str] = None) -> None: logger.debug('Dropping schema "{}"."{}".', database, schema) kwargs = { 'database_name': self.quote_as_configured(database, 'database'), @@ -182,7 +184,7 @@ def drop_schema(self, database, schema): self.execute_macro(DROP_SCHEMA_MACRO_NAME, kwargs=kwargs) - def list_relations_without_caching(self, information_schema, schema): + def list_relations_without_caching(self, information_schema, schema, model_name: Optional[str] = None) -> List[BaseRelation]: kwargs = {'information_schema': information_schema, 'schema': schema} results = self.execute_macro( LIST_RELATIONS_MACRO_NAME, @@ -209,10 +211,10 @@ def list_relations_without_caching(self, information_schema, schema): )) return relations - def quote(cls, identifier): + def quote(self, identifier): return '"{}"'.format(identifier) - def list_schemas(self, database): + def list_schemas(self, database: str, model_name: Optional[str] = None) -> List[str]: results = self.execute_macro( LIST_SCHEMAS_MACRO_NAME, kwargs={'database': database} @@ -220,7 +222,7 @@ def list_schemas(self, database): return [row[0] for row in results] - def check_schema_exists(self, database, schema): + def check_schema_exists(self, database: str, schema: str, model_name: Optional[str] = None) -> bool: information_schema = self.Relation.create( database=database, schema=schema, diff --git a/core/dbt/contracts/graph/parsed.py b/core/dbt/contracts/graph/parsed.py index 4c5b12303a9..8f34f0d51b1 100644 --- a/core/dbt/contracts/graph/parsed.py +++ b/core/dbt/contracts/graph/parsed.py @@ -396,7 +396,7 @@ def _create_if_else_chain( key: str, criteria: List[Tuple[str, Type[JsonSchemaMixin]]], default: Type[JsonSchemaMixin] -) -> dict: +) -> Dict[str, Any]: """Mutate a given schema key that contains a 'oneOf' to instead be an 'if-then-else' chain. This results is much better/more consistent errors from jsonschema.