Skip to content

Commit

Permalink
enhancement(cli/consume): improve thread exception handling in cli co…
Browse files Browse the repository at this point in the history
…nsumer (#5118)

* enhancement(cli/consume): improve thread exception handling in cli consumer

* enhancement(cli/consume): improve thread exception handling in cli consumer
  • Loading branch information
wssheldon authored Aug 22, 2024
1 parent c58d2b3 commit 6ede625
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 51 deletions.
155 changes: 124 additions & 31 deletions src/dispatch/cli.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import logging
import os
from threading import Thread, Event
import time
from typing import Any

import click
import uvicorn
Expand Down Expand Up @@ -791,52 +794,124 @@ def signals_group():
pass


def _run_consume(plugin_slug: str, organization_slug: str, project_id: int, running: bool):
def _run_consume(
plugin_slug: str,
organization_slug: str,
project_id: int,
running: Event,
) -> None:
"""
Runs the consume method of a plugin instance.
Args:
plugin_slug (str): The slug of the plugin to run.
organization_slug (str): The slug of the organization.
project_id (int): The ID of the project.
running (Event): An event to signal when the thread should stop running.
Returns:
None
"""
from dispatch.database.core import get_organization_session
from dispatch.plugin import service as plugin_service
from dispatch.project import service as project_service
from dispatch.common.utils.cli import install_plugins

install_plugins()

with get_organization_session(organization_slug) as session:
plugin = plugin_service.get_active_instance_by_slug(
db_session=session, slug=plugin_slug, project_id=project_id
)
project = project_service.get(db_session=session, project_id=project_id)
while True:
if not running:
break

while running.is_set():
plugin.instance.consume(db_session=session, project=project)


def _run_consume_with_exception_handling(
plugin_slug: str,
organization_slug: str,
project_id: int,
running: Event,
) -> None:
"""
Runs the consume method of a plugin instance with exception handling.
Args:
plugin_slug (str): The slug of the plugin to run.
organization_slug (str): The slug of the organization.
project_id (int): The ID of the project.
running (Event): An event to signal when the thread should stop running.
Returns:
None
"""
while running.is_set():
try:
_run_consume(plugin_slug, organization_slug, project_id, running)
except Exception as e:
log.error(f"Exception in thread for plugin {plugin_slug}: {e}", exc_info=True)
time.sleep(1) # Optional: Add a small delay before retrying


def _create_consumer_thread(
plugin_slug: str,
organization_slug: str,
project_id: int,
running: Event,
) -> Thread:
"""
Creates a new consumer thread for a plugin.
Args:
plugin_slug (str): The slug of the plugin to run.
organization_slug (str): The slug of the organization.
project_id (int): The ID of the project.
running (Event): An event to signal when the thread should stop running.
Returns:
Thread: A new daemon thread that will run the plugin's consume method.
"""
return Thread(
target=_run_consume_with_exception_handling,
args=(
plugin_slug,
organization_slug,
project_id,
running,
),
daemon=True,
)


@signals_group.command("consume")
def consume_signals():
"""Runs a continuous process that consumes signals from the specified plugin."""
import time
from threading import Thread, Event
import logging
"""
Runs a continuous process that consumes signals from the specified plugins.
This function sets up consumer threads for all active signal-consumer plugins
across all organizations and projects. It monitors these threads and restarts
them if they die. The process can be terminated using SIGINT or SIGTERM.
Returns:
None
"""
import signal
from types import FrameType

from dispatch.common.utils.cli import install_plugins
from dispatch.project import service as project_service
from dispatch.plugin import service as plugin_service

from dispatch.organization.service import get_all as get_all_organizations
from dispatch.database.core import get_session, get_organization_session

install_plugins()
with get_session() as session:
organizations = get_all_organizations(db_session=session)

log = logging.getLogger(__name__)

# Replace manager dictionary with an Event
worker_configs: list[dict[str, Any]] = []
workers: list[Thread] = []
running = Event()
running.set()

workers = []
with get_session() as session:
organizations = get_all_organizations(db_session=session)

for organization in organizations:
with get_organization_session(organization.slug) as session:
Expand All @@ -850,33 +925,51 @@ def consume_signals():
log.warning(
f"No signals consumed. No signal-consumer plugins enabled. Project: {project.name}. Organization: {project.organization.name}"
)
continue

for plugin in plugins:
log.debug(f"Consuming signals for plugin: {plugin.plugin.slug}")
for _ in range(5): # TODO add plugin.instance.concurrency
t = Thread(
target=_run_consume,
args=(plugin.plugin.slug, organization.slug, project.id, running),
daemon=True, # Set thread to daemon
)
worker_config = {
"plugin_slug": plugin.plugin.slug,
"organization_slug": organization.slug,
"project_id": project.id,
}
worker_configs.append(worker_config)
t = _create_consumer_thread(**worker_config, running=running)
t.start()
workers.append(t)

def terminate_processes(signum, frame):
print("Terminating main process...")
def terminate_processes(signum: int, frame: FrameType) -> None:
"""
Signal handler to terminate all processes.
Args:
signum (int): The signal number.
frame (FrameType): The current stack frame.
Returns:
None
"""
log.info("Terminating main process...")
running.clear() # stop all threads
for worker in workers:
worker.join()

signal.signal(signal.SIGINT, terminate_processes)
signal.signal(signal.SIGTERM, terminate_processes)

# Keep the main thread running
while True:
if not running.is_set():
print("Main process terminating.")
break
time.sleep(1)
while running.is_set():
for i, worker in enumerate(workers):
if not worker.is_alive():
log.warning(f"Thread {i} died. Restarting...")
config = worker_configs[i]
new_worker = _create_consumer_thread(**config, running=running)
new_worker.start()
workers[i] = new_worker
time.sleep(1) # Check every second

log.info("Main process terminating.")


@signals_group.command("process")
Expand Down
32 changes: 17 additions & 15 deletions src/dispatch/plugin/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from pydantic.error_wrappers import ErrorWrapper, ValidationError
from typing import List, Optional

from sqlalchemy.orm import Session

from dispatch.exceptions import InvalidConfigurationError
from dispatch.plugins.bases import OncallPlugin
from dispatch.project import service as project_service
Expand All @@ -20,12 +22,12 @@
log = logging.getLogger(__name__)


def get(*, db_session, plugin_id: int) -> Optional[Plugin]:
def get(*, db_session: Session, plugin_id: int) -> Optional[Plugin]:
"""Returns a plugin based on the given plugin id."""
return db_session.query(Plugin).filter(Plugin.id == plugin_id).one_or_none()


def get_by_slug(*, db_session, slug: str) -> Plugin:
def get_by_slug(*, db_session: Session, slug: str) -> Plugin:
"""Fetches a plugin by slug."""
return db_session.query(Plugin).filter(Plugin.slug == slug).one_or_none()

Expand All @@ -35,12 +37,12 @@ def get_all(*, db_session) -> List[Optional[Plugin]]:
return db_session.query(Plugin).all()


def get_by_type(*, db_session, plugin_type: str) -> List[Optional[Plugin]]:
def get_by_type(*, db_session: Session, plugin_type: str) -> List[Optional[Plugin]]:
"""Fetches all plugins for a given type."""
return db_session.query(Plugin).filter(Plugin.type == plugin_type).all()


def get_instance(*, db_session, plugin_instance_id: int) -> Optional[PluginInstance]:
def get_instance(*, db_session: Session, plugin_instance_id: int) -> Optional[PluginInstance]:
"""Returns a plugin instance based on the given instance id."""
return (
db_session.query(PluginInstance)
Expand All @@ -50,7 +52,7 @@ def get_instance(*, db_session, plugin_instance_id: int) -> Optional[PluginInsta


def get_active_instance(
*, db_session, plugin_type: str, project_id=None
*, db_session: Session, plugin_type: str, project_id=None
) -> Optional[PluginInstance]:
"""Fetches the current active plugin for the given type."""
return (
Expand All @@ -64,7 +66,7 @@ def get_active_instance(


def get_active_instances(
*, db_session, plugin_type: str, project_id=None
*, db_session: Session, plugin_type: str, project_id=None
) -> Optional[PluginInstance]:
"""Fetches the current active plugin for the given type."""
return (
Expand All @@ -78,7 +80,7 @@ def get_active_instances(


def get_active_instance_by_slug(
*, db_session, slug: str, project_id=None
*, db_session: Session, slug: str, project_id: int | None = None
) -> Optional[PluginInstance]:
"""Fetches the current active plugin for the given type."""
return (
Expand All @@ -92,7 +94,7 @@ def get_active_instance_by_slug(


def get_enabled_instances_by_type(
*, db_session, project_id: int, plugin_type: str
*, db_session: Session, project_id: int, plugin_type: str
) -> List[Optional[PluginInstance]]:
"""Fetches all enabled plugins for a given type."""
return (
Expand All @@ -105,7 +107,7 @@ def get_enabled_instances_by_type(
)


def create_instance(*, db_session, plugin_instance_in: PluginInstanceCreate) -> PluginInstance:
def create_instance(*, db_session: Session, plugin_instance_in: PluginInstanceCreate) -> PluginInstance:
"""Creates a new plugin instance."""
project = project_service.get_by_name_or_raise(
db_session=db_session, project_in=plugin_instance_in.project
Expand All @@ -124,7 +126,7 @@ def create_instance(*, db_session, plugin_instance_in: PluginInstanceCreate) ->


def update_instance(
*, db_session, plugin_instance: PluginInstance, plugin_instance_in: PluginInstanceUpdate
*, db_session: Session, plugin_instance: PluginInstance, plugin_instance_in: PluginInstanceUpdate
) -> PluginInstance:
"""Updates a plugin instance."""
plugin_instance_data = plugin_instance.dict()
Expand Down Expand Up @@ -169,28 +171,28 @@ def update_instance(
return plugin_instance


def delete_instance(*, db_session, plugin_instance_id: int):
def delete_instance(*, db_session: Session, plugin_instance_id: int):
"""Deletes a plugin instance."""
db_session.query(PluginInstance).filter(PluginInstance.id == plugin_instance_id).delete()
db_session.commit()


def get_plugin_event_by_id(*, db_session, plugin_event_id: int) -> Optional[PluginEvent]:
def get_plugin_event_by_id(*, db_session: Session, plugin_event_id: int) -> Optional[PluginEvent]:
"""Returns a plugin event based on the plugin event id."""
return db_session.query(PluginEvent).filter(PluginEvent.id == plugin_event_id).one_or_none()


def get_plugin_event_by_slug(*, db_session, slug: str) -> Optional[PluginEvent]:
def get_plugin_event_by_slug(*, db_session: Session, slug: str) -> Optional[PluginEvent]:
"""Returns a project based on the plugin event slug."""
return db_session.query(PluginEvent).filter(PluginEvent.slug == slug).one_or_none()


def get_all_events_for_plugin(*, db_session, plugin_id: int) -> List[Optional[PluginEvent]]:
def get_all_events_for_plugin(*, db_session: Session, plugin_id: int) -> List[Optional[PluginEvent]]:
"""Returns all plugin events for a given plugin."""
return db_session.query(PluginEvent).filter(PluginEvent.plugin_id == plugin_id).all()


def create_plugin_event(*, db_session, plugin_event_in: PluginEventCreate) -> PluginEvent:
def create_plugin_event(*, db_session: Session, plugin_event_in: PluginEventCreate) -> PluginEvent:
"""Creates a new plugin event."""
plugin_event = PluginEvent(**plugin_event_in.dict(exclude={"plugin"}))
plugin_event.plugin = get(db_session=db_session, plugin_id=plugin_event_in.plugin.id)
Expand Down
11 changes: 6 additions & 5 deletions src/dispatch/project/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,23 @@
from pydantic.error_wrappers import ErrorWrapper
from dispatch.exceptions import NotFoundError

from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import true

from .models import Project, ProjectCreate, ProjectUpdate, ProjectRead


def get(*, db_session, project_id: int) -> Optional[Project]:
def get(*, db_session: Session, project_id: int) -> Project | None:
"""Returns a project based on the given project id."""
return db_session.query(Project).filter(Project.id == project_id).first()


def get_default(*, db_session) -> Optional[Project]:
def get_default(*, db_session: Session) -> Optional[Project]:
"""Returns the default project."""
return db_session.query(Project).filter(Project.default == true()).one_or_none()


def get_default_or_raise(*, db_session) -> Project:
def get_default_or_raise(*, db_session: Session) -> Project:
"""Returns the default project or raise a ValidationError if one doesn't exist."""
project = get_default(db_session=db_session)

Expand All @@ -36,12 +37,12 @@ def get_default_or_raise(*, db_session) -> Project:
return project


def get_by_name(*, db_session, name: str) -> Optional[Project]:
def get_by_name(*, db_session: Session, name: str) -> Optional[Project]:
"""Returns a project based on the given project name."""
return db_session.query(Project).filter(Project.name == name).one_or_none()


def get_by_name_or_raise(*, db_session, project_in=ProjectRead) -> Project:
def get_by_name_or_raise(*, db_session: Session, project_in=ProjectRead) -> Project:
"""Returns the project specified or raises ValidationError."""
project = get_by_name(db_session=db_session, name=project_in.name)

Expand Down

0 comments on commit 6ede625

Please sign in to comment.