Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow task and fleet updates over ROS 2 #1003

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 138 additions & 0 deletions packages/api-server/api_server/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
import asyncio
import base64
import hashlib
import json
import logging
from datetime import datetime
from typing import Any, cast
from uuid import uuid4

import rclpy
import rclpy.client
Expand Down Expand Up @@ -35,19 +37,27 @@
from rmf_task_msgs.srv import SubmitTask as RmfSubmitTask
from rosidl_runtime_py.convert import message_to_ordereddict
from std_msgs.msg import Bool as BoolMsg
from std_msgs.msg import String as StringMsg
from tortoise.exceptions import IntegrityError

from api_server.exceptions import AlreadyExistsError, InvalidInputError, NotFoundError
from api_server.fast_io.singleton_dep import singleton_dep
from api_server.logging import default_logger
from api_server.models.user import User
from api_server.repositories.alerts import AlertRepository
from api_server.repositories.cached_files import get_cached_file_repo
from api_server.repositories.fleets import FleetRepository
from api_server.repositories.rmf import RmfRepository
from api_server.repositories.tasks import TaskRepository
from api_server.rmf_io.events import (
AlertEvents,
FleetEvents,
RmfEvents,
TaskEvents,
get_alert_events,
get_fleet_events,
get_rmf_events,
get_task_events,
)
from api_server.ros import get_ros_node

Expand All @@ -57,11 +67,17 @@
BeaconState,
BuildingMap,
DeliveryAlert,
DispatchStatus,
DispenserState,
DoorState,
FireAlarmTriggerState,
FleetLog,
FleetState,
IngestorState,
LiftState,
TaskEventLog,
TaskState,
TaskStatus,
)
from .repositories import CachedFilesRepository

Expand All @@ -75,6 +91,10 @@ def __init__(
alert_repo: AlertRepository,
rmf_events: RmfEvents,
rmf_repo: RmfRepository,
task_events: TaskEvents,
task_repo: TaskRepository,
fleet_events: FleetEvents,
fleet_repo: FleetRepository,
loop: asyncio.AbstractEventLoop,
*,
logger: logging.Logger | None = None,
Expand All @@ -85,6 +105,10 @@ def __init__(
self._alert_repo = alert_repo
self._rmf_events = rmf_events
self._rmf_repo = rmf_repo
self._task_events = task_events
self._task_repo = task_repo
self._fleet_events = fleet_events
self._fleet_repo = fleet_repo
self._loop = loop
self._logger = logger or logging.getLogger()

Expand Down Expand Up @@ -474,6 +498,116 @@ def handle_fire_alarm_trigger(msg):
)
self._subscriptions.append(fire_alarm_trigger_sub)

def handle_task_state_update(msg):
msg = cast(StringMsg, msg)
json_msg = json.loads(msg.data)

async def save(task_state: TaskState):
await self._task_repo.save_task_state(task_state)
self._task_events.task_states.on_next(task_state)

task_state = TaskState.model_validate(json_msg["data"])
self._loop.create_task(save(task_state))

async def save_alert(alert_request: AlertRequest):
try:
created_alert = await self._alert_repo.create_new_alert(
alert_request
)
except AlreadyExistsError as e:
self._logger.error(e)
return
self._alert_events.alert_requests.on_next(created_alert)

if task_state.status == TaskStatus.completed:
alert_request = AlertRequest(
id=str(uuid4()),
unix_millis_alert_time=round(datetime.now().timestamp() * 1000),
title="Task completed",
subtitle=f"ID: {task_state.booking.id}",
message="",
display=True,
tier=AlertRequest.Tier.Info,
responses_available=["Acknowledge"],
alert_parameters=[],
task_id=task_state.booking.id,
)
self._loop.create_task(save_alert(alert_request))
elif task_state.status == TaskStatus.failed:
errorMessage = ""
if (
task_state.dispatch is not None
and task_state.dispatch.status == DispatchStatus.failed_to_assign
):
errorMessage += "Failed to assign\n"
if task_state.dispatch.errors is not None:
for error in task_state.dispatch.errors:
errorMessage += error.json() + "\n"

alert_request = AlertRequest(
id=str(uuid4()),
unix_millis_alert_time=round(datetime.now().timestamp() * 1000),
title="Task failed",
subtitle=f"ID: {task_state.booking.id}",
message=errorMessage,
display=True,
tier=AlertRequest.Tier.Error,
responses_available=["Acknowledge"],
alert_parameters=[],
task_id=task_state.booking.id,
)
self._loop.create_task(save_alert(alert_request))

task_state_update_sub = self._ros_node.create_subscription(
StringMsg, "task_state_update", handle_task_state_update, 10
)
self._subscriptions.append(task_state_update_sub)

def handle_task_log_update(msg):
msg = cast(StringMsg, msg)
json_msg = json.loads(msg.data)

async def save(task_event_log: TaskEventLog):
await self._task_repo.save_task_log(task_event_log)
self._task_events.task_event_logs.on_next(task_event_log)

self._loop.create_task(save(TaskEventLog.model_validate(json_msg["data"])))

task_log_update_sub = self._ros_node.create_subscription(
StringMsg, "task_log_update", handle_task_log_update, 10
)
self._subscriptions.append(task_log_update_sub)

def handle_fleet_state_update(msg):
msg = cast(StringMsg, msg)
json_msg = json.loads(msg.data)

async def save(fleet_state: FleetState):
await self._fleet_repo.save_fleet_state(fleet_state)
self._fleet_events.fleet_states.on_next(fleet_state)

self._loop.create_task(save(FleetState.model_validate(json_msg["data"])))

fleet_state_update_sub = self._ros_node.create_subscription(
StringMsg, "fleet_state_update", handle_fleet_state_update, 10
)
self._subscriptions.append(fleet_state_update_sub)

def handle_fleet_log_update(msg):
msg = cast(StringMsg, msg)
json_msg = json.loads(msg.data)

async def save(fleet_log: FleetLog):
await self._fleet_repo.save_fleet_log(fleet_log)
self._fleet_events.fleet_logs.on_next(fleet_log)

self._loop.create_task(save(FleetLog.model_validate(json_msg["data"])))

fleet_log_update_sub = self._ros_node.create_subscription(
StringMsg, "fleet_log_update", handle_fleet_log_update, 10
)
self._subscriptions.append(fleet_log_update_sub)

Comment on lines +501 to +610
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please refactor so both the internal route and this share the same code.

async def __aexit__(self, *exc):
for sub in self._subscriptions:
sub.destroy()
Expand Down Expand Up @@ -562,5 +696,9 @@ def get_rmf_gateway():
AlertRepository(),
get_rmf_events(),
RmfRepository(User.get_system_user()),
get_task_events(),
TaskRepository(User.get_system_user(), default_logger),
get_fleet_events(),
FleetRepository(User.get_system_user(), default_logger),
asyncio.get_event_loop(),
)
Loading