Skip to content

Commit

Permalink
Add means to Duplicate connections from UI (#15574)
Browse files Browse the repository at this point in the history
Co-authored-by: Ash Berlin-Taylor <[email protected]>
Co-authored-by: Tzu-ping Chung <[email protected]>
  • Loading branch information
3 people authored Jun 17, 2021
1 parent 621ef76 commit 2011da2
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 0 deletions.
53 changes: 53 additions & 0 deletions airflow/www/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import json
import logging
import math
import re
import socket
import sys
import traceback
Expand Down Expand Up @@ -78,6 +79,7 @@
from pygments import highlight, lexers
from pygments.formatters import HtmlFormatter # noqa pylint: disable=no-name-in-module
from sqlalchemy import Date, and_, desc, func, or_, union_all
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import joinedload
from wtforms import SelectField, validators
from wtforms.validators import InputRequired
Expand Down Expand Up @@ -3124,6 +3126,7 @@ class ConnectionModelView(AirflowModelView):
'edit': 'edit',
'delete': 'delete',
'action_muldelete': 'delete',
'action_mulduplicate': 'create',
}

base_permissions = [
Expand Down Expand Up @@ -3177,6 +3180,56 @@ def action_muldelete(self, items):
self.update_redirect()
return redirect(self.get_redirect())

@action(
'mulduplicate',
'Duplicate',
'Are you sure you want to duplicate the selected connections?',
single=False,
)
@provide_session
@auth.has_access(
[
(permissions.ACTION_CAN_CREATE, permissions.RESOURCE_CONNECTION),
(permissions.ACTION_CAN_READ, permissions.RESOURCE_CONNECTION),
]
)
def action_mulduplicate(self, connections, session=None):
"""Duplicate Multiple connections"""
for selected_conn in connections:
new_conn_id = selected_conn.conn_id
match = re.search(r"_copy(\d+)$", selected_conn.conn_id)
if match:
conn_id_prefix = selected_conn.conn_id[: match.start()]
new_conn_id = f"{conn_id_prefix}_copy{int(match.group(1)) + 1}"
else:
new_conn_id += '_copy1'

dup_conn = Connection(
new_conn_id,
selected_conn.conn_type,
selected_conn.description,
selected_conn.host,
selected_conn.login,
selected_conn.password,
selected_conn.schema,
selected_conn.port,
selected_conn.extra,
)

try:
session.add(dup_conn)
session.commit()
flash(f"Connection {new_conn_id} added successfully.", "success")
except IntegrityError:
flash(
f"Connection {new_conn_id} can't be added. Integrity error, probably unique constraint.",
"warning",
)
session.rollback()

self.update_redirect()
return redirect(self.get_redirect())

def process_form(self, form, is_created):
"""Process form data."""
conn_type = form.data['conn_type']
Expand Down
42 changes: 42 additions & 0 deletions tests/www/views/test_views_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,45 @@ def test_prefill_form_null_extra():

cmv = ConnectionModelView()
cmv.prefill_form(form=mock_form, pk=1)


def test_duplicate_connection(admin_client):
"""Test Duplicate multiple connection with suffix"""
conn1 = Connection(
conn_id='test_duplicate_gcp_connection',
conn_type='Google Cloud',
description='Google Cloud Connection',
)
conn2 = Connection(
conn_id='test_duplicate_mysql_connection',
conn_type='FTP',
description='MongoDB2',
host='localhost',
schema='airflow',
port=3306,
)
conn3 = Connection(
conn_id='test_duplicate_postgres_connection_copy1',
conn_type='FTP',
description='Postgres',
host='localhost',
schema='airflow',
port=3306,
)
with create_session() as session:
session.query(Connection).delete()
session.add_all([conn1, conn2, conn3])
session.commit()

data = {"action": "mulduplicate", "rowid": [conn1.id, conn3.id]}
resp = admin_client.post('/connection/action_post', data=data, follow_redirects=True)
expected_result = {
'test_duplicate_gcp_connection',
'test_duplicate_gcp_connection_copy1',
'test_duplicate_mysql_connection',
'test_duplicate_postgres_connection_copy1',
'test_duplicate_postgres_connection_copy2',
}
response = {conn[0] for conn in session.query(Connection.conn_id).all()}
assert resp.status_code == 200
assert expected_result == response

0 comments on commit 2011da2

Please sign in to comment.