diff --git a/genoss/llm/openai/openai_llm.py b/genoss/llm/openai/openai_llm.py index c463827..16b7cd7 100644 --- a/genoss/llm/openai/openai_llm.py +++ b/genoss/llm/openai/openai_llm.py @@ -30,7 +30,7 @@ def __init__(self, model_name: str, api_key, *args, **kwargs): def generate_answer(self, question: str) -> Dict: print("Generating Answer") - llm = OpenAIChat(model=self.model_name, openai_api_key=self.openai_api_key) + llm = OpenAIChat(model_name=self.model_name, openai_api_key=self.openai_api_key) llm_chain = LLMChain(llm=llm, prompt=prompt_template) response_text = llm_chain(question) diff --git a/genoss/services/model_factory.py b/genoss/services/model_factory.py index 5f27a04..1d777f3 100644 --- a/genoss/services/model_factory.py +++ b/genoss/services/model_factory.py @@ -1,8 +1,11 @@ from typing import Optional -from genoss.llm.base_genoss_llm import BaseGenossLLM +from genoss.llm.base_genoss import BaseGenossLLM from genoss.llm.fake_llm import FAKE_LLM_NAME, FakeLLM -from genoss.llm.local.gpt4all_llm import Gpt4AllLLM +from genoss.llm.hf_hub.falcon import HuggingFaceHubFalconLLM +from genoss.llm.hf_hub.gpt2 import HuggingFaceHubGPT2LLM +from genoss.llm.hf_hub.llama2 import HuggingFaceHubLlama2LLM +from genoss.llm.local.gpt4all import Gpt4AllLLM from genoss.llm.openai.openai_llm import OpenAILLM OPENAI_NAME_LIST = ["gpt-4", "gpt-3.5-turbo"] @@ -10,7 +13,7 @@ class ModelFactory: @staticmethod - def get_model_from_name(name: str) -> Optional[BaseGenossLLM]: + def get_model_from_name(name: str, api_key) -> Optional[BaseGenossLLM]: if name.lower() in OPENAI_NAME_LIST: return OpenAILLM(model_name=name, api_key=api_key) if name.lower() == "gpt4all": diff --git a/tests/services/test_model_factory.py b/tests/services/test_model_factory.py index 3399d3c..7d827ee 100644 --- a/tests/services/test_model_factory.py +++ b/tests/services/test_model_factory.py @@ -1,7 +1,7 @@ import unittest from genoss.llm.fake_llm import FAKE_LLM_NAME, FakeLLM -from genoss.llm.local.gpt4all_llm import Gpt4AllLLM +from genoss.llm.local.gpt4all import Gpt4AllLLM from genoss.llm.openai.openai_llm import OpenAILLM from genoss.services.model_factory import ModelFactory