From 362f6b7a39e5e34aa061c3f9769e3989ca115088 Mon Sep 17 00:00:00 2001 From: Mike Fiedler Date: Fri, 20 Sep 2024 09:31:35 -0400 Subject: [PATCH] feat(admin): add prohibited email domains (#16747) * feat(admin): add prohibited email domains Signed-off-by: Mike Fiedler * add permissions Signed-off-by: Mike Fiedler * add views and templates Signed-off-by: Mike Fiedler * feat: handle non-exact domain inputs Signed-off-by: Mike Fiedler * refactor query to use `exists()` subquery Signed-off-by: Mike Fiedler * fix: disallow using the live service during extraction If we have to do this a third time, we probably want to wrap the extractor in a utility function. Signed-off-by: Mike Fiedler --------- Signed-off-by: Mike Fiedler --- tests/common/db/accounts.py | 19 +- tests/unit/admin/test_routes.py | 15 ++ .../views/test_prohibited_email_domains.py | 219 ++++++++++++++++++ tests/unit/test_config.py | 4 + warehouse/admin/routes.py | 16 ++ warehouse/admin/templates/admin/base.html | 4 +- .../admin/prohibited_email_domains/list.html | 183 +++++++++++++++ .../admin/views/prohibited_email_domains.py | 130 +++++++++++ warehouse/authnz/_permissions.py | 3 + warehouse/config.py | 4 + 10 files changed, 594 insertions(+), 3 deletions(-) create mode 100644 tests/unit/admin/views/test_prohibited_email_domains.py create mode 100644 warehouse/admin/templates/admin/prohibited_email_domains/list.html create mode 100644 warehouse/admin/views/prohibited_email_domains.py diff --git a/tests/common/db/accounts.py b/tests/common/db/accounts.py index 435bd1717938..00e2756012cf 100644 --- a/tests/common/db/accounts.py +++ b/tests/common/db/accounts.py @@ -13,13 +13,21 @@ import datetime import factory +import faker from argon2 import PasswordHasher -from warehouse.accounts.models import Email, ProhibitedUserName, User +from warehouse.accounts.models import ( + Email, + ProhibitedEmailDomain, + ProhibitedUserName, + User, +) from .base import WarehouseFactory +fake = faker.Faker() + class UserFactory(WarehouseFactory): class Meta: @@ -90,6 +98,15 @@ class Meta: transient_bounces = 0 +class ProhibitedEmailDomainFactory(WarehouseFactory): + class Meta: + model = ProhibitedEmailDomain + + # TODO: Replace when factory_boy supports `unique`. + # See https://github.com/FactoryBoy/factory_boy/pull/997 + domain = factory.Sequence(lambda _: fake.unique.domain_name()) + + class ProhibitedUsernameFactory(WarehouseFactory): class Meta: model = ProhibitedUserName diff --git a/tests/unit/admin/test_routes.py b/tests/unit/admin/test_routes.py index 9b1125594bda..ab481bf4cd75 100644 --- a/tests/unit/admin/test_routes.py +++ b/tests/unit/admin/test_routes.py @@ -284,6 +284,21 @@ def test_includeme(): "/admin/prohibited_user_names/bulk/", domain=warehouse, ), + pretend.call( + "admin.prohibited_email_domains.list", + "/admin/prohibited_email_domains/", + domain=warehouse, + ), + pretend.call( + "admin.prohibited_email_domains.add", + "/admin/prohibited_email_domains/add/", + domain=warehouse, + ), + pretend.call( + "admin.prohibited_email_domains.remove", + "/admin/prohibited_email_domains/remove/", + domain=warehouse, + ), pretend.call( "admin.observations.list", "/admin/observations/", diff --git a/tests/unit/admin/views/test_prohibited_email_domains.py b/tests/unit/admin/views/test_prohibited_email_domains.py new file mode 100644 index 000000000000..9530aea62f43 --- /dev/null +++ b/tests/unit/admin/views/test_prohibited_email_domains.py @@ -0,0 +1,219 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pretend +import pytest + +from pyramid.httpexceptions import HTTPBadRequest, HTTPSeeOther + +from warehouse.admin.views import prohibited_email_domains as views + +from ....common.db.accounts import ProhibitedEmailDomain, ProhibitedEmailDomainFactory + + +class TestProhibitedEmailDomainsList: + def test_no_query(self, db_request): + prohibited = sorted( + ProhibitedEmailDomainFactory.create_batch(30), + key=lambda b: b.created, + ) + + result = views.prohibited_email_domains(db_request) + + assert result == {"prohibited_email_domains": prohibited[:25], "query": None} + + def test_with_page(self, db_request): + prohibited = sorted( + ProhibitedEmailDomainFactory.create_batch(30), + key=lambda b: b.created, + ) + db_request.GET["page"] = "2" + + result = views.prohibited_email_domains(db_request) + + assert result == {"prohibited_email_domains": prohibited[25:], "query": None} + + def test_with_invalid_page(self): + request = pretend.stub(params={"page": "not an integer"}) + + with pytest.raises(HTTPBadRequest): + views.prohibited_email_domains(request) + + def test_basic_query(self, db_request): + prohibited = sorted( + ProhibitedEmailDomainFactory.create_batch(30), + key=lambda b: b.created, + ) + db_request.GET["q"] = prohibited[0].domain + + result = views.prohibited_email_domains(db_request) + + assert result == { + "prohibited_email_domains": [prohibited[0]], + "query": prohibited[0].domain, + } + + def test_wildcard_query(self, db_request): + prohibited = sorted( + ProhibitedEmailDomainFactory.create_batch(30), + key=lambda b: b.created, + ) + db_request.GET["q"] = f"{prohibited[0].domain[:-1]}%" + + result = views.prohibited_email_domains(db_request) + + assert result == { + "prohibited_email_domains": [prohibited[0]], + "query": f"{prohibited[0].domain[:-1]}%", + } + + +class TestProhibitedEmailDomainsAdd: + def test_no_email_domain(self, db_request): + db_request.method = "POST" + db_request.route_path = lambda a: "/admin/prohibited_email_domains/add/" + db_request.session = pretend.stub( + flash=pretend.call_recorder(lambda *a, **kw: None) + ) + db_request.POST = {} + + with pytest.raises(HTTPSeeOther): + views.add_prohibited_email_domain(db_request) + + assert db_request.session.flash.calls == [ + pretend.call("Email domain is required.", queue="error") + ] + + def test_invalid_domain(self, db_request): + db_request.method = "POST" + db_request.route_path = lambda a: "/admin/prohibited_email_domains/add/" + db_request.session = pretend.stub( + flash=pretend.call_recorder(lambda *a, **kw: None) + ) + db_request.POST = {"email_domain": "invalid"} + + with pytest.raises(HTTPSeeOther): + views.add_prohibited_email_domain(db_request) + + assert db_request.session.flash.calls == [ + pretend.call("Invalid domain name 'invalid'", queue="error") + ] + + def test_duplicate_domain(self, db_request): + existing_domain = ProhibitedEmailDomainFactory.create() + db_request.method = "POST" + db_request.route_path = lambda a: "/admin/prohibited_email_domains/add/" + db_request.session = pretend.stub( + flash=pretend.call_recorder(lambda *a, **kw: None) + ) + db_request.POST = {"email_domain": existing_domain.domain} + + with pytest.raises(HTTPSeeOther): + views.add_prohibited_email_domain(db_request) + + assert db_request.session.flash.calls == [ + pretend.call( + f"Email domain '{existing_domain.domain}' already exists.", + queue="error", + ) + ] + + @pytest.mark.parametrize( + ("input_domain", "expected_domain"), + [ + ("example.com", "example.com"), + ("mail.example.co.uk", "example.co.uk"), + ("https://example.com/", "example.com"), + ], + ) + def test_success(self, db_request, input_domain, expected_domain): + db_request.method = "POST" + db_request.route_path = lambda a: "/admin/prohibited_email_domains/list/" + db_request.session = pretend.stub( + flash=pretend.call_recorder(lambda *a, **kw: None) + ) + db_request.POST = { + "email_domain": input_domain, + "is_mx_record": "on", + "comment": "testing", + } + + response = views.add_prohibited_email_domain(db_request) + + assert response.status_code == 303 + assert response.headers["Location"] == "/admin/prohibited_email_domains/list/" + assert db_request.session.flash.calls == [ + pretend.call("Prohibited email domain added.", queue="success") + ] + + query = db_request.db.query(ProhibitedEmailDomain).filter( + ProhibitedEmailDomain.domain == expected_domain + ) + assert query.count() == 1 + assert query.one().is_mx_record + assert query.one().comment == "testing" + + +class TestProhibitedEmailDomainsRemove: + def test_no_domain_name(self, db_request): + db_request.method = "POST" + db_request.route_path = lambda a: "/admin/prohibited_email_domains/remove/" + db_request.session = pretend.stub( + flash=pretend.call_recorder(lambda *a, **kw: None) + ) + db_request.POST = {} + + with pytest.raises(HTTPSeeOther): + views.remove_prohibited_email_domain(db_request) + + assert db_request.session.flash.calls == [ + pretend.call("Domain name is required.", queue="error") + ] + + def test_domain_not_found(self, db_request): + db_request.method = "POST" + db_request.route_path = lambda a: "/admin/prohibited_email_domains/remove/" + db_request.session = pretend.stub( + flash=pretend.call_recorder(lambda *a, **kw: None) + ) + db_request.POST = {"domain_name": "example.com"} + + with pytest.raises(HTTPSeeOther): + views.remove_prohibited_email_domain(db_request) + + assert db_request.session.flash.calls == [ + pretend.call("Domain not found.", queue="error") + ] + + def test_success(self, db_request): + domain = ProhibitedEmailDomainFactory.create() + db_request.method = "POST" + db_request.route_path = lambda a: "/admin/prohibited_email_domains/list/" + db_request.session = pretend.stub( + flash=pretend.call_recorder(lambda *a, **kw: None) + ) + db_request.POST = {"domain_name": domain.domain} + + response = views.remove_prohibited_email_domain(db_request) + + assert response.status_code == 303 + assert response.headers["Location"] == "/admin/prohibited_email_domains/list/" + assert db_request.session.flash.calls == [ + pretend.call( + f"Prohibited email domain '{domain.domain}' removed.", queue="success" + ) + ] + + query = db_request.db.query(ProhibitedEmailDomain).filter( + ProhibitedEmailDomain.domain == domain.domain + ) + assert query.count() == 0 diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index 71b27843713b..ba312034fefb 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -574,6 +574,8 @@ def test_root_factory_access_control_list(): Permissions.AdminObservationsWrite, Permissions.AdminOrganizationsRead, Permissions.AdminOrganizationsWrite, + Permissions.AdminProhibitedEmailDomainsRead, + Permissions.AdminProhibitedEmailDomainsWrite, Permissions.AdminProhibitedProjectsRead, Permissions.AdminProhibitedProjectsWrite, Permissions.AdminProhibitedUsernameRead, @@ -604,6 +606,7 @@ def test_root_factory_access_control_list(): Permissions.AdminObservationsRead, Permissions.AdminObservationsWrite, Permissions.AdminOrganizationsRead, + Permissions.AdminProhibitedEmailDomainsRead, Permissions.AdminProhibitedProjectsRead, Permissions.AdminProhibitedUsernameRead, Permissions.AdminProjectsRead, @@ -629,6 +632,7 @@ def test_root_factory_access_control_list(): Permissions.AdminObservationsRead, Permissions.AdminObservationsWrite, Permissions.AdminOrganizationsRead, + Permissions.AdminProhibitedEmailDomainsRead, Permissions.AdminProhibitedProjectsRead, Permissions.AdminProhibitedUsernameRead, Permissions.AdminProjectsRead, diff --git a/warehouse/admin/routes.py b/warehouse/admin/routes.py index 6fb87d7a4520..0f8b941d7599 100644 --- a/warehouse/admin/routes.py +++ b/warehouse/admin/routes.py @@ -293,6 +293,22 @@ def includeme(config): "/admin/prohibited_user_names/bulk/", domain=warehouse, ) + # Prohibited Email related Admin pages + config.add_route( + "admin.prohibited_email_domains.list", + "/admin/prohibited_email_domains/", + domain=warehouse, + ) + config.add_route( + "admin.prohibited_email_domains.add", + "/admin/prohibited_email_domains/add/", + domain=warehouse, + ) + config.add_route( + "admin.prohibited_email_domains.remove", + "/admin/prohibited_email_domains/remove/", + domain=warehouse, + ) # Observation related Admin pages config.add_route( diff --git a/warehouse/admin/templates/admin/base.html b/warehouse/admin/templates/admin/base.html index e2bb9eefa842..7416eb3e08c8 100644 --- a/warehouse/admin/templates/admin/base.html +++ b/warehouse/admin/templates/admin/base.html @@ -175,9 +175,9 @@ diff --git a/warehouse/admin/templates/admin/prohibited_email_domains/list.html b/warehouse/admin/templates/admin/prohibited_email_domains/list.html new file mode 100644 index 000000000000..3579ac9a4967 --- /dev/null +++ b/warehouse/admin/templates/admin/prohibited_email_domains/list.html @@ -0,0 +1,183 @@ +{# + # Licensed under the Apache License, Version 2.0 (the "License"); + # you may not use this file except in compliance with the License. + # You may obtain a copy of the License at + # + # http://www.apache.org/licenses/LICENSE-2.0 + # + # Unless required by applicable law or agreed to in writing, software + # distributed under the License is distributed on an "AS IS" BASIS, + # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + # See the License for the specific language governing permissions and + # limitations under the License. +-#} + +{% extends "admin/base.html" %} + +{% import "admin/utils/pagination.html" as pagination %} + +{% set perms_admin_prohibited_email_domain_write = request.has_permission(Permissions.AdminProhibitedEmailDomainsWrite) %} + +{% block title %} + Prohibited Email Domains +{% endblock title %} + +{% block breadcrumb %} + + +{% endblock breadcrumb %} + +{% block content %} +
+
+
+
+ + +
+ +
+
+
+
+
+
+
+ + + + + + + + + + + + + {% for prohibited_email_domain in prohibited_email_domains %} + + + + + + + + + {% endfor %} + +
Domain NameMX record?Prohibited byProhibited onComment
{{ prohibited_email_domain.domain }} + {% if prohibited_email_domain.is_mx_record %}{% endif %} + + + {{ prohibited_email_domain.prohibited_by.username }} + + {{ prohibited_email_domain.created | format_datetime }}{{ prohibited_email_domain.comment }} + + + + + +
+
+ +
+
+
+
+

Prohibit email domain

+
+
+ +
+
+ + +
+
+ + +
+
+ + +
+
+ +
+
+
+{% endblock content %} diff --git a/warehouse/admin/views/prohibited_email_domains.py b/warehouse/admin/views/prohibited_email_domains.py new file mode 100644 index 000000000000..73c4d9852aee --- /dev/null +++ b/warehouse/admin/views/prohibited_email_domains.py @@ -0,0 +1,130 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paginate_sqlalchemy import SqlalchemyOrmPage as SQLAlchemyORMPage +from pyramid.httpexceptions import HTTPBadRequest, HTTPSeeOther +from pyramid.view import view_config +from sqlalchemy import exists, select +from tldextract import TLDExtract + +from warehouse.accounts.models import ProhibitedEmailDomain +from warehouse.authnz import Permissions +from warehouse.utils.paginate import paginate_url_factory + + +@view_config( + route_name="admin.prohibited_email_domains.list", + renderer="admin/prohibited_email_domains/list.html", + permission=Permissions.AdminProhibitedEmailDomainsRead, + request_method="GET", + uses_session=True, +) +def prohibited_email_domains(request): + q = request.params.get("q") + + try: + page_num = int(request.params.get("page", 1)) + except ValueError: + raise HTTPBadRequest("'page' must be an integer.") from None + + prohibited_email_domains_query = request.db.query(ProhibitedEmailDomain).order_by( + ProhibitedEmailDomain.created.desc() + ) + + if q: + prohibited_email_domains_query = prohibited_email_domains_query.filter( + ProhibitedEmailDomain.domain.ilike(q) + ) + + prohibited_email_domains = SQLAlchemyORMPage( + prohibited_email_domains_query, + page=page_num, + items_per_page=25, + url_maker=paginate_url_factory(request), + ) + + return {"prohibited_email_domains": prohibited_email_domains, "query": q} + + +@view_config( + route_name="admin.prohibited_email_domains.add", + permission=Permissions.AdminProhibitedEmailDomainsWrite, + request_method="POST", + uses_session=True, + require_methods=False, +) +def add_prohibited_email_domain(request): + email_domain = request.POST.get("email_domain") + if not email_domain: + request.session.flash("Email domain is required.", queue="error") + raise HTTPSeeOther(request.route_path("admin.prohibited_email_domains.list")) + # validate that the domain is valid + extractor = TLDExtract(suffix_list_urls=()) # Updated during image build + registered_domain = extractor(email_domain).registered_domain + if not registered_domain: + request.session.flash(f"Invalid domain name '{email_domain}'", queue="error") + raise HTTPSeeOther(request.route_path("admin.prohibited_email_domains.list")) + # make sure we don't have a duplicate entry + if request.db.scalar( + select(exists().where(ProhibitedEmailDomain.domain == registered_domain)) + ): + request.session.flash( + f"Email domain '{registered_domain}' already exists.", queue="error" + ) + raise HTTPSeeOther(request.route_path("admin.prohibited_email_domains.list")) + + # Add the domain to the database + is_mx_record = bool(request.POST.get("is_mx_record")) + comment = request.POST.get("comment") + + prohibited_email_domain = ProhibitedEmailDomain( + domain=registered_domain, + is_mx_record=is_mx_record, + prohibited_by=request.user, + comment=comment, + ) + + request.db.add(prohibited_email_domain) + request.session.flash("Prohibited email domain added.", queue="success") + + return HTTPSeeOther(request.route_path("admin.prohibited_email_domains.list")) + + +@view_config( + route_name="admin.prohibited_email_domains.remove", + permission=Permissions.AdminProhibitedEmailDomainsWrite, + request_method="POST", + uses_session=True, + require_methods=False, +) +def remove_prohibited_email_domain(request): + domain_name = request.POST.get("domain_name") + if not domain_name: + request.session.flash("Domain name is required.", queue="error") + raise HTTPSeeOther(request.route_path("admin.prohibited_email_domains.list")) + + domain_record = ( + request.db.query(ProhibitedEmailDomain) + .filter(ProhibitedEmailDomain.domain == domain_name) + .first() + ) + + if not domain_record: + request.session.flash("Domain not found.", queue="error") + raise HTTPSeeOther(request.route_path("admin.prohibited_email_domains.list")) + + request.db.delete(domain_record) + request.session.flash( + f"Prohibited email domain '{domain_record.domain}' removed.", queue="success" + ) + + return HTTPSeeOther(request.route_path("admin.prohibited_email_domains.list")) diff --git a/warehouse/authnz/_permissions.py b/warehouse/authnz/_permissions.py index 373788380ca2..62d60ac07088 100644 --- a/warehouse/authnz/_permissions.py +++ b/warehouse/authnz/_permissions.py @@ -60,6 +60,9 @@ class Permissions(StrEnum): AdminOrganizationsRead = "admin:organizations:read" AdminOrganizationsWrite = "admin:organizations:write" + AdminProhibitedEmailDomainsRead = "admin:prohibited-email-domains:read" + AdminProhibitedEmailDomainsWrite = "admin:prohibited-email-domains:write" + AdminProhibitedProjectsRead = "admin:prohibited-projects:read" AdminProhibitedProjectsWrite = "admin:prohibited-projects:write" diff --git a/warehouse/config.py b/warehouse/config.py index 5234e08363d0..32065f72933c 100644 --- a/warehouse/config.py +++ b/warehouse/config.py @@ -87,6 +87,8 @@ class RootFactory: Permissions.AdminObservationsWrite, Permissions.AdminOrganizationsRead, Permissions.AdminOrganizationsWrite, + Permissions.AdminProhibitedEmailDomainsRead, + Permissions.AdminProhibitedEmailDomainsWrite, Permissions.AdminProhibitedProjectsRead, Permissions.AdminProhibitedProjectsWrite, Permissions.AdminProhibitedUsernameRead, @@ -117,6 +119,7 @@ class RootFactory: Permissions.AdminObservationsRead, Permissions.AdminObservationsWrite, Permissions.AdminOrganizationsRead, + Permissions.AdminProhibitedEmailDomainsRead, Permissions.AdminProhibitedProjectsRead, Permissions.AdminProhibitedUsernameRead, Permissions.AdminProjectsRead, @@ -142,6 +145,7 @@ class RootFactory: Permissions.AdminObservationsRead, Permissions.AdminObservationsWrite, Permissions.AdminOrganizationsRead, + Permissions.AdminProhibitedEmailDomainsRead, Permissions.AdminProhibitedProjectsRead, Permissions.AdminProhibitedUsernameRead, Permissions.AdminProjectsRead,