-
Notifications
You must be signed in to change notification settings - Fork 149
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* revert transformers dependency temporary because spacy-transformers…
… don't support it * bump version to 0.1.3 * remove jailbreak scanner as it copies prompt injection one * updated no_refusal scanner to use transformers to classify
- Loading branch information
Showing
18 changed files
with
273 additions
and
434 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,56 +1,46 @@ | ||
import logging | ||
import os | ||
|
||
from sentence_transformers import SentenceTransformer, util | ||
|
||
from llm_guard.util import read_json_file | ||
from transformers import pipeline | ||
|
||
from .base import Scanner | ||
|
||
log = logging.getLogger(__name__) | ||
_dataset_path = os.path.join( | ||
os.path.dirname(os.path.abspath(__file__)), "..", "resources", "refusal.json" | ||
) | ||
_transformer_name = "sentence-transformers/all-MiniLM-L6-v2" | ||
_model_path = "MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7" | ||
_categories = ["refusal"] | ||
|
||
|
||
class NoRefusal(Scanner): | ||
""" | ||
A class used to detect and handle refusals in language model output. | ||
This class uses a transformer model to encode the output and compare it to a dataset of refusal | ||
sentences to identify any refusals. | ||
This class uses the zero-shot-classification model from Hugging Face to classify refusals. | ||
Refusals are common when the prompt breaches policies defined by the model. | ||
""" | ||
|
||
def __init__(self, dataset_path: str = _dataset_path, threshold: float = 0.5): | ||
def __init__(self, threshold: float = 0.5): | ||
""" | ||
Initializes an instance of the NoRefusal class. | ||
Parameters: | ||
dataset_path (str): The path to the refusal dataset. Default is refusal.json. | ||
threshold (float): The similarity threshold to consider an output as refusal. | ||
""" | ||
|
||
self._threshold = threshold | ||
self._transformer_model = SentenceTransformer(_transformer_name) | ||
self._dataset_embeddings = [ | ||
self._transformer_model.encode(s, convert_to_tensor=True) | ||
for s in read_json_file(dataset_path)["refusal"] | ||
] | ||
self._classifier = pipeline("zero-shot-classification", model=_model_path) | ||
|
||
def scan(self, prompt: str, output: str) -> (str, bool, float): | ||
similarities = [] | ||
text_embedding = self._transformer_model.encode(output, convert_to_tensor=True) | ||
for embedding in self._dataset_embeddings: | ||
similarity = util.pytorch_cos_sim(text_embedding, embedding) | ||
similarities.append(similarity.item()) | ||
if output.strip() == "": | ||
return output, True, 0.0 | ||
|
||
classifier_output = self._classifier(output, _categories, multi_label=False) | ||
|
||
max_score = round(max(similarities) if similarities else 0, 2) | ||
max_score = round(max(classifier_output["scores"]) if classifier_output["scores"] else 0, 2) | ||
if max_score > self._threshold: | ||
log.warning(f"Detected refusal result with similarity score: {max_score}") | ||
|
||
return output, False, max_score | ||
|
||
log.debug(f"No refusals. Max similarity with the known refusal results: {max_score}") | ||
|
||
return output, True, 0.0 |
Oops, something went wrong.