diff --git a/.coveragerc b/.coveragerc index d315b87..de936df 100644 --- a/.coveragerc +++ b/.coveragerc @@ -2,7 +2,7 @@ # https://coverage.readthedocs.io/en/latest/config.html [run] -source = src/example +source = src/cyhy_db omit = branch = true diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 9e4ff7b..e3d24b7 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -19,10 +19,10 @@ updates: - dependency-name: hashicorp/setup-terraform - dependency-name: mxschmitt/action-tmate - dependency-name: step-security/harden-runner - # # Managed by cisagov/cyhy-db - # - dependency-name: actions/download-artifact - # - dependency-name: actions/upload-artifact - # - dependency-name: github/codeql-action + # Managed by cisagov/skeleton-python-library + - dependency-name: actions/download-artifact + - dependency-name: actions/upload-artifact + - dependency-name: github/codeql-action package-ecosystem: github-actions schedule: interval: weekly diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index e429274..fb1d7d1 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -180,10 +180,6 @@ jobs: fail-fast: false matrix: python-version: - - "3.7" - - "3.8" - - "3.9" - - "3.10" - "3.11" - "3.12" steps: @@ -286,10 +282,6 @@ jobs: fail-fast: false matrix: python-version: - - "3.7" - - "3.8" - - "3.9" - - "3.10" - "3.11" - "3.12" steps: @@ -341,10 +333,6 @@ jobs: fail-fast: false matrix: python-version: - - "3.7" - - "3.8" - - "3.9" - - "3.10" - "3.11" - "3.12" steps: diff --git a/.gitignore b/.gitignore index 242b4aa..82ac921 100644 --- a/.gitignore +++ b/.gitignore @@ -5,8 +5,12 @@ ## Python ## __pycache__ .coverage +.hypothesis .mypy_cache .pytest_cache .python-version *.egg-info dist + +## VSCode ## +.vscode diff --git a/README.md b/README.md index 139655b..c3afd47 100644 --- a/README.md +++ b/README.md @@ -5,20 +5,104 @@ [![Coverage Status](https://coveralls.io/repos/github/cisagov/cyhy-db/badge.svg?branch=develop)](https://coveralls.io/github/cisagov/cyhy-db?branch=develop) [![Known Vulnerabilities](https://snyk.io/test/github/cisagov/cyhy-db/develop/badge.svg)](https://snyk.io/test/github/cisagov/cyhy-db) -This is a generic skeleton project that can be used to quickly get a -new [cisagov](https://github.com/cisagov) Python library GitHub -project started. This skeleton project contains [licensing -information](LICENSE), as well as -[pre-commit hooks](https://pre-commit.com) and -[GitHub Actions](https://github.com/features/actions) configurations -appropriate for a Python library project. - -## New Repositories from a Skeleton ## - -Please see our [Project Setup guide](https://github.com/cisagov/development-guide/tree/develop/project_setup) -for step-by-step instructions on how to start a new repository from -a skeleton. This will save you time and effort when configuring a -new repository! +This repository implements a Python module for interacting with a Cyber Hygiene database. + +## Pre-requisites ## + +- [Python 3.11](https://www.python.org/downloads/) or newer +- A running [MongoDB](https://www.mongodb.com/) instance that you have access to + +## Starting a Local MongoDB Instance for Testing ## + +> [!IMPORTANT] +> This requires [Docker](https://www.docker.com/) to be installed in +> order for this to work. + +You can start a local MongoDB instance in a container with the following +command: + +```console +pytest -vs --mongo-express +``` + +> [!NOTE] +> The command `pytest -vs --mongo-express` not only starts a local +> MongoDB instance, but also runs all the `cyhy-db` unit tests, which will +> create various collections and documents in the database. + +Sample output (trimmed to highlight the important parts): + +```console + +MongoDB is accessible at mongodb://mongoadmin:secret@localhost:32859 with database named "test" +Mongo Express is accessible at http://admin:pass@localhost:8081 + +Press Enter to stop Mongo Express and MongoDB containers... +``` + +Based on the example output above, you can access the MongoDB instance at +`mongodb://mongoadmin:secret@localhost:32859` and the Mongo Express web +interface at `http://admin:pass@localhost:8081`. Note that the MongoDB +containers will remain running until you press "Enter" in that terminal. + +## Example Usage ## + +Once you have a MongoDB instance running, the sample Python code below +demonstrates how to initialize the database, create a new request document, save +it, and then retrieve it. + +```python +import asyncio +from cyhy_db import initialize_db +from cyhy_db.models import RequestDoc +from cyhy_db.models.request_doc import Agency + +async def main(): + # Initialize the CyHy database + await initialize_db("mongodb://mongoadmin:secret@localhost:32859", "test") + + # Create a new CyHy request document and save it in the database + new_request = RequestDoc( + agency=Agency(name="Acme Industries", acronym="AI") + ) + await new_request.save() + + # Find the request document and print its agency information + request = await RequestDoc.get("AI") + print(request.agency) + +asyncio.run(main()) +``` + +Output: + +```console +name='Acme Industries' acronym='AI' type=None contacts=[] location=None +``` + +## Additional Testing Options ## + +> [!WARNING] +> The default usernames and passwords are for testing purposes only. +> Do not use them in production environments. Always set strong, unique +> credentials. + +### Environment Variables ### + +| Variable | Description | Default | +|----------|-------------|---------| +| `MONGO_INITDB_ROOT_USERNAME` | The MongoDB root username | `mongoadmin` | +| `MONGO_INITDB_ROOT_PASSWORD` | The MongoDB root password | `secret` | +| `DATABASE_NAME` | The name of the database to use for testing | `test` | +| `MONGO_EXPRESS_PORT` | The port to use for the Mongo Express web interface | `8081` | + +### Pytest Options ### + +| Option | Description | Default | +|--------|-------------|---------| +| `--mongo-express` | Start a local MongoDB instance and Mongo Express web interface | n/a | +| `--mongo-image-tag` | The tag of the MongoDB Docker image to use | `docker.io/mongo:latest` | +| `--runslow` | Run slow tests | n/a | ## Contributing ## diff --git a/bump_version.sh b/bump_version.sh index bd520bd..df7c371 100755 --- a/bump_version.sh +++ b/bump_version.sh @@ -6,7 +6,7 @@ set -o nounset set -o errexit set -o pipefail -VERSION_FILE=src/example/_version.py +VERSION_FILE=src/cyhy_db/_version.py HELP_INFORMATION="bump_version.sh (show|major|minor|patch|prerelease|build|finalize)" diff --git a/pytest.ini b/pytest.ini index ed958e0..caca126 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,2 +1,4 @@ [pytest] -addopts = -v -ra --cov +addopts = -v -ra --cov --log-cli-level=INFO +asyncio_default_fixture_loop_scope = session +asyncio_mode = auto diff --git a/setup-env b/setup-env index ac7ecfc..b3554cb 100755 --- a/setup-env +++ b/setup-env @@ -251,7 +251,7 @@ for req_file in "requirements-dev.txt" "requirements-test.txt" "requirements.txt done # Install all necessary mypy type stubs -mypy --install-types src/ +mypy --install-types --non-interactive src/ # Install git pre-commit hooks now or later. pre-commit install ${INSTALL_HOOKS:+"--install-hooks"} diff --git a/setup.py b/setup.py index fdb21eb..a03d864 100644 --- a/setup.py +++ b/setup.py @@ -42,10 +42,10 @@ def get_version(version_file): setup( - name="example", + name="cyhy-db", # Versions should comply with PEP440 - version=get_version("src/example/_version.py"), - description="Example Python library", + version=get_version("src/cyhy_db/_version.py"), + description="CyHy Database Python library", long_description=readme(), long_description_content_type="text/markdown", # Landing page for CISA's cybersecurity mission @@ -75,38 +75,37 @@ def get_version(version_file): # that you indicate whether you support Python 2, Python 3 or both. "Programming Language :: Python :: 3", "Programming Language :: Python :: 3 :: Only", - "Programming Language :: Python :: 3.7", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: Implementation :: CPython", ], - python_requires=">=3.7", + python_requires=">=3.11", # What does your project relate to? - keywords="skeleton", + keywords=["cyhy", "database"], packages=find_packages(where="src"), package_dir={"": "src"}, - package_data={"example": ["data/*.txt"]}, + package_data={"cyhy_db": ["py.typed"]}, py_modules=[splitext(basename(path))[0] for path in glob("src/*.py")], include_package_data=True, - install_requires=["docopt", "schema", "setuptools >= 24.2.0"], + install_requires=[ + "beanie", + "pydantic[email, hypothesis]", # hypothesis plugin is currently disabled: https://github.com/pydantic/pydantic/issues/4682 + "setuptools", + ], extras_require={ "test": [ + "pytest-asyncio", "coverage", - # coveralls 1.11.0 added a service number for calls from - # GitHub Actions. This caused a regression which resulted in a 422 - # response from the coveralls API with the message: - # Unprocessable Entity for url: https://coveralls.io/api/v1/jobs - # 1.11.1 fixed this issue, but to ensure expected behavior we'll pin - # to never grab the regression version. - "coveralls != 1.11.0", + "coveralls", + "docker", + "hypothesis", + "mimesis-factory", + "mimesis", "pre-commit", "pytest-cov", + "pytest-factoryboy", "pytest", ] }, - # Conveniently allows one to run the CLI tool as `example` - entry_points={"console_scripts": ["example = example.example:main"]}, + entry_points={}, ) diff --git a/src/example/__init__.py b/src/cyhy_db/__init__.py similarity index 70% rename from src/example/__init__.py rename to src/cyhy_db/__init__.py index 556a7d2..36ea325 100644 --- a/src/example/__init__.py +++ b/src/cyhy_db/__init__.py @@ -1,10 +1,11 @@ -"""The example library.""" +"""The cyhy_db package provides an interface to a CyHy database.""" # We disable a Flake8 check for "Module imported but unused (F401)" here because # although this import is not directly used, it populates the value # package_name.__version__, which is used to get version information about this # Python package. + from ._version import __version__ # noqa: F401 -from .example import example_div +from .db import initialize_db -__all__ = ["example_div"] +__all__ = ["initialize_db"] diff --git a/src/cyhy_db/_version.py b/src/cyhy_db/_version.py new file mode 100644 index 0000000..5becc17 --- /dev/null +++ b/src/cyhy_db/_version.py @@ -0,0 +1 @@ +__version__ = "1.0.0" diff --git a/src/cyhy_db/db.py b/src/cyhy_db/db.py new file mode 100644 index 0000000..c74e14b --- /dev/null +++ b/src/cyhy_db/db.py @@ -0,0 +1,54 @@ +"""CyHy database top-level functions.""" + +# Third-Party Libraries +from beanie import Document, View, init_beanie +from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase + +from .models import ( + CVEDoc, + HostDoc, + HostScanDoc, + KEVDoc, + NotificationDoc, + PlaceDoc, + PortScanDoc, + ReportDoc, + RequestDoc, + SnapshotDoc, + SystemControlDoc, + TallyDoc, + TicketDoc, + VulnScanDoc, +) + +ALL_MODELS: list[type[Document] | type[View] | str] = [ + CVEDoc, + HostDoc, + HostScanDoc, + KEVDoc, + NotificationDoc, + PlaceDoc, + PortScanDoc, + RequestDoc, + ReportDoc, + SnapshotDoc, + SystemControlDoc, + TallyDoc, + TicketDoc, + VulnScanDoc, +] + +# Note: ScanDoc is intentionally excluded from the list of models to be imported +# or initialized because it is an abstract base class. + + +async def initialize_db(db_uri: str, db_name: str) -> AsyncIOMotorDatabase: + """Initialize the database.""" + try: + client: AsyncIOMotorClient = AsyncIOMotorClient(db_uri) + db: AsyncIOMotorDatabase = client[db_name] + await init_beanie(database=db, document_models=ALL_MODELS) + return db + except Exception as e: + print(f"Failed to initialize database with error: {e}") + raise diff --git a/src/cyhy_db/models/__init__.py b/src/cyhy_db/models/__init__.py new file mode 100644 index 0000000..8e06369 --- /dev/null +++ b/src/cyhy_db/models/__init__.py @@ -0,0 +1,47 @@ +"""This module contains the models for the CyHy database. + +# Imports are ordered to avoid a circular import. +# isort is disabled for this file as it will break the ordering. + +isort:skip_file +""" + +# Scan documents (order matters) +from .scan_doc import ScanDoc +from .host_scan_doc import HostScanDoc +from .port_scan_doc import PortScanDoc +from .vuln_scan_doc import VulnScanDoc + +# Snapshot documents (order matters) +from .snapshot_doc import SnapshotDoc +from .report_doc import ReportDoc + +# Other documents +from .cve_doc import CVEDoc +from .host_doc import HostDoc +from .kev_doc import KEVDoc +from .notification_doc import NotificationDoc +from .place_doc import PlaceDoc +from .request_doc import RequestDoc +from .system_control_doc import SystemControlDoc +from .tally_doc import TallyDoc +from .ticket_doc import TicketDoc + + +__all__ = [ + "CVEDoc", + "HostDoc", + "HostScanDoc", + "KEVDoc", + "NotificationDoc", + "PlaceDoc", + "PortScanDoc", + "RequestDoc", + "ReportDoc", + "ScanDoc", + "SnapshotDoc", + "SystemControlDoc", + "TallyDoc", + "TicketDoc", + "VulnScanDoc", +] diff --git a/src/cyhy_db/models/cve_doc.py b/src/cyhy_db/models/cve_doc.py new file mode 100644 index 0000000..86956ae --- /dev/null +++ b/src/cyhy_db/models/cve_doc.py @@ -0,0 +1,52 @@ +"""The model for CVE (Common Vulnerabilities and Exposures) documents.""" + +# Standard Python Libraries +from typing import Any, Dict + +# Third-Party Libraries +from beanie import Document, Indexed +from pydantic import ConfigDict, Field, model_validator + +from .enum import CVSSVersion + + +class CVEDoc(Document): + """The CVE document model.""" + + # Validate on assignment so severity is calculated + model_config = ConfigDict(extra="forbid", validate_assignment=True) + + cvss_score: float = Field(ge=0.0, le=10.0) + cvss_version: CVSSVersion = Field(default=CVSSVersion.V3_1) + # See: https://github.com/cisagov/cyhy-db/issues/7 + # CVE ID as a string + id: str = Indexed(primary_field=True) # type: ignore[assignment] + severity: int = Field(ge=1, le=4, default=1) + + @model_validator(mode="before") + def calculate_severity(cls, values: Dict[str, Any]) -> Dict[str, Any]: + """Calculate CVE severity based on the CVSS score and version.""" + if values["cvss_version"] == CVSSVersion.V2: + if values["cvss_score"] == 10: + values["severity"] = 4 + elif values["cvss_score"] >= 7.0: + values["severity"] = 3 + elif values["cvss_score"] >= 4.0: + values["severity"] = 2 + else: + values["severity"] = 1 + else: # CVSS versions 3.0 or 3.1 + if values["cvss_score"] >= 9.0: + values["severity"] = 4 + elif values["cvss_score"] >= 7.0: + values["severity"] = 3 + elif values["cvss_score"] >= 4.0: + values["severity"] = 2 + else: + values["severity"] = 1 + return values + + class Settings: + """Beanie settings.""" + + name = "cves" diff --git a/src/cyhy_db/models/enum.py b/src/cyhy_db/models/enum.py new file mode 100644 index 0000000..02b77b0 --- /dev/null +++ b/src/cyhy_db/models/enum.py @@ -0,0 +1,124 @@ +"""The enumerations used in CyHy.""" + +# Standard Python Libraries +from enum import StrEnum, auto + + +class AgencyType(StrEnum): + """Agency types.""" + + FEDERAL = auto() + LOCAL = auto() + PRIVATE = auto() + STATE = auto() + TERRITORIAL = auto() + TRIBAL = auto() + + +class ControlAction(StrEnum): + """Commander control actions.""" + + PAUSE = auto() + STOP = auto() + + +class ControlTarget(StrEnum): + """Commander control targets.""" + + COMMANDER = auto() + + +class CVSSVersion(StrEnum): + """CVSS versions.""" + + V2 = "2.0" + V3 = "3.0" + V3_1 = "3.1" + + +class DayOfWeek(StrEnum): + """Days of the week.""" + + MONDAY = auto() + TUESDAY = auto() + WEDNESDAY = auto() + THURSDAY = auto() + FRIDAY = auto() + SATURDAY = auto() + SUNDAY = auto() + + +class PocType(StrEnum): + """Point of contact types.""" + + DISTRO = auto() + TECHNICAL = auto() + + +class Protocol(StrEnum): + """Network protocols.""" + + TCP = auto() + UDP = auto() + + +class ReportPeriod(StrEnum): + """CyHy reporting periods.""" + + MONTHLY = auto() + QUARTERLY = auto() + WEEKLY = auto() + + +class ReportType(StrEnum): + """CyHy report types.""" + + BOD = auto() + CYBEX = auto() + CYHY = auto() + CYHY_THIRD_PARTY = auto() + DNSSEC = auto() + PHISHING = auto() + + +class ScanType(StrEnum): + """CyHy scan types.""" + + CYHY = auto() + DNSSEC = auto() + PHISHING = auto() + + +class Scheduler(StrEnum): + """CyHy schedulers.""" + + PERSISTENT1 = auto() + + +class Stage(StrEnum): + """CyHy scan stages.""" + + NETSCAN1 = auto() + NETSCAN2 = auto() + PORTSCAN = auto() + VULNSCAN = auto() + + +class Status(StrEnum): + """CyHy scan statuses.""" + + DONE = auto() + READY = auto() + RUNNING = auto() + WAITING = auto() + + +class TicketAction(StrEnum): + """Actions for ticket events.""" + + CHANGED = auto() + CLOSED = auto() + OPENED = auto() + REOPENED = auto() + UNVERIFIED = auto() + VERIFIED = auto() diff --git a/src/cyhy_db/models/exceptions.py b/src/cyhy_db/models/exceptions.py new file mode 100644 index 0000000..6f2d4fd --- /dev/null +++ b/src/cyhy_db/models/exceptions.py @@ -0,0 +1,43 @@ +"""The exceptions used in CyHy.""" + + +class PortScanNotFoundException(Exception): + """Exception raised when a referenced PortScanDoc is not found.""" + + def __init__(self, ticket_id, port_scan_id, port_scan_time, *args): + """Initialize the exception with the given ticket ID, port scan ID, and port scan time. + + Args: + ticket_id (str): The ID of the ticket. + port_scan_id (str): The ID of the port scan. + port_scan_time (datetime): The time of the port scan. + *args: Additional arguments to pass to the base Exception class. + """ + message = "Ticket {}: referenced PortScanDoc {} at time {} not found".format( + ticket_id, port_scan_id, port_scan_time + ) + self.ticket_id = ticket_id + self.port_scan_id = port_scan_id + self.port_scan_time = port_scan_time + super().__init__(message, *args) + + +class VulnScanNotFoundException(Exception): + """Exception raised when a referenced VulnScanDoc is not found.""" + + def __init__(self, ticket_id, vuln_scan_id, vuln_scan_time, *args): + """Initialize the exception with the given ticket ID, vulnerability scan ID, and vulnerability scan time. + + Args: + ticket_id (str): The ID of the ticket. + vuln_scan_id (str): The ID of the vulnerability scan document. + vuln_scan_time (str): The time of the vulnerability scan. + *args: Additional arguments to pass to the base exception class. + """ + message = "Ticket {}: referenced VulnScanDoc {} at time {} not found".format( + ticket_id, vuln_scan_id, vuln_scan_time + ) + self.ticket_id = ticket_id + self.vuln_scan_id = vuln_scan_id + self.vuln_scan_time = vuln_scan_time + super().__init__(message, *args) diff --git a/src/cyhy_db/models/host_doc.py b/src/cyhy_db/models/host_doc.py new file mode 100644 index 0000000..066e8e8 --- /dev/null +++ b/src/cyhy_db/models/host_doc.py @@ -0,0 +1,132 @@ +"""The model for CyHy host documents.""" + +# Standard Python Libraries +from datetime import datetime +from ipaddress import IPv4Address, ip_address +import random +from typing import Any, Dict, Optional, Tuple + +# Third-Party Libraries +from beanie import Document, Insert, Replace, ValidateOnSave, before_event +from pydantic import BaseModel, ConfigDict, Field, model_validator +from pymongo import ASCENDING, IndexModel + +from ..utils import deprecated, utcnow +from .enum import Stage, Status + + +class State(BaseModel): + """The state of a host.""" + + reason: str + up: bool + + +class HostDoc(Document): + """The host document model.""" + + model_config = ConfigDict(extra="forbid") + + # See: https://github.com/cisagov/cyhy-db/issues/7 + # IP address as an integer + id: int = Field(default_factory=int) # type: ignore[assignment] + ip: IPv4Address = Field(...) + last_change: datetime = Field(default_factory=utcnow) + latest_scan: Dict[Stage, datetime] = Field(default_factory=dict) + loc: Optional[Tuple[float, float]] = Field(default=None) + next_scan: Optional[datetime] = Field(default=None) + owner: str = Field(...) + priority: int = Field(default=0) + r: float = Field(default_factory=random.random) + stage: Stage = Field(default=Stage.NETSCAN1) + state: State = Field(default_factory=lambda: State(reason="new", up=False)) + status: Status = Field(default=Status.WAITING) + + @model_validator(mode="before") + def calculate_ip_int(cls, values: Dict[str, Any]) -> Dict[str, Any]: + """Calculate the integer representation of an IP address.""" + # ip may still be string if it was just set + values["_id"] = int(ip_address(values["ip"])) + return values + + @before_event(Insert, Replace, ValidateOnSave) + async def before_save(self): + """Set data just prior to saving a host document.""" + self.last_change = utcnow() + + class Settings: + """Beanie settings.""" + + name = "hosts" + indexes = [ + IndexModel( + [ + ("status", ASCENDING), + ("stage", ASCENDING), + ("owner", ASCENDING), + ("priority", ASCENDING), + ("r", ASCENDING), + ], + name="claim", + ), + IndexModel( + [ + ("ip", ASCENDING), + ], + name="ip", + ), + IndexModel( + [ + ("state.up", ASCENDING), + ("owner", ASCENDING), + ], + name="up", + ), + IndexModel( + [ + ("next_scan", ASCENDING), + ("state.up", ASCENDING), + ("status", ASCENDING), + ], + sparse=True, + name="next_scan", + ), + IndexModel( + [ + ("owner", ASCENDING), + ], + name="owner", + ), + IndexModel( + [ + ("owner", ASCENDING), + ("state.up", ASCENDING), + ("latest_scan.VULNSCAN", ASCENDING), + ], + name="latest_scan_done", + ), + ] + + def set_state(self, nmap_says_up, has_open_ports, reason=None): + """Set state.up based on different stage evidence. + + nmap has a concept of up which is different from our definition. An nmap + "up" just means it got a reply, not that there are any open ports. Note + either argument can be None. + """ + if has_open_ports: # Only PORTSCAN sends in has_open_ports + self.state = State(up=True, reason="open-port") + elif has_open_ports is False: + self.state = State(up=False, reason="no-open") + elif nmap_says_up is False: # NETSCAN says host is down + self.state = State(up=False, reason=reason) + + # TODO: There are a lot of functions in the Python 2 version that may or may not be used. + # Instead of porting them all over, we should just port them as they are needed. + # And rewrite things that can be done better in Python 3. + + @classmethod + @deprecated("Use HostDoc.find_one(HostDoc.ip == ip) instead.") + async def get_by_ip(cls, ip: IPv4Address): + """Return a host document with the given IP address.""" + return await cls.find_one(cls.ip == ip) diff --git a/src/cyhy_db/models/host_scan_doc.py b/src/cyhy_db/models/host_scan_doc.py new file mode 100644 index 0000000..4d5dc93 --- /dev/null +++ b/src/cyhy_db/models/host_scan_doc.py @@ -0,0 +1,32 @@ +"""The model for CyHy host scan documents.""" + +# Standard Python Libraries +from typing import List + +# Third-Party Libraries +from pydantic import ConfigDict +from pymongo import ASCENDING, IndexModel + +from . import ScanDoc + + +class HostScanDoc(ScanDoc): + """The host scan document model.""" + + model_config = ConfigDict(extra="forbid") + + accuracy: int + classes: List[dict] = [] + line: int + name: str + + class Settings: + """Beanie settings.""" + + name = "host_scans" + indexes = ScanDoc.Abstract_Settings.indexes + [ + IndexModel( + [("latest", ASCENDING), ("owner", ASCENDING)], name="latest_owner" + ), + IndexModel([("owner", ASCENDING)], name="owner"), + ] diff --git a/src/cyhy_db/models/kev_doc.py b/src/cyhy_db/models/kev_doc.py new file mode 100644 index 0000000..f2745be --- /dev/null +++ b/src/cyhy_db/models/kev_doc.py @@ -0,0 +1,20 @@ +"""The model for KEV (Known Exploited Vulnerabilities) documents.""" + +# Third-Party Libraries +from beanie import Document +from pydantic import ConfigDict, Field + + +class KEVDoc(Document): + """The KEV document model.""" + + model_config = ConfigDict(extra="forbid") + + # See: https://github.com/cisagov/cyhy-db/issues/7 + id: str = Field(default_factory=str) # type: ignore[assignment] + known_ransomware: bool + + class Settings: + """Beanie settings.""" + + name = "kevs" diff --git a/src/cyhy_db/models/notification_doc.py b/src/cyhy_db/models/notification_doc.py new file mode 100644 index 0000000..9754e71 --- /dev/null +++ b/src/cyhy_db/models/notification_doc.py @@ -0,0 +1,25 @@ +"""The model for notification documents.""" + +# Standard Python Libraries +from typing import List + +# Third-Party Libraries +from beanie import BeanieObjectId, Document +from pydantic import ConfigDict, Field + + +class NotificationDoc(Document): + """The notification document model.""" + + model_config = ConfigDict(extra="forbid") + + generated_for: List[str] = Field( + default=[] + ) # list of owners built as notifications are generated + ticket_id: BeanieObjectId = Field(...) # ticket id that triggered the notification + ticket_owner: str # owner of the ticket + + class Settings: + """Beanie settings.""" + + name = "notifications" diff --git a/src/cyhy_db/models/place_doc.py b/src/cyhy_db/models/place_doc.py new file mode 100644 index 0000000..ed1e356 --- /dev/null +++ b/src/cyhy_db/models/place_doc.py @@ -0,0 +1,38 @@ +"""The model for place documents.""" + +# Standard Python Libraries +from typing import Optional + +# Third-Party Libraries +from beanie import Document +from pydantic import ConfigDict, Field + + +class PlaceDoc(Document): + """The place document model.""" + + model_config = ConfigDict(extra="forbid", populate_by_name=True) + + class_: str = Field(alias="class") # 'class' is a reserved keyword in Python + country_name: str + country: str + county_fips: Optional[str] = None + county: Optional[str] = None + elevation_feet: Optional[int] = None + elevation_meters: Optional[int] = None + # See: https://github.com/cisagov/cyhy-db/issues/7 + # GNIS FEATURE_ID (INCITS 446-2008) - https://geonames.usgs.gov/domestic/index.html + id: int = Field(default_factory=int) # type: ignore[assignment] + latitude_dec: float + latitude_dms: Optional[str] = None + longitude_dec: float + longitude_dms: Optional[str] = None + name: str + state_fips: str + state_name: str + state: str + + class Settings: + """Beanie settings.""" + + name = "places" diff --git a/src/cyhy_db/models/port_scan_doc.py b/src/cyhy_db/models/port_scan_doc.py new file mode 100644 index 0000000..bbf1304 --- /dev/null +++ b/src/cyhy_db/models/port_scan_doc.py @@ -0,0 +1,46 @@ +"""The model for CyHy port scan documents.""" + +# Standard Python Libraries +from typing import Dict + +# Third-Party Libraries +from pydantic import ConfigDict +from pymongo import ASCENDING, IndexModel + +from . import ScanDoc +from .enum import Protocol + + +class PortScanDoc(ScanDoc): + """The port scan document model.""" + + model_config = ConfigDict(extra="forbid") + + port: int + protocol: Protocol + reason: str + service: Dict = {} # Assuming no specific structure for "service" + state: str + + class Settings: + """Beanie settings.""" + + name = "port_scans" + indexes = ScanDoc.Abstract_Settings.indexes + [ + IndexModel( + [("latest", ASCENDING), ("owner", ASCENDING), ("state", ASCENDING)], + name="latest_owner_state", + ), + IndexModel( + [("latest", ASCENDING), ("service.name", ASCENDING)], + name="latest_service_name", + ), + IndexModel( + [("latest", ASCENDING), ("time", ASCENDING)], + name="latest_time", + ), + IndexModel( + [("owner", ASCENDING)], + name="owner", + ), + ] diff --git a/src/cyhy_db/models/report_doc.py b/src/cyhy_db/models/report_doc.py new file mode 100644 index 0000000..2ce4f88 --- /dev/null +++ b/src/cyhy_db/models/report_doc.py @@ -0,0 +1,44 @@ +"""The model for CyHy report documents.""" + +# Standard Python Libraries +from datetime import datetime +from typing import List + +# Third-Party Libraries +from beanie import Document, Link +from pydantic import ConfigDict, Field +from pymongo import ASCENDING, IndexModel + +from . import SnapshotDoc +from ..utils import utcnow +from .enum import ReportType + + +class ReportDoc(Document): + """The report document model.""" + + model_config = ConfigDict(extra="forbid") + + generated_time: datetime = Field(default_factory=utcnow) + owner: str + report_types: List[ReportType] + snapshots: List[Link[SnapshotDoc]] + + class Settings: + """Beanie settings.""" + + name = "reports" + indexes = [ + IndexModel( + [ + ("owner", ASCENDING), + ], + name="owner", + ), + IndexModel( + [ + ("generated_time", ASCENDING), + ], + name="generated_time", + ), + ] diff --git a/src/cyhy_db/models/request_doc.py b/src/cyhy_db/models/request_doc.py new file mode 100644 index 0000000..39778a1 --- /dev/null +++ b/src/cyhy_db/models/request_doc.py @@ -0,0 +1,134 @@ +"""The model for CyHy request documents.""" + +# Standard Python Libraries +from datetime import datetime, time +from ipaddress import IPv4Network +from typing import List, Optional + +# Third-Party Libraries +from beanie import Document, Insert, Link, Replace, ValidateOnSave, before_event +from pydantic import BaseModel, ConfigDict, EmailStr, Field, field_validator + +from ..utils import utcnow +from .enum import ( + AgencyType, + DayOfWeek, + PocType, + ReportPeriod, + ReportType, + ScanType, + Scheduler, + Stage, +) + +BOGUS_ID = "bogus_id_replace_me" + + +class Contact(BaseModel): + """A point of contact for the entity.""" + + model_config = ConfigDict(extra="forbid") + + email: EmailStr + name: str + phone: str + type: PocType + + +class Location(BaseModel): + """A location with various geographical identifiers.""" + + model_config = ConfigDict(extra="forbid") + + country_name: str + country: str + county_fips: str + county: str + gnis_id: int + name: str + state_fips: str + state_name: str + state: str + + +class Agency(BaseModel): + """Model representing a CyHy-enrolled entity.""" + + model_config = ConfigDict(extra="forbid") + + name: str + acronym: str + type: Optional[AgencyType] = Field(default=None) + contacts: List[Contact] = Field(default=[]) + location: Optional[Location] = Field(default=None) + + +class ScanLimit(BaseModel): + """Scan limits for a specific scan type.""" + + model_config = ConfigDict(extra="forbid") + + scan_type: ScanType + concurrent: int = Field(ge=0) + + +class Window(BaseModel): + """A day and time window for scheduling scans.""" + + model_config = ConfigDict(extra="forbid") + + day: DayOfWeek = Field(default=DayOfWeek.SUNDAY) + duration: int = Field(default=168, ge=0, le=168) + start: time = Field(default=time(0, 0, 0)) + + @field_validator("start", mode="before") + @classmethod + def parse_time(cls, v): + """Parse and validate a time representation.""" + if isinstance(v, str): + # Parse the string to datetime.time + return datetime.strptime(v, "%H:%M:%S").time() + elif isinstance(v, time): + return v + else: + raise ValueError( + "Invalid time format. Expected a string in '%H:%M:%S' format or datetime.time instance." + ) + + +class RequestDoc(Document): + """The request document model.""" + + model_config = ConfigDict(extra="forbid") + + agency: Agency + children: List[Link["RequestDoc"]] = Field(default=[]) + enrolled: datetime = Field(default_factory=utcnow) + # See: https://github.com/cisagov/cyhy-db/issues/7 + id: str = Field(default=BOGUS_ID) # type: ignore[assignment] + init_stage: Stage = Field(default=Stage.NETSCAN1) + key: Optional[str] = Field(default=None) + networks: List[IPv4Network] = Field(default=[]) + period_start: datetime = Field(default_factory=utcnow) + report_period: ReportPeriod = Field(default=ReportPeriod.WEEKLY) + report_types: List[ReportType] = Field(default=[]) + retired: bool = False + scan_limits: List[ScanLimit] = Field(default=[]) + scan_types: List[ScanType] = Field(default=[]) + scheduler: Scheduler = Field(default=Scheduler.PERSISTENT1) + stakeholder: bool = False + windows: List[Window] = Field(default=[Window()]) + + @before_event(Insert, Replace, ValidateOnSave) + async def set_id_to_acronym(self): + """Set the id to the agency acronym if it is the default value.""" + if self.id == BOGUS_ID: + self.id = self.agency.acronym + + class Settings: + """Beanie settings.""" + + bson_encoders = { + time: lambda value: value.strftime("%H:%M:%S") + } # Register custom encoder for datetime.time + name = "requests" diff --git a/src/cyhy_db/models/scan_doc.py b/src/cyhy_db/models/scan_doc.py new file mode 100644 index 0000000..a552714 --- /dev/null +++ b/src/cyhy_db/models/scan_doc.py @@ -0,0 +1,122 @@ +"""ScanDoc model for use as the base of other scan document classes.""" + +# Standard Python Libraries +from abc import ABC +from datetime import datetime +from ipaddress import IPv4Address, ip_address +from typing import Any, Dict, Iterable, List, Union + +# Third-Party Libraries +from beanie import Document, Link +from beanie.operators import In, Push, Set +from bson import ObjectId +from bson.dbref import DBRef +from pydantic import ConfigDict, Field, model_validator +from pymongo import ASCENDING, IndexModel + +from ..utils import utcnow +from .snapshot_doc import SnapshotDoc + + +class ScanDoc(Document, ABC): + """The abstract base class for scan-like documents.""" + + # Validate on assignment so ip_int is recalculated as ip is set + model_config = ConfigDict(extra="forbid", validate_assignment=True) + + ip: IPv4Address = Field(...) + ip_int: int = Field(...) + latest: bool = Field(default=True) + owner: str = Field(...) + snapshots: List[Link["SnapshotDoc"]] = Field(default=[]) + source: str = Field(...) + time: datetime = Field(default_factory=utcnow) + + @model_validator(mode="before") + def calculate_ip_int(cls, values: Dict[str, Any]) -> Dict[str, Any]: + """Calculate the integer representation of an IP address.""" + # ip may still be string if it was just set + values["ip_int"] = int(ip_address(values["ip"])) + return values + + class Settings: + """Beanie settings to be used during testing.""" + + # These settings are intended for use only during testing. See + # Abstract_Settings below. + + name = "PyTest_ScanDocs" + + class Abstract_Settings: + """Beanie settings to be inherited by subclasses.""" + + # This class is intentionally not named "Settings" to prevent Beanie from + # automatically applying these indices, which would result in the creation of + # an unnecessary collection. Instead, subclasses should define their own + # "Settings" class and include these indices along with any additional + # subclass-specific indices. + + indexes = [ + IndexModel( + [("latest", ASCENDING), ("ip_int", ASCENDING)], name="latest_ip" + ), + IndexModel([("time", ASCENDING), ("owner", ASCENDING)], name="time_owner"), + IndexModel([("int_ip", ASCENDING)], name="int_ip"), + IndexModel([("snapshots", ASCENDING)], name="snapshots", sparse=True), + ] + + @classmethod + async def reset_latest_flag_by_owner(cls, owner: str): + """Reset the latest flag for all scans for a given owner.""" + # flake8 E712 is "comparison to True should be 'if cond is True:' or 'if + # cond:'" but this is unavoidable due to Beanie syntax. + await cls.find(cls.latest == True, cls.owner == owner).update_many( # noqa E712 + Set({cls.latest: False}) + ) + + @classmethod + async def reset_latest_flag_by_ip( + cls, + ips: ( + int + | IPv4Address + | Iterable[int] + | Iterable[IPv4Address] + | Iterable[str] + | str + ), + ): + """Reset the latest flag for all scans for a given IP address.""" + if isinstance(ips, (int, IPv4Address, str)): + ip_ints = [int(ip_address(ips))] + else: + # Are you questing for 100% test coverage? + # Here be dragons: https://github.com/nedbat/coveragepy/issues/1617 + ip_ints = [int(ip_address(x)) for x in ips] + + # flake8 E712 is "comparison to True should be 'if cond is True:' or 'if + # cond:'" but this is unavoidable due to Beanie syntax. + await cls.find( + cls.latest == True, In(cls.ip_int, ip_ints) # noqa E712 + ).update_many(Set({cls.latest: False})) + + @classmethod + async def tag_latest( + cls, owners: List[str], snapshot: Union[SnapshotDoc, ObjectId, str] + ): + """Tag the latest scan for given owners with a snapshot id.""" + from . import SnapshotDoc + + if isinstance(snapshot, SnapshotDoc): + ref = DBRef(SnapshotDoc.Settings.name, snapshot.id) + elif isinstance(snapshot, ObjectId): + ref = DBRef(SnapshotDoc.Settings.name, snapshot) + elif isinstance(snapshot, str): + ref = DBRef(SnapshotDoc.Settings.name, ObjectId(snapshot)) + else: + raise ValueError("Invalid snapshot type") + # flake8 E712 is "comparison to True should be 'if cond is True:' or 'if + # cond:'" but this is unavoidable due to Beanie syntax. + await cls.find( + cls.latest == True, In(cls.owner, owners) # noqa E712 + ).update_many(Push({cls.snapshots: ref})) diff --git a/src/cyhy_db/models/snapshot_doc.py b/src/cyhy_db/models/snapshot_doc.py new file mode 100644 index 0000000..d4bffd2 --- /dev/null +++ b/src/cyhy_db/models/snapshot_doc.py @@ -0,0 +1,124 @@ +"""The model for CyHy snapshot documents.""" + +# Standard Python Libraries +from datetime import datetime +from ipaddress import IPv4Network +from typing import Dict, List + +# Third-Party Libraries +from beanie import Document +from pydantic import BaseModel, ConfigDict, Field +from pymongo import ASCENDING, IndexModel + +from ..utils import utcnow + + +class VulnerabilityCounts(BaseModel): + """The model for vulnerability counts.""" + + model_config = ConfigDict(extra="forbid") + + critical: int = 0 + high: int = 0 + medium: int = 0 + low: int = 0 + total: int = 0 + + +class WorldData(BaseModel): + """The model for aggregated metrics of all CyHy entities.""" + + model_config = ConfigDict(extra="forbid") + + cvss_average_all: float = 0.0 + cvss_average_vulnerable: float = 0.0 + host_count: int = 0 + unique_vulnerabilities: VulnerabilityCounts = Field( + default_factory=VulnerabilityCounts + ) + vulnerable_host_count: int = 0 + vulnerabilities: VulnerabilityCounts = Field(default_factory=VulnerabilityCounts) + + +class TicketMetrics(BaseModel): + """The model for ticket metrics.""" + + model_config = ConfigDict(extra="forbid") + + max: int = 0 + median: int = 0 + + +class TicketOpenMetrics(BaseModel): + """The model for open ticket metrics.""" + + model_config = ConfigDict(extra="forbid") + + # Numbers in this section refer to how long open tix were open AT this date/time + tix_open_as_of_date: datetime = Field(default_factory=utcnow) + critical: TicketMetrics = Field(default_factory=TicketMetrics) + high: TicketMetrics = Field(default_factory=TicketMetrics) + medium: TicketMetrics = Field(default_factory=TicketMetrics) + low: TicketMetrics = Field(default_factory=TicketMetrics) + + +class TicketCloseMetrics(BaseModel): + """The model for closed ticket metrics.""" + + model_config = ConfigDict(extra="forbid") + + # Numbers in this section only include tix that closed AT/AFTER this date/time + tix_closed_after_date: datetime = Field(default_factory=utcnow) + critical: TicketMetrics = Field(default_factory=TicketMetrics) + high: TicketMetrics = Field(default_factory=TicketMetrics) + medium: TicketMetrics = Field(default_factory=TicketMetrics) + low: TicketMetrics = Field(default_factory=TicketMetrics) + + +class SnapshotDoc(Document): + """The snapshot document model.""" + + model_config = ConfigDict(extra="forbid") + + addresses_scanned: int = Field(default=0) + cvss_average_all: float = Field(default=0.0) + cvss_average_vulnerable: float = Field(default=0.0) + descendants_included: List[str] = Field(default=[]) + end_time: datetime = Field(...) + host_count: int = Field(default=0) + last_change: datetime = Field(default_factory=utcnow) + latest: bool = Field(default=True) + networks: List[IPv4Network] = Field(default=[]) + owner: str = Field(...) + port_count: int = Field(default=0) + services: Dict = Field(default_factory=dict) + start_time: datetime = Field(...) + tix_msec_open: TicketOpenMetrics = Field(default_factory=TicketOpenMetrics) + tix_msec_to_close: TicketCloseMetrics = Field(default_factory=TicketCloseMetrics) + unique_operating_systems: int = Field(default=0) + unique_port_count: int = Field(default=0) + unique_vulnerabilities: VulnerabilityCounts = Field( + default_factory=VulnerabilityCounts + ) + vulnerabilities: VulnerabilityCounts = Field(default_factory=VulnerabilityCounts) + vulnerable_host_count: int = Field(default=0) + world: WorldData = Field(default_factory=WorldData) + + class Settings: + """Beanie settings.""" + + name = "snapshots" + indexes = [ + IndexModel( + [ + ("owner", ASCENDING), + ("start_time", ASCENDING), + ("end_time", ASCENDING), + ], + name="uniques", + unique=True, + ), + IndexModel( + [("latest", ASCENDING), ("owner", ASCENDING)], name="latest_owner" + ), + ] diff --git a/src/cyhy_db/models/system_control_doc.py b/src/cyhy_db/models/system_control_doc.py new file mode 100644 index 0000000..6486bdb --- /dev/null +++ b/src/cyhy_db/models/system_control_doc.py @@ -0,0 +1,49 @@ +"""The model for CyHy system control documents.""" + +# Standard Python Libraries +import asyncio +from datetime import datetime +from typing import Optional + +# Third-Party Libraries +from beanie import Document +from pydantic import ConfigDict, Field + +from ..utils import utcnow +from .enum import ControlAction, ControlTarget + +CONTROL_DOC_POLL_INTERVAL = 5 # seconds + + +class SystemControlDoc(Document): + """The system control document model.""" + + model_config = ConfigDict(extra="forbid") + + action: ControlAction + completed: bool = False # Set to True when after the action has occurred + reason: str # Free-form, for UI / Logging + sender: str # Free-form, for UI / Logging + target: ControlTarget + time: datetime = Field(default_factory=utcnow) # creation time + + class Settings: + """Beanie settings.""" + + name = "control" + + @classmethod + async def wait_for_completion(cls, document_id, timeout: Optional[int] = None): + """Wait for this control action to complete. + + If a timeout is set, only wait a maximum of timeout seconds. + Returns True if the document was completed, False otherwise. + """ + start_time = utcnow() + while True: + doc = await cls.get(document_id) + if doc and doc.completed: + return True + if timeout and (utcnow() - start_time).total_seconds() > timeout: + return False + await asyncio.sleep(CONTROL_DOC_POLL_INTERVAL) diff --git a/src/cyhy_db/models/tally_doc.py b/src/cyhy_db/models/tally_doc.py new file mode 100644 index 0000000..54522b2 --- /dev/null +++ b/src/cyhy_db/models/tally_doc.py @@ -0,0 +1,55 @@ +"""The model for CyHy tally documents.""" + +# Standard Python Libraries +from datetime import datetime + +# Third-Party Libraries +from beanie import Document, Insert, Replace, ValidateOnSave, before_event +from pydantic import BaseModel, ConfigDict, Field + +from ..utils import utcnow + + +class StatusCounts(BaseModel): + """The model for host status counts.""" + + model_config = ConfigDict(extra="forbid") + + DONE: int = 0 + READY: int = 0 + RUNNING: int = 0 + WAITING: int = 0 + + +class Counts(BaseModel): + """The model for scan stage counts.""" + + model_config = ConfigDict(extra="forbid") + + BASESCAN: StatusCounts = Field(default_factory=StatusCounts) + NETSCAN1: StatusCounts = Field(default_factory=StatusCounts) + NETSCAN2: StatusCounts = Field(default_factory=StatusCounts) + PORTSCAN: StatusCounts = Field(default_factory=StatusCounts) + VULNSCAN: StatusCounts = Field(default_factory=StatusCounts) + + +class TallyDoc(Document): + """The tally document model.""" + + model_config = ConfigDict(extra="forbid") + + counts: Counts = Field(default_factory=Counts) + # See: https://github.com/cisagov/cyhy-db/issues/7 + # Owner ID string + id: str # type: ignore[assignment] + last_change: datetime = Field(default_factory=utcnow) + + @before_event(Insert, Replace, ValidateOnSave) + async def before_save(self): + """Set data just prior to saving a tally document.""" + self.last_change = utcnow() + + class Settings: + """Beanie settings.""" + + name = "tallies" diff --git a/src/cyhy_db/models/ticket_doc.py b/src/cyhy_db/models/ticket_doc.py new file mode 100644 index 0000000..f621216 --- /dev/null +++ b/src/cyhy_db/models/ticket_doc.py @@ -0,0 +1,269 @@ +"""The model for CyHy ticket documents.""" + +# Standard Python Libraries +from datetime import datetime, timedelta +from ipaddress import IPv4Address +from typing import List, Optional, Tuple + +# Third-Party Libraries +from beanie import ( + BeanieObjectId, + Document, + Insert, + Link, + Replace, + ValidateOnSave, + before_event, +) +from beanie.operators import In, Pull, Push +from pydantic import BaseModel, ConfigDict, Field +from pymongo import ASCENDING, IndexModel + +# cisagov Libraries +from cyhy_db.utils.time import utcnow + +from . import PortScanDoc, SnapshotDoc, VulnScanDoc +from .enum import Protocol, TicketAction +from .exceptions import PortScanNotFoundException, VulnScanNotFoundException + + +class EventDelta(BaseModel): + """The event delta model.""" + + model_config = ConfigDict(populate_by_name=True) + + from_: Optional[bool | float | int | str] = Field(..., alias="from") + key: str = Field(...) + to: Optional[bool | float | int | str] = Field(...) + + +class TicketEvent(BaseModel): + """The ticket event model.""" + + action: TicketAction + delta: Optional[EventDelta] = Field(default=None) + reason: str = Field(...) + reference: Optional[BeanieObjectId] = Field(default=None) + time: datetime + + +class TicketDoc(Document): + """The ticket document model.""" + + model_config = ConfigDict(extra="forbid") + + details: dict = Field(default_factory=dict) + events: list[TicketEvent] = Field(default_factory=list) + false_positive: bool = Field(default=False) + fp_expiration_date: Optional[datetime] = Field(default=None) + ip_int: int = Field(...) + ip: IPv4Address = Field(...) + last_change: datetime = Field(default_factory=utcnow) + loc: Optional[Tuple[float, float]] = Field(default=None) + open: bool = Field(default=True) + owner: str = Field(...) + port: int = Field(...) + protocol: Protocol = Field(...) + snapshots: Optional[List[Link[SnapshotDoc]]] = Field(default_factory=list) + source_id: int = Field(...) + source: str = Field(...) + time_closed: Optional[datetime] = Field(default=None) + time_opened: datetime = Field(default_factory=utcnow) + + class Settings: + """Beanie settings.""" + + name = "tickets" + + indexes = [ + IndexModel( + [ + ("ip_int", ASCENDING), + ("port", ASCENDING), + ("protocol", ASCENDING), + ("source", ASCENDING), + ("source_id", ASCENDING), + ("open", ASCENDING), + ("false_positive", ASCENDING), + ], + name="ip_port_protocol_source_open_false_positive", + ), + IndexModel( + [("ip_int", ASCENDING), ("open", ASCENDING)], + name="ip_open", + ), + IndexModel( + [("open", ASCENDING), ("owner", ASCENDING)], + name="open_owner", + ), + IndexModel( + [("time_opened", ASCENDING), ("open", ASCENDING)], + name="time_opened", + ), + IndexModel( + [("last_change", ASCENDING)], + name="last_change", + ), + IndexModel( + [("time_closed", ASCENDING)], + name="time_closed", + sparse=True, + ), + ] + + @before_event(Insert, Replace, ValidateOnSave) + async def before_save(self): + """Do a false positive sanity check and set data just prior to saving a ticket document.""" + if self.false_positive and not self.open: + raise Exception("A ticket marked as a false positive cannot be closed.") + self.last_change = utcnow() + + def add_event(self, action, reason, reference=None, time=None, delta=None): + """Add an event to the list of ticket events.""" + try: + action = TicketAction(action) + # If action is not in the enumerated TicketAction class, Python 3.11 + # throws a TypeError, while Python 3.12 throws a ValueError + except (TypeError, ValueError): + raise Exception( + 'Invalid action "' + action + '" cannot be added to ticket events.' + ) + if not time: + time = utcnow() + event = TicketEvent( + action=action, reason=reason, reference=reference, time=time + ) + if delta: + event.delta = delta + self.events.append(event) + + def false_positive_dates(self): + """Return most recent false positive effective and expiration dates (if any).""" + if self.false_positive: + for event in reversed(self.events): + if not event.delta: + continue + if ( + event.action == TicketAction.CHANGED + and event.delta.key == "false_positive" + ): + return (event.time, self.fp_expiration_date) + return None + + def last_detection_date(self): + """Return date of most recent detection of a ticket's finding.""" + for event in reversed(self.events): + if event.action in [ + TicketAction.OPENED, + TicketAction.VERIFIED, + TicketAction.REOPENED, + ]: + return event.time + # This should never happen, but if we don't find any OPENED/VERIFIED/REOPENED events above, gracefully return time_opened + return self.time_opened + + async def latest_port(self): + """Return the last referenced port scan in the event list. + + This should only be used for tickets generated by portscans. + """ + for event in self.events[::-1]: + reference_id = event.reference + if reference_id: + break + else: + raise Exception("No references found in ticket events: " + str(self.id)) + port = await PortScanDoc.get(reference_id) + if not port: + # This can occur when a port_scan has been archived + # Raise an exception with the info we have for this port_scan from the ticket + raise PortScanNotFoundException( + ticket_id=self.id, + port_scan_id=reference_id, + port_scan_time=event.time, + ) + return port + + async def latest_vuln(self): + """Return the last referenced vulnerability in the event list. + + This should only be used for tickets generated by vulnscans. + """ + for event in self.events[::-1]: + reference_id = event.reference + if reference_id: + break + else: + raise Exception("No references found in ticket events: " + str(self.id)) + vuln = await VulnScanDoc.get(reference_id) + if not vuln: + # This can occur when a vuln_scan has been archived + # Raise an exception with the info we have for this vuln_scan from the ticket + raise VulnScanNotFoundException( + ticket_id=self.id, + vuln_scan_id=reference_id, + vuln_scan_time=event.time, + ) + return vuln + + def set_false_positive(self, new_state: bool, reason: str, expire_days: int): + """Mark a ticket as a false positive.""" + if self.false_positive == new_state: + return + + # Define the event delta + delta = EventDelta( + from_=self.false_positive, to=new_state, key="false_positive" + ) + + # Update ticket state + self.false_positive = new_state + now = utcnow() + expiration_date = None + + if new_state: + # Only include the expiration date when setting false_positive to + # True + expiration_date = now + timedelta(days=expire_days) + + # If ticket is not open, re-open it; false positive tix should + # always be open + if not self.open: + self.open = True + self.time_closed = None + self.add_event( + action=TicketAction.REOPENED, + reason="setting false positive", + time=now, + ) + + # Add the change event + self.add_event( + action=TicketAction.CHANGED, reason=reason, time=now, delta=delta + ) + + # Set ticket expiration date if applicable + self.fp_expiration_date = expiration_date + + @classmethod + async def tag_open(cls, owners, snapshot_oid): + """Add a snapshot object ID to the snapshots field of all open tickets belonging to the specified owners.""" + # flake8 E712 is "comparison to True should be 'if cond is True:' or 'if + # cond:'" but this is unavoidable due to Beanie syntax. + await cls.find( + cls.open == True, In(cls.owner, owners) # noqa E712 + ).update_many(Push({cls.snapshots: snapshot_oid})) + + @classmethod + async def tag_matching(cls, existing_snapshot_oids, new_snapshot_oid): + """Add a new snapshot object ID to the snapshots field of all tickets whose snapshots field contain any of specified existing snapshot object IDs.""" + await cls.find(In(cls.snapshots, existing_snapshot_oids)).update_many( + Push({cls.snapshots: new_snapshot_oid}) + ) + + @classmethod + async def remove_tag(cls, snapshot_oid): + """Remove the specified snapshot object ID from the snapshots field of all tickets whose snapshots field contain that snapshot object ID.""" + await cls.find(In(cls.snapshots, [snapshot_oid])).update_many( + Pull({cls.snapshots: snapshot_oid}) + ) diff --git a/src/cyhy_db/models/vuln_scan_doc.py b/src/cyhy_db/models/vuln_scan_doc.py new file mode 100644 index 0000000..9df40c1 --- /dev/null +++ b/src/cyhy_db/models/vuln_scan_doc.py @@ -0,0 +1,46 @@ +"""The model for CyHy vulnerability scan documents.""" + +# Standard Python Libraries +from datetime import datetime + +# Third-Party Libraries +from pydantic import ConfigDict +from pymongo import ASCENDING, IndexModel + +from . import ScanDoc +from .enum import Protocol + + +class VulnScanDoc(ScanDoc): + """The vulnerability scan document model.""" + + model_config = ConfigDict(extra="forbid") + + cvss_base_score: float + cvss_vector: str + description: str + fname: str + plugin_family: str + plugin_id: int + plugin_modification_date: datetime + plugin_name: str + plugin_publication_date: datetime + plugin_type: str + port: int + protocol: Protocol + risk_factor: str + service: str + severity: int + solution: str + synopsis: str + + class Settings: + """Beanie settings.""" + + name = "vuln_scans" + indexes = ScanDoc.Abstract_Settings.indexes + [ + IndexModel( + [("owner", ASCENDING), ("latest", ASCENDING), ("severity", ASCENDING)], + name="owner_latest_severity", + ), + ] diff --git a/src/cyhy_db/py.typed b/src/cyhy_db/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/src/cyhy_db/utils/__init__.py b/src/cyhy_db/utils/__init__.py new file mode 100644 index 0000000..72d8685 --- /dev/null +++ b/src/cyhy_db/utils/__init__.py @@ -0,0 +1,6 @@ +"""Utility functions for cyhy_db.""" + +from .decorators import deprecated +from .time import utcnow + +__all__ = ["deprecated", "utcnow"] diff --git a/src/cyhy_db/utils/decorators.py b/src/cyhy_db/utils/decorators.py new file mode 100644 index 0000000..3c64523 --- /dev/null +++ b/src/cyhy_db/utils/decorators.py @@ -0,0 +1,22 @@ +"""Decorators for the cyhy_db package.""" + +# Standard Python Libraries +import warnings + + +def deprecated(reason): + """Mark a function as deprecated.""" + + def decorator(func): + if isinstance(reason, str): + message = f"{func.__name__} is deprecated and will be removed in a future version. {reason}" + else: + message = f"{func.__name__} is deprecated and will be removed in a future version." + + def wrapper(*args, **kwargs): + warnings.warn(message, DeprecationWarning, stacklevel=2) + return func(*args, **kwargs) + + return wrapper + + return decorator diff --git a/src/cyhy_db/utils/time.py b/src/cyhy_db/utils/time.py new file mode 100644 index 0000000..3997df2 --- /dev/null +++ b/src/cyhy_db/utils/time.py @@ -0,0 +1,15 @@ +"""Utility functions for working with time and dates.""" + +# Standard Python Libraries +from datetime import datetime, timezone + + +def utcnow() -> datetime: + """Return a timezone-aware datetime object with the current time in UTC. + + This is useful for default value factories in Beanie models. + + Returns: + datetime: The current time in UTC + """ + return datetime.now(timezone.utc) diff --git a/src/example/__main__.py b/src/example/__main__.py deleted file mode 100644 index 11a3238..0000000 --- a/src/example/__main__.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Code to run if this package is used as a Python module.""" - -from .example import main - -main() diff --git a/src/example/_version.py b/src/example/_version.py deleted file mode 100644 index 3252c71..0000000 --- a/src/example/_version.py +++ /dev/null @@ -1,3 +0,0 @@ -"""This file defines the version of this module.""" - -__version__ = "0.2.1" diff --git a/src/example/data/secret.txt b/src/example/data/secret.txt deleted file mode 100644 index c40a49b..0000000 --- a/src/example/data/secret.txt +++ /dev/null @@ -1 +0,0 @@ -Three may keep a secret, if two of them are dead. diff --git a/src/example/example.py b/src/example/example.py deleted file mode 100644 index d3eda19..0000000 --- a/src/example/example.py +++ /dev/null @@ -1,103 +0,0 @@ -"""example is an example Python library and tool. - -Divide one integer by another and log the result. Also log some information -from an environment variable and a package resource. - -EXIT STATUS - This utility exits with one of the following values: - 0 Calculation completed successfully. - >0 An error occurred. - -Usage: - example [--log-level=LEVEL] - example (-h | --help) - -Options: - -h --help Show this message. - --log-level=LEVEL If specified, then the log level will be set to - the specified value. Valid values are "debug", "info", - "warning", "error", and "critical". [default: info] -""" - -# Standard Python Libraries -import logging -import os -import sys -from typing import Any, Dict - -# Third-Party Libraries -import docopt -import pkg_resources -from schema import And, Schema, SchemaError, Use - -from ._version import __version__ - -DEFAULT_ECHO_MESSAGE: str = "Hello World from the example default!" - - -def example_div(dividend: int, divisor: int) -> float: - """Print some logging messages.""" - logging.debug("This is a debug message") - logging.info("This is an info message") - logging.warning("This is a warning message") - logging.error("This is an error message") - logging.critical("This is a critical message") - return dividend / divisor - - -def main() -> None: - """Set up logging and call the example function.""" - args: Dict[str, str] = docopt.docopt(__doc__, version=__version__) - # Validate and convert arguments as needed - schema: Schema = Schema( - { - "--log-level": And( - str, - Use(str.lower), - lambda n: n in ("debug", "info", "warning", "error", "critical"), - error="Possible values for --log-level are " - + "debug, info, warning, error, and critical.", - ), - "": Use(int, error=" must be an integer."), - "": And( - Use(int), - lambda n: n != 0, - error=" must be an integer that is not 0.", - ), - str: object, # Don't care about other keys, if any - } - ) - - try: - validated_args: Dict[str, Any] = schema.validate(args) - except SchemaError as err: - # Exit because one or more of the arguments were invalid - print(err, file=sys.stderr) - sys.exit(1) - - # Assign validated arguments to variables - dividend: int = validated_args[""] - divisor: int = validated_args[""] - log_level: str = validated_args["--log-level"] - - # Set up logging - logging.basicConfig( - format="%(asctime)-15s %(levelname)s %(message)s", level=log_level.upper() - ) - - logging.info("%d / %d == %f", dividend, divisor, example_div(dividend, divisor)) - - # Access some data from an environment variable - message: str = os.getenv("ECHO_MESSAGE", DEFAULT_ECHO_MESSAGE) - logging.info('ECHO_MESSAGE="%s"', message) - - # Access some data from our package data (see the setup.py) - secret_message: str = ( - pkg_resources.resource_string("example", "data/secret.txt") - .decode("utf-8") - .strip() - ) - logging.info('Secret="%s"', secret_message) - - # Stop logging and clean up - logging.shutdown() diff --git a/tests/conftest.py b/tests/conftest.py index ba89c85..a1197c3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,15 +3,176 @@ https://docs.pytest.org/en/latest/writing_plugins.html#conftest-py-plugins """ +# Standard Python Libraries +import asyncio +import os +import time + # Third-Party Libraries +import docker +from motor.core import AgnosticClient import pytest +# cisagov Libraries +from cyhy_db import initialize_db + +MONGO_INITDB_ROOT_USERNAME = os.environ.get("MONGO_INITDB_ROOT_USERNAME", "mongoadmin") +MONGO_INITDB_ROOT_PASSWORD = os.environ.get("MONGO_INITDB_ROOT_PASSWORD", "secret") +DATABASE_NAME = os.environ.get("DATABASE_NAME", "test") +MONGO_EXPRESS_PORT = os.environ.get("MONGO_EXPRESS_PORT", 8081) + +# Set the default event loop policy to be compatible with asyncio +AgnosticClient.get_io_loop = asyncio.get_running_loop + + +@pytest.fixture(autouse=True) +def group_github_log_lines(request): + """Group log lines when running in GitHub actions.""" + # Group output from each test with workflow log groups + # https://help.github.com/en/actions/reference/workflow-commands-for-github-actions#grouping-log-lines + + if os.environ.get("GITHUB_ACTIONS") != "true": + # Not running in GitHub actions + yield + return + # Group using the current test name + print() + print(f"::group::{request.node.name}") + yield + print() + print("::endgroup::") + + +@pytest.fixture(scope="session") +def docker_client(): + """Fixture for the Docker client.""" + yield docker.from_env() + + +@pytest.fixture(scope="session") +def mongodb_container(docker_client, mongo_image_tag): + """Fixture for the MongoDB test container.""" + container = docker_client.containers.run( + mongo_image_tag, + detach=True, + environment={ + "MONGO_INITDB_ROOT_USERNAME": MONGO_INITDB_ROOT_USERNAME, + "MONGO_INITDB_ROOT_PASSWORD": MONGO_INITDB_ROOT_PASSWORD, + }, + name="mongodb", + ports={"27017/tcp": None}, + volumes={}, + healthcheck={ + "test": ["CMD", "mongosh", "--eval", "'db.runCommand(\"ping\").ok'"], + "interval": 1000000000, # ns -> 1 second + "timeout": 1000000000, # ns -> 1 second + "retries": 5, + "start_period": 3000000000, # ns -> 3 seconds + }, + ) + TIMEOUT = 180 + # Wait for container to be healthy + for _ in range(TIMEOUT): + # Verify the container is still running + container.reload() + assert container.status == "running", "The container unexpectedly exited." + status = container.attrs["State"]["Health"]["Status"] + if status == "healthy": + break + time.sleep(1) + else: + assert ( + False + ), f"Container status did not transition to 'healthy' within {TIMEOUT} seconds." + + yield container + container.stop() + container.remove(force=True) + + +@pytest.fixture(autouse=True, scope="session") +def mongo_express_container(docker_client, db_uri, request): + """Fixture for the Mongo Express test container.""" + if not request.config.getoption("--mongo-express"): + yield None + return + + # Configuration for Mongo Express + mongo_express_container = docker_client.containers.run( + "mongo-express", + environment={ + "ME_CONFIG_MONGODB_ADMINUSERNAME": MONGO_INITDB_ROOT_USERNAME, + "ME_CONFIG_MONGODB_ADMINPASSWORD": MONGO_INITDB_ROOT_PASSWORD, + "ME_CONFIG_MONGODB_SERVER": "mongodb", + "ME_CONFIG_MONGODB_ENABLE_ADMIN": "true", + }, + links={"mongodb": "mongodb"}, + ports={"8081/tcp": 8081}, + detach=True, + ) + + def fin(): + if request.config.getoption("--mongo-express"): + print( + f'\n\nMongoDB is accessible at {db_uri} with database named "{DATABASE_NAME}"' + ) + print( + f"Mongo Express is accessible at http://admin:pass@localhost:{MONGO_EXPRESS_PORT}\n" + ) + input("Press Enter to stop Mongo Express and MongoDB containers...") + mongo_express_container.stop() + mongo_express_container.remove(force=True) + + request.addfinalizer(fin) + yield mongo_express_container + + +@pytest.fixture(scope="session") +def db_uri(mongodb_container): + """Fixture for the database URI.""" + mongo_port = mongodb_container.attrs["NetworkSettings"]["Ports"]["27017/tcp"][0][ + "HostPort" + ] + uri = f"mongodb://{MONGO_INITDB_ROOT_USERNAME}:{MONGO_INITDB_ROOT_PASSWORD}@localhost:{mongo_port}" + yield uri + + +@pytest.fixture(scope="session") +def db_name(mongodb_container): + """Fixture for the database name.""" + yield DATABASE_NAME + + +@pytest.fixture(autouse=True, scope="session") +async def db_client(db_uri): + """Fixture for client init.""" + print(f"Connecting to {db_uri}") + await initialize_db(db_uri, DATABASE_NAME) + def pytest_addoption(parser): """Add new commandline options to pytest.""" parser.addoption( "--runslow", action="store_true", default=False, help="run slow tests" ) + parser.addoption( + "--mongo-image-tag", + action="store", + default="docker.io/mongo:latest", + help="mongodb image tag to use for testing", + ) + parser.addoption( + "--mongo-express", + action="store_true", + default=False, + help="run Mongo Express for database inspection", + ) + + +@pytest.fixture(scope="session") +def mongo_image_tag(request): + """Get the image tag to test.""" + return request.config.getoption("--mongo-image-tag") def pytest_configure(config): diff --git a/tests/test_connection.py b/tests/test_connection.py new file mode 100644 index 0000000..3ddda07 --- /dev/null +++ b/tests/test_connection.py @@ -0,0 +1,64 @@ +"""Test database connection.""" + +# Standard Python Libraries +from unittest.mock import AsyncMock, patch + +# Third-Party Libraries +from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase +import pytest + +# cisagov Libraries +from cyhy_db.db import ALL_MODELS, initialize_db +from cyhy_db.models import CVEDoc + + +async def test_connection_motor(db_uri, db_name): + """Test connectivity to database.""" + client = AsyncIOMotorClient(db_uri) + db = client[db_name] + server_info = await db.command("ping") + assert server_info["ok"] == 1.0, "Direct database ping failed" + + +async def test_connection_beanie(): + """Test a simple database query.""" + # Attempt to find a document in the empty CVE collection + result = await CVEDoc.get("CVE-2024-DOES-NOT-EXIST") + assert result is None, "Expected no document to be found" + + +# Confused about the order of patch statements relative to the order of the test +# function parameters? See here: +# https://docs.python.org/3/library/unittest.mock.html#patch +@patch("cyhy_db.db.init_beanie", return_value=None) # mock_init_beanie +@patch("cyhy_db.db.AsyncIOMotorClient") # mock_async_iomotor_client +@pytest.mark.asyncio +async def test_initialize_db_success(mock_async_iomotor_client, mock_init_beanie): + """Test a success case of the initialize_db function.""" + db_uri = "mongodb://localhost:27017" + db_name = "test_db" + + mock_client = AsyncMock() + mock_db = AsyncMock(AsyncIOMotorDatabase) + mock_client.__getitem__.return_value = mock_db + mock_async_iomotor_client.return_value = mock_client + + db = await initialize_db(db_uri, db_name) + assert db == mock_db + mock_async_iomotor_client.assert_called_once_with(db_uri) + mock_init_beanie.assert_called_once_with( + database=mock_db, document_models=ALL_MODELS + ) + + +@pytest.mark.asyncio +async def test_initialize_db_failure(): + """Test a failure case of the initialize_db function.""" + db_uri = "mongodb://localhost:27017" + db_name = "test_db" + + with patch( + "cyhy_db.db.AsyncIOMotorClient", side_effect=Exception("Connection error") + ): + with pytest.raises(Exception, match="Connection error"): + await initialize_db(db_uri, db_name) diff --git a/tests/test_cve_doc.py b/tests/test_cve_doc.py new file mode 100644 index 0000000..5dfe5a9 --- /dev/null +++ b/tests/test_cve_doc.py @@ -0,0 +1,50 @@ +"""Test CVE model functionality.""" + +# Third-Party Libraries +from pydantic import ValidationError +import pytest + +# cisagov Libraries +from cyhy_db.models import CVEDoc +from cyhy_db.models.enum import CVSSVersion + +severity_params = [ + (CVSSVersion.V2, 10, 4), + (CVSSVersion.V2, 7.0, 3), + (CVSSVersion.V2, 4.0, 2), + (CVSSVersion.V2, 0.0, 1), + (CVSSVersion.V3, 9.0, 4), + (CVSSVersion.V3, 7.0, 3), + (CVSSVersion.V3, 4.0, 2), + (CVSSVersion.V3, 0.0, 1), + (CVSSVersion.V3_1, 9.0, 4), + (CVSSVersion.V3_1, 7.0, 3), + (CVSSVersion.V3_1, 4.0, 2), + (CVSSVersion.V3_1, 0.0, 1), +] + + +@pytest.mark.parametrize("version, score, expected_severity", severity_params) +def test_calculate_severity(version, score, expected_severity): + """Test that the severity is calculated correctly.""" + cve = CVEDoc(id="CVE-2024-0128", cvss_version=version, cvss_score=score) + assert ( + cve.severity == expected_severity + ), f"Failed for CVSS {version} with score {score}" + + +@pytest.mark.parametrize("bad_score", [-1.0, 11.0]) +def test_invalid_cvss_score(bad_score): + """Test that an invalid CVSS score raises a ValueError.""" + with pytest.raises(ValidationError): + CVEDoc(cvss_version=CVSSVersion.V3_1, cvss_score=bad_score, id="test-cve") + + +async def test_save(): + """Test that the severity is calculated correctly on save.""" + cve = CVEDoc(cvss_version=CVSSVersion.V3_1, cvss_score=9.0, id="test-cve") + await cve.save() # Saving the object + saved_cve = await CVEDoc.get("test-cve") # Retrieving the object + + assert saved_cve is not None, "CVE not saved correctly" + assert saved_cve.severity == 4, "Severity not calculated correctly on save" diff --git a/tests/test_data_generator.py b/tests/test_data_generator.py new file mode 100644 index 0000000..a61604a --- /dev/null +++ b/tests/test_data_generator.py @@ -0,0 +1,242 @@ +""" +This module generates test data for CyHy reports using factory classes. + +It includes factories for creating instances of various models such as CVE, +Agency, Contact, Location, Window, and RequestDoc. Additionally, it provides a +custom provider for generating specific data like CVE IDs and IPv4 networks. +""" + +# Standard Python Libraries +from datetime import datetime +import ipaddress +import random + +# Third-Party Libraries +import factory +from mimesis import Generic +from mimesis.locales import DEFAULT_LOCALE +from mimesis.providers.base import BaseProvider +from pytest_factoryboy import register + +# cisagov Libraries +from cyhy_db.models import CVEDoc, RequestDoc +from cyhy_db.models.enum import ( + AgencyType, + CVSSVersion, + DayOfWeek, + PocType, + ReportPeriod, + ScanType, + Scheduler, + Stage, +) +from cyhy_db.models.request_doc import Agency, Contact, Location, Window +from cyhy_db.utils import utcnow + + +class CyHyProvider(BaseProvider): + """Custom provider for generating specific CyHy data.""" + + class Meta: + """Meta class for CyHyProvider.""" + + name = "cyhy_provider" + + def cve_id(self, year=None): + """ + Generate a CVE ID. + + Args: + year (int, optional): The year for the CVE ID. If None, a random + year between 1999 and the current year is used. + + Returns: + str: A CVE ID in the format CVE-YYYY-NNNNN. + """ + if year is None: + year = self.random.randint(1999, datetime.now().year) + number = self.random.randint(1, 99999) + return f"CVE-{year}-{number:05d}" + + def network_ipv4(self): + """ + Generate an IPv4 network. + + Returns: + ipaddress.IPv4Network: A randomly generated IPv4 network. + """ + base_ip = generic.internet.ip_v4() + # The following line generates a warning from bandit about "Standard + # pseudo-random generators are not suitable for security/cryptographic + # purposes." We aren't using Random() for the purposes of cryptography + # here, so we can safely ignore that warning. + cidr = random.randint(24, 30) # nosec B311 + network = ipaddress.IPv4Network(f"{base_ip}/{cidr}", strict=False) + return network + + +generic = Generic(locale=DEFAULT_LOCALE) +generic.add_provider(CyHyProvider) + + +@register +class CVEFactory(factory.Factory): + """Factory for creating CVE instances.""" + + class Meta: + """Meta class for CVEFactory.""" + + model = CVEDoc + + id = factory.LazyFunction(lambda: generic.cyhy_provider.cve_id()) + # The following lines generate warnings from bandit about "Standard + # pseudo-random generators are not suitable for security/cryptographic + # purposes." We aren't using Random() for the purposes of cryptography + # here, so we can safely ignore those warnings. + cvss_score = factory.LazyFunction( + lambda: round(random.uniform(0, 10), 1) # nosec B311 + ) + cvss_version = factory.LazyFunction( + lambda: random.choice(list(CVSSVersion)) # nosec B311 + ) + severity = factory.LazyFunction(lambda: random.randint(1, 4)) # nosec B311 + + +class AgencyFactory(factory.Factory): + """Factory for creating Agency instances.""" + + class Meta: + """Meta class for AgencyFactory.""" + + model = Agency + + name = factory.Faker("company") + acronym = factory.LazyAttribute( + lambda o: "".join(word[0].upper() for word in o.name.split()) + ) + # The following lines generate warnings from bandit about "Standard + # pseudo-random generators are not suitable for security/cryptographic + # purposes." We aren't using Random() for the purposes of cryptography + # here, so we can safely ignore those warnings. + type = factory.LazyFunction(lambda: random.choice(list(AgencyType))) # nosec B311 + contacts = factory.LazyFunction( + lambda: [ContactFactory() for _ in range(random.randint(1, 5))] # nosec B311 + ) + location = factory.LazyFunction(lambda: LocationFactory()) + + +class ContactFactory(factory.Factory): + """Factory for creating Contact instances.""" + + class Meta: + """Meta class for ContactFactory.""" + + model = Contact + + email = factory.Faker("email") + name = factory.Faker("name") + phone = factory.Faker("phone_number") + # The following line generates a warning from bandit about "Standard + # pseudo-random generators are not suitable for security/cryptographic + # purposes." We aren't using Random() for the purposes of cryptography + # here, so we can safely ignore that warning. + type = factory.LazyFunction(lambda: random.choice(list(PocType))) # nosec B311 + + +class LocationFactory(factory.Factory): + """Factory for creating Location instances.""" + + class Meta: + """Meta class for LocationFactory.""" + + model = Location + + country_name = factory.Faker("country") + country = factory.Faker("country_code") + county_fips = factory.Faker("numerify", text="##") + county = factory.Faker("city") + gnis_id = factory.Faker("numerify", text="#######") + name = factory.Faker("city") + state_fips = factory.Faker("numerify", text="##") + state_name = factory.Faker("state") + state = factory.Faker("state_abbr") + + +class WindowFactory(factory.Factory): + """Factory for creating Window instances.""" + + class Meta: + """Meta class for WindowFactory.""" + + model = Window + + # The following lines generate warnings from bandit about "Standard + # pseudo-random generators are not suitable for security/cryptographic + # purposes." We aren't using Random() for the purposes of cryptography + # here, so we can safely ignore those warnings. + day = factory.LazyFunction(lambda: random.choice(list(DayOfWeek))) # nosec B311 + duration = factory.LazyFunction(lambda: random.randint(0, 168)) # nosec B311 + start = factory.Faker("time", pattern="%H:%M:%S") + + +class RequestDocFactory(factory.Factory): + """Factory for creating RequestDoc instances.""" + + class Meta: + """Meta class for RequestDocFactory.""" + + model = RequestDoc + + # The following lines generate warnings from bandit about "Standard + # pseudo-random generators are not suitable for security/cryptographic + # purposes." We aren't using Random() for the purposes of cryptography + # here, so we can safely ignore those warnings. + id = factory.LazyAttribute( + lambda o: o.agency.acronym + "-" + str(random.randint(1, 1000)) # nosec B311 + ) + agency = factory.SubFactory(AgencyFactory) + enrolled = factory.LazyFunction(utcnow) + init_stage = factory.LazyFunction(lambda: random.choice(list(Stage))) # nosec B311 + key = factory.Faker("password") + period_start = factory.LazyFunction(utcnow) + report_period = factory.LazyFunction( + lambda: random.choice(list(ReportPeriod)) # nosec B311 + ) + retired = factory.LazyFunction(lambda: random.choice([True, False])) # nosec B311 + scheduler = factory.LazyFunction( + lambda: random.choice(list(Scheduler)) # nosec B311 + ) + stakeholder = factory.LazyFunction( + lambda: random.choice([True, False]) # nosec B311 + ) + windows = factory.LazyFunction( + lambda: [WindowFactory() for _ in range(random.randint(1, 5))] # nosec B311 + ) + networks = factory.LazyFunction( + lambda: [ + generic.cyhy_provider.network_ipv4() + for _ in range(random.randint(1, 5)) # nosec B311 + ] + ) + scan_types = factory.LazyFunction( + lambda: { + random.choice(list(ScanType)) # nosec B311 + for _ in range(random.randint(1, 3)) # nosec B311 + } + ) + + +async def test_create_cves(): + """Test function to create and save 100 CVE instances.""" + for _ in range(100): + cve = CVEFactory() + print(cve) + await cve.save() + + +async def test_create_request_docs(): + """Test function to create and save 100 RequestDoc instances.""" + for _ in range(100): + request_doc = RequestDocFactory() + print(request_doc) + await request_doc.save() diff --git a/tests/test_decorators.py b/tests/test_decorators.py new file mode 100644 index 0000000..3f02638 --- /dev/null +++ b/tests/test_decorators.py @@ -0,0 +1,39 @@ +"""Test utils/decorators.py functionality.""" + +# Third-Party Libraries +import pytest + +# cisagov Libraries +from cyhy_db.utils.decorators import deprecated + + +def test_deprecated_decorator_with_reason(): + """Test the deprecated decorator with a reason.""" + + @deprecated("Use another function") + def old_function(): + """Impersonate a deprecated function.""" + return "result" + + with pytest.warns( + DeprecationWarning, + match="old_function is deprecated and will be removed in a future version. Use another function", + ): + result = old_function() + assert result == "result" + + +def test_deprecated_decorator_without_reason(): + """Test the deprecated decorator without a reason.""" + + @deprecated(None) + def old_function(): + """Impersonate a deprecated function.""" + return "result" + + with pytest.warns( + DeprecationWarning, + match="old_function is deprecated and will be removed in a future version.", + ): + result = old_function() + assert result == "result" diff --git a/tests/test_example.py b/tests/test_example.py deleted file mode 100644 index f8dea67..0000000 --- a/tests/test_example.py +++ /dev/null @@ -1,144 +0,0 @@ -#!/usr/bin/env pytest -vs -"""Tests for example.""" - -# Standard Python Libraries -import logging -import os -import sys -from unittest.mock import patch - -# Third-Party Libraries -import pytest - -# cisagov Libraries -import example - -div_params = [ - (1, 1, 1), - (2, 2, 1), - (0, 1, 0), - (8, 2, 4), -] - -log_levels = ( - "debug", - "info", - "warning", - "error", - "critical", -) - -# define sources of version strings -RELEASE_TAG = os.getenv("RELEASE_TAG") -PROJECT_VERSION = example.__version__ - - -def test_stdout_version(capsys): - """Verify that version string sent to stdout agrees with the module version.""" - with pytest.raises(SystemExit): - with patch.object(sys, "argv", ["bogus", "--version"]): - example.example.main() - captured = capsys.readouterr() - assert ( - captured.out == f"{PROJECT_VERSION}\n" - ), "standard output by '--version' should agree with module.__version__" - - -def test_running_as_module(capsys): - """Verify that the __main__.py file loads correctly.""" - with pytest.raises(SystemExit): - with patch.object(sys, "argv", ["bogus", "--version"]): - # F401 is a "Module imported but unused" warning. This import - # emulates how this project would be run as a module. The only thing - # being done by __main__ is importing the main entrypoint of the - # package and running it, so there is nothing to use from this - # import. As a result, we can safely ignore this warning. - # cisagov Libraries - import example.__main__ # noqa: F401 - captured = capsys.readouterr() - assert ( - captured.out == f"{PROJECT_VERSION}\n" - ), "standard output by '--version' should agree with module.__version__" - - -@pytest.mark.skipif( - RELEASE_TAG in [None, ""], reason="this is not a release (RELEASE_TAG not set)" -) -def test_release_version(): - """Verify that release tag version agrees with the module version.""" - assert ( - RELEASE_TAG == f"v{PROJECT_VERSION}" - ), "RELEASE_TAG does not match the project version" - - -@pytest.mark.parametrize("level", log_levels) -def test_log_levels(level): - """Validate commandline log-level arguments.""" - with patch.object(sys, "argv", ["bogus", f"--log-level={level}", "1", "1"]): - with patch.object(logging.root, "handlers", []): - assert ( - logging.root.hasHandlers() is False - ), "root logger should not have handlers yet" - return_code = None - try: - example.example.main() - except SystemExit as sys_exit: - return_code = sys_exit.code - assert return_code is None, "main() should return success" - assert ( - logging.root.hasHandlers() is True - ), "root logger should now have a handler" - assert ( - logging.getLevelName(logging.root.getEffectiveLevel()) == level.upper() - ), f"root logger level should be set to {level.upper()}" - assert return_code is None, "main() should return success" - - -def test_bad_log_level(): - """Validate bad log-level argument returns error.""" - with patch.object(sys, "argv", ["bogus", "--log-level=emergency", "1", "1"]): - return_code = None - try: - example.example.main() - except SystemExit as sys_exit: - return_code = sys_exit.code - assert return_code == 1, "main() should exit with error" - - -@pytest.mark.parametrize("dividend, divisor, quotient", div_params) -def test_division(dividend, divisor, quotient): - """Verify division results.""" - result = example.example_div(dividend, divisor) - assert result == quotient, "result should equal quotient" - - -@pytest.mark.slow -def test_slow_division(): - """Example of using a custom marker. - - This test will only be run if --runslow is passed to pytest. - Look in conftest.py to see how this is implemented. - """ - # Standard Python Libraries - import time - - result = example.example_div(256, 16) - time.sleep(4) - assert result == 16, "result should equal be 16" - - -def test_zero_division(): - """Verify that division by zero throws the correct exception.""" - with pytest.raises(ZeroDivisionError): - example.example_div(1, 0) - - -def test_zero_divisor_argument(): - """Verify that a divisor of zero is handled as expected.""" - with patch.object(sys, "argv", ["bogus", "1", "0"]): - return_code = None - try: - example.example.main() - except SystemExit as sys_exit: - return_code = sys_exit.code - assert return_code == 1, "main() should exit with error" diff --git a/tests/test_host_doc.py b/tests/test_host_doc.py new file mode 100644 index 0000000..ef20b54 --- /dev/null +++ b/tests/test_host_doc.py @@ -0,0 +1,89 @@ +"""Test HostDoc model functionality.""" + +# Standard Python Libraries +from ipaddress import ip_address + +# cisagov Libraries +from cyhy_db.models import HostDoc +from cyhy_db.models.host_doc import State + +VALID_IP_1_STR = "0.0.0.1" +VALID_IP_2_STR = "0.0.0.2" +VALID_IP_1_INT = int(ip_address(VALID_IP_1_STR)) +VALID_IP_2_INT = int(ip_address(VALID_IP_2_STR)) + + +def test_host_doc_init(): + """Test HostDoc object initialization.""" + # Create a HostDoc object + host_doc = HostDoc( + ip=ip_address(VALID_IP_1_STR), + owner="YOUR_MOM", + ) + # Check that the HostDoc object was created correctly + assert host_doc.ip == ip_address(VALID_IP_1_STR) + + +async def test_save(): + """Test saving a HostDoc object to the database.""" + # Create a HostDoc object + host_doc = HostDoc( + ip=ip_address(VALID_IP_1_STR), + owner="YOUR_MOM", + ) + # Save the HostDoc object to the database + await host_doc.save() + assert host_doc.id == VALID_IP_1_INT + + +async def test_get_by_ip(): + """Test finding a HostDoc object by its IP address.""" + # Find a HostDoc object by its IP address + host_doc = await HostDoc.get_by_ip(ip_address(VALID_IP_1_STR)) + assert host_doc.ip == ip_address(VALID_IP_1_STR) + + +async def test_set_state_open_ports(): + """Test setting HostDoc state with open ports.""" + # Find a HostDoc object by its IP address + host_doc = await HostDoc.get_by_ip(ip_address(VALID_IP_1_STR)) + host_doc.set_state(nmap_says_up=None, has_open_ports=True) + assert host_doc.state == State(up=True, reason="open-port") + + +async def test_set_state_no_open_ports(): + """Test setting HostDoc state with no open ports.""" + # Find a HostDoc object by its IP address + host_doc = await HostDoc.get_by_ip(ip_address(VALID_IP_1_STR)) + host_doc.set_state(nmap_says_up=None, has_open_ports=False) + assert host_doc.state == State(up=False, reason="no-open") + + +async def test_set_state_nmap_says_down(): + """Test setting HostDoc state when nmap says the host is down.""" + # Find a HostDoc object by its IP address + host_doc = await HostDoc.get_by_ip(ip_address(VALID_IP_1_STR)) + host_doc.set_state(nmap_says_up=False, has_open_ports=None, reason="no-reply") + assert host_doc.state == State(up=False, reason="no-reply") + + +async def test_set_state_no_op(): + """Test setting HostDoc state when inputs are supplied that results in no state change.""" + # Create a HostDoc object + host_doc = HostDoc( + ip=ip_address(VALID_IP_2_STR), + owner="NO-OP", + ) + # Save the HostDoc object to the database + await host_doc.save() + assert host_doc.id == VALID_IP_2_INT + + # Find HostDoc object by its IP address + host_doc = await HostDoc.get_by_ip(ip_address(VALID_IP_2_INT)) + assert host_doc.state == State(up=False, reason="new") + + host_doc.set_state(nmap_says_up=True, has_open_ports=None, reason="no-op-test-1") + assert host_doc.state == State(up=False, reason="new") + + host_doc.set_state(nmap_says_up=None, has_open_ports=None, reason="no-op-test-2") + assert host_doc.state == State(up=False, reason="new") diff --git a/tests/test_request_doc.py b/tests/test_request_doc.py new file mode 100644 index 0000000..5837437 --- /dev/null +++ b/tests/test_request_doc.py @@ -0,0 +1,83 @@ +"""Test RequestDoc model functionality.""" + +# Standard Python Libraries +from datetime import time + +# Third-Party Libraries +import pytest + +# cisagov Libraries +from cyhy_db.models import RequestDoc +from cyhy_db.models.enum import ScanType +from cyhy_db.models.request_doc import Agency, ScanLimit, Window + + +async def test_init(): + """Test RequestDoc object initialization.""" + # Create a RequestDoc object + request_doc = RequestDoc( + agency=Agency( + name="Cybersecurity and Infrastructure Security Agency", acronym="CISA" + ) + ) + + await request_doc.save() + + # Verify that the id was set to the acronym + assert ( + request_doc.id == request_doc.agency.acronym + ), "id was not correctly set to agency acronym" + + +def test_parse_time_valid_time_str(): + """Test the parse_time validator with valid string input.""" + valid_time_str = "12:34:56" + parsed_time = Window.parse_time(valid_time_str) + assert parsed_time == time(12, 34, 56), "Failed to parse valid time string" + + +def test_parse_time_invalid_time_str(): + """Test the parse_time validator with invalid string input.""" + invalid_time_str = "invalid_time" + with pytest.raises( + ValueError, + match="time data 'invalid_time' does not match format '%H:%M:%S'", + ): + Window.parse_time(invalid_time_str) + + +def test_parse_time_valid_time_obj(): + """Test the parse_time validator with valid time input.""" + valid_time_obj = time(12, 34, 56) + parsed_time = Window.parse_time(valid_time_obj) + assert parsed_time == valid_time_obj, "Failed to parse valid time object" + + +def test_parse_time_invalid_type(): + """Test the parse_time validator with an invalid input type.""" + invalid_time_type = 12345 + with pytest.raises( + ValueError, + match="Invalid time format. Expected a string in '%H:%M:%S' format or datetime.time instance.", + ): + Window.parse_time(invalid_time_type) + + +async def test_scan_limit(): + """Test the ScanLimit model.""" + # Create a RequestDoc object + request_doc = RequestDoc( + agency=Agency(name="Office of Fragile Networking", acronym="OFN") + ) + + scan_limit = ScanLimit(scan_type=ScanType.CYHY, concurrent=1) + assert scan_limit.scan_type == ScanType.CYHY, "Scan type was not set correctly" + assert scan_limit.concurrent == 1, "Concurrent was not set correctly" + + request_doc.scan_limits.append(scan_limit) + assert ( + request_doc.scan_limits[0].scan_type == ScanType.CYHY + ), "Scan type was not set correctly" + await request_doc.save() + + # TODO complete this test diff --git a/tests/test_scan_doc.py b/tests/test_scan_doc.py new file mode 100644 index 0000000..4b69c97 --- /dev/null +++ b/tests/test_scan_doc.py @@ -0,0 +1,338 @@ +"""Test ScanDoc abstract base class model functionality.""" + +# Standard Python Libraries +import ipaddress + +# Third-Party Libraries +from pydantic import ValidationError +import pytest + +# cisagov Libraries +from cyhy_db.models import ScanDoc, SnapshotDoc +from cyhy_db.utils import utcnow + +VALID_IP_1_STR = "0.0.0.1" +VALID_IP_2_STR = "0.0.0.2" +VALID_IP_1_INT = int(ipaddress.ip_address(VALID_IP_1_STR)) +VALID_IP_2_INT = int(ipaddress.ip_address(VALID_IP_2_STR)) + +# Note: Running these tests will create a "ScanDoc" collection in the database. +# This collection is typically not created in a production environment since +# ScanDoc is an abstract base class. + + +def test_ip_int_init(): + """Test IP address integer conversion on initialization. + + This test verifies that the IP address is correctly converted to an + integer when a ScanDoc object is initialized. + """ + # Create a ScanDoc object + scan_doc = ScanDoc( + ip=ipaddress.ip_address(VALID_IP_1_STR), + owner="YOUR_MOM", + source="nmap", + ) + + assert scan_doc.ip_int == int( + ipaddress.ip_address(VALID_IP_1_STR) + ), "IP address integer was not calculated correctly on init" + + +def test_ip_int_change(): + """Test IP address integer conversion on IP address change. + + This test verifies that the IP address is correctly converted to an + integer when the IP address of a ScanDoc object is changed. + """ + # Create a ScanDoc object + scan_doc = ScanDoc( + ip=ipaddress.ip_address(VALID_IP_1_STR), + owner="YOUR_MOM", + source="nmap", + ) + + scan_doc.ip = ipaddress.ip_address(VALID_IP_2_STR) + + assert scan_doc.ip_int == int( + ipaddress.ip_address(VALID_IP_2_STR) + ), "IP address integer was not calculated correctly on change" + + +def test_ip_string_set(): + """Test IP address string conversion and integer calculation. + + This test verifies that an IP address provided as a string is correctly + converted to an ipaddress.IPv4Address object and that the corresponding + integer value is calculated correctly. + """ + scan_doc = ScanDoc( + ip=VALID_IP_1_STR, + owner="YOUR_MOM", + source="nmap", + ) + + assert isinstance( + scan_doc.ip, ipaddress.IPv4Address + ), "IP address was not converted" + assert scan_doc.ip_int == VALID_IP_1_INT, "IP address integer was not calculated" + + +async def test_ip_address_field_fetch(): + """Test IP address retrieval from the database. + + This test verifies that the IP address of a ScanDoc object is correctly + retrieved from the database. + """ + # Create a ScanDoc object + scan_doc = ScanDoc( + ip=ipaddress.ip_address(VALID_IP_1_STR), + owner="YOUR_MOM", + source="nmap", + ) + + # Save the ScanDoc object to the database + await scan_doc.save() + + # Retrieve the ScanDoc object from the database + retrieved_doc = await ScanDoc.get(scan_doc.id) + + # Assert that the retrieved IP address is equal to the one we saved + assert retrieved_doc.ip == ipaddress.ip_address( + VALID_IP_1_STR + ), "IP address does not match" + + assert retrieved_doc.ip_int == VALID_IP_1_INT, "IP address integer does not match" + + +def test_invalid_ip_address(): + """Test validation error for invalid IP addresses. + + This test verifies that a ValidationError is raised when an invalid IP + address is provided to a ScanDoc object. + """ + with pytest.raises(ValidationError): + ScanDoc( + ip="999.999.999.999", # This should be invalid + owner="owner_example", + source="source_example", + ) + + +async def test_reset_latest_flag_by_owner(): + """Test resetting the latest flag by owner. + + This test verifies that the latest flag of ScanDoc objects is correctly + reset when the reset_latest_flag_by_owner method is called. + """ + # Create a ScanDoc object + OWNER = "RESET_BY_OWNER" + scan_doc = ScanDoc( + ip=ipaddress.ip_address(VALID_IP_1_STR), owner=OWNER, source="nmap" + ) + await scan_doc.save() + # Check that the latest flag is set to True + assert scan_doc.latest is True + # Reset the latest flag + await ScanDoc.reset_latest_flag_by_owner(OWNER) + # Retrieve the ScanDoc object from the database + await scan_doc.sync() + # Check that the latest flag is set to False + assert scan_doc.latest is False + + +async def test_tag_latest_snapshot_doc(): + """Test tagging the latest scan with a SnapshotDoc. + + This test verifies that the latest ScanDoc object is correctly tagged with a + SnapshotDoc object when the tag_latest method is called with a SnapshotDoc. + """ + # Create a SnapshotDoc object + owner = "TAG_LATEST_SNAPSHOT_DOC" + snapshot_doc = SnapshotDoc( + owner=owner, + start_time=utcnow(), + end_time=utcnow(), + ) + await snapshot_doc.save() + # Create a ScanDoc object + scan_doc = ScanDoc( + ip=ipaddress.ip_address(VALID_IP_1_STR), + owner=owner, + source="nmap", + ) + await scan_doc.save() + + # Tag the latest scan + await ScanDoc.tag_latest([owner], snapshot_doc) + + # Retrieve the ScanDoc object from the database + scan_doc = await ScanDoc.find_one(ScanDoc.id == scan_doc.id, fetch_links=True) + + # Check that the scan now has a snapshot + assert scan_doc.snapshots == [snapshot_doc], "Snapshot not added to scan" + + +async def test_tag_latest_snapshot_id(): + """Test tagging the latest scan with a snapshot ObjectId. + + This test verifies that the latest ScanDoc object is correctly tagged with a + SnapshotDoc object when the tag_latest method is called with a snapshot + ObjectId. + """ + # Create a SnapshotDoc object + owner = "TAG_LATEST_SNAPSHOT_ID" + snapshot_doc = SnapshotDoc( + owner=owner, + start_time=utcnow(), + end_time=utcnow(), + ) + await snapshot_doc.save() + # Create a ScanDoc object + scan_doc = ScanDoc( + ip=ipaddress.ip_address(VALID_IP_1_STR), + owner=owner, + source="nmap", + ) + await scan_doc.save() + + # Tag the latest scan with the snapshot id + await ScanDoc.tag_latest([owner], snapshot_doc.id) + + # Retrieve the ScanDoc object from the database + scan_doc = await ScanDoc.find_one(ScanDoc.id == scan_doc.id, fetch_links=True) + + # Check that the scan now has a snapshot + assert scan_doc.snapshots == [snapshot_doc], "Snapshot not added to scan" + + +async def test_tag_latest_snapshot_id_str(): + """Test tagging the latest scan with the string representation of a snapshot ObjectId. + + This test verifies that the latest ScanDoc object is correctly tagged with a + SnapshotDoc object when the tag_latest method is called with the string + representation of a snapshot ObjectId. + """ + # Create a SnapshotDoc object + owner = "TAG_LATEST_SNAPSHOT_ID_STR" + snapshot_doc = SnapshotDoc( + owner=owner, + start_time=utcnow(), + end_time=utcnow(), + ) + await snapshot_doc.save() + # Create a ScanDoc object + scan_doc = ScanDoc( + ip=ipaddress.ip_address(VALID_IP_1_STR), + owner=owner, + source="nmap", + ) + await scan_doc.save() + + # Tag the latest scan with the string representation of the snapshot id + await ScanDoc.tag_latest([owner], str(snapshot_doc.id)) + + # Retrieve the ScanDoc object from the database + scan_doc = await ScanDoc.find_one(ScanDoc.id == scan_doc.id, fetch_links=True) + + # Check that the scan now has a snapshot + assert scan_doc.snapshots == [snapshot_doc], "Snapshot not added to scan" + + +async def test_tag_latest_invalid_type(): + """Test tagging the latest scan with an invalid object type.""" + owner = "TAG_LATEST_INVALID_TYPE" + scan_doc = ScanDoc( + ip=ipaddress.ip_address(VALID_IP_1_STR), + owner=owner, + source="nmap", + ) + await scan_doc.save() + + with pytest.raises(ValueError, match="Invalid snapshot type"): + # Attempt to tag the latest scan with an invalid object type + await ScanDoc.tag_latest([owner], 12345) + + # Retrieve the ScanDoc object from the database + scan_doc = await ScanDoc.find_one(ScanDoc.id == scan_doc.id, fetch_links=True) + + # Confirm that the scan does not have a snapshot + assert scan_doc.snapshots == [], "Scan should not have any snapshots" + + +async def test_reset_latest_flag_by_ip_single(): + """Test reset_latest_flag_by_ip with a single IP address.""" + owner = "RESET_FLAG_SINGLE_IP" + scan_doc = ScanDoc( + ip=ipaddress.ip_address(VALID_IP_1_STR), + owner=owner, + source="nmap", + ) + await scan_doc.save() + + # Reset the latest flag for a single IP address + await ScanDoc.reset_latest_flag_by_ip(scan_doc.ip) + + # Retrieve the ScanDoc object from the database + scan_doc = await ScanDoc.find_one(ScanDoc.id == scan_doc.id) + + # Check that the latest flag has been reset + assert ( + scan_doc.latest is False + ), "The latest flag was not reset for the single IP address" + + +async def test_reset_latest_flag_by_ip_list(): + """Test reset_latest_flag_by_ip with a list of IP addresses.""" + owner = "RESET_FLAG_IP_LIST" + scan_doc_1 = ScanDoc( + ip=ipaddress.ip_address(VALID_IP_1_STR), + owner=owner, + source="nmap", + ) + await scan_doc_1.save() + + scan_doc_2 = ScanDoc( + ip=ipaddress.ip_address(VALID_IP_2_STR), + owner=owner, + source="nmap", + ) + await scan_doc_2.save() + + # Reset the latest flag for a list of IP addresses + await ScanDoc.reset_latest_flag_by_ip([scan_doc_1.ip, scan_doc_2.ip]) + + # Retrieve the ScanDoc objects from the database + scan_doc_1 = await ScanDoc.find_one(ScanDoc.id == scan_doc_1.id) + scan_doc_2 = await ScanDoc.find_one(ScanDoc.id == scan_doc_2.id) + + # Check that the latest flag has been reset for both IP addresses + assert ( + scan_doc_1.latest is False + ), "The latest flag was not reset for the first IP address" + assert ( + scan_doc_2.latest is False + ), "The latest flag was not reset for the second IP address" + + +@pytest.mark.asyncio +async def test_reset_latest_flag_by_ip_empty_iterable(): + """Test reset_latest_flag_by_ip with an empty iterable.""" + owner = "RESET_FLAG_EMPTY_ITERABLE" + scan_doc = ScanDoc( + ip=ipaddress.ip_address(VALID_IP_1_STR), + owner=owner, + source="nmap", + ) + await scan_doc.save() + + # Reset the latest flag for an empty list of IP addresses + await ScanDoc.reset_latest_flag_by_ip([]) + + # Retrieve the ScanDoc object from the database + scan_doc = await ScanDoc.find_one(ScanDoc.id == scan_doc.id) + + # Check that the latest flag has not been modified + assert ( + scan_doc.latest is True + ), "The latest flag should remain True for empty iterable input" diff --git a/tests/test_system_control_doc.py b/tests/test_system_control_doc.py new file mode 100644 index 0000000..a3cd1dc --- /dev/null +++ b/tests/test_system_control_doc.py @@ -0,0 +1,54 @@ +"""Test SystemControlDoc model functionality.""" + +# Standard Python Libraries +from datetime import timedelta +from unittest.mock import AsyncMock, patch + +# cisagov Libraries +from cyhy_db.models.system_control_doc import SystemControlDoc +from cyhy_db.utils import utcnow + + +async def test_wait_for_completion_completed(): + """Test wait_for_completion when the document is completed.""" + document_id = "test_id" + mock_doc = AsyncMock() + mock_doc.completed = True + + with patch.object(SystemControlDoc, "get", return_value=mock_doc): + result = await SystemControlDoc.wait_for_completion(document_id) + assert result is True + + +async def test_wait_for_completion_timeout(): + """Test wait_for_completion when the document is not completed before the timeout.""" + document_id = "test_id" + mock_doc = AsyncMock() + mock_doc.completed = False + + with patch.object(SystemControlDoc, "get", return_value=mock_doc): + with patch( + "cyhy_db.models.system_control_doc.utcnow", + side_effect=[utcnow(), utcnow() + timedelta(seconds=10)], + ): + result = await SystemControlDoc.wait_for_completion(document_id, timeout=5) + assert result is False + + +async def test_wait_for_completion_no_timeout(): + """Test wait_for_completion when a timeout is not set.""" + document_id = "test_id" + mock_doc = AsyncMock() + mock_doc.completed = False + + async def side_effect(*args, **kwargs): + if side_effect.call_count == 2: + mock_doc.completed = True + side_effect.call_count += 1 + return mock_doc + + side_effect.call_count = 0 + + with patch.object(SystemControlDoc, "get", side_effect=side_effect): + result = await SystemControlDoc.wait_for_completion(document_id) + assert result is True diff --git a/tests/test_tally_doc.py b/tests/test_tally_doc.py new file mode 100644 index 0000000..b573772 --- /dev/null +++ b/tests/test_tally_doc.py @@ -0,0 +1,29 @@ +"""Test TallyDoc model functionality.""" + +# Standard Python Libraries +from datetime import datetime + +# cisagov Libraries +from cyhy_db.models.tally_doc import Counts, TallyDoc + + +async def test_tally_doc_creation(): + """Test TallyDoc creation.""" + tally_doc = TallyDoc(id="TALLY-TEST-1") + await tally_doc.insert() + fetched_doc = await TallyDoc.get(tally_doc.id) + assert fetched_doc is not None + assert fetched_doc.id == "TALLY-TEST-1" + assert fetched_doc.counts == Counts() + assert isinstance(fetched_doc.last_change, datetime) + + +async def test_tally_doc_last_change(): + """Test TallyDoc last_change update.""" + tally_doc = TallyDoc(id="TALLY-TEST-2") + await tally_doc.save() + initial_last_change = tally_doc.last_change + + # Save TallyDoc again to force the last_change timestamp to update + await tally_doc.save() + assert tally_doc.last_change > initial_last_change diff --git a/tests/test_ticket_doc.py b/tests/test_ticket_doc.py new file mode 100644 index 0000000..fc2678a --- /dev/null +++ b/tests/test_ticket_doc.py @@ -0,0 +1,425 @@ +"""Test TicketDoc model functionality.""" + +# Standard Python Libraries +from datetime import timedelta +from ipaddress import IPv4Address +from unittest.mock import AsyncMock, patch + +# Third-Party Libraries +from beanie import PydanticObjectId +import pytest + +# cisagov Libraries +from cyhy_db.models.enum import Protocol, TicketAction +from cyhy_db.models.exceptions import ( + PortScanNotFoundException, + VulnScanNotFoundException, +) +from cyhy_db.models.port_scan_doc import PortScanDoc +from cyhy_db.models.snapshot_doc import SnapshotDoc +from cyhy_db.models.ticket_doc import EventDelta, TicketDoc, TicketEvent +from cyhy_db.models.vuln_scan_doc import VulnScanDoc +from cyhy_db.utils.time import utcnow + +VALID_IP_1_STR = "0.0.0.1" +VALID_IP_1_INT = int(IPv4Address(VALID_IP_1_STR)) + + +def sample_ticket(): + """Create a sample TicketDoc object.""" + return TicketDoc( + ip_int=VALID_IP_1_INT, + ip=IPv4Address(VALID_IP_1_STR), + owner="TICKET-TEST-1", + port=80, + protocol=Protocol.TCP, + source_id=1, + source="test", + ) + + +def test_init(): + """Test TicketDoc object initialization.""" + # Create a TicketDoc object + ticket_doc = sample_ticket() + + # Verify that default values are set correctly + assert ticket_doc.details == {}, "details was not set to an empty dict" + assert ticket_doc.events == [], "events was not set to an empty list" + assert ticket_doc.false_positive is False, "false_positive was not set to False" + assert ticket_doc.last_change is not None, "last_change was not set" + assert ticket_doc.open is True, "open was not set to True" + assert ticket_doc.snapshots == [], "snapshots was not set to an empty list" + assert ticket_doc.time_closed is None, "time_closed was not set to None" + assert ticket_doc.time_opened is not None, "time_opened was not set" + + +async def test_save(): + """Test saving a TicketDoc object to the database.""" + # Create a TicketDoc object and save it to the DB + ticket_doc = sample_ticket() + await ticket_doc.save() + + # Find ticket in DB and confirm it was saved correctly + ticket_doc_db = await TicketDoc.find_one(TicketDoc.ip_int == VALID_IP_1_INT) + assert ticket_doc_db is not None, "ticket_doc was not saved to the database" + + +async def test_save_with_event(): + """Test saving a ticket that contains an event.""" + ticket_doc = await TicketDoc.find_one(TicketDoc.ip_int == VALID_IP_1_INT) + ticket_doc.set_false_positive( + new_state=True, reason="Test set false positive", expire_days=30 + ) + await ticket_doc.save() + + # Find ticket in DB and confirm it was saved correctly + ticket_doc_db = await TicketDoc.find_one(TicketDoc.ip_int == VALID_IP_1_INT) + assert ticket_doc_db is not None, "ticket_doc was not saved to the database" + assert ticket_doc_db.events[0].action == TicketAction.CHANGED + assert ticket_doc_db.events[0].delta.from_ is False + assert ticket_doc_db.events[0].delta.to is True + + +async def test_before_save(): + """Test the before_save method.""" + ticket_doc = sample_ticket() + ticket_doc.false_positive = True + ticket_doc.open = False + with pytest.raises( + Exception, match="A ticket marked as a false positive cannot be closed." + ): + await ticket_doc.save() + + +def test_add_event(): + """Test adding an event to a ticket.""" + ticket_doc = sample_ticket() + ticket_doc.add_event(action=TicketAction.OPENED, reason="Test reason") + assert len(ticket_doc.events) == 1, "event was not added to the ticket" + assert ticket_doc.events[0].action == TicketAction.OPENED + assert ticket_doc.events[0].reason == "Test reason" + + +def test_add_event_exception(): + """Test adding an invalid event to a ticket.""" + ticket_doc = sample_ticket() + with pytest.raises( + Exception, match='Invalid action "INVALID" cannot be added to ticket events.' + ): + ticket_doc.add_event(action="INVALID", reason="Test reason") + + +def test_set_false_positive_true(): + """Test setting a ticket as false positive.""" + ticket_doc = sample_ticket() + ticket_doc.set_false_positive( + new_state=True, reason="Test set false positive", expire_days=30 + ) + assert ticket_doc.false_positive is True, "ticket was not set as false positive" + assert ( + ticket_doc.fp_expiration_date is not None + ), "false positive expiration date was not set" + + +def test_set_false_positive_no_change(): + """Test setting a ticket that was already false positive to false positive.""" + ticket_doc = sample_ticket() + + ticket_doc.set_false_positive( + new_state=True, reason="Test set false positive", expire_days=30 + ) + fp_expiration_date = ticket_doc.fp_expiration_date + + ticket_doc.set_false_positive( + new_state=True, reason="Test set false positive again", expire_days=60 + ) + assert ( + ticket_doc.false_positive is True + ), "ticket should have remained a false positive" + assert ( + ticket_doc.fp_expiration_date == fp_expiration_date + ), "false positive expiration date should not have changed" + + +def test_set_false_positive_false(): + """Test setting a ticket as false positive false.""" + ticket_doc = sample_ticket() + ticket_doc.set_false_positive( + new_state=True, reason="Test set false positive true", expire_days=30 + ) + + assert ( + ticket_doc.fp_expiration_date is not None + ), "false positive expiration date was not set" + + ticket_doc.set_false_positive( + new_state=False, reason="Test set false positive false", expire_days=0 + ) + assert ( + ticket_doc.false_positive is False + ), "ticket should not still be false positive" + assert ( + ticket_doc.fp_expiration_date is None + ), "false positive expiration date was not cleared" + + +def test_set_false_positive_on_closed_ticket(): + """Test setting a closed ticket as false positive.""" + ticket_doc = sample_ticket() + # Close the ticket + ticket_doc.open = False + ticket_doc.time_closed = utcnow() + + ticket_doc.set_false_positive( + expire_days=30, + new_state=True, + reason="Test set false positive on closed ticket", + ) + assert ticket_doc.open is True, "ticket should have been reopened" + assert ticket_doc.false_positive is True, "ticket should be false positive" + assert ticket_doc.time_closed is None, "ticket should not have a time_closed" + assert ticket_doc.events[-2].action == TicketAction.REOPENED + assert ticket_doc.events[-2].reason == "setting false positive" + assert ticket_doc.events[-2].time is not None + + +def test_false_positive_dates(): + """Test getting false positive dates.""" + ticket_doc = sample_ticket() + + # Set ticket as false positive + ticket_doc.set_false_positive( + new_state=True, reason="Test set false positive", expire_days=30 + ) + fp_dates = ticket_doc.false_positive_dates() + assert fp_dates is not None + + # Add another sample event + ticket_doc.add_event( + action=TicketAction.UNVERIFIED, reason="Test reason", time=utcnow() + ) + assert ticket_doc.false_positive_dates() == fp_dates + + # Unset ticket as false positive + ticket_doc.false_positive = False + event = TicketEvent( + action=TicketAction.CHANGED, + delta=EventDelta(from_=True, to=False, key="false_positive"), + reason="Test false positive expired", + reference=None, + time=utcnow(), + ) + ticket_doc.events.append(event) + assert ticket_doc.false_positive_dates() is None + + +def test_false_positive_dates_edge_cases(): + """Test getting false positive dates edge cases.""" + ticket_doc = sample_ticket() + ticket_doc.false_positive = True + assert ticket_doc.false_positive_dates() is None + + # Add a sample non-false-positive CHANGED event + test_delta = EventDelta(from_=False, to=True, key="test_key") + ticket_doc.add_event( + action=TicketAction.CHANGED, + delta=test_delta, + reason="Test reason", + time=utcnow(), + ) + assert ticket_doc.false_positive_dates() is None + + +def test_last_detection_date(): + """Test getting the last detection date.""" + ticket_doc = sample_ticket() + ticket_doc.add_event( + action=TicketAction.OPENED, reason="Test reason", time=utcnow() + ) + detection_date = ticket_doc.last_detection_date() + assert detection_date == ticket_doc.events[0].time + + +def test_last_detection_date_edge_case(): + """Test an edge case of last_detection_date.""" + ticket_doc = sample_ticket() + ticket_doc.add_event( + action=TicketAction.CLOSED, reason="Test reason", time=utcnow() + ) + detection_date = ticket_doc.last_detection_date() + assert detection_date == ticket_doc.time_opened + + +async def test_tagging(): + """Test tag_open, tag_matching, and remove_tag.""" + # Find our test ticket in the DB + ticket_doc_db = await TicketDoc.find_one(TicketDoc.ip_int == VALID_IP_1_INT) + test_owner = ticket_doc_db.owner + assert len(ticket_doc_db.snapshots) == 0 + + # Create a test snapshot and save it to the DB + snapshot_end_time = utcnow() + snapshot_start_time = snapshot_end_time - timedelta(days=1) + test_snapshot_1 = SnapshotDoc( + owner=test_owner, end_time=snapshot_end_time, start_time=snapshot_start_time + ) + await test_snapshot_1.save() + assert test_snapshot_1 not in ticket_doc_db.snapshots + + # Use tag_open() to tag the ticket with the snapshot ID + await TicketDoc.tag_open(owners=[test_owner], snapshot_oid=test_snapshot_1.id) + + updated_ticket = await TicketDoc.find_one(TicketDoc.ip_int == VALID_IP_1_INT) + # I'm not using fetch_links=True in the find_one() above because I can't get + # it to work correctly. Instead, I'm using fetch_all_links() below. + await updated_ticket.fetch_all_links() + assert len(updated_ticket.snapshots) == 1 + assert test_snapshot_1 in updated_ticket.snapshots + + # Create another test snapshot and save it to the DB + snapshot_end_time = utcnow() + snapshot_start_time = snapshot_end_time - timedelta(days=1) + test_snapshot_2 = SnapshotDoc( + owner=test_owner, end_time=snapshot_end_time, start_time=snapshot_start_time + ) + await test_snapshot_2.save() + assert test_snapshot_2 not in updated_ticket.snapshots + + # Use tag_matching() to tag the ticket with the new snapshot ID + await TicketDoc.tag_matching( + existing_snapshot_oids=[test_snapshot_1.id], + new_snapshot_oid=test_snapshot_2.id, + ) + + updated_ticket = await TicketDoc.find_one(TicketDoc.ip_int == VALID_IP_1_INT) + # I'm not using fetch_links=True in the find_one() above because I can't get + # it to work correctly. Instead, I'm using fetch_all_links() below. + await updated_ticket.fetch_all_links() + assert len(updated_ticket.snapshots) == 2 + assert test_snapshot_2 in updated_ticket.snapshots + + # Use remove_tag() to remove the test_snapshot_2.id from the ticket + await TicketDoc.remove_tag(snapshot_oid=test_snapshot_2.id) + + updated_ticket = await TicketDoc.find_one(TicketDoc.ip_int == VALID_IP_1_INT) + # I'm not using fetch_links=True in the find_one() above because I can't get + # it to work correctly. Instead, I'm using fetch_all_links() below. + await updated_ticket.fetch_all_links() + assert len(updated_ticket.snapshots) == 1 + assert test_snapshot_2 not in updated_ticket.snapshots + + +async def test_latest_port(): + """Test the latest_port method.""" + ticket_doc = sample_ticket() + ticket_doc.id = PydanticObjectId() + reference_id = PydanticObjectId() + # Add an event with our test reference ID + ticket_doc.add_event( + action=TicketAction.OPENED, + reason="Test reason", + reference=reference_id, + time=utcnow(), + ) + # Add another event without a reference ID + ticket_doc.add_event( + action=TicketAction.VERIFIED, + reason="Test reason", + time=utcnow(), + ) + + # Create a dummy port scan document with the reference ID + mock_doc = AsyncMock() + mock_doc.id = reference_id + + with patch.object(PortScanDoc, "get", return_value=mock_doc): + port = await ticket_doc.latest_port() + assert ( + port.id == reference_id + ), "latest_port did not return the correct port scan" + + +async def test_latest_port_no_references(): + """Test the latest_port method when there are no references.""" + ticket_doc = sample_ticket() + ticket_doc.id = PydanticObjectId() + + with pytest.raises( + Exception, match=("No references found in ticket events: " + str(ticket_doc.id)) + ): + await ticket_doc.latest_port() + + +async def test_latest_port_not_found(): + """Test the latest_port method when the port scan is not found.""" + ticket_doc = sample_ticket() + reference_id = PydanticObjectId() + ticket_doc.add_event( + action=TicketAction.OPENED, + reason="Test reason", + reference=reference_id, + time=utcnow(), + ) + + with pytest.raises(PortScanNotFoundException): + # Mock PortScanDoc.get to return None + with patch.object(PortScanDoc, "get", return_value=None): + await ticket_doc.latest_port() + + +async def test_latest_vuln(): + """Test the latest_vuln method.""" + ticket_doc = sample_ticket() + reference_id = PydanticObjectId() + # Add an event with our test reference ID + ticket_doc.add_event( + action=TicketAction.OPENED, + reason="Test reason", + reference=reference_id, + time=utcnow(), + ) + # Add another event without a reference ID + ticket_doc.add_event( + action=TicketAction.VERIFIED, + reason="Test reason", + time=utcnow(), + ) + + # Create a dummy port scan document with the reference ID + mock_doc = AsyncMock() + mock_doc._id = reference_id + + with patch.object(VulnScanDoc, "get", return_value=mock_doc): + vuln = await ticket_doc.latest_vuln() + assert ( + vuln._id == reference_id + ), "latest_vuln did not return the correct vuln scan" + + +async def test_latest_vuln_no_references(): + """Test the latest_vuln method when there are no references.""" + ticket_doc = sample_ticket() + ticket_doc.id = PydanticObjectId() + + with pytest.raises( + Exception, match=("No references found in ticket events: " + str(ticket_doc.id)) + ): + await ticket_doc.latest_vuln() + + +async def test_latest_vuln_not_found(): + """Test the latest_vuln method when the port scan is not found.""" + ticket_doc = sample_ticket() + reference_id = PydanticObjectId() + ticket_doc.add_event( + action=TicketAction.OPENED, + reason="Test reason", + reference=reference_id, + time=utcnow(), + ) + + with pytest.raises(VulnScanNotFoundException): + # Mock VulnScanDoc.get to return None + with patch.object(VulnScanDoc, "get", return_value=None): + await ticket_doc.latest_vuln()