Skip to content

Commit

Permalink
4.8.0
Browse files Browse the repository at this point in the history
- Fix DNSSEC test
- Add `DMARCRecordStartsWithWhitespace` exception (PR #97)
- Properly parse DMARC and BIMI records for domains that do not have an identified base domain (PR #98)
- Add `ignore_unrelated_records` argument to `query_dmarc_record()` (Slight modification of PR #99 - Close issue #91)
- Mark syntax error positions (Slight modification of PR #100)
  • Loading branch information
seanthegeek committed Sep 7, 2023
1 parent 5dde906 commit bb758a7
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 47 deletions.
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
Changelog
=========

4.8.0
-----

- Fix DNSSEC test
- Add `DMARCRecordStartsWithWhitespace` exception (PR #97)
- Properly parse DMARC and BIMI records for domains that do not have an identified base domain (PR #98)
- Add `ignore_unrelated_records` argument to `query_dmarc_record()` (Slight modification of PR #99 - Close issue #91)
- Mark syntax error positions (Slight modification of PR #100)

4.7.0
-----

Expand Down
98 changes: 51 additions & 47 deletions checkdmarc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@
import platform
import shutil
import atexit
from ssl import SSLError, CertificateError, create_default_context
from ssl import SSLError, create_default_context

from io import StringIO
from expiringdict import ExpiringDict

import publicsuffixlist
import dns
import dns.resolver
import dns.dnssec
import dns.exception
import timeout_decorator
from pyleri import (Grammar,
Expand All @@ -35,7 +36,7 @@
)
import ipaddress

"""Copyright 2019 Sean Whalen
"""Copyright 2019-2023 Sean Whalen
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand All @@ -49,7 +50,7 @@
See the License for the specific language governing permissions and
limitations under the License."""

__version__ = "4.7.0"
__version__ = "4.8.0"

DMARC_VERSION_REGEX_STRING = r"v *= *DMARC1;"
BIMI_VERSION_REGEX_STRING = r"v=BIMI1;"
Expand Down Expand Up @@ -79,7 +80,7 @@
TLS_CACHE = ExpiringDict(max_len=200000, max_age_seconds=1800)
STARTTLS_CACHE = ExpiringDict(max_len=200000, max_age_seconds=1800)

SYNTAX_ERROR_MARK_SIGN = "➞"
SYNTAX_ERROR_MARKER = "➞"

TMPDIR = tempfile.mkdtemp()

Expand Down Expand Up @@ -627,7 +628,7 @@ def _query_dns(domain, record_type, nameservers=None, resolver=None,
if cache is None:
cache = DNS_CACHE
if type(cache) is ExpiringDict:
records = cache.get(cache_key, None)
records = cache.get(cache_key)
if records:
return records
if not resolver:
Expand Down Expand Up @@ -786,7 +787,7 @@ def _get_reverse_dns(ip_address, nameservers=None, resolver=None, timeout=2.0):
"""
try:
name = dns.reversename.from_address(ip_address)
name = str(dns.reversename.from_address(ip_address))
hostnames = _query_dns(name, "PTR", nameservers=nameservers,
resolver=resolver, timeout=timeout)
except dns.resolver.NXDOMAIN:
Expand Down Expand Up @@ -830,7 +831,8 @@ def _get_txt_records(domain, nameservers=None, resolver=None, timeout=2.0):
return records


def _query_dmarc_record(domain, nameservers=None, resolver=None, timeout=2.0, raise_for_unrelated_records=True):
def _query_dmarc_record(domain, nameservers=None, resolver=None, timeout=2.0,
ignore_unrelated_records=False):
"""
Queries DNS for a DMARC record
Expand Down Expand Up @@ -859,8 +861,8 @@ def _query_dmarc_record(domain, nameservers=None, resolver=None, timeout=2.0, ra
elif record.strip().startswith("v=DMARC1"):
raise DMARCRecordStartsWithWhitespace(
"Found a DMARC record that starts with whitespace. "
"Please remove the whitespace, as some implementations may not "
"process it correctly."
"Please remove the whitespace, as some implementations "
"may not process it correctly."
)
else:
unrelated_records.append(record)
Expand All @@ -870,13 +872,15 @@ def _query_dmarc_record(domain, nameservers=None, resolver=None, timeout=2.0, ra
"Multiple DMARC policy records are not permitted - "
"https://tools.ietf.org/html/rfc7489#section-6.6.3")
if len(unrelated_records) > 0:
if raise_for_unrelated_records:
if not ignore_unrelated_records:
raise UnrelatedTXTRecordFoundAtDMARC(
"Unrelated TXT records were discovered. These should be "
"removed, as some receivers may not expect to find "
"unrelated TXT records "
"at {0}\n\n{1}".format(target, "\n\n".join(unrelated_records)))
dmarc_record = [record for record in records if record.startswith(dmarc_prefix)][0]
"at {0}\n\n{1}".format(target, "\n\n".join(
unrelated_records)))
dmarc_record = [record for record in records if record.startswith(
dmarc_prefix)][0]

except dns.resolver.NoAnswer:
try:
Expand Down Expand Up @@ -973,7 +977,8 @@ def _query_bmi_record(domain, selector="default", nameservers=None,
return bimi_record


def query_dmarc_record(domain, nameservers=None, resolver=None, timeout=2.0, raise_for_unrelated_records=True):
def query_dmarc_record(domain, nameservers=None, resolver=None, timeout=2.0,
ignore_unrelated_records=False):
"""
Queries DNS for a DMARC record
Expand All @@ -983,6 +988,7 @@ def query_dmarc_record(domain, nameservers=None, resolver=None, timeout=2.0, rai
resolver (dns.resolver.Resolver): A resolver object to use for DNS
requests
timeout (float): number of seconds to wait for a record from DNS
ignore_unrelated_records (bool): Ignore unrelated TXT records
Returns:
OrderedDict: An ``OrderedDict`` with the following keys:
Expand All @@ -1001,8 +1007,10 @@ def query_dmarc_record(domain, nameservers=None, resolver=None, timeout=2.0, rai
warnings = []
base_domain = get_base_domain(domain)
location = domain.lower()
record = _query_dmarc_record(domain, nameservers=nameservers,
resolver=resolver, timeout=timeout, raise_for_unrelated_records=raise_for_unrelated_records)
record = _query_dmarc_record(
domain, nameservers=nameservers,
resolver=resolver, timeout=timeout,
ignore_unrelated_records=ignore_unrelated_records)
try:
root_records = _query_dns(domain.lower(), "TXT",
nameservers=nameservers, resolver=resolver,
Expand Down Expand Up @@ -1069,8 +1077,9 @@ def query_bimi_record(domain, selector="default", nameservers=None,
except Exception:
pass

if record is None and base_domain and domain != base_domain and selector != "default":
record = _query_bmi_record(base_domain, selector="default",
if (record is None and base_domain
and domain != base_domain and selector != "default"):
record = _query_bmi_record(base_domain,
nameservers=nameservers, resolver=resolver,
timeout=timeout)
location = base_domain
Expand Down Expand Up @@ -1283,7 +1292,7 @@ def verify_dmarc_report_destination(source_domain, destination_domain,
def parse_dmarc_record(record, domain, parked=False,
include_tag_descriptions=False,
nameservers=None, resolver=None, timeout=2.0,
syntax_error_mark_sign=SYNTAX_ERROR_MARK_SIGN):
syntax_error_marker=SYNTAX_ERROR_MARKER):
"""
Parses a DMARC record
Expand All @@ -1296,6 +1305,7 @@ def parse_dmarc_record(record, domain, parked=False,
resolver (dns.resolver.Resolver): A resolver object to use for DNS
requests
timeout (float): number of seconds to wait for an answer from DNS
syntax_error_marker (str): The maker for
Returns:
OrderedDict: An ``OrderedDict`` with the following keys:
Expand Down Expand Up @@ -1337,13 +1347,15 @@ def parse_dmarc_record(record, domain, parked=False,
if not parsed_record.is_valid:
expecting = list(
map(lambda x: str(x).strip('"'), list(parsed_record.expecting)))
record_marked = record[:parsed_record.pos] + syntax_error_mark_sign + record[parsed_record.pos:]
raise DMARCSyntaxError("Error: Expected {0} at position {1} (marked with {2}) in: "
marked_record = (record[:parsed_record.pos] + syntax_error_marker +
record[parsed_record.pos:])
raise DMARCSyntaxError("Error: Expected {0} at position {1} "
"(marked with {2}) in: "
"{3}".format(
" or ".join(expecting),
parsed_record.pos,
syntax_error_mark_sign,
record_marked))
syntax_error_marker,
marked_record))

pairs = DMARC_TAG_VALUE_REGEX.findall(record)
tags = OrderedDict()
Expand Down Expand Up @@ -1638,7 +1650,7 @@ def query_spf_record(domain, nameservers=None, resolver=None, timeout=2.0):
def parse_spf_record(record, domain, parked=False, seen=None,
nameservers=None, resolver=None,
recursion=None, timeout=2.0,
syntax_error_mark_sign=SYNTAX_ERROR_MARK_SIGN):
syntax_error_marker=SYNTAX_ERROR_MARKER):
"""
Parses an SPF record, including resolving ``a``, ``mx``, and ``include``
mechanisms
Expand Down Expand Up @@ -1689,14 +1701,14 @@ def parse_spf_record(record, domain, parked=False, seen=None,
expecting = list(
map(lambda x: str(x).strip('"'), list(parsed_record.expecting)))
expecting = " or ".join(expecting)
record_marked = record[:pos] + syntax_error_mark_sign + record[pos:]
marked_record = record[:pos] + syntax_error_marker + record[pos:]
raise SPFSyntaxError(
"{0}: Expected {1} at position {2} (marked with {3}) in: {4}".format(
domain,
expecting,
pos,
syntax_error_mark_sign,
record_marked))
"{0}: Expected {1} at position {2} "
"(marked with {3}) in: {4}".format(domain,
expecting,
pos,
syntax_error_marker,
marked_record))
matches = SPF_MECHANISM_REGEX.findall(record.lower())
parsed = OrderedDict([("pass", []),
("neutral", []),
Expand Down Expand Up @@ -1959,7 +1971,7 @@ def test_tls(hostname, ssl_context=None, cache=None):
"""
tls = False
if cache:
cached_result = cache.get(hostname, None)
cached_result = cache.get(hostname)
if cached_result is not None:
if cached_result["error"] is not None:
raise SMTPError(cached_result["error"])
Expand Down Expand Up @@ -2014,11 +2026,6 @@ def test_tls(hostname, ssl_context=None, cache=None):
if cache:
cache[hostname] = dict(tls=False, error=error)
raise SMTPError(error)
except CertificateError as e:
error = "Certificate error: {0}".format(e.__str__())
if cache:
cache[hostname] = dict(tls=False, error=error)
raise SMTPError(error)
except smtplib.SMTPConnectError as e:
message = e.__str__()
error_code = int(message.lstrip("(").split(",")[0])
Expand Down Expand Up @@ -2074,7 +2081,7 @@ def test_starttls(hostname, ssl_context=None, cache=None):
"""
starttls = False
if cache:
cached_result = cache.get(hostname, None)
cached_result = cache.get(hostname)
if cached_result is not None:
if cached_result["error"] is not None:
raise SMTPError(cached_result["error"])
Expand Down Expand Up @@ -2134,11 +2141,6 @@ def test_starttls(hostname, ssl_context=None, cache=None):
if cache:
cache[hostname] = dict(starttls=False, error=error)
raise SMTPError(error)
except CertificateError as e:
error = "Certificate error: {0}".format(e.__str__())
if cache:
cache[hostname] = dict(starttls=False, error=error)
raise SMTPError(error)
except smtplib.SMTPConnectError as e:
message = e.__str__()
error_code = int(message.lstrip("(").split(",")[0])
Expand Down Expand Up @@ -2375,16 +2377,18 @@ def test_dnssec(domain, nameservers=None, timeout=2.0):
nameservers = dns.resolver.Resolver().nameservers

request = dns.message.make_query(get_base_domain(domain),
dns.rdatatype.NS,
dns.rdatatype.DNSKEY,
want_dnssec=True)
for nameserver in nameservers:
try:
response = dns.query.udp(request, nameserver, timeout=timeout)
if response is not None:
for record in response.answer:
if record.rdtype == dns.rdatatype.RRSIG:
if response.flags & dns.flags.AD:
return True
answer = response.answer
if len(answer) != 2:
return False
name = dns.name.from_text(f'{domain.lower()}.')
dns.dnssec.validate(answer[0], answer[1], {name: answer[0]})
return True
except Exception as e:
logging.debug("DNSSEC query error: {0}".format(e))

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ classifiers = [
]
dependencies = [
"dnspython>=2.0.0",
"cryptography>=41.0.3",
"expiringdict>=1.1.4",
"publicsuffixlist>=0.10.0",
"pyleri>=1.3.2",
Expand Down

0 comments on commit bb758a7

Please sign in to comment.