Skip to content

Commit

Permalink
feat(admin): add prohibited email domains (#16747)
Browse files Browse the repository at this point in the history
* feat(admin): add prohibited email domains

Signed-off-by: Mike Fiedler <[email protected]>

* add permissions

Signed-off-by: Mike Fiedler <[email protected]>

* add views and templates

Signed-off-by: Mike Fiedler <[email protected]>

* feat: handle non-exact domain inputs

Signed-off-by: Mike Fiedler <[email protected]>

* refactor query to use `exists()` subquery

Signed-off-by: Mike Fiedler <[email protected]>

* fix: disallow using the live service during extraction

If we have to do this a third time, we probably want to wrap the
extractor in a utility function.

Signed-off-by: Mike Fiedler <[email protected]>

---------

Signed-off-by: Mike Fiedler <[email protected]>
  • Loading branch information
miketheman committed Sep 20, 2024
1 parent f2150d1 commit 362f6b7
Show file tree
Hide file tree
Showing 10 changed files with 594 additions and 3 deletions.
19 changes: 18 additions & 1 deletion tests/common/db/accounts.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,21 @@
import datetime

import factory
import faker

from argon2 import PasswordHasher

from warehouse.accounts.models import Email, ProhibitedUserName, User
from warehouse.accounts.models import (
Email,
ProhibitedEmailDomain,
ProhibitedUserName,
User,
)

from .base import WarehouseFactory

fake = faker.Faker()


class UserFactory(WarehouseFactory):
class Meta:
Expand Down Expand Up @@ -90,6 +98,15 @@ class Meta:
transient_bounces = 0


class ProhibitedEmailDomainFactory(WarehouseFactory):
class Meta:
model = ProhibitedEmailDomain

# TODO: Replace when factory_boy supports `unique`.
# See https://github.com/FactoryBoy/factory_boy/pull/997
domain = factory.Sequence(lambda _: fake.unique.domain_name())


class ProhibitedUsernameFactory(WarehouseFactory):
class Meta:
model = ProhibitedUserName
Expand Down
15 changes: 15 additions & 0 deletions tests/unit/admin/test_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,21 @@ def test_includeme():
"/admin/prohibited_user_names/bulk/",
domain=warehouse,
),
pretend.call(
"admin.prohibited_email_domains.list",
"/admin/prohibited_email_domains/",
domain=warehouse,
),
pretend.call(
"admin.prohibited_email_domains.add",
"/admin/prohibited_email_domains/add/",
domain=warehouse,
),
pretend.call(
"admin.prohibited_email_domains.remove",
"/admin/prohibited_email_domains/remove/",
domain=warehouse,
),
pretend.call(
"admin.observations.list",
"/admin/observations/",
Expand Down
219 changes: 219 additions & 0 deletions tests/unit/admin/views/test_prohibited_email_domains.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pretend
import pytest

from pyramid.httpexceptions import HTTPBadRequest, HTTPSeeOther

from warehouse.admin.views import prohibited_email_domains as views

from ....common.db.accounts import ProhibitedEmailDomain, ProhibitedEmailDomainFactory


class TestProhibitedEmailDomainsList:
def test_no_query(self, db_request):
prohibited = sorted(
ProhibitedEmailDomainFactory.create_batch(30),
key=lambda b: b.created,
)

result = views.prohibited_email_domains(db_request)

assert result == {"prohibited_email_domains": prohibited[:25], "query": None}

def test_with_page(self, db_request):
prohibited = sorted(
ProhibitedEmailDomainFactory.create_batch(30),
key=lambda b: b.created,
)
db_request.GET["page"] = "2"

result = views.prohibited_email_domains(db_request)

assert result == {"prohibited_email_domains": prohibited[25:], "query": None}

def test_with_invalid_page(self):
request = pretend.stub(params={"page": "not an integer"})

with pytest.raises(HTTPBadRequest):
views.prohibited_email_domains(request)

def test_basic_query(self, db_request):
prohibited = sorted(
ProhibitedEmailDomainFactory.create_batch(30),
key=lambda b: b.created,
)
db_request.GET["q"] = prohibited[0].domain

result = views.prohibited_email_domains(db_request)

assert result == {
"prohibited_email_domains": [prohibited[0]],
"query": prohibited[0].domain,
}

def test_wildcard_query(self, db_request):
prohibited = sorted(
ProhibitedEmailDomainFactory.create_batch(30),
key=lambda b: b.created,
)
db_request.GET["q"] = f"{prohibited[0].domain[:-1]}%"

result = views.prohibited_email_domains(db_request)

assert result == {
"prohibited_email_domains": [prohibited[0]],
"query": f"{prohibited[0].domain[:-1]}%",
}


class TestProhibitedEmailDomainsAdd:
def test_no_email_domain(self, db_request):
db_request.method = "POST"
db_request.route_path = lambda a: "/admin/prohibited_email_domains/add/"
db_request.session = pretend.stub(
flash=pretend.call_recorder(lambda *a, **kw: None)
)
db_request.POST = {}

with pytest.raises(HTTPSeeOther):
views.add_prohibited_email_domain(db_request)

assert db_request.session.flash.calls == [
pretend.call("Email domain is required.", queue="error")
]

def test_invalid_domain(self, db_request):
db_request.method = "POST"
db_request.route_path = lambda a: "/admin/prohibited_email_domains/add/"
db_request.session = pretend.stub(
flash=pretend.call_recorder(lambda *a, **kw: None)
)
db_request.POST = {"email_domain": "invalid"}

with pytest.raises(HTTPSeeOther):
views.add_prohibited_email_domain(db_request)

assert db_request.session.flash.calls == [
pretend.call("Invalid domain name 'invalid'", queue="error")
]

def test_duplicate_domain(self, db_request):
existing_domain = ProhibitedEmailDomainFactory.create()
db_request.method = "POST"
db_request.route_path = lambda a: "/admin/prohibited_email_domains/add/"
db_request.session = pretend.stub(
flash=pretend.call_recorder(lambda *a, **kw: None)
)
db_request.POST = {"email_domain": existing_domain.domain}

with pytest.raises(HTTPSeeOther):
views.add_prohibited_email_domain(db_request)

assert db_request.session.flash.calls == [
pretend.call(
f"Email domain '{existing_domain.domain}' already exists.",
queue="error",
)
]

@pytest.mark.parametrize(
("input_domain", "expected_domain"),
[
("example.com", "example.com"),
("mail.example.co.uk", "example.co.uk"),
("https://example.com/", "example.com"),
],
)
def test_success(self, db_request, input_domain, expected_domain):
db_request.method = "POST"
db_request.route_path = lambda a: "/admin/prohibited_email_domains/list/"
db_request.session = pretend.stub(
flash=pretend.call_recorder(lambda *a, **kw: None)
)
db_request.POST = {
"email_domain": input_domain,
"is_mx_record": "on",
"comment": "testing",
}

response = views.add_prohibited_email_domain(db_request)

assert response.status_code == 303
assert response.headers["Location"] == "/admin/prohibited_email_domains/list/"
assert db_request.session.flash.calls == [
pretend.call("Prohibited email domain added.", queue="success")
]

query = db_request.db.query(ProhibitedEmailDomain).filter(
ProhibitedEmailDomain.domain == expected_domain
)
assert query.count() == 1
assert query.one().is_mx_record
assert query.one().comment == "testing"


class TestProhibitedEmailDomainsRemove:
def test_no_domain_name(self, db_request):
db_request.method = "POST"
db_request.route_path = lambda a: "/admin/prohibited_email_domains/remove/"
db_request.session = pretend.stub(
flash=pretend.call_recorder(lambda *a, **kw: None)
)
db_request.POST = {}

with pytest.raises(HTTPSeeOther):
views.remove_prohibited_email_domain(db_request)

assert db_request.session.flash.calls == [
pretend.call("Domain name is required.", queue="error")
]

def test_domain_not_found(self, db_request):
db_request.method = "POST"
db_request.route_path = lambda a: "/admin/prohibited_email_domains/remove/"
db_request.session = pretend.stub(
flash=pretend.call_recorder(lambda *a, **kw: None)
)
db_request.POST = {"domain_name": "example.com"}

with pytest.raises(HTTPSeeOther):
views.remove_prohibited_email_domain(db_request)

assert db_request.session.flash.calls == [
pretend.call("Domain not found.", queue="error")
]

def test_success(self, db_request):
domain = ProhibitedEmailDomainFactory.create()
db_request.method = "POST"
db_request.route_path = lambda a: "/admin/prohibited_email_domains/list/"
db_request.session = pretend.stub(
flash=pretend.call_recorder(lambda *a, **kw: None)
)
db_request.POST = {"domain_name": domain.domain}

response = views.remove_prohibited_email_domain(db_request)

assert response.status_code == 303
assert response.headers["Location"] == "/admin/prohibited_email_domains/list/"
assert db_request.session.flash.calls == [
pretend.call(
f"Prohibited email domain '{domain.domain}' removed.", queue="success"
)
]

query = db_request.db.query(ProhibitedEmailDomain).filter(
ProhibitedEmailDomain.domain == domain.domain
)
assert query.count() == 0
4 changes: 4 additions & 0 deletions tests/unit/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,8 @@ def test_root_factory_access_control_list():
Permissions.AdminObservationsWrite,
Permissions.AdminOrganizationsRead,
Permissions.AdminOrganizationsWrite,
Permissions.AdminProhibitedEmailDomainsRead,
Permissions.AdminProhibitedEmailDomainsWrite,
Permissions.AdminProhibitedProjectsRead,
Permissions.AdminProhibitedProjectsWrite,
Permissions.AdminProhibitedUsernameRead,
Expand Down Expand Up @@ -604,6 +606,7 @@ def test_root_factory_access_control_list():
Permissions.AdminObservationsRead,
Permissions.AdminObservationsWrite,
Permissions.AdminOrganizationsRead,
Permissions.AdminProhibitedEmailDomainsRead,
Permissions.AdminProhibitedProjectsRead,
Permissions.AdminProhibitedUsernameRead,
Permissions.AdminProjectsRead,
Expand All @@ -629,6 +632,7 @@ def test_root_factory_access_control_list():
Permissions.AdminObservationsRead,
Permissions.AdminObservationsWrite,
Permissions.AdminOrganizationsRead,
Permissions.AdminProhibitedEmailDomainsRead,
Permissions.AdminProhibitedProjectsRead,
Permissions.AdminProhibitedUsernameRead,
Permissions.AdminProjectsRead,
Expand Down
16 changes: 16 additions & 0 deletions warehouse/admin/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,22 @@ def includeme(config):
"/admin/prohibited_user_names/bulk/",
domain=warehouse,
)
# Prohibited Email related Admin pages
config.add_route(
"admin.prohibited_email_domains.list",
"/admin/prohibited_email_domains/",
domain=warehouse,
)
config.add_route(
"admin.prohibited_email_domains.add",
"/admin/prohibited_email_domains/add/",
domain=warehouse,
)
config.add_route(
"admin.prohibited_email_domains.remove",
"/admin/prohibited_email_domains/remove/",
domain=warehouse,
)

# Observation related Admin pages
config.add_route(
Expand Down
4 changes: 2 additions & 2 deletions warehouse/admin/templates/admin/base.html
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,9 @@
</a>
</li>
<li class="nav-item">
<a href="#" class="nav-link">
<a href="{{ request.route_path('admin.prohibited_email_domains.list') }}" class="nav-link">
<i class="nav-icon fa fa-envelope fa-flip-vertical"></i>
<p>Email Domains <span class="right badge badge-warning">TODO</span></p>
<p>Email Domains</p>
</a>
</li>
</ul>
Expand Down
Loading

0 comments on commit 362f6b7

Please sign in to comment.