Skip to content

Commit

Permalink
refactor: Add store_custom_services
Browse files Browse the repository at this point in the history
  • Loading branch information
Ramimashkouk committed Oct 11, 2024
1 parent ca9c92f commit 0b888b5
Show file tree
Hide file tree
Showing 9 changed files with 80 additions and 45 deletions.
33 changes: 17 additions & 16 deletions backend/chatsky_ui/schemas/front_graph_components/interface.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
from pydantic import model_validator, RootModel
from pydantic import Field, model_validator
from typing import Any

from .base_component import BaseComponent
from typing import Optional, Dict

class Interface(BaseComponent):
telegram: Optional[Dict[str, Any]] = Field(default=None)
cli: Optional[Dict[str, Any]] = Field(default=None)

class Interface(BaseComponent, RootModel):
@model_validator(mode="before")
def validate_interface(cls, v):
if not isinstance(v, dict):
raise ValueError('interface must be a dictionary')
if "telegram" in v:
if not isinstance(v['telegram'], dict):
raise ValueError('telegram must be a dictionary')
if 'token' not in v['telegram'] or not isinstance(v['telegram']['token'], str):
raise ValueError('telegram dictionary must contain a string token')
elif "cli" in v:
pass
else:
raise ValueError('interface must contain either telegram or cli')
return v
@model_validator(mode='after')
def check_one_not_none(cls, values):
telegram, cli = values.telegram, values.cli
if (telegram is None) == (cli is None):
raise ValueError('Exactly one of "telegram" or "cli" must be provided.')
return values

@model_validator(mode='after')
def check_telegram_token(cls, values):
if values.telegram is not None and 'token' not in values.telegram:
raise ValueError('Telegram token must be provided.')
return values
2 changes: 1 addition & 1 deletion backend/chatsky_ui/schemas/front_graph_components/slot.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class Slot(BaseComponent):
class RegexpSlot(Slot):
id: str
regexp: str
match_group_idx: Optional[int]
match_group_idx: int


class GroupSlot(Slot):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,9 @@ def __init__(self, interface: dict):
self.interface = Interface(**interface)

def _convert(self):
return self.interface.model_dump()
if self.interface.cli is not None:
return {"chatsky.messengers.console.CLIMessengerInterface": {}}
elif self.interface.telegram is not None:
return {
"chatsky.messengers.telegram.LongpollingInterface": {"token": self.interface.telegram["token"]}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from ..consts import CUSTOM_FILE, CONDITIONS_FILE
from ..base_converter import BaseConverter
from ....schemas.front_graph_components.info_holders.condition import CustomCondition, SlotCondition
from ....core.config import settings
from .service_replacer import store_custom_service


class ConditionConverter(BaseConverter, ABC):
Expand All @@ -19,14 +21,8 @@ def __init__(self, condition: dict):
code=condition["data"]["python"]["action"],
)

def _parse_code(self):
condition_code = next(iter(ast.parse(self.condition.code).body))

if not isinstance(condition_code, ast.ClassDef):
raise ValueError("Condition python code is not a ClassDef")
return condition_code

def _convert(self):
store_custom_service(settings.conditions_path, [self.condition.code])
custom_cnd = {
f"{CUSTOM_FILE}.{CONDITIONS_FILE}.{self.condition.name}": None
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from ..base_converter import BaseConverter
from ....schemas.front_graph_components.info_holders.response import TextResponse, CustomResponse
from ..consts import CUSTOM_FILE, RESPONSES_FILE
from ....core.config import settings
from .service_replacer import store_custom_service


class ResponseConverter(BaseConverter):
Expand All @@ -26,21 +28,13 @@ def _convert(self):

class CustomResponseConverter(ResponseConverter):
def __init__(self, response: dict):
# self.code =
self.response = CustomResponse(
name=response["name"],
code=next(iter(response["data"]))["python"]["action"],
)

def _parse_code(self):
response_code = next(iter(ast.parse(self.response.code).body))

if not isinstance(response_code, ast.ClassDef):
raise ValueError("Response python code is not a ClassDef")
return response_code

def _convert(self):
store_custom_service(settings.responses_path, [self.response.code])
return {
f"{CUSTOM_FILE}.{RESPONSES_FILE}.{self.response.name}": None
}

Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import ast
from ast import NodeTransformer
from typing import Dict, List
from pathlib import Path

from chatsky_ui.core.logger_config import get_logger

Expand All @@ -13,17 +14,18 @@ def __init__(self, new_services: List[str]):

def _get_classes_def(self, services_code: List[str]) -> Dict[str, ast.ClassDef]:
parsed_codes = [ast.parse(service_code) for service_code in services_code]
result_nodes = {}
for idx, parsed_code in enumerate(parsed_codes):
self._extract_class_defs(parsed_code, result_nodes, services_code[idx])
return result_nodes
classes = self._extract_class_defs(parsed_code, services_code[idx])
return classes

def _extract_class_defs(self, parsed_code: ast.Module, result_nodes: Dict[str, ast.ClassDef], service_code: str):
def _extract_class_defs(self, parsed_code: ast.Module, service_code: str):
classes = {}
for node in parsed_code.body:
if isinstance(node, ast.ClassDef):
result_nodes[node.name] = node
classes[node.name] = node
else:
logger.error("No class definition found in new_service: %s", service_code)
return classes

def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
logger.debug("Visiting class '%s' and comparing with: %s", node.name, self.new_services_classes.keys())
Expand All @@ -46,3 +48,14 @@ def _append_new_services(self, node: ast.Module):
logger.info("Services not found, appending new services: %s", list(self.new_services_classes.keys()))
for _, service in self.new_services_classes.items():
node.body.append(service)


def store_custom_service(services_path: Path, services: List[str]):
with open(services_path, "r", encoding="UTF-8") as file:
conditions_tree = ast.parse(file.read())

replacer = ServiceReplacer(services)
replacer.visit(conditions_tree)

with open(services_path, "w") as file:
file.write(ast.unparse(conditions_tree))
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@ def __init__(self, pipeline_id: int):

def __call__(self, input_file: Path, output_dir: Path):
self.from_yaml(file_path=input_file)

self.pipeline = Pipeline(**self.graph)
self.converted_pipeline = super().__call__()

self.to_yaml(dir_path=output_dir)

def from_yaml(self, file_path: Path):
Expand All @@ -36,11 +38,15 @@ def to_yaml(self, dir_path: Path):

def _convert(self):
slots_converter = SlotsConverter(self.pipeline.flows)
script_converter = ScriptConverter(self.pipeline.flows)

slots_conf = slots_converter.map_slots()
start_label, fallback_label = script_converter.extract_start_fallback_labels()

return {
"script": ScriptConverter(self.pipeline.flows)(slots_conf=slots_conf),
"interface": InterfaceConverter(self.pipeline.interface)(),
"script": script_converter(slots_conf=slots_conf),
"messenger_interface": InterfaceConverter(self.pipeline.interface)(),
"slots": slots_converter(),
# "start_label": self.script.get_start_label(),
# "fallback_label": self.script.get_fallback_label(),
"start_label": start_label,
"fallback_label": fallback_label,
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,23 @@ def _map_flows(self):
mapped_flows[flow["name"]][node["id"]] = node
return mapped_flows

def extract_start_fallback_labels(self): #TODO: refactor this huge method
start_label, fallback_label = None, None

for flow in self.script.flows:
for node in flow["data"]["nodes"]:
flags = node["data"]["flags"]

if "start" in flags:
if start_label:
raise ValueError("Multiple start nodes found")
start_label = [flow["name"], node["data"]["name"]]
if "fallback" in flags:
if fallback_label:
raise ValueError("Multiple fallback nodes found")
fallback_label = [flow["name"], node["data"]["name"]]

if start_label and fallback_label:
return start_label, fallback_label

return None, None
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,13 @@ def __init__(self, slot: dict):
id=slot["id"],
name=slot["name"],
regexp=slot["value"],
match_group_idx=slot.get("match_group_idx", None),
match_group_idx=slot.get("match_group_idx", 1),
)

def _convert(self):
return {
self.slot.name: {
"chatksy.slots.RegexpSlot": {
"chatsky.slots.RegexpSlot": {
"regexp": self.slot.regexp,
"match_group_idx": self.slot.match_group_idx,
}
Expand Down

0 comments on commit 0b888b5

Please sign in to comment.