Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve error handling in cli and introduce consistency #12764

Merged
merged 4 commits into from
Dec 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions airflow/cli/commands/celery_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# specific language governing permissions and limitations
# under the License.
"""Celery command"""
import sys

from multiprocessing import Process
from typing import Optional

Expand Down Expand Up @@ -95,8 +95,7 @@ def _serve_logs(skip_serve_logs: bool = False) -> Optional[Process]:
def worker(args):
"""Starts Airflow Celery worker"""
if not settings.validate_session():
print("Worker exiting... database connection precheck failed! ")
sys.exit(1)
raise SystemExit("Worker exiting, database connection precheck failed.")

autoscale = args.autoscale
skip_serve_logs = args.skip_serve_logs
Expand Down
7 changes: 2 additions & 5 deletions airflow/cli/commands/config_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
# under the License.
"""Config sub-commands"""
import io
import sys

import pygments
from pygments.lexers.configs import IniLexer
Expand All @@ -39,12 +38,10 @@ def show_config(args):
def get_value(args):
"""Get one value from configuration"""
if not conf.has_section(args.section):
print(f'The section [{args.section}] is not found in config.', file=sys.stderr)
sys.exit(1)
raise SystemExit(f'The section [{args.section}] is not found in config.')

if not conf.has_option(args.section, args.option):
print(f'The option [{args.section}/{args.option}] is not found in config.', file=sys.stderr)
sys.exit(1)
raise SystemExit(f'The option [{args.section}/{args.option}] is not found in config.')

value = conf.get(args.section, args.option)
print(value)
44 changes: 17 additions & 27 deletions airflow/cli/commands/connection_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,17 +112,13 @@ def _format_connections(conns: List[Connection], fmt: str) -> str:


def _is_stdout(fileio: io.TextIOWrapper) -> bool:
if fileio.name == '<stdout>':
return True
return False
return fileio.name == '<stdout>'


def _valid_uri(uri: str) -> bool:
"""Check if a URI is valid, by checking if both scheme and netloc are available"""
uri_parts = urlparse(uri)
if uri_parts.scheme == '' or uri_parts.netloc == '':
return False
return True
return uri_parts.scheme != '' and uri_parts.netloc != ''


def connections_export(args):
Expand All @@ -140,11 +136,10 @@ def connections_export(args):
_, filetype = os.path.splitext(args.file.name)
filetype = filetype.lower()
if filetype not in allowed_formats:
msg = (
f"Unsupported file format. "
f"The file must have the extension {', '.join(allowed_formats)}"
raise SystemExit(
f"Unsupported file format. The file must have "
f"the extension {', '.join(allowed_formats)}."
)
raise SystemExit(msg)

connections = session.query(Connection).order_by(Connection.conn_id).all()
msg = _format_connections(connections, filetype)
Expand All @@ -153,7 +148,7 @@ def connections_export(args):
if _is_stdout(args.file):
print("Connections successfully exported.", file=sys.stderr)
else:
print(f"Connections successfully exported to {args.file.name}")
print(f"Connections successfully exported to {args.file.name}.")


alternative_conn_specs = ['conn_type', 'conn_host', 'conn_login', 'conn_password', 'conn_schema', 'conn_port']
Expand All @@ -167,20 +162,19 @@ def connections_add(args):
invalid_args = []
if args.conn_uri:
if not _valid_uri(args.conn_uri):
msg = f'The URI provided to --conn-uri is invalid: {args.conn_uri}'
raise SystemExit(msg)
raise SystemExit(f'The URI provided to --conn-uri is invalid: {args.conn_uri}')
for arg in alternative_conn_specs:
if getattr(args, arg) is not None:
invalid_args.append(arg)
elif not args.conn_type:
missing_args.append('conn-uri or conn-type')
if missing_args:
msg = f'The following args are required to add a connection: {missing_args!r}'
raise SystemExit(msg)
raise SystemExit(f'The following args are required to add a connection: {missing_args!r}')
if invalid_args:
msg = 'The following args are not compatible with the add flag and --conn-uri flag: {invalid!r}'
msg = msg.format(invalid=invalid_args)
raise SystemExit(msg)
raise SystemExit(
f'The following args are not compatible with the '
f'add flag and --conn-uri flag: {invalid_args!r}'
)

if args.conn_uri:
new_conn = Connection(conn_id=args.conn_id, description=args.conn_description, uri=args.conn_uri)
Expand All @@ -201,7 +195,7 @@ def connections_add(args):
with create_session() as session:
if not session.query(Connection).filter(Connection.conn_id == new_conn.conn_id).first():
session.add(new_conn)
msg = '\n\tSuccessfully added `conn_id`={conn_id} : {uri}\n'
msg = 'Successfully added `conn_id`={conn_id} : {uri}'
msg = msg.format(
conn_id=new_conn.conn_id,
uri=args.conn_uri
Expand All @@ -223,7 +217,7 @@ def connections_add(args):
)
print(msg)
else:
msg = f'\n\tA connection with `conn_id`={new_conn.conn_id} already exists\n'
msg = f'A connection with `conn_id`={new_conn.conn_id} already exists.'
raise SystemExit(msg)


Expand All @@ -234,13 +228,9 @@ def connections_delete(args):
try:
to_delete = session.query(Connection).filter(Connection.conn_id == args.conn_id).one()
except exc.NoResultFound:
msg = f'\n\tDid not find a connection with `conn_id`={args.conn_id}\n'
raise SystemExit(msg)
raise SystemExit(f'Did not find a connection with `conn_id`={args.conn_id}')
except exc.MultipleResultsFound:
msg = f'\n\tFound more than one connection with `conn_id`={args.conn_id}\n'
raise SystemExit(msg)
raise SystemExit(f'Found more than one connection with `conn_id`={args.conn_id}')
else:
deleted_conn_id = to_delete.conn_id
session.delete(to_delete)
msg = f'\n\tSuccessfully deleted `conn_id`={deleted_conn_id}\n'
print(msg)
print(f"Successfully deleted connection with `conn_id`={to_delete.conn_id}")
6 changes: 2 additions & 4 deletions airflow/cli/commands/dag_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,12 +177,10 @@ def dag_show(args):
imgcat = args.imgcat

if filename and imgcat:
print(
raise SystemExit(
"Option --save and --imgcat are mutually exclusive. "
"Please remove one option to execute the command.",
file=sys.stderr,
)
sys.exit(1)
elif filename:
_save_dot_to_file(dot, filename)
elif imgcat:
Expand All @@ -197,7 +195,7 @@ def _display_dot_via_imgcat(dot: Dot):
proc = subprocess.Popen("imgcat", stdout=subprocess.PIPE, stdin=subprocess.PIPE)
except OSError as e:
if e.errno == errno.ENOENT:
raise AirflowException(
raise SystemExit(
"Failed to execute. Make sure the imgcat executables are on your systems \'PATH\'"
)
else:
Expand Down
2 changes: 1 addition & 1 deletion airflow/cli/commands/db_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def initdb(args):
"""Initializes the metadata database"""
print("DB: " + repr(settings.engine.url))
db.initdb()
print("Done.")
print("Initialization done")


def resetdb(args):
Expand Down
2 changes: 1 addition & 1 deletion airflow/cli/commands/kubernetes_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def cleanup_pods(args):
try:
_delete_pod(pod.metadata.name, namespace)
except ApiException as e:
print(f"can't remove POD: {e}", file=sys.stderr)
print(f"Can't remove POD: {e}", file=sys.stderr)
continue
print(f'No action taken on pod {pod_name}')
continue_token = pod_list.metadata._continue # pylint: disable=protected-access
Expand Down
20 changes: 11 additions & 9 deletions airflow/cli/commands/pool_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
"""Pools sub-commands"""
import json
import os
import sys
from json import JSONDecodeError

from airflow.api.client import get_current_api_client
Expand Down Expand Up @@ -52,8 +51,11 @@ def pool_list(args):
def pool_get(args):
"""Displays pool info by a given name"""
api_client = get_current_api_client()
pools = [api_client.get_pool(name=args.pool)]
_show_pools(pools=pools, output=args.output)
try:
pools = [api_client.get_pool(name=args.pool)]
_show_pools(pools=pools, output=args.output)
except PoolNotFound:
raise SystemExit(f"Pool {args.pool} does not exist")


@cli_utils.action_logging
Expand All @@ -74,18 +76,19 @@ def pool_delete(args):
api_client.delete_pool(name=args.pool)
print("Pool deleted")
except PoolNotFound:
sys.exit(f"Pool {args.pool} does not exist")
raise SystemExit(f"Pool {args.pool} does not exist")


@cli_utils.action_logging
@suppress_logs_and_warning()
def pool_import(args):
"""Imports pools from the file"""
if not os.path.exists(args.file):
sys.exit("Missing pools file.")
_, failed = pool_import_helper(args.file)
raise SystemExit("Missing pools file.")
pools, failed = pool_import_helper(args.file)
if len(failed) > 0:
sys.exit(f"Failed to update pool(s): {', '.join(failed)}")
raise SystemExit(f"Failed to update pool(s): {', '.join(failed)}")
print(f"Uploaded {len(pools)} pool(s)")


def pool_export(args):
Expand All @@ -103,15 +106,14 @@ def pool_import_helper(filepath):
try: # pylint: disable=too-many-nested-blocks
pools_json = json.loads(data)
except JSONDecodeError as e:
sys.exit("Invalid json file: " + str(e))
raise SystemExit("Invalid json file: " + str(e))
pools = []
failed = []
for k, v in pools_json.items():
if isinstance(v, dict) and len(v) == 2:
pools.append(api_client.create_pool(name=k, slots=v["slots"], description=v["description"]))
else:
failed.append(k)
print(f"{len(pools)} of {len(pools_json)} pool(s) successfully updated.")
return pools, failed


Expand Down
2 changes: 2 additions & 0 deletions airflow/cli/commands/role_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@ def roles_list(args):


@cli_utils.action_logging
@suppress_logs_and_warning()
def roles_create(args):
"""Creates new empty role in DB"""
appbuilder = cached_app().appbuilder # pylint: disable=no-member
for role_name in args.role:
appbuilder.sm.add_role(role_name)
print(f"Added {len(args.role)} role(s)")
36 changes: 15 additions & 21 deletions airflow/cli/commands/user_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import random
import re
import string
import sys

from airflow.cli.simple_table import AirflowConsole
from airflow.utils import cli as cli_utils
Expand Down Expand Up @@ -59,16 +58,16 @@ def users_create(args):
password = getpass.getpass('Password:')
password_confirmation = getpass.getpass('Repeat for confirmation:')
if password != password_confirmation:
raise SystemExit('Passwords did not match!')
raise SystemExit('Passwords did not match')

if appbuilder.sm.find_user(args.username):
print(f'{args.username} already exist in the db')
return
user = appbuilder.sm.add_user(args.username, args.firstname, args.lastname, args.email, role, password)
if user:
print(f'{args.role} user {args.username} created.')
print(f'{args.role} user {args.username} created')
else:
raise SystemExit('Failed to create user.')
raise SystemExit('Failed to create user')


@cli_utils.action_logging
Expand All @@ -79,12 +78,12 @@ def users_delete(args):
try:
user = next(u for u in appbuilder.sm.get_all_users() if u.username == args.username)
except StopIteration:
raise SystemExit(f'{args.username} is not a valid user.')
raise SystemExit(f'{args.username} is not a valid user')

if appbuilder.sm.del_register_user(user):
print(f'User {args.username} deleted.')
print(f'User {args.username} deleted')
else:
raise SystemExit('Failed to delete user.')
raise SystemExit('Failed to delete user')


@cli_utils.action_logging
Expand All @@ -110,16 +109,16 @@ def users_manage_role(args, remove=False):
if role in user.roles:
user.roles = [r for r in user.roles if r != role]
appbuilder.sm.update_user(user)
print(f'User "{user}" removed from role "{args.role}".')
print(f'User "{user}" removed from role "{args.role}"')
else:
raise SystemExit(f'User "{user}" is not a member of role "{args.role}".')
raise SystemExit(f'User "{user}" is not a member of role "{args.role}"')
else:
if role in user.roles:
raise SystemExit(f'User "{user}" is already a member of role "{args.role}".')
raise SystemExit(f'User "{user}" is already a member of role "{args.role}"')
else:
user.roles.append(role)
appbuilder.sm.update_user(user)
print(f'User "{user}" added to role "{args.role}".')
print(f'User "{user}" added to role "{args.role}"')


def users_export(args):
Expand Down Expand Up @@ -153,16 +152,14 @@ def users_import(args):
"""Imports users from the json file"""
json_file = getattr(args, 'import')
if not os.path.exists(json_file):
print("File '{}' does not exist")
sys.exit(1)
raise SystemExit(f"File '{json_file}' does not exist")

users_list = None # pylint: disable=redefined-outer-name
try:
with open(json_file) as file:
users_list = json.loads(file.read())
except ValueError as e:
print(f"File '{json_file}' is not valid JSON. Error: {e}")
sys.exit(1)
raise SystemExit(f"File '{json_file}' is not valid JSON. Error: {e}")

users_created, users_updated = _import_users(users_list)
if users_created:
Expand All @@ -183,16 +180,14 @@ def _import_users(users_list): # pylint: disable=redefined-outer-name
role = appbuilder.sm.find_role(rolename)
if not role:
valid_roles = appbuilder.sm.get_all_roles()
print(f"Error: '{rolename}' is not a valid role. Valid roles are: {valid_roles}")
sys.exit(1)
raise SystemExit(f"Error: '{rolename}' is not a valid role. Valid roles are: {valid_roles}")
else:
roles.append(role)

required_fields = ['username', 'firstname', 'lastname', 'email', 'roles']
for field in required_fields:
if not user.get(field):
print(f"Error: '{field}' is a required field, but was not specified")
sys.exit(1)
raise SystemExit(f"Error: '{field}' is a required field, but was not specified")

existing_user = appbuilder.sm.find_user(email=user['email'])
if existing_user:
Expand All @@ -202,12 +197,11 @@ def _import_users(users_list): # pylint: disable=redefined-outer-name
existing_user.last_name = user['lastname']

if existing_user.username != user['username']:
print(
raise SystemExit(
"Error: Changing the username is not allowed - "
"please delete and recreate the user with "
"email '{}'".format(user['email'])
)
sys.exit(1)

appbuilder.sm.update_user(existing_user)
users_updated.append(user['email'])
Expand Down
Loading