Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
NotBioWaste905 committed Sep 26, 2024
1 parent f2d6b68 commit 6fddaea
Show file tree
Hide file tree
Showing 11 changed files with 136 additions and 46 deletions.
1 change: 1 addition & 0 deletions chatsky/llm/conditions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from chatsky.llm.methods import BaseMethod


def llm_condition(model_name: str, prompt: str, method: BaseMethod):
"""
Basic function for using LLM in condition cases.
Expand Down
3 changes: 3 additions & 0 deletions chatsky/llm/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class BaseFilter(BaseModel, abc.ABC):
"""
Base class for all message history filters.
"""

@abc.abstractmethod
def __call__(self, ctx: Context, request: Message, response: Message, model_name: str) -> bool:
"""
Expand All @@ -29,6 +30,7 @@ class IsImportant(BaseFilter):
"""
Filter that checks if the "important" field in a Message.misc is True.
"""

def __call__(
self, ctx: Context = None, request: Message = None, response: Message = None, model_name: str = None
) -> bool:
Expand All @@ -43,6 +45,7 @@ class FromTheModel(BaseFilter):
"""
Filter that checks if the message was sent by the model.
"""

def __call__(
self, ctx: Context = None, request: Message = None, response: Message = None, model_name: str = None
) -> bool:
Expand Down
6 changes: 5 additions & 1 deletion chatsky/llm/llm_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from langchain_mistralai import ChatMistralAI
from langchain_core.output_parsers import StrOutputParser
from langchain_core.language_models.chat_models import BaseChatModel

langchain_available = True
except ImportError:
langchain_available = False
Expand Down Expand Up @@ -78,7 +79,10 @@ async def respond(

async def condition(self, prompt: str, method: BaseMethod, return_schema=None):
async def process_input(ctx: Context, _: Pipeline) -> bool:
condition_history = [await message_to_langchain(Message(prompt), pipeline=_, source="system"), await message_to_langchain(ctx.last_request, pipeline=_, source="human")]
condition_history = [
await message_to_langchain(Message(prompt), pipeline=_, source="system"),
await message_to_langchain(ctx.last_request, pipeline=_, source="human"),
]
result = method(ctx, await self.model.agenerate([condition_history], logprobs=True, top_logprobs=10))
return result

Expand Down
6 changes: 5 additions & 1 deletion chatsky/llm/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class BaseMethod(BaseModel, abc.ABC):
"""
Base class to evaluate models response as condition.
"""

@abc.abstractmethod
async def __call__(self, ctx: Context, model_result: LLMResult) -> bool:
raise NotImplementedError
Expand All @@ -35,6 +36,7 @@ class Contains(BaseMethod):
:return: True if pattern is contained in model result
:rtype: bool
"""

pattern: str

async def __call__(self, ctx: Context, model_result: LLMResult) -> bool:
Expand All @@ -52,11 +54,13 @@ class LogProb(BaseMethod):
:return: True if logprob is higher then threshold
:rtype: bool
"""

target_token: str
threshold: float = -0.5

async def __call__(self, ctx: Context, model_result: LLMResult) -> bool:
try:
result = model_result.generations[0][0].generation_info['logprobs']['content'][0]['top_logprobs']
result = model_result.generations[0][0].generation_info["logprobs"]["content"][0]["top_logprobs"]
except ValueError:
raise ValueError("LogProb method can only be applied to OpenAI models.")
for tok in result:
Expand Down
26 changes: 19 additions & 7 deletions chatsky/llm/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,14 @@
from pydantic import BaseModel


def llm_response(model_name: str, prompt: str = "", history: int = 5, filter_func: Callable = lambda *args: True, message_schema: Union[None, Type[Message], Type[BaseModel]] = None, max_size: int=1000):
def llm_response(
model_name: str,
prompt: str = "",
history: int = 5,
filter_func: Callable = lambda *args: True,
message_schema: Union[None, Type[Message], Type[BaseModel]] = None,
max_size: int = 1000,
):
"""
Basic function for receiving LLM responses.
:param ctx: Context object. (Assigned automatically)
Expand All @@ -33,19 +40,25 @@ async def wrapped(ctx: Context, pipeline: Pipeline) -> Message:
# populate history with global and local prompts
if "global_prompt" in current_misc:
global_prompt = current_misc["global_prompt"]
history_messages.append(await message_to_langchain(Message(global_prompt), pipeline=pipeline, source="system"))
history_messages.append(
await message_to_langchain(Message(global_prompt), pipeline=pipeline, source="system")
)
if "local_prompt" in current_misc:
local_prompt = current_misc["local_prompt"]
history_messages.append(await message_to_langchain(Message(local_prompt), pipeline=pipeline, source="system"))
history_messages.append(
await message_to_langchain(Message(local_prompt), pipeline=pipeline, source="system")
)
if "prompt" in current_misc:
node_prompt = current_misc["prompt"]
history_messages.append(await message_to_langchain(Message(node_prompt), pipeline=pipeline, source="system"))
history_messages.append(
await message_to_langchain(Message(node_prompt), pipeline=pipeline, source="system")
)

# iterate over context to retrieve history messages
if not (history == 0 or len(ctx.responses) == 0 or len(ctx.requests) == 0):
pairs = zip(
[ctx.requests[x] for x in range(1, len(ctx.requests)+1)],
[ctx.responses[x] for x in range(1, len(ctx.responses)+1)],
[ctx.requests[x] for x in range(1, len(ctx.requests) + 1)],
[ctx.responses[x] for x in range(1, len(ctx.responses) + 1)],
)
# print(f"Pairs: {[p for p in pairs]}")
if history != -1:
Expand All @@ -64,4 +77,3 @@ async def wrapped(ctx: Context, pipeline: Pipeline) -> Message:
return await model.respond(history_messages, message_schema=message_schema)

return wrapped

12 changes: 9 additions & 3 deletions chatsky/llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage


async def message_to_langchain(message: Message, pipeline: Pipeline, source: str = "human", max_size: int=1000):
async def message_to_langchain(message: Message, pipeline: Pipeline, source: str = "human", max_size: int = 1000):
"""
Creates a langchain message from a ~chatsky.script.core.message.Message object.
Expand All @@ -19,13 +19,19 @@ async def message_to_langchain(message: Message, pipeline: Pipeline, source: str
if len(message.text) > max_size:
raise ValueError("Message is too long.")

if message.text is None: message.text = ""
if message.text is None:
message.text = ""
content = [{"type": "text", "text": message.text}]

if message.attachments:
for image in message.attachments:
if isinstance(image, Image):
content.append({"type": "image_url", "image_url": {"url": await __attachment_to_content(image, pipeline.messenger_interface)}})
content.append(
{
"type": "image_url",
"image_url": {"url": await __attachment_to_content(image, pipeline.messenger_interface)},
}
)

if source == "human":
return HumanMessage(content=content)
Expand Down
Loading

0 comments on commit 6fddaea

Please sign in to comment.