diff --git a/src/dispatch/cli.py b/src/dispatch/cli.py index 855b7f8bb73a..2bb7d4d74e3e 100644 --- a/src/dispatch/cli.py +++ b/src/dispatch/cli.py @@ -1,5 +1,8 @@ import logging import os +from threading import Thread, Event +import time +from typing import Any import click import uvicorn @@ -791,52 +794,124 @@ def signals_group(): pass -def _run_consume(plugin_slug: str, organization_slug: str, project_id: int, running: bool): +def _run_consume( + plugin_slug: str, + organization_slug: str, + project_id: int, + running: Event, +) -> None: + """ + Runs the consume method of a plugin instance. + + Args: + plugin_slug (str): The slug of the plugin to run. + organization_slug (str): The slug of the organization. + project_id (int): The ID of the project. + running (Event): An event to signal when the thread should stop running. + + Returns: + None + """ from dispatch.database.core import get_organization_session from dispatch.plugin import service as plugin_service from dispatch.project import service as project_service - from dispatch.common.utils.cli import install_plugins - - install_plugins() with get_organization_session(organization_slug) as session: plugin = plugin_service.get_active_instance_by_slug( db_session=session, slug=plugin_slug, project_id=project_id ) project = project_service.get(db_session=session, project_id=project_id) - while True: - if not running: - break + + while running.is_set(): plugin.instance.consume(db_session=session, project=project) +def _run_consume_with_exception_handling( + plugin_slug: str, + organization_slug: str, + project_id: int, + running: Event, +) -> None: + """ + Runs the consume method of a plugin instance with exception handling. + + Args: + plugin_slug (str): The slug of the plugin to run. + organization_slug (str): The slug of the organization. + project_id (int): The ID of the project. + running (Event): An event to signal when the thread should stop running. + + Returns: + None + """ + while running.is_set(): + try: + _run_consume(plugin_slug, organization_slug, project_id, running) + except Exception as e: + log.error(f"Exception in thread for plugin {plugin_slug}: {e}", exc_info=True) + time.sleep(1) # Optional: Add a small delay before retrying + + +def _create_consumer_thread( + plugin_slug: str, + organization_slug: str, + project_id: int, + running: Event, +) -> Thread: + """ + Creates a new consumer thread for a plugin. + + Args: + plugin_slug (str): The slug of the plugin to run. + organization_slug (str): The slug of the organization. + project_id (int): The ID of the project. + running (Event): An event to signal when the thread should stop running. + + Returns: + Thread: A new daemon thread that will run the plugin's consume method. + """ + return Thread( + target=_run_consume_with_exception_handling, + args=( + plugin_slug, + organization_slug, + project_id, + running, + ), + daemon=True, + ) + + @signals_group.command("consume") def consume_signals(): - """Runs a continuous process that consumes signals from the specified plugin.""" - import time - from threading import Thread, Event - import logging + """ + Runs a continuous process that consumes signals from the specified plugins. + This function sets up consumer threads for all active signal-consumer plugins + across all organizations and projects. It monitors these threads and restarts + them if they die. The process can be terminated using SIGINT or SIGTERM. + + Returns: + None + """ import signal + from types import FrameType from dispatch.common.utils.cli import install_plugins from dispatch.project import service as project_service from dispatch.plugin import service as plugin_service - from dispatch.organization.service import get_all as get_all_organizations from dispatch.database.core import get_session, get_organization_session install_plugins() - with get_session() as session: - organizations = get_all_organizations(db_session=session) - - log = logging.getLogger(__name__) - # Replace manager dictionary with an Event + worker_configs: list[dict[str, Any]] = [] + workers: list[Thread] = [] running = Event() running.set() - workers = [] + with get_session() as session: + organizations = get_all_organizations(db_session=session) for organization in organizations: with get_organization_session(organization.slug) as session: @@ -850,20 +925,33 @@ def consume_signals(): log.warning( f"No signals consumed. No signal-consumer plugins enabled. Project: {project.name}. Organization: {project.organization.name}" ) + continue for plugin in plugins: log.debug(f"Consuming signals for plugin: {plugin.plugin.slug}") for _ in range(5): # TODO add plugin.instance.concurrency - t = Thread( - target=_run_consume, - args=(plugin.plugin.slug, organization.slug, project.id, running), - daemon=True, # Set thread to daemon - ) + worker_config = { + "plugin_slug": plugin.plugin.slug, + "organization_slug": organization.slug, + "project_id": project.id, + } + worker_configs.append(worker_config) + t = _create_consumer_thread(**worker_config, running=running) t.start() workers.append(t) - def terminate_processes(signum, frame): - print("Terminating main process...") + def terminate_processes(signum: int, frame: FrameType) -> None: + """ + Signal handler to terminate all processes. + + Args: + signum (int): The signal number. + frame (FrameType): The current stack frame. + + Returns: + None + """ + log.info("Terminating main process...") running.clear() # stop all threads for worker in workers: worker.join() @@ -871,12 +959,17 @@ def terminate_processes(signum, frame): signal.signal(signal.SIGINT, terminate_processes) signal.signal(signal.SIGTERM, terminate_processes) - # Keep the main thread running - while True: - if not running.is_set(): - print("Main process terminating.") - break - time.sleep(1) + while running.is_set(): + for i, worker in enumerate(workers): + if not worker.is_alive(): + log.warning(f"Thread {i} died. Restarting...") + config = worker_configs[i] + new_worker = _create_consumer_thread(**config, running=running) + new_worker.start() + workers[i] = new_worker + time.sleep(1) # Check every second + + log.info("Main process terminating.") @signals_group.command("process") diff --git a/src/dispatch/plugin/service.py b/src/dispatch/plugin/service.py index 9334b12167cb..ff36bf5a6a4a 100644 --- a/src/dispatch/plugin/service.py +++ b/src/dispatch/plugin/service.py @@ -2,6 +2,8 @@ from pydantic.error_wrappers import ErrorWrapper, ValidationError from typing import List, Optional +from sqlalchemy.orm import Session + from dispatch.exceptions import InvalidConfigurationError from dispatch.plugins.bases import OncallPlugin from dispatch.project import service as project_service @@ -20,12 +22,12 @@ log = logging.getLogger(__name__) -def get(*, db_session, plugin_id: int) -> Optional[Plugin]: +def get(*, db_session: Session, plugin_id: int) -> Optional[Plugin]: """Returns a plugin based on the given plugin id.""" return db_session.query(Plugin).filter(Plugin.id == plugin_id).one_or_none() -def get_by_slug(*, db_session, slug: str) -> Plugin: +def get_by_slug(*, db_session: Session, slug: str) -> Plugin: """Fetches a plugin by slug.""" return db_session.query(Plugin).filter(Plugin.slug == slug).one_or_none() @@ -35,12 +37,12 @@ def get_all(*, db_session) -> List[Optional[Plugin]]: return db_session.query(Plugin).all() -def get_by_type(*, db_session, plugin_type: str) -> List[Optional[Plugin]]: +def get_by_type(*, db_session: Session, plugin_type: str) -> List[Optional[Plugin]]: """Fetches all plugins for a given type.""" return db_session.query(Plugin).filter(Plugin.type == plugin_type).all() -def get_instance(*, db_session, plugin_instance_id: int) -> Optional[PluginInstance]: +def get_instance(*, db_session: Session, plugin_instance_id: int) -> Optional[PluginInstance]: """Returns a plugin instance based on the given instance id.""" return ( db_session.query(PluginInstance) @@ -50,7 +52,7 @@ def get_instance(*, db_session, plugin_instance_id: int) -> Optional[PluginInsta def get_active_instance( - *, db_session, plugin_type: str, project_id=None + *, db_session: Session, plugin_type: str, project_id=None ) -> Optional[PluginInstance]: """Fetches the current active plugin for the given type.""" return ( @@ -64,7 +66,7 @@ def get_active_instance( def get_active_instances( - *, db_session, plugin_type: str, project_id=None + *, db_session: Session, plugin_type: str, project_id=None ) -> Optional[PluginInstance]: """Fetches the current active plugin for the given type.""" return ( @@ -78,7 +80,7 @@ def get_active_instances( def get_active_instance_by_slug( - *, db_session, slug: str, project_id=None + *, db_session: Session, slug: str, project_id: int | None = None ) -> Optional[PluginInstance]: """Fetches the current active plugin for the given type.""" return ( @@ -92,7 +94,7 @@ def get_active_instance_by_slug( def get_enabled_instances_by_type( - *, db_session, project_id: int, plugin_type: str + *, db_session: Session, project_id: int, plugin_type: str ) -> List[Optional[PluginInstance]]: """Fetches all enabled plugins for a given type.""" return ( @@ -105,7 +107,7 @@ def get_enabled_instances_by_type( ) -def create_instance(*, db_session, plugin_instance_in: PluginInstanceCreate) -> PluginInstance: +def create_instance(*, db_session: Session, plugin_instance_in: PluginInstanceCreate) -> PluginInstance: """Creates a new plugin instance.""" project = project_service.get_by_name_or_raise( db_session=db_session, project_in=plugin_instance_in.project @@ -124,7 +126,7 @@ def create_instance(*, db_session, plugin_instance_in: PluginInstanceCreate) -> def update_instance( - *, db_session, plugin_instance: PluginInstance, plugin_instance_in: PluginInstanceUpdate + *, db_session: Session, plugin_instance: PluginInstance, plugin_instance_in: PluginInstanceUpdate ) -> PluginInstance: """Updates a plugin instance.""" plugin_instance_data = plugin_instance.dict() @@ -169,28 +171,28 @@ def update_instance( return plugin_instance -def delete_instance(*, db_session, plugin_instance_id: int): +def delete_instance(*, db_session: Session, plugin_instance_id: int): """Deletes a plugin instance.""" db_session.query(PluginInstance).filter(PluginInstance.id == plugin_instance_id).delete() db_session.commit() -def get_plugin_event_by_id(*, db_session, plugin_event_id: int) -> Optional[PluginEvent]: +def get_plugin_event_by_id(*, db_session: Session, plugin_event_id: int) -> Optional[PluginEvent]: """Returns a plugin event based on the plugin event id.""" return db_session.query(PluginEvent).filter(PluginEvent.id == plugin_event_id).one_or_none() -def get_plugin_event_by_slug(*, db_session, slug: str) -> Optional[PluginEvent]: +def get_plugin_event_by_slug(*, db_session: Session, slug: str) -> Optional[PluginEvent]: """Returns a project based on the plugin event slug.""" return db_session.query(PluginEvent).filter(PluginEvent.slug == slug).one_or_none() -def get_all_events_for_plugin(*, db_session, plugin_id: int) -> List[Optional[PluginEvent]]: +def get_all_events_for_plugin(*, db_session: Session, plugin_id: int) -> List[Optional[PluginEvent]]: """Returns all plugin events for a given plugin.""" return db_session.query(PluginEvent).filter(PluginEvent.plugin_id == plugin_id).all() -def create_plugin_event(*, db_session, plugin_event_in: PluginEventCreate) -> PluginEvent: +def create_plugin_event(*, db_session: Session, plugin_event_in: PluginEventCreate) -> PluginEvent: """Creates a new plugin event.""" plugin_event = PluginEvent(**plugin_event_in.dict(exclude={"plugin"})) plugin_event.plugin = get(db_session=db_session, plugin_id=plugin_event_in.plugin.id) diff --git a/src/dispatch/project/service.py b/src/dispatch/project/service.py index 62d6944e7c82..15aa24c04f17 100644 --- a/src/dispatch/project/service.py +++ b/src/dispatch/project/service.py @@ -4,22 +4,23 @@ from pydantic.error_wrappers import ErrorWrapper from dispatch.exceptions import NotFoundError +from sqlalchemy.orm import Session from sqlalchemy.sql.expression import true from .models import Project, ProjectCreate, ProjectUpdate, ProjectRead -def get(*, db_session, project_id: int) -> Optional[Project]: +def get(*, db_session: Session, project_id: int) -> Project | None: """Returns a project based on the given project id.""" return db_session.query(Project).filter(Project.id == project_id).first() -def get_default(*, db_session) -> Optional[Project]: +def get_default(*, db_session: Session) -> Optional[Project]: """Returns the default project.""" return db_session.query(Project).filter(Project.default == true()).one_or_none() -def get_default_or_raise(*, db_session) -> Project: +def get_default_or_raise(*, db_session: Session) -> Project: """Returns the default project or raise a ValidationError if one doesn't exist.""" project = get_default(db_session=db_session) @@ -36,12 +37,12 @@ def get_default_or_raise(*, db_session) -> Project: return project -def get_by_name(*, db_session, name: str) -> Optional[Project]: +def get_by_name(*, db_session: Session, name: str) -> Optional[Project]: """Returns a project based on the given project name.""" return db_session.query(Project).filter(Project.name == name).one_or_none() -def get_by_name_or_raise(*, db_session, project_in=ProjectRead) -> Project: +def get_by_name_or_raise(*, db_session: Session, project_in=ProjectRead) -> Project: """Returns the project specified or raises ValidationError.""" project = get_by_name(db_session=db_session, name=project_in.name)