Skip to content

Commit

Permalink
[#2990] Normalize global CLI args/flags
Browse files Browse the repository at this point in the history
  • Loading branch information
gshank committed Aug 31, 2021
1 parent 052e54d commit 0b6099c
Show file tree
Hide file tree
Showing 44 changed files with 272 additions and 519 deletions.
12 changes: 0 additions & 12 deletions core/dbt/adapters/base/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,12 +238,6 @@ def _close_handle(cls, connection: Connection) -> None:
@classmethod
def _rollback(cls, connection: Connection) -> None:
"""Roll back the given connection."""
if flags.STRICT_MODE:
if not isinstance(connection, Connection):
raise dbt.exceptions.CompilerException(
f'In _rollback, got {connection} - not a Connection!'
)

if connection.transaction_open is False:
raise dbt.exceptions.InternalException(
f'Tried to rollback transaction on connection '
Expand All @@ -257,12 +251,6 @@ def _rollback(cls, connection: Connection) -> None:

@classmethod
def close(cls, connection: Connection) -> Connection:
if flags.STRICT_MODE:
if not isinstance(connection, Connection):
raise dbt.exceptions.CompilerException(
f'In close, got {connection} - not a Connection!'
)

# if the connection is in closed or init, there's nothing to do
if connection.state in {ConnectionState.CLOSED, ConnectionState.INIT}:
return connection
Expand Down
20 changes: 4 additions & 16 deletions core/dbt/adapters/base/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
get_relation_returned_multiple_results,
InternalException, NotImplementedException, RuntimeException,
)
from dbt import flags

from dbt import deprecations
from dbt.adapters.protocol import (
Expand Down Expand Up @@ -289,9 +288,7 @@ def clear_macro_manifest(self):
def _schema_is_cached(self, database: Optional[str], schema: str) -> bool:
"""Check if the schema is cached, and by default logs if it is not."""

if flags.USE_CACHE is False:
return False
elif (database, schema) not in self.cache:
if (database, schema) not in self.cache:
logger.debug(
'On "{}": cache miss for schema "{}.{}", this is inefficient'
.format(self.nice_connection_name(), database, schema)
Expand Down Expand Up @@ -340,9 +337,6 @@ def _relations_cache_for_schemas(self, manifest: Manifest) -> None:
"""Populate the relations cache for the given schemas. Returns an
iterable of the schemas populated, as strings.
"""
if not flags.USE_CACHE:
return

cache_schemas = self._get_cache_schemas(manifest)
with executor(self.config) as tpe:
futures: List[Future[List[BaseRelation]]] = []
Expand Down Expand Up @@ -375,9 +369,6 @@ def set_relations_cache(
"""Run a query that gets a populated cache of the relations in the
database and set the cache on this adapter.
"""
if not flags.USE_CACHE:
return

with self.cache.lock:
if clear:
self.cache.clear()
Expand All @@ -391,8 +382,7 @@ def cache_added(self, relation: Optional[BaseRelation]) -> str:
raise_compiler_error(
'Attempted to cache a null relation for {}'.format(name)
)
if flags.USE_CACHE:
self.cache.add(relation)
self.cache.add(relation)
# so jinja doesn't render things
return ''

Expand All @@ -406,8 +396,7 @@ def cache_dropped(self, relation: Optional[BaseRelation]) -> str:
raise_compiler_error(
'Attempted to drop a null relation for {}'.format(name)
)
if flags.USE_CACHE:
self.cache.drop(relation)
self.cache.drop(relation)
return ''

@available
Expand All @@ -428,8 +417,7 @@ def cache_renamed(
.format(src_name, dst_name, name)
)

if flags.USE_CACHE:
self.cache.rename(from_relation, to_relation)
self.cache.rename(from_relation, to_relation)
return ''

###
Expand Down
14 changes: 0 additions & 14 deletions core/dbt/adapters/sql/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
Connection, ConnectionState, AdapterResponse
)
from dbt.logger import GLOBAL_LOGGER as logger
from dbt import flags


class SQLConnectionManager(BaseConnectionManager):
Expand Down Expand Up @@ -144,13 +143,6 @@ def add_commit_query(self):

def begin(self):
connection = self.get_thread_connection()

if flags.STRICT_MODE:
if not isinstance(connection, Connection):
raise dbt.exceptions.CompilerException(
f'In begin, got {connection} - not a Connection!'
)

if connection.transaction_open is True:
raise dbt.exceptions.InternalException(
'Tried to begin a new transaction on connection "{}", but '
Expand All @@ -163,12 +155,6 @@ def begin(self):

def commit(self):
connection = self.get_thread_connection()
if flags.STRICT_MODE:
if not isinstance(connection, Connection):
raise dbt.exceptions.CompilerException(
f'In commit, got {connection} - not a Connection!'
)

if connection.transaction_open is False:
raise dbt.exceptions.InternalException(
'Tried to commit transaction on connection "{}", but '
Expand Down
2 changes: 0 additions & 2 deletions core/dbt/context/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,8 +526,6 @@ def flags(self) -> Any:
The list of valid flags are:
- `flags.STRICT_MODE`: True if `--strict` (or `-S`) was provided on the
command line
- `flags.FULL_REFRESH`: True if `--full-refresh` was provided on the
command line
- `flags.NON_DESTRUCTIVE`: True if `--non-destructive` was provided on
Expand Down
4 changes: 1 addition & 3 deletions core/dbt/contracts/graph/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,9 +223,7 @@ def __post_init__(self):
self.user_id = tracking.active_user.id

if self.send_anonymous_usage_stats is None:
self.send_anonymous_usage_stats = (
not tracking.active_user.do_not_track
)
self.send_anonymous_usage_stats = flags.SEND_ANONYMOUS_USAGE_STATS

@classmethod
def default(cls):
Expand Down
12 changes: 0 additions & 12 deletions core/dbt/contracts/graph/parsed.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,13 +156,6 @@ def patch(self, patch: 'ParsedNodePatch'):
self.columns = patch.columns
self.meta = patch.meta
self.docs = patch.docs
if flags.STRICT_MODE:
# It seems odd that an instance can be invalid
# Maybe there should be validation or restrictions
# elsewhere?
assert isinstance(self, dbtClassMixin)
dct = self.to_dict(omit_none=False)
self.validate(dct)

def get_materialization(self):
return self.config.materialized
Expand Down Expand Up @@ -509,11 +502,6 @@ def patch(self, patch: ParsedMacroPatch):
self.meta = patch.meta
self.docs = patch.docs
self.arguments = patch.arguments
if flags.STRICT_MODE:
# What does this actually validate?
assert isinstance(self, dbtClassMixin)
dct = self.to_dict(omit_none=False)
self.validate(dct)

def same_contents(self, other: Optional['ParsedMacro']) -> bool:
if other is None:
Expand Down
25 changes: 11 additions & 14 deletions core/dbt/contracts/project.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dbt.contracts.util import Replaceable, Mergeable, list_str
from dbt.contracts.connection import UserConfigContract, QueryComment
from dbt.contracts.connection import QueryComment
from dbt.helper_types import NoValue
from dbt.logger import GLOBAL_LOGGER as logger # noqa
from dbt import tracking
Expand Down Expand Up @@ -225,23 +225,20 @@ def validate(cls, data):


@dataclass
class UserConfig(ExtensibleDbtClassMixin, Replaceable, UserConfigContract):
class UserConfig(ExtensibleDbtClassMixin, Replaceable):
send_anonymous_usage_stats: bool = DEFAULT_SEND_ANONYMOUS_USAGE_STATS
use_colors: Optional[bool] = None
partial_parse: Optional[bool] = None
printer_width: Optional[int] = None

def set_values(self, cookie_dir):
if self.send_anonymous_usage_stats:
tracking.initialize_tracking(cookie_dir)
else:
tracking.do_not_track()

if self.use_colors is not None:
ui.use_colors(self.use_colors)

if self.printer_width:
ui.printer_width(self.printer_width)
# new config below
write_json: Optional[bool] = None
warn_error: Optional[bool] = None
log_format: Optional[bool] = None
debug: Optional[bool] = None
no_version_check: Optional[bool] = None
fail_fast: Optional[bool] = None
profiles_dir: Optional[str] = None
use_experimental_parser: Optional[bool] = None


@dataclass
Expand Down
126 changes: 72 additions & 54 deletions core/dbt/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,40 @@
from pathlib import Path
from typing import Optional

# initially all flags are set to None, the on-load call of reset() will set
# them for their first time.
STRICT_MODE = None
FULL_REFRESH = None
USE_CACHE = None
WARN_ERROR = None
TEST_NEW_PARSER = None

STRICT_MODE = False # Only here for backwards compatibility
FULL_REFRESH = False # subcommand
STORE_FAILURES = False # subcommand

# Global CLI commands
USE_EXPERIMENTAL_PARSER = None
WARN_ERROR = None
WRITE_JSON = None
PARTIAL_PARSE = None
USE_COLORS = None
STORE_FAILURES = None
PROFILES_DIR = None
DEBUG = None
LOG_FORMAT = None
NO_VERSION_CHECK = None
FAIL_FAST = None
SEND_ANONYMOUS_USAGE_STATS = None
PRINTER_WIDTH = None

# Global CLI defaults
flag_defaults = {
"USE_EXPERIMENTAL_PARSER": False,
"WARN_ERROR": False,
"WRITE_JSON": True,
"PARTIAL_PARSE": False,
"USE_COLORS": True,
"PROFILES_DIR": None,
"DEBUG": False,
"LOG_FORMAT": None,
"NO_VERSION_CHECK": False,
"FAIL_FAST": False,
"SEND_ANONYMOUS_USAGE_STATS": True,
"PRINTER_WIDTH": 80
}


def env_set_truthy(key: str) -> Optional[str]:
Expand Down Expand Up @@ -50,56 +72,52 @@ def _get_context():
return multiprocessing.get_context('spawn')


# This is not a flag, it's a place to store the lock
MP_CONTEXT = _get_context()


def reset():
global STRICT_MODE, FULL_REFRESH, USE_CACHE, WARN_ERROR, TEST_NEW_PARSER, \
USE_EXPERIMENTAL_PARSER, WRITE_JSON, PARTIAL_PARSE, MP_CONTEXT, USE_COLORS, \
STORE_FAILURES

STRICT_MODE = False
FULL_REFRESH = False
USE_CACHE = True
WARN_ERROR = False
TEST_NEW_PARSER = False
USE_EXPERIMENTAL_PARSER = False
WRITE_JSON = True
PARTIAL_PARSE = False
MP_CONTEXT = _get_context()
USE_COLORS = True
STORE_FAILURES = False


def set_from_args(args):
global STRICT_MODE, FULL_REFRESH, USE_CACHE, WARN_ERROR, TEST_NEW_PARSER, \
USE_EXPERIMENTAL_PARSER, WRITE_JSON, PARTIAL_PARSE, MP_CONTEXT, USE_COLORS, \
STORE_FAILURES

USE_CACHE = getattr(args, 'use_cache', USE_CACHE)
def set_from_args(args, user_config):
global STRICT_MODE, FULL_REFRESH, WARN_ERROR, \
USE_EXPERIMENTAL_PARSER, WRITE_JSON, PARTIAL_PARSE, USE_COLORS, \
STORE_FAILURES, PROFILES_DIR, DEBUG, LOG_FORMAT,\
NO_VERSION_CHECK, FAIL_FAST, SEND_ANONYMOUS_USAGE_STATS, PRINTER_WIDTH

STRICT_MODE = False # backwards compatibility
# cli args without user_config or env var option
FULL_REFRESH = getattr(args, 'full_refresh', FULL_REFRESH)
STRICT_MODE = getattr(args, 'strict', STRICT_MODE)
WARN_ERROR = (
STRICT_MODE or
getattr(args, 'warn_error', STRICT_MODE or WARN_ERROR)
)

TEST_NEW_PARSER = getattr(args, 'test_new_parser', TEST_NEW_PARSER)
USE_EXPERIMENTAL_PARSER = getattr(args, 'use_experimental_parser', USE_EXPERIMENTAL_PARSER)
WRITE_JSON = getattr(args, 'write_json', WRITE_JSON)
PARTIAL_PARSE = getattr(args, 'partial_parse', None)
MP_CONTEXT = _get_context()

# The use_colors attribute will always have a value because it is assigned
# None by default from the add_mutually_exclusive_group function
use_colors_override = getattr(args, 'use_colors')

if use_colors_override is not None:
USE_COLORS = use_colors_override

STORE_FAILURES = getattr(args, 'store_failures', STORE_FAILURES)


# initialize everything to the defaults on module load
reset()
# global cli flags with env var and user_config alternatives
USE_EXPERIMENTAL_PARSER = get_flag_value('USE_EXPERIMENTAL_PARSER', args, user_config)
WARN_ERROR = get_flag_value('WARN_ERROR', args, user_config)
WRITE_JSON = get_flag_value('WRITE_JSON', args, user_config)
PARTIAL_PARSE = get_flag_value('PARTIAL_PARSE', args, user_config)
USE_COLORS = get_flag_value('USE_COLORS', args, user_config)
PROFILES_DIR = get_flag_value('PROFILES_DIR', args, user_config)
DEBUG = get_flag_value('DEBUG', args, user_config)
LOG_FORMAT = get_flag_value('LOG_FORMAT', args, user_config)
NO_VERSION_CHECK = get_flag_value('NO_VERSION_CHECK', args, user_config)
FAIL_FAST = get_flag_value('FAIL_FAST', args, user_config)
SEND_ANONYMOUS_USAGE_STATS = get_flag_value('SEND_ANONYMOUS_USAGE_STATS', args, user_config)
PRINTER_WIDTH = get_flag_value('PRINTER_WIDTH', args, user_config)


def get_flag_value(flag, args, user_config):
lc_flag = flag.lower()
flag_value = getattr(args, lc_flag, None)
if flag_value is None:
env_flag = f"DBT_{flag}"
env_value = os.getenv(env_flag)
if env_value is not None:
if flag in ['LOG_FORMAT', 'PROFILES_DIR', 'PRINTER_WIDTH']:
flag_value = env_value
else:
flag_value = True if env_value else False
elif user_config is not None and getattr(user_config, lc_flag, None):
flag_value = getattr(user_config, lc_flag)
else:
flag_value = flag_defaults[flag]
if flag == 'PROFILES_DIR' and flag_value:
flag_value = os.path.abspath(flag_value)

return flag_value
Loading

0 comments on commit 0b6099c

Please sign in to comment.