Skip to content

Commit

Permalink
* load models from local folder in the API
Browse files Browse the repository at this point in the history
* updates to the documentation
  • Loading branch information
asofter committed Apr 23, 2024
1 parent 42b16e1 commit 67f9608
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 31 deletions.
16 changes: 16 additions & 0 deletions docs/api/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,22 @@ All configurations are stored in `config/scanners.yml`. It supports configuring
1. Enable `SCAN_FAIL_FAST` to avoid unnecessary scans.
2. Enable `CACHE_MAX_SIZE` and `CACHE_TTL` to cache results and avoid unnecessary scans.
3. Enable authentication and rate limiting to avoid abuse.
4. Enable lazy loading of models to avoid failed HTTP probes.
5. Enable load of models from a directory to avoid downloading models each time the container starts.

### Load models from a directory

It's possible to load models from a local directory.
You can set `model_path` in each supported scanner with the folder to the ONNX version of the model.

This way, the models won't be downloaded each time the container starts.

[Relevant notebook](../tutorials/notebooks/local_models.ipynb)

### Lazy loading

You can enable `lazy_load` in the YAML config file to load models only on the first request instead of the API start.
That way, you can avoid failed HTTP probes due to the long model loading time.

## Observability

Expand Down
1 change: 1 addition & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Benchmarks on the AMD CPU.
- `API` has a new endpoint `POST /scan/prompt` to scan the prompt without sanitizing it. It is faster than the `POST /analyze/scan` endpoint.
- Example of running [LLM Guard with ChatGPT streaming mode](./tutorials/openai.md) enabled.
- `API` supports loading models from the local folder.

### Fixed
- `InvisibleText` scanner to allow control characters like `\n`, `\t`, etc.
Expand Down
1 change: 0 additions & 1 deletion llm_guard_api/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ COPY --chown=user:user app ./app

# Install the project's dependencies
RUN pip install --no-cache-dir --upgrade pip && \
pip install torch==2.2.2 --index-url https://download.pytorch.org/whl/cpu && \
pip install --no-cache-dir ".[cpu]"

RUN python -m spacy download en_core_web_sm
Expand Down
4 changes: 2 additions & 2 deletions llm_guard_api/app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ async def submit_analyze_prompt(
) -> AnalyzePromptResponse:
LOGGER.debug("Received analyze prompt request", request=request)

cached_result = cache.get(request.prompt)
cached_result = cache.get(f"analyze|{request.prompt}")
if cached_result:
LOGGER.debug("Response was found in cache")

Expand Down Expand Up @@ -388,7 +388,7 @@ async def submit_scan_prompt(
) -> ScanPromptResponse:
LOGGER.debug("Received scan prompt request", request=request)

cached_result = cache.get(request.prompt)
cached_result = cache.get(f"scan|{request.prompt}")
if cached_result:
LOGGER.debug("Response was found in cache")
response.headers["X-Cache-Hit"] = "true"
Expand Down
111 changes: 88 additions & 23 deletions llm_guard_api/app/scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,20 @@

from llm_guard import input_scanners, output_scanners
from llm_guard.input_scanners.anonymize_helpers import DISTILBERT_AI4PRIVACY_v2_CONF
from llm_guard.input_scanners.ban_code import MODEL_TINY as BAN_CODE_MODEL
from llm_guard.input_scanners.ban_code import MODEL_SM as BAN_CODE_MODEL
from llm_guard.input_scanners.ban_competitors import MODEL_SMALL as BAN_COMPETITORS_MODEL
from llm_guard.input_scanners.ban_topics import MODEL_ROBERTA_BASE_C_V2 as BAN_TOPICS_MODEL
from llm_guard.input_scanners.base import Scanner as InputScanner
from llm_guard.input_scanners.code import DEFAULT_MODEL as CODE_MODEL
from llm_guard.input_scanners.gibberish import DEFAULT_MODEL as GIBBERISH_MODEL
from llm_guard.input_scanners.language import DEFAULT_MODEL as LANGUAGE_MODEL
from llm_guard.input_scanners.prompt_injection import V2_MODEL as PROMPT_INJECTION_MODEL
from llm_guard.input_scanners.toxicity import DEFAULT_MODEL as TOXICITY_MODEL
from llm_guard.model import Model
from llm_guard.output_scanners.base import Scanner as OutputScanner
from llm_guard.output_scanners.bias import DEFAULT_MODEL as BIAS_MODEL
from llm_guard.output_scanners.malicious_urls import DEFAULT_MODEL as MALICIOUS_URLS_MODEL
from llm_guard.output_scanners.no_refusal import DEFAULT_MODEL as NO_REFUSAL_MODEL
from llm_guard.output_scanners.relevance import MODEL_EN_BGE_SMALL as RELEVANCE_MODEL
from llm_guard.vault import Vault

Expand Down Expand Up @@ -67,6 +74,16 @@ def get_output_scanners(scanners: List[ScannerConfig], vault: Vault) -> List[Out
return output_scanners_loaded


def _use_local_model(model: Model, path: Optional[str]):
if path is None:
return

model.path = path
model.onnx_path = path
model.onnx_subfolder = ""
model.kwargs = {"local_files_only": True}


def _get_input_scanner(
scanner_name: str,
scanner_config: Optional[Dict],
Expand All @@ -92,25 +109,40 @@ def _get_input_scanner(
scanner_config["use_onnx"] = True

if scanner_name == "Anonymize":
_use_local_model(DISTILBERT_AI4PRIVACY_v2_CONF, scanner_config.get("model_path"))
scanner_config["recognizer_conf"] = DISTILBERT_AI4PRIVACY_v2_CONF

if scanner_name == "Language":
LANGUAGE_MODEL.onnx_filename = "model_optimized.onnx"
scanner_config["model"] = LANGUAGE_MODEL
if scanner_name == "BanCode":
_use_local_model(BAN_CODE_MODEL, scanner_config.get("model_path"))
scanner_config["model"] = BAN_CODE_MODEL

if scanner_name == "PromptInjection":
PROMPT_INJECTION_MODEL.onnx_filename = "model_optimized.onnx"
PROMPT_INJECTION_MODEL.kwargs["max_length"] = 128
scanner_config["model"] = PROMPT_INJECTION_MODEL
if scanner_name == "BanTopics":
_use_local_model(BAN_TOPICS_MODEL, scanner_config.get("model_path"))
scanner_config["model"] = BAN_TOPICS_MODEL

if scanner_name == "BanCompetitors":
_use_local_model(BAN_COMPETITORS_MODEL, scanner_config.get("model_path"))
scanner_config["model"] = BAN_COMPETITORS_MODEL

if scanner_name == "BanTopics":
scanner_config["model"] = BAN_TOPICS_MODEL
if scanner_name == "Code":
_use_local_model(CODE_MODEL, scanner_config.get("model_path"))
scanner_config["model"] = CODE_MODEL

if scanner_name == "BanCode":
scanner_config["model"] = BAN_CODE_MODEL
if scanner_name == "Gibberish":
_use_local_model(GIBBERISH_MODEL, scanner_config.get("model_path"))
scanner_config["model"] = GIBBERISH_MODEL

if scanner_name == "Language":
_use_local_model(LANGUAGE_MODEL, scanner_config.get("model_path"))
scanner_config["model"] = LANGUAGE_MODEL

if scanner_name == "PromptInjection":
_use_local_model(PROMPT_INJECTION_MODEL, scanner_config.get("model_path"))
scanner_config["model"] = PROMPT_INJECTION_MODEL

if scanner_name == "Toxicity":
_use_local_model(TOXICITY_MODEL, scanner_config.get("model_path"))
scanner_config["model"] = TOXICITY_MODEL

return input_scanners.get_scanner_by_name(scanner_name, scanner_config)

Expand All @@ -132,36 +164,69 @@ def _get_output_scanner(
"BanTopics",
"Bias",
"Code",
"FactualConsistency",
"Gibberish",
"Language",
"LanguageSame",
"MaliciousURLs",
"NoRefusal",
"FactualConsistency",
"Gibberish",
"Relevance",
"Sensitive",
"Toxicity",
]:
scanner_config["use_onnx"] = True

if scanner_name == "Sensitive":
scanner_config["recognizer_conf"] = DISTILBERT_AI4PRIVACY_v2_CONF

if scanner_name == "Language":
LANGUAGE_MODEL.onnx_filename = "model_optimized.onnx"
scanner_config["model"] = LANGUAGE_MODEL
if scanner_name == "BanCode":
_use_local_model(BAN_CODE_MODEL, scanner_config.get("model_path"))
scanner_config["model"] = BAN_CODE_MODEL

if scanner_name == "BanCompetitors":
_use_local_model(BAN_COMPETITORS_MODEL, scanner_config.get("model_path"))
scanner_config["model"] = BAN_COMPETITORS_MODEL

if scanner_name == "FactualConsistency" or scanner_name == "BanTopics":
if scanner_name == "BanTopics" or scanner_name == "FactualConsistency":
_use_local_model(BAN_TOPICS_MODEL, scanner_config.get("model_path"))
scanner_config["model"] = BAN_TOPICS_MODEL

if scanner_name == "Bias":
_use_local_model(BIAS_MODEL, scanner_config.get("model_path"))
scanner_config["model"] = BIAS_MODEL

if scanner_name == "Code":
_use_local_model(CODE_MODEL, scanner_config.get("model_path"))
scanner_config["model"] = CODE_MODEL

if scanner_name == "Language":
_use_local_model(LANGUAGE_MODEL, scanner_config.get("model_path"))
scanner_config["model"] = LANGUAGE_MODEL

if scanner_name == "LanguageSame":
_use_local_model(LANGUAGE_MODEL, scanner_config.get("model_path"))
scanner_config["model"] = LANGUAGE_MODEL

if scanner_name == "MaliciousURLs":
_use_local_model(MALICIOUS_URLS_MODEL, scanner_config.get("model_path"))
scanner_config["model"] = MALICIOUS_URLS_MODEL

if scanner_name == "NoRefusal":
_use_local_model(NO_REFUSAL_MODEL, scanner_config.get("model_path"))
scanner_config["model"] = NO_REFUSAL_MODEL

if scanner_name == "Gibberish":
_use_local_model(GIBBERISH_MODEL, scanner_config.get("model_path"))
scanner_config["model"] = GIBBERISH_MODEL

if scanner_name == "Relevance":
_use_local_model(RELEVANCE_MODEL, scanner_config.get("model_path"))
scanner_config["model"] = RELEVANCE_MODEL

if scanner_name == "BanCode":
scanner_config["model"] = BAN_CODE_MODEL
if scanner_name == "Sensitive":
_use_local_model(DISTILBERT_AI4PRIVACY_v2_CONF, scanner_config.get("model_path"))
scanner_config["recognizer_conf"] = DISTILBERT_AI4PRIVACY_v2_CONF

if scanner_name == "Toxicity":
_use_local_model(TOXICITY_MODEL, scanner_config.get("model_path"))
scanner_config["model"] = TOXICITY_MODEL

return output_scanners.get_scanner_by_name(scanner_name, scanner_config)

Expand Down
11 changes: 6 additions & 5 deletions llm_guard_api/config/scanners.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,10 @@ input_scanners:
# preamble: ""
use_faker: false
threshold: 0.6
# model_path: "./distilbert_finetuned_ai4privacy_v2"
- type: BanCode
params:
threshold: 0.95
threshold: 0.97
- type: BanCompetitors
params:
competitors: ["facebook"]
Expand Down Expand Up @@ -82,7 +83,7 @@ input_scanners:
- type: Sentiment
params:
# lexicon: "vader_lexicon"
threshold: -0.1
threshold: -0.5
- type: TokenLimit
params:
limit: 4096
Expand All @@ -95,7 +96,7 @@ input_scanners:
output_scanners:
- type: BanCode
params:
threshold: 0.95
threshold: 0.97
- type: BanCompetitors
params:
competitors: ["facebook"]
Expand All @@ -113,7 +114,7 @@ output_scanners:
threshold: 0.6
- type: Bias
params:
threshold: 0.75
threshold: 0.9
# - type: Code
# params:
# languages: ["Python"]
Expand Down Expand Up @@ -161,7 +162,7 @@ output_scanners:
threshold: 0.6
- type: Sentiment
params:
threshold: -0.1
threshold: -0.5
# lexicon: "vader_lexicon"
- type: Toxicity
params:
Expand Down

0 comments on commit 67f9608

Please sign in to comment.