Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Uses a representative sample of related cases for the GenAI analysis #5283

Merged
merged 3 commits into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions src/dispatch/plugins/dispatch_slack/case/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from slack_sdk.web.client import WebClient
from sqlalchemy.orm import Session

from dispatch.case.enums import CaseStatus
from dispatch.case.enums import CaseResolutionReason, CaseStatus
from dispatch.case.models import Case
from dispatch.config import DISPATCH_UI_URL
from dispatch.messaging.strings import CASE_STATUS_DESCRIPTIONS, CASE_VISIBILITY_DESCRIPTIONS
Expand Down Expand Up @@ -320,13 +320,17 @@ def create_genai_signal_analysis_message(
return signal_metadata_blocks

# Fetch related cases
related_cases = (
signal_service.get_cases_for_signal(
db_session=db_session, signal_id=first_instance_signal.id
related_cases = []
for resolution_reason in CaseResolutionReason:
related_cases.extend(
signal_service.get_cases_for_signal_by_resolution_reason(
db_session=db_session,
signal_id=first_instance_signal.id,
resolution_reason=resolution_reason,
)
.from_self() # NOTE: function deprecated in SQLAlchemy 1.4 and removed in 2.0
.filter(Case.id != case.id)
)
.from_self() # NOTE: function deprecated in SQLAlchemy 1.4 and removed in 2.0
.filter(Case.id != case.id)
)

# Prepare historical context
historical_context = []
Expand Down
46 changes: 46 additions & 0 deletions src/dispatch/signal/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,16 @@ def get_unprocessed_signal_instance_ids(session: Session) -> list[int]:


def get_instances_in_case(db_session: Session, case_id: int) -> Query:
"""
Retrieves signal instances associated with a given case.

Args:
db_session (Session): The database session.
case_id (int): The ID of the case.

Returns:
Query: A SQLAlchemy query object for the signal instances associated with the case.
"""
return (
db_session.query(SignalInstance, Signal)
.join(Signal)
Expand All @@ -771,10 +781,46 @@ def get_instances_in_case(db_session: Session, case_id: int) -> Query:


def get_cases_for_signal(db_session: Session, signal_id: int, limit: int = 10) -> Query:
"""
Retrieves cases associated with a given signal.

Args:
db_session (Session): The database session.
signal_id (int): The ID of the signal.
limit (int, optional): The maximum number of cases to retrieve. Defaults to 10.

Returns:
Query: A SQLAlchemy query object for the cases associated with the signal.
"""
return (
db_session.query(Case)
.join(SignalInstance)
.filter(SignalInstance.signal_id == signal_id)
.order_by(desc(Case.created_at))
.limit(limit)
)


def get_cases_for_signal_by_resolution_reason(
db_session: Session, signal_id: int, resolution_reason: str, limit: int = 10
) -> Query:
"""
Retrieves cases associated with a given signal and resolution reason.

Args:
db_session (Session): The database session.
signal_id (int): The ID of the signal.
resolution_reason (str): The resolution reason to filter cases by.
limit (int, optional): The maximum number of cases to retrieve. Defaults to 10.

Returns:
Query: A SQLAlchemy query object for the cases associated with the signal and resolution reason.
"""
return (
db_session.query(Case)
.join(SignalInstance)
.filter(SignalInstance.signal_id == signal_id)
.filter(Case.resolution_reason == resolution_reason)
.order_by(desc(Case.created_at))
.limit(limit)
)
Loading