Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement basic consumer push model #78098

Merged
merged 11 commits into from
Oct 1, 2024
56 changes: 52 additions & 4 deletions src/sentry/runner/commands/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,10 +256,39 @@ def worker(ignore_unknown_queues: bool, **options: Any) -> None:
)
@click.option("--autoreload", is_flag=True, default=False, help="Enable autoreloading.")
@click.option("--max-tasks-per-child", default=10000)
@click.option("--port", "-P", help=("The port number the sever runs on"), default=50051)
@log_options()
@configuration
def taskworker(**options: Any) -> None:
from sentry.taskworker.worker import Worker
def taskworker_push(port: int, **options: Any) -> None:
from sentry.taskworker.worker_push import serve

with managed_bgtasks(role="taskworker"):
serve(port, **options)


@run.command()
@click.option(
"--hostname",
"-n",
help=(
"Set custom hostname, e.g. 'w1.%h'. Expands: %h" "(hostname), %n (name) and %d, (domain)."
),
)
@click.option(
"--namespace",
"-N",
help=(
"The task namespace, or namespaces to consume from. "
"Can be a comma separated list, or * "
"Example: -N video,image"
),
)
@click.option("--autoreload", is_flag=True, default=False, help="Enable autoreloading.")
@click.option("--max-tasks-per-child", default=10000)
@log_options()
@configuration
def taskworker_pull(**options: Any) -> None:
from sentry.taskworker.worker_pull import Worker

with managed_bgtasks(role="taskworker"):
worker = Worker(
Expand All @@ -269,11 +298,30 @@ def taskworker(**options: Any) -> None:
raise SystemExit(worker.exitcode)


@run.command()
@click.option(
"--worker-addrs",
"-W",
help=(
"The address of the workers, in the form of <IP>:<PORT>. "
"Can be a comma separated list"
"Example: -W 127.0.0.1:50051,127.0.0.1:50052"
),
)
@log_options()
@configuration
def kafka_task_grpc_push(worker_addrs: str) -> None:
from sentry.taskworker.consumer_grpc_push import start

with managed_bgtasks(role="taskworker"):
start(worker_addrs.split(","))


@run.command()
@log_options()
@configuration
def kafka_task_grpc_server(**options: Any) -> None:
from sentry.taskworker.grpc_server import serve
def kafka_task_grpc_pull(**options: Any) -> None:
from sentry.taskworker.consumer_grpc_pull import serve

with managed_bgtasks(role="taskworker"):
serve()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
"""
This module is a gRPC server that listens for task requests from the taskworker.
"""

import logging
from concurrent.futures import ThreadPoolExecutor

Expand Down
93 changes: 93 additions & 0 deletions src/sentry/taskworker/consumer_grpc_push.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""
This module is gRPC client that pushes tasks to the taskworker.
"""

import logging
import time
from collections import deque
from collections.abc import Iterable
from concurrent.futures import FIRST_COMPLETED, wait
from concurrent.futures.thread import ThreadPoolExecutor

import grpc
from sentry_protos.sentry.v1alpha.taskworker_pb2 import (
TASK_ACTIVATION_STATUS_PENDING,
DispatchRequest,
InflightActivation,
)
from sentry_protos.sentry.v1alpha.taskworker_pb2_grpc import WorkerServiceStub

from sentry.taskworker.pending_task_store import PendingTaskStore

logger = logging.getLogger("sentry.taskworker.grpc_server")


class ConsumerGrpc:
def __init__(self, worker_addrs: Iterable[str]) -> None:
self.pending_task_store = PendingTaskStore()
self.available_stubs = deque(
[WorkerServiceStub(grpc.insecure_channel(worker_addr)) for worker_addr in worker_addrs]
)
self.current_connections = set()

def start(self):
with ThreadPoolExecutor(max_workers=len(self.available_stubs)) as executor:
logger.info("Starting consumer grpc with %s threads", len(self.available_stubs))
while True:
if len(self.available_stubs) == 0:
done, not_done = wait(self.current_connections, return_when=FIRST_COMPLETED)
self.available_stubs.extend([future.result() for future in done])
self.current_connections = not_done

self.current_connections.add(
executor.submit(
self._dispatch_activation,
self.available_stubs.popleft(),
self._poll_pending_task(),
)
)

def _poll_pending_task(self) -> InflightActivation:
while True:
inflight_activation = self.pending_task_store.get_pending_task()
if inflight_activation:
logger.info("Polled task %s", inflight_activation.activation.id)
return inflight_activation
logger.info("No tasks")
time.sleep(1)

def _dispatch_activation(
self,
stub: WorkerServiceStub,
inflight_activation: InflightActivation,
) -> WorkerServiceStub:
try:
timeout_in_sec = inflight_activation.processing_deadline.seconds - time.time()
dispatch_task_response = stub.Dispatch(
DispatchRequest(task_activation=inflight_activation.activation),
timeout=timeout_in_sec,
)
Copy link
Member Author

@john-z-yang john-z-yang Sep 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking about this a bit more, a timeout might be better handled on server (worker) side instead of client side.

Here's my reasoning:
I think when a task times out, this should be considered as if it has thrown an exception, because it is not a platform problem (like failing to connect to a worker), but an issue with the execution of the task itself. So it should go into the same flow in the worker that determines the next state of the activation (here) instead of requeuing the task into the store like what we're doing right now.

@enochtangg what do you think?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed offline: We will move deadline timeout handling to the worker side. To do this, we will be evaluating multiprocessing pool to handle interrupts.

self.pending_task_store.set_task_status(
task_id=inflight_activation.activation.id,
task_status=dispatch_task_response.status,
)
except grpc.RpcError as rpc_error:
logger.exception(
"gRPC failed, code: %s, details: %s",
rpc_error.code(),
rpc_error.details(),
)
self.pending_task_store.set_task_status(
task_id=inflight_activation.activation.id,
task_status=TASK_ACTIVATION_STATUS_PENDING,
)
self.pending_task_store.set_task_deadline(
task_id=inflight_activation.activation.id, task_deadline=None
)
time.sleep(1)
return stub


def start(worker_addrs: Iterable[str]):
consumer_grpc = ConsumerGrpc(worker_addrs)
consumer_grpc.start()
15 changes: 13 additions & 2 deletions src/sentry/taskworker/pending_task_store.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from collections.abc import Sequence
from datetime import timedelta
from datetime import datetime, timedelta

from django.db.models import Max
from django.utils import timezone
Expand Down Expand Up @@ -33,7 +33,7 @@ def get_pending_task(self) -> InflightActivation | None:
return None

# TODO this duration should be a tasknamespace setting, or with an option
deadline = task.added_at + timedelta(minutes=3)
deadline = datetime.now() + timedelta(minutes=3)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The processing deadline should be the timestamp at which the task was pulled out of the datastore to be sent/received to/by the worker plus some adjustable duration. The added_at timestamp is first captured when the message was read from kafka and inserted to the datastore. In the scenario where the worker does not pick up the task before task.added_at + <adjustable duration>, then the task can never be completed. More simply, the current time will always be ahead of processing_deadline.


task.update(
status=InflightActivationModel.Status.PROCESSING, processing_deadline=deadline
Expand All @@ -55,6 +55,17 @@ def set_task_status(self, task_id: str, task_status: TaskActivationStatus.ValueT
if task_status == InflightActivationModel.Status.RETRY:
task.update(retry_attempts=task.retry_attempts + 1)

def set_task_deadline(self, task_id: str, task_deadline: datetime | None):
from django.db import router, transaction

from sentry.taskworker.models import InflightActivationModel

with transaction.atomic(using=router.db_for_write(InflightActivationModel)):
# Pull a select for update here to lock the row while we mutate the retry count
task = InflightActivationModel.objects.select_for_update().filter(id=task_id).get()

task.update(deadline=task_deadline)

def handle_retry_state_tasks(self) -> None:
from sentry.taskworker.config import taskregistry
from sentry.taskworker.models import InflightActivationModel
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ def process_tasks(self, namespace: TaskNamespace) -> None:
task_latency = execution_time - task_added_time

# Dump results to a log file that is CSV shaped
result_logger.info(f"task.complete,{task_added_time},{execution_time},{task_latency}")
result_logger.info(
"task.complete, %s, %s, %s", task_added_time, execution_time, task_latency
)

if next_state == TASK_ACTIVATION_STATUS_COMPLETE:
logger.info(
Expand Down
87 changes: 87 additions & 0 deletions src/sentry/taskworker/worker_push.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from __future__ import annotations

import logging
import time
from concurrent import futures

import grpc
import orjson
from django.conf import settings
from sentry_protos.sentry.v1alpha.taskworker_pb2 import (
TASK_ACTIVATION_STATUS_COMPLETE,
TASK_ACTIVATION_STATUS_FAILURE,
TASK_ACTIVATION_STATUS_RETRY,
DispatchRequest,
DispatchResponse,
)
from sentry_protos.sentry.v1alpha.taskworker_pb2_grpc import (
WorkerServiceServicer as BaseWorkerServiceServicer,
)
from sentry_protos.sentry.v1alpha.taskworker_pb2_grpc import add_WorkerServiceServicer_to_server

from sentry.taskworker.config import TaskNamespace, taskregistry

logger = logging.getLogger("sentry.taskworker")
result_logger = logging.getLogger("taskworker.results")


class WorkerServicer(BaseWorkerServiceServicer):
__namespace: TaskNamespace | None = None

def __init__(self, **options) -> None:
super().__init__()
self.options = options
self.do_imports()

@property
def namespace(self) -> TaskNamespace:
if self.__namespace:
return self.__namespace

name = self.options["namespace"]
self.__namespace = taskregistry.get(name)
return self.__namespace

def do_imports(self) -> None:
for module in settings.TASKWORKER_IMPORTS:
__import__(module)

def Dispatch(self, request: DispatchRequest, _) -> DispatchResponse:
activation = request.task_activation
try:
task_meta = self.namespace.get(activation.taskname)
except KeyError:
logger.exception("Could not resolve task with name %s", activation.taskname)
return

# TODO: Check idempotency
task_added_time = activation.received_at.seconds
execution_time = time.time()
next_state = TASK_ACTIVATION_STATUS_FAILURE
try:
task_data_parameters = orjson.loads(activation.parameters)
task_meta(*task_data_parameters["args"], **task_data_parameters["kwargs"])
next_state = TASK_ACTIVATION_STATUS_COMPLETE
except Exception as err:
logger.info("taskworker.task_errored", extra={"error": str(err)})
# TODO check retry policy
if task_meta.should_retry(activation.retry_state, err):
logger.info("taskworker.task.retry", extra={"task": activation.taskname})
next_state = TASK_ACTIVATION_STATUS_RETRY
task_latency = execution_time - task_added_time

# Dump results to a log file that is CSV shaped
result_logger.info(
"task.complete, %s, %s, %s", task_added_time, execution_time, task_latency
)

return DispatchResponse(status=next_state)


def serve(port: int, **options):
logger.info("Starting server on: %s", port)
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
add_WorkerServiceServicer_to_server(WorkerServicer(**options), server)
server.add_insecure_port(f"[::]:{port}")
server.start()
server.wait_for_termination()
Loading