-
Notifications
You must be signed in to change notification settings - Fork 63
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #18 from OpenGenenerativeAI/add-hf-inference-endpoint
Add hf inference endpoint
- Loading branch information
Showing
14 changed files
with
174 additions
and
62 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
HUGGINGFACEHUB_API_TOKEN=<your token> | ||
OPENAI_API_KEY=<your token> | ||
CUSTOM_HF_ENDPOINT_URL=<your url> |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import openai | ||
from pydantic import BaseModel, SecretStr | ||
|
||
from demo.constants.paths import GENOSS_URL | ||
from demo.constants.settings import SETTINGS | ||
|
||
|
||
class ModelConfig(BaseModel): | ||
display_name: str | ||
model_name: str | ||
api_key: SecretStr | ||
endpoint_url: str | ||
|
||
def configure_open_ai_module(self) -> None: | ||
openai.api_key = self.api_key.get_secret_value() | ||
openai.api_base = self.endpoint_url | ||
|
||
|
||
AVAILABLE_MODELS = [ | ||
ModelConfig( | ||
display_name="OpenAI-GPT-4", | ||
model_name="gpt-4", | ||
api_key=SETTINGS.openai_api_key, | ||
endpoint_url=openai.api_base, | ||
), | ||
ModelConfig( | ||
display_name="OpenAI-GPT-4 (through Genoss)", | ||
model_name="gpt-4", | ||
api_key=SETTINGS.openai_api_key, | ||
endpoint_url=GENOSS_URL, | ||
), | ||
ModelConfig( | ||
display_name="hf-gpt2", | ||
model_name="hf-gpt2", | ||
api_key=SETTINGS.huggingfacehub_api_token, | ||
endpoint_url=GENOSS_URL, | ||
), | ||
ModelConfig( | ||
display_name="hf-llama2", | ||
model_name="hf-llama2", | ||
api_key=SETTINGS.huggingfacehub_api_token, | ||
endpoint_url=GENOSS_URL, | ||
), | ||
ModelConfig( | ||
display_name="hf-custom/llama", | ||
model_name=f"hf-inference-endpoint/{SETTINGS.custom_hf_endpoint_url}", | ||
api_key=SETTINGS.huggingfacehub_api_token, | ||
endpoint_url=GENOSS_URL, | ||
), | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from pathlib import Path | ||
|
||
ROOT_FOLDER = Path(__file__).parent.parent.parent | ||
GENOSS_URL = "http://localhost:4321" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from pydantic import BaseSettings, HttpUrl, SecretStr | ||
|
||
from demo.constants.paths import ROOT_FOLDER | ||
|
||
|
||
class Settings(BaseSettings): | ||
class Config: | ||
env_file = ROOT_FOLDER / "demo" / ".env" | ||
|
||
huggingfacehub_api_token: SecretStr | ||
openai_api_key: SecretStr | ||
custom_hf_endpoint_url: HttpUrl | ||
|
||
|
||
SETTINGS = Settings() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
from abc import ABC | ||
from typing import Any, Literal | ||
from unittest import mock | ||
|
||
from langchain import LLMChain | ||
from langchain.llms import HuggingFaceEndpoint | ||
|
||
from genoss.entities.chat.chat_completion import ChatCompletion | ||
from genoss.llm.base_genoss import BaseGenossLLM | ||
from genoss.prompts.prompt_template import prompt_template | ||
|
||
|
||
class HuggingFaceInferenceEndpointLLM(BaseGenossLLM, ABC): | ||
"""Class for interacting with Hugging Face Inference APIs.""" | ||
|
||
# Subclasses must define these | ||
name = "HF Inference Endpoint" | ||
api_key: str | None = None | ||
endpoint_url: str | ||
description: str = "Hugging Face Inference API custom endpoint." | ||
task: Literal[ | ||
"text-generation", "text-generation", "summarization" | ||
] = "text-generation" | ||
|
||
@mock.patch( | ||
"huggingface_hub.inference_api.INFERENCE_ENDPOINT", "http://0.0.0.0:8080" | ||
) | ||
def generate_answer(self, question: str) -> dict[str, Any]: | ||
"""Generate answer from prompt.""" | ||
llm = HuggingFaceEndpoint( | ||
endpoint_url=self.endpoint_url, | ||
huggingfacehub_api_token=self.api_key, | ||
task=self.task, | ||
) | ||
llm_chain = LLMChain(prompt=prompt_template, llm=llm) | ||
|
||
response_text = llm_chain(question) | ||
|
||
answer = response_text["text"] | ||
|
||
chat_completion = ChatCompletion( | ||
model=self.name, question=question, answer=answer | ||
) | ||
|
||
return chat_completion.to_dict() | ||
|
||
def generate_embedding(self, text: str) -> list[float]: | ||
"""Dummy method to satisfy base class requirement.""" | ||
# TODO: why is this necessary? Architecture issue? | ||
raise NotImplementedError( | ||
"This method is not used for Hugging Face Inference API." | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters