Skip to content

Commit

Permalink
add single query input/output guardrails
Browse files Browse the repository at this point in the history
Signed-off-by: Tyler Wilbers <[email protected]>
  • Loading branch information
Tyler Wilbers committed Aug 6, 2024
1 parent 7749ce3 commit 4871435
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 8 deletions.
2 changes: 1 addition & 1 deletion comps/guardrails/llama_guard/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ pip install -r requirements.txt
export HF_TOKEN=${your_hf_api_token}
export LANGCHAIN_TRACING_V2=true
export LANGCHAIN_API_KEY=${your_langchain_api_key}
export LANGCHAIN_PROJECT="opea/gaurdrails"
export LANGCHAIN_PROJECT="opea/guardrails"
volume=$PWD/data
model_id="meta-llama/Meta-Llama-Guard-2-8B"
docker pull ghcr.io/huggingface/tgi-gaudi:2.0.1
Expand Down
16 changes: 11 additions & 5 deletions comps/guardrails/llama_guard/guardrails_tgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
# SPDX-License-Identifier: Apache-2.0

import os
from typing import List, Union

from langchain_community.utilities.requests import JsonRequestsWrapper
from langchain_huggingface import ChatHuggingFace
from langchain_huggingface.llms import HuggingFaceEndpoint
from langsmith import traceable

from comps import ServiceType, TextDoc, opea_microservices, register_microservice
from comps import GeneratedDoc, ServiceType, TextDoc, opea_microservices, register_microservice

DEFAULT_MODEL = "meta-llama/LlamaGuard-7b"

Expand Down Expand Up @@ -59,12 +60,18 @@ def get_tgi_service_model_id(endpoint_url, default=DEFAULT_MODEL):
endpoint="/v1/guardrails",
host="0.0.0.0",
port=9090,
input_datatype=TextDoc,
input_datatype=Union[GeneratedDoc, TextDoc],
output_datatype=TextDoc,
)
@traceable(run_type="llm")
def safety_guard(input: TextDoc) -> TextDoc:
response_input_guard = llm_engine_hf.invoke([{"role": "user", "content": input.text}]).content
def safety_guard(input: Union[GeneratedDoc, TextDoc]) -> TextDoc:
if isinstance(input, GeneratedDoc):
messages = [{"role": "user", "content": input.prompt}, {"role": "assistant", "content": input.text}]
else:
messages = [{"role": "user", "content": input.text}]
response_input_guard = llm_engine_hf.invoke(messages).content

# response_input_guard = llm_engine_hf.invoke([{"role": input.role, "content": input.text}]).content
if "unsafe" in response_input_guard:
unsafe_dict = get_unsafe_dict(llm_engine_hf.model_id)
policy_violation_level = response_input_guard.split("\n")[1].strip()
Expand All @@ -75,7 +82,6 @@ def safety_guard(input: TextDoc) -> TextDoc:
)
else:
res = TextDoc(text=input.text)

return res


Expand Down
3 changes: 2 additions & 1 deletion comps/guardrails/llama_guard/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
docarray[full]
fastapi
huggingface_hub
# Fix for issue with langchain-huggingface not using InferenceClient `base_url` kwarg
huggingface-hub<=0.24.0
langchain-community
langchain-huggingface
langsmith
Expand Down
1 change: 0 additions & 1 deletion tests/test_guardrails_llama_guard.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ function start_service() {
sleep 4m
docker run -d --name="test-guardrails-langchain-service" -p 9090:9090 --ipc=host -e http_proxy=$http_proxy -e https_proxy=$https_proxy -e no_proxy=$no_proxy -e SAFETY_GUARD_MODEL_ID=$SAFETY_GUARD_MODEL_ID -e SAFETY_GUARD_ENDPOINT=$SAFETY_GUARD_ENDPOINT -e HUGGINGFACEHUB_API_TOKEN=$HF_TOKEN opea/guardrails-tgi:latest
sleep 10s

echo "Microservice started"
}

Expand Down

0 comments on commit 4871435

Please sign in to comment.