Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add annotations to the code #1982

Merged
merged 6 commits into from
Dec 6, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions core/dbt/adapters/base/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ class SchemaSearchMap(dict):
"""A utility class to keep track of what information_schema tables to
search for what schemas
"""
def add(self, relation):
def add(self, relation: BaseRelation):
key = relation.information_schema_only()
if key not in self:
self[key] = set()
Expand Down Expand Up @@ -225,7 +225,7 @@ def clear_transaction(self) -> None:
def commit_if_has_connection(self) -> None:
self.connections.commit_if_has_connection()

def nice_connection_name(self):
def nice_connection_name(self) -> str:
conn = self.connections.get_if_exists()
if conn is None or conn.name is None:
return '<None>'
Expand All @@ -234,7 +234,7 @@ def nice_connection_name(self):
@contextmanager
def connection_named(
self, name: str, node: Optional[CompileResultNode] = None
):
) -> Connection:
try:
self.connections.query_header.set(name, node)
conn = self.acquire_connection(name)
Expand Down Expand Up @@ -306,7 +306,7 @@ def load_internal_manifest(self) -> Manifest:
###
# Caching methods
###
def _schema_is_cached(self, database: str, schema: str):
def _schema_is_cached(self, database: str, schema: str) -> bool:
"""Check if the schema is cached, and by default logs if it is not."""

if dbt.flags.USE_CACHE is False:
Expand Down Expand Up @@ -345,7 +345,7 @@ def _get_cache_schemas(

def _relations_cache_for_schemas(self, manifest: Manifest) -> None:
"""Populate the relations cache for the given schemas. Returns an
iteratble of the schemas populated, as strings.
iterable of the schemas populated, as strings.
"""
if not dbt.flags.USE_CACHE:
return
Expand Down
21 changes: 10 additions & 11 deletions core/dbt/adapters/cache.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,22 @@
from collections import namedtuple
from copy import deepcopy
from typing import List, Iterable
from typing import List, Iterable, Optional
import threading

from dbt.logger import CACHE_LOGGER as logger
import dbt.exceptions


_ReferenceKey = namedtuple('_ReferenceKey', 'database schema identifier')


def _lower(value):
def _lower(value: Optional[str]) -> Optional[str]:
"""Postgres schemas can be None so we can't just call lower()."""
if value is None:
return None
return value.lower()


def _make_key(relation):
def _make_key(relation) -> _ReferenceKey:
"""Make _ReferenceKeys with lowercase values for the cache so we don't have
to keep track of quoting
"""
Expand All @@ -26,10 +25,10 @@ def _make_key(relation):
_lower(relation.identifier))


def dot_separated(key):
def dot_separated(key: _ReferenceKey) -> str:
"""Return the key in dot-separated string form.

:param key _ReferenceKey: The key to stringify.
:param _ReferenceKey key: The key to stringify.
"""
return '.'.join(map(str, key))

Expand All @@ -47,21 +46,21 @@ def __init__(self, inner):
self.referenced_by = {}
self.inner = inner

def __str__(self):
def __str__(self) -> str:
return (
'_CachedRelation(database={}, schema={}, identifier={}, inner={})'
).format(self.database, self.schema, self.identifier, self.inner)

@property
def database(self):
def database(self) -> Optional[str]:
return _lower(self.inner.database)

@property
def schema(self):
def schema(self) -> Optional[str]:
return _lower(self.inner.schema)

@property
def identifier(self):
def identifier(self) -> Optional[str]:
return _lower(self.inner.identifier)

def __copy__(self):
Expand All @@ -84,7 +83,7 @@ def key(self):
"""
return _make_key(self)

def add_reference(self, referrer):
def add_reference(self, referrer: '_CachedRelation'):
"""Add a reference from referrer to self, indicating that if this node
were drop...cascaded, the referrer would be dropped as well.

Expand Down
4 changes: 2 additions & 2 deletions core/dbt/adapters/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ def reset_adapters(self):
self.adapters.clear()

def cleanup_connections(self):
"""Only clean up the adapter connections list without resetting the actual
adapters.
"""Only clean up the adapter connections list without resetting the
actual adapters.
"""
with self.lock:
for adapter in self.adapters.values():
Expand Down
44 changes: 28 additions & 16 deletions core/dbt/adapters/sql/impl.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import agate
from typing import Any, Optional, Tuple, Type
from typing import Any, Optional, Tuple, Type, List

import dbt.clients.agate_helper
from dbt.contracts.connection import Connection
Expand All @@ -9,6 +9,7 @@
from dbt.adapters.sql import SQLConnectionManager
from dbt.logger import GLOBAL_LOGGER as logger

from dbt.adapters.factory import BaseRelation

LIST_RELATIONS_MACRO_NAME = 'list_relations_without_caching'
GET_COLUMNS_IN_RELATION_MACRO_NAME = 'get_columns_in_relation'
Expand Down Expand Up @@ -38,6 +39,7 @@ class SQLAdapter(BaseAdapter):
- list_relations_without_caching
- get_columns_in_relation
"""

ConnectionManager: Type[SQLConnectionManager]
connections: SQLConnectionManager

Expand All @@ -63,32 +65,38 @@ def add_query(
abridge_sql_log)

@classmethod
def convert_text_type(cls, agate_table, col_idx):
def convert_text_type(cls, agate_table: agate.Table, col_idx: int) -> str:
return "text"

@classmethod
def convert_number_type(cls, agate_table, col_idx):
def convert_number_type(
cls, agate_table: agate.Table, col_idx: int
) -> str:
decimals = agate_table.aggregate(agate.MaxPrecision(col_idx))
return "float8" if decimals else "integer"

@classmethod
def convert_boolean_type(cls, agate_table, col_idx):
def convert_boolean_type(
cls, agate_table: agate.Table, col_idx: int
) -> str:
return "boolean"

@classmethod
def convert_datetime_type(cls, agate_table, col_idx):
def convert_datetime_type(
cls, agate_table: agate.Table, col_idx: int
) -> str:
return "timestamp without time zone"

@classmethod
def convert_date_type(cls, agate_table, col_idx):
def convert_date_type(cls, agate_table: agate.Table, col_idx: int) -> str:
return "date"

@classmethod
def convert_time_type(cls, agate_table, col_idx):
def convert_time_type(cls, agate_table: agate.Table, col_idx: int) -> str:
return "time"

@classmethod
def is_cancelable(cls):
def is_cancelable(cls) -> bool:
return True

def expand_column_types(self, goal, current):
Expand All @@ -114,7 +122,9 @@ def expand_column_types(self, goal, current):

self.alter_column_type(current, column_name, new_type)

def alter_column_type(self, relation, column_name, new_column_type):
def alter_column_type(
self, relation, column_name, new_column_type
) -> None:
"""
1. Create a new column (w/ temp name and correct type)
2. Copy data over to it
Expand Down Expand Up @@ -158,13 +168,13 @@ def rename_relation(self, from_relation, to_relation):
kwargs=kwargs
)

def get_columns_in_relation(self, relation):
def get_columns_in_relation(self, relation: str):
return self.execute_macro(
GET_COLUMNS_IN_RELATION_MACRO_NAME,
kwargs={'relation': relation}
)

def create_schema(self, database, schema):
def create_schema(self, database: str, schema: str) -> None:
logger.debug('Creating schema "{}"."{}".', database, schema)
kwargs = {
'database_name': self.quote_as_configured(database, 'database'),
Expand All @@ -175,7 +185,7 @@ def create_schema(self, database, schema):
# we can't update the cache here, as if the schema already existed we
# don't want to (incorrectly) say that it's empty

def drop_schema(self, database, schema):
def drop_schema(self, database: str, schema: str) -> None:
logger.debug('Dropping schema "{}"."{}".', database, schema)
kwargs = {
'database_name': self.quote_as_configured(database, 'database'),
Expand All @@ -185,7 +195,9 @@ def drop_schema(self, database, schema):
# we can update the cache here
self.cache.drop_schema(database, schema)

def list_relations_without_caching(self, information_schema, schema):
def list_relations_without_caching(
self, information_schema, schema
) -> List[BaseRelation]:
kwargs = {'information_schema': information_schema, 'schema': schema}
results = self.execute_macro(
LIST_RELATIONS_MACRO_NAME,
Expand All @@ -212,18 +224,18 @@ def list_relations_without_caching(self, information_schema, schema):
))
return relations

def quote(cls, identifier):
def quote(self, identifier):
return '"{}"'.format(identifier)

def list_schemas(self, database):
def list_schemas(self, database: str) -> List[str]:
results = self.execute_macro(
LIST_SCHEMAS_MACRO_NAME,
kwargs={'database': database}
)

return [row[0] for row in results]

def check_schema_exists(self, database, schema):
def check_schema_exists(self, database: str, schema: str) -> bool:
information_schema = self.Relation.create(
database=database,
schema=schema,
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/contracts/graph/parsed.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ def _create_if_else_chain(
key: str,
criteria: List[Tuple[str, Type[JsonSchemaMixin]]],
default: Type[JsonSchemaMixin]
) -> dict:
) -> Dict[str, Any]:
"""Mutate a given schema key that contains a 'oneOf' to instead be an
'if-then-else' chain. This results is much better/more consistent errors
from jsonschema.
Expand Down