Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AIRFLOW-5640] fix get_email_address_list types #6315

Merged
merged 1 commit into from
Oct 18, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,16 @@ class derived from this one results in the creation of a task object,
:type task_id: str
:param owner: the owner of the task, using the unix username is recommended
:type owner: str
:param email: the 'to' email address(es) used in email alerts. This can be a
single email or multiple ones. Multiple addresses can be specified as a
comma or semi-colon separated string or by passing a list of strings.
:type email: str or list[str]
:param email_on_retry: Indicates whether email alerts should be sent when a
task is retried
:type email_on_retry: bool
:param email_on_failure: Indicates whether email alerts should be sent when
a task failed
:type email_on_failure: bool
:param retries: the number of retries that should be performed before
failing the task
:type retries: int
Expand Down Expand Up @@ -156,6 +166,8 @@ class derived from this one results in the creation of a task object,
DAGS. Options can be set as string or using the constants defined in
the static class ``airflow.utils.WeightRule``
:type weight_rule: str
:param queue: specifies which task queue to use
:type queue: str
:param pool: the slot pool this task should run in, slot pools are a
way to limit concurrency for certain tasks
:type pool: str
Expand Down Expand Up @@ -270,7 +282,7 @@ def __init__(
self,
task_id: str,
owner: str = conf.get('operators', 'DEFAULT_OWNER'),
email: Optional[str] = None,
email: Optional[Union[str, Iterable[str]]] = None,
email_on_retry: bool = True,
email_on_failure: bool = True,
retries: Optional[int] = conf.getint('core', 'default_task_retries', fallback=0),
Expand Down
29 changes: 20 additions & 9 deletions airflow/utils/email.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@
# specific language governing permissions and limitations
# under the License.

import collections
import importlib
import os
import smtplib
from email.mime.application import MIMEApplication
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from email.utils import formatdate
from typing import Iterable, List, Union

from airflow.configuration import conf
from airflow.exceptions import AirflowConfigException
Expand Down Expand Up @@ -120,13 +122,22 @@ def send_MIME_email(e_from, e_to, mime_msg, dryrun=False):
s.quit()


def get_email_address_list(address_string):
if isinstance(address_string, str):
if ',' in address_string:
address_string = [address.strip() for address in address_string.split(',')]
elif ';' in address_string:
address_string = [address.strip() for address in address_string.split(';')]
else:
address_string = [address_string]
def get_email_address_list(addresses: Union[str, Iterable[str]]) -> List[str]:
SaturnFromTitan marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(addresses, str):
return _get_email_list_from_str(addresses)

return address_string
elif isinstance(addresses, collections.abc.Iterable):
if not all(isinstance(item, str) for item in addresses):
raise TypeError("The items in your iterable must be strings.")
return list(addresses)

received_type = type(addresses).__name__
raise TypeError("Unexpected argument type: Received '{}'.".format(received_type))


def _get_email_list_from_str(addresses: str) -> List[str]:
delimiters = [",", ";"]
for delimiter in delimiters:
if delimiter in addresses:
return [address.strip() for address in addresses.split(delimiter)]
return [addresses]
SaturnFromTitan marked this conversation as resolved.
Show resolved Hide resolved
24 changes: 24 additions & 0 deletions tests/utils/test_email.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@

class TestEmail(unittest.TestCase):

def test_get_email_address_single_email(self):
emails_string = '[email protected]'

self.assertEqual(
get_email_address_list(emails_string), [emails_string])

def test_get_email_address_comma_sep_string(self):
emails_string = '[email protected], [email protected]'

Expand All @@ -43,3 +49,21 @@ def test_get_email_address_list(self):

self.assertEqual(
get_email_address_list(emails_list), EMAILS)

def test_get_email_address_tuple(self):
emails_tuple = ('[email protected]', '[email protected]')

self.assertEqual(
get_email_address_list(emails_tuple), EMAILS)

def test_get_email_address_invalid_type(self):
emails_string = 1

self.assertRaises(
TypeError, get_email_address_list, emails_string)

def test_get_email_address_invalid_type_in_iterable(self):
emails_list = ['[email protected]', 2]

self.assertRaises(
TypeError, get_email_address_list, emails_list)