Skip to content

Commit

Permalink
[AIRFLOW-6080] fix type error from new version of mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
Qingping Hou committed Nov 27, 2019
1 parent 03c870a commit d404c9a
Show file tree
Hide file tree
Showing 17 changed files with 197 additions and 423 deletions.
4 changes: 2 additions & 2 deletions airflow/contrib/operators/file_to_wasb.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ def execute(self, context):
"""Upload a file to Azure Blob Storage."""
hook = WasbHook(wasb_conn_id=self.wasb_conn_id)
self.log.info(
'Uploading %s to wasb://%s '
'as %s'.format(self.file_path, self.container_name, self.blob_name)
'Uploading %s to wasb://%s as %s',
self.file_path, self.container_name, self.blob_name,
)
hook.load_file(self.file_path, self.container_name,
self.blob_name, **self.load_options)
2 changes: 1 addition & 1 deletion airflow/gcp/utils/field_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def _validate_is_empty(full_field_path: str, value: str) -> None:
if not value:
raise GcpFieldValidationException(
"The body field '{}' can't be empty. Please provide a value."
.format(full_field_path, value))
.format(full_field_path))

def _validate_dict(self, children_validation_specs: Dict, full_field_path: str, value: Dict) -> None:
for child_validation_spec in children_validation_specs:
Expand Down
2 changes: 1 addition & 1 deletion airflow/jobs/backfill_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,7 @@ def _per_task_process(task, key, ti, session=None):
if open_slots <= 0:
raise NoAvailablePoolSlot(
"Not scheduling since there are "
"%s open slots in pool %s".format(
"{0} open slots in pool {1}".format(
open_slots, task.pool))

num_running_task_instances_in_dag = DAG.get_num_task_instances(
Expand Down
74 changes: 37 additions & 37 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
import os
import signal
import time
from datetime import timedelta
from typing import Iterable, Optional, Union
from datetime import datetime, timedelta
from typing import Any, Dict, Iterable, Optional, Union
from urllib.parse import quote

import dill
Expand Down Expand Up @@ -721,16 +721,16 @@ def get_dagrun(self, session):
@provide_session
def _check_and_change_state_before_execution(
self,
verbose=True,
ignore_all_deps=False,
ignore_depends_on_past=False,
ignore_task_deps=False,
ignore_ti_state=False,
mark_success=False,
test_mode=False,
job_id=None,
pool=None,
session=None):
verbose: bool = True,
ignore_all_deps: bool = False,
ignore_depends_on_past: bool = False,
ignore_task_deps: bool = False,
ignore_ti_state: bool = False,
mark_success: bool = False,
test_mode: bool = False,
job_id: Optional[str] = None,
pool: Optional[str] = None,
session=None) -> bool:
"""
Checks dependencies and then sets state to RUNNING if they are met. Returns
True if and only if state is set to RUNNING, which implies that task should be
Expand Down Expand Up @@ -838,7 +838,7 @@ def _check_and_change_state_before_execution(

# Closing all pooled connections to prevent
# "max number of connections reached"
settings.engine.dispose()
settings.engine.dispose() # type: ignore
if verbose:
if mark_success:
self.log.info("Marking success for %s on %s", self.task, self.execution_date)
Expand All @@ -850,11 +850,11 @@ def _check_and_change_state_before_execution(
@Sentry.enrich_errors
def _run_raw_task(
self,
mark_success=False,
test_mode=False,
job_id=None,
pool=None,
session=None):
mark_success: bool = False,
test_mode: bool = False,
job_id: Optional[str] = None,
pool: Optional[str] = None,
session=None) -> None:
"""
Immediately runs the task (without checking or changing db state
before execution) and then sets the appropriate final state after
Expand All @@ -876,7 +876,7 @@ def _run_raw_task(
self.hostname = get_hostname()
self.operator = task.__class__.__name__

context = {}
context = {} # type: Dict
actual_start_date = timezone.utcnow()
try:
if not mark_success:
Expand Down Expand Up @@ -974,16 +974,16 @@ def signal_handler(signum, frame):
@provide_session
def run(
self,
verbose=True,
ignore_all_deps=False,
ignore_depends_on_past=False,
ignore_task_deps=False,
ignore_ti_state=False,
mark_success=False,
test_mode=False,
job_id=None,
pool=None,
session=None):
verbose: bool = True,
ignore_all_deps: bool = False,
ignore_depends_on_past: bool = False,
ignore_task_deps: bool = False,
ignore_ti_state: bool = False,
mark_success: bool = False,
test_mode: bool = False,
job_id: Optional[str] = None,
pool: Optional[str] = None,
session=None) -> None:
res = self._check_and_change_state_before_execution(
verbose=verbose,
ignore_all_deps=ignore_all_deps,
Expand Down Expand Up @@ -1096,11 +1096,11 @@ def is_eligible_to_retry(self):
return self.task.retries and self.try_number <= self.max_tries

@provide_session
def get_template_context(self, session=None):
def get_template_context(self, session=None) -> Dict[str, Any]:
task = self.task
from airflow import macros

params = {}
params = {} # type: Dict[str, Any]
run_id = ''
dag_run = None
if hasattr(task, 'dag'):
Expand Down Expand Up @@ -1236,7 +1236,7 @@ def overwrite_params_with_dag_run_conf(self, params, dag_run):
if dag_run and dag_run.conf:
params.update(dag_run.conf)

def render_templates(self, context=None) -> None:
def render_templates(self, context: Optional[Dict] = None) -> None:
"""Render templates in the operator fields."""
if not context:
context = self.get_template_context()
Expand Down Expand Up @@ -1281,17 +1281,17 @@ def render(key, content):
html_content = render('html_content_template', default_html_content)
send_email(self.task.email, subject, html_content)

def set_duration(self):
def set_duration(self) -> None:
if self.end_date and self.start_date:
self.duration = (self.end_date - self.start_date).total_seconds()
else:
self.duration = None

def xcom_push(
self,
key,
value,
execution_date=None):
key: str,
value: Any,
execution_date: Optional[datetime] = None) -> None:
"""
Make an XCom available for tasks to pull.
Expand Down Expand Up @@ -1324,7 +1324,7 @@ def xcom_pull(
task_ids: Optional[Union[str, Iterable[str]]] = None,
dag_id: Optional[str] = None,
key: str = XCOM_RETURN_KEY,
include_prior_dates: bool = False):
include_prior_dates: bool = False) -> Any:
"""
Pull XComs that optionally meet certain criteria.
Expand Down
10 changes: 5 additions & 5 deletions airflow/operators/dagrun_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,14 @@ def __init__(
self.trigger_dag_id = trigger_dag_id
self.conf = conf

if execution_date is None or isinstance(execution_date, (str, datetime.datetime)):
self.execution_date = execution_date
else:
if not isinstance(execution_date, (str, datetime.datetime, type(None))):
raise TypeError(
"Expected str or datetime.datetime type for execution_date. "
"Expected str or datetime.datetime type for execution_date."
"Got {}".format(type(execution_date))
)

self.execution_date: Optional[datetime.datetime] = execution_date # type: ignore

def execute(self, context: Dict):
if isinstance(self.execution_date, datetime.datetime):
run_id = "trig__{}".format(self.execution_date.isoformat())
Expand All @@ -72,7 +72,7 @@ def execute(self, context: Dict):
run_id = "trig__{}".format(timezone.utcnow().isoformat())

# Ignore MyPy type for self.execution_date because it doesn't pick up the timezone.parse() for strings
trigger_dag( # type: ignore
trigger_dag(
dag_id=self.trigger_dag_id,
run_id=run_id,
conf=self.conf,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def write_version(filename: str = os.path.join(*["airflow", "git_version"])):
############################################################################################################

if PY3:
devel += ['mypy==0.720']
devel += ['mypy==0.740']
else:
devel += ['unittest2']

Expand Down
2 changes: 1 addition & 1 deletion tests/cli/commands/test_dag_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def test_backfill(self, mock_run):

output = stdout.getvalue()
self.assertIn("Dry run of DAG example_bash_operator on {}\n".format(DEFAULT_DATE.isoformat()), output)
self.assertIn("Task runme_0\n".format(DEFAULT_DATE.isoformat()), output)
self.assertIn("Task runme_0\n", output)

mock_run.assert_not_called() # Dry run shouldn't run the backfill

Expand Down
6 changes: 2 additions & 4 deletions tests/cli/commands/test_db_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,13 @@ def test_cli_upgradedb(self, mock_upgradedb):
mock_upgradedb.assert_called_once_with()

@mock.patch("airflow.cli.commands.db_command.subprocess")
@mock.patch( # type: ignore
"airflow.cli.commands.db_command.NamedTemporaryFile",
**{'return_value.__enter__.return_value.name': "/tmp/name"}
)
@mock.patch("airflow.cli.commands.db_command.NamedTemporaryFile")
@mock.patch(
"airflow.cli.commands.db_command.settings.engine.url",
make_url("mysql://root@mysql/airflow")
)
def test_cli_shell_mysql(self, mock_tmp_file, mock_subprocess):
mock_tmp_file.return_value.__enter__.return_value.name = "/tmp/name"
db_command.shell(self.parser.parse_args(['db', 'shell']))
mock_subprocess.Popen.assert_called_once_with(
['mysql', '--defaults-extra-file=/tmp/name']
Expand Down
8 changes: 4 additions & 4 deletions tests/gcp/hooks/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,11 +366,11 @@ def test_get_credentials_and_project_id_with_default_auth(self, mock_auth_defaul
mock_auth_default.assert_called_once_with(scopes=self.instance.scopes)
self.assertEqual(('CREDENTIALS', 'PROJECT_ID'), result)

@mock.patch( # type: ignore
@mock.patch(
MODULE_NAME + '.google.oauth2.service_account.Credentials.from_service_account_file',
**{'return_value.project_id': "PROJECT_ID"}
)
def test_get_credentials_and_project_id_with_service_account_file(self, mock_from_service_account_file):
mock_from_service_account_file.return_value.project_id = "PROJECT_ID"
self.instance.extras = {
'extra__google_cloud_platform__key_path': "KEY_PATH.json"
}
Expand Down Expand Up @@ -400,11 +400,11 @@ def test_get_credentials_and_project_id_with_service_account_file_and_unknown_ke
with self.assertRaises(AirflowException):
self.instance._get_credentials_and_project_id()

@mock.patch( # type: ignore
@mock.patch(
MODULE_NAME + '.google.oauth2.service_account.Credentials.from_service_account_info',
**{'return_value.project_id': "PROJECT_ID"}
)
def test_get_credentials_and_project_id_with_service_account_info(self, mock_from_service_account_file):
mock_from_service_account_file.return_value.project_id = "PROJECT_ID"
service_account = {
'private_key': "PRIVATE_KEY"
}
Expand Down
4 changes: 1 addition & 3 deletions tests/gcp/hooks/test_bigquery_dts.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,7 @@ def setUp(self) -> None:
"airflow.gcp.hooks.bigquery_dts.GoogleCloudBaseHook.__init__",
new=mock_base_gcp_hook_no_default_project_id,
):
self.hook = BiqQueryDataTransferServiceHook( # type: ignore
gcp_conn_id=None
)
self.hook = BiqQueryDataTransferServiceHook()
self.hook._get_credentials = mock.MagicMock( # type: ignore
return_value=CREDENTIALS
)
Expand Down
2 changes: 1 addition & 1 deletion tests/gcp/hooks/test_dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,7 @@ def test_data_flow_valid_job_id(self):
cmd = [
'echo', 'additional unit test lines.\n' +
'https://console.cloud.google.com/dataflow/jobsDetail/locations/us-central1/'
'jobs/test-job-id?project=XXX'.format(TEST_JOB_ID)
'jobs/{}?project=XXX'.format(TEST_JOB_ID)
]
self.assertEqual(_DataflowRunner(cmd).wait_for_done(), TEST_JOB_ID)

Expand Down
Loading

0 comments on commit d404c9a

Please sign in to comment.