Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prompt blocking #81

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 100 additions & 43 deletions libs/community/langchain_community/chains/pebblo_retrieval/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,16 @@ class PebbloRetrievalQA(Chain):
"""Flag to check if prompt payload has been sent."""
enable_prompt_gov: bool = True #: :meta private:
"""Flag to check if prompt governance is enabled or not"""
block_prompt: bool = True
"""
Flag to determine whether prompting should be blocked.
- True: Prompts are blocked.
- False: Prompts are allowed.
"""
prompt_message: str = "Your prompt is blocked as it has some sensitive information."
"""
Message displayed to the user when a prompt is blocked due to sensitive content.
"""

def _call(
self,
Expand All @@ -105,20 +115,34 @@ def _call(
question = inputs[self.input_key]
auth_context = inputs.get(self.auth_context_key, {})
semantic_context = inputs.get(self.semantic_context_key, {})
_, prompt_entities = self._check_prompt_validity(question)

accepts_run_manager = (
"run_manager" in inspect.signature(self._get_docs).parameters
)
if accepts_run_manager:
docs = self._get_docs(
question, auth_context, semantic_context, run_manager=_run_manager
is_prompt_valid: bool = True
is_prompt_blocked: bool = False
prompt_entities: dict = {}
if self.enable_prompt_gov:
is_prompt_valid, prompt_entities = self._check_prompt_validity(
question, semantic_context
)

if self.block_prompt and is_prompt_valid is False:
docs = []
is_prompt_blocked = True
answer = self.prompt_message
else:
docs = self._get_docs(question, auth_context, semantic_context) # type: ignore[call-arg]
answer = self.combine_documents_chain.run(
input_documents=docs, question=question, callbacks=_run_manager.get_child()
)
accepts_run_manager = (
"run_manager" in inspect.signature(self._get_docs).parameters
)
if accepts_run_manager:
docs = self._get_docs(
question, auth_context, semantic_context, run_manager=_run_manager
)
else:
docs = self._get_docs(question, auth_context, semantic_context) # type: ignore[call-arg]

answer = self.combine_documents_chain.run(
input_documents=docs,
question=question,
callbacks=_run_manager.get_child(),
)

qa = {
"name": self.app_name,
Expand All @@ -142,7 +166,9 @@ def _call(
"data": question,
"entities": prompt_entities.get("entities", {}),
"entityCount": prompt_entities.get("entityCount", 0),
"prompt_gov_enabled": self.enable_prompt_gov,
"entityDetails": prompt_entities.get("entityDetails", {}),
"promptGovEnabled": self.enable_prompt_gov,
"promptBlocked": is_prompt_blocked,
},
"response": {
"data": answer,
Expand All @@ -154,7 +180,7 @@ def _call(
else [],
"classifier_location": self.classifier_location,
}

logger.info(f"QA------ {qa}")
qa_payload = Qa(**qa)
self._send_prompt(qa_payload)

Expand Down Expand Up @@ -187,7 +213,9 @@ async def _acall(
"run_manager" in inspect.signature(self._aget_docs).parameters
)

_, prompt_entities = self._check_prompt_validity(question)
is_valid_prompt, prompt_entities = self._check_prompt_validity(
question, semantic_context
)

if accepts_run_manager:
docs = await self._aget_docs(
Expand Down Expand Up @@ -443,7 +471,7 @@ def _send_prompt(self, qa_payload: Qa) -> None:
json=payload,
timeout=20,
)
logger.debug("prompt-payload: %s", payload)

logger.debug(
"send_prompt[local]: request url %s, body %s len %s\
response status %s body %s",
Expand Down Expand Up @@ -525,64 +553,93 @@ def _send_prompt(self, qa_payload: Qa) -> None:
logger.warning("API key is missing for sending prompt to Pebblo cloud.")
raise NameError("API key is missing for sending prompt to Pebblo cloud.")

def _check_prompt_validity(self, question: str) -> Tuple[bool, Dict[str, Any]]:
def _check_prompt_validity(
self, question: str, semantic_context: Optional[SemanticContext]
) -> Tuple[bool, Dict[str, Any]]:
"""
Check the validity of the given prompt using a remote classification service.

This method sends a prompt to a remote classifier service and return entities
present in prompt or not.
This method sends a prompt to a remote classifier service and returns entities
present in the prompt or not.

Args:
question (str): The prompt question to be validated.
semantic_context (SemanticContext): The semantic context containing
deny lists.

Returns:
bool: True if the prompt is valid (does not contain deny list entities),
False otherwise.
dict: The entities present in the prompt
dict: A dictionary containing details about the entities present
in the prompt.
"""

headers = {
"Accept": "application/json",
"Content-Type": "application/json",
}
prompt_payload = {"prompt": question}
is_valid_prompt: bool = True
prompt_gov_api_url = f"{self.classifier_url}{PROMPT_GOV_URL}"
pebblo_resp = None
prompt_entities: dict = {"entities": {}, "entityCount": 0}

is_valid_prompt = True
prompt_entities: dict = {
"entities": {},
"entityCount": 0,
"entityDetails": dict(),
}

# Extract deny lists from the semantic context
group_deny_list = (
semantic_context.pebblo_entity_group.deny
if semantic_context and semantic_context.pebblo_entity_group
else []
)
entity_deny_list = (
semantic_context.pebblo_semantic_entities.deny
if semantic_context and semantic_context.pebblo_semantic_entities
else []
)

# Perform the check using the local classifier
if self.classifier_location == "local":
try:
logger.info(
f"Sending Prompt Governance request to {prompt_gov_api_url}"
)

pebblo_resp = requests.post(
prompt_gov_api_url,
headers=headers,
json=prompt_payload,
timeout=20,
)

logger.debug("prompt-payload: %s", prompt_payload)
logger.debug(
"send_prompt[local]: request url %s, body %s len %s\
response status %s body %s",
pebblo_resp.request.url,
str(pebblo_resp.request.body),
str(
len(
pebblo_resp.request.body if pebblo_resp.request.body else []
)
),
str(pebblo_resp.status_code),
pebblo_resp.json(),
)
logger.debug(f"pebblo_resp.json() {pebblo_resp.json()}")
prompt_entities["entities"] = pebblo_resp.json().get("entities", {})
prompt_entities["entityCount"] = pebblo_resp.json().get(
"entityCount", 0
pebblo_resp_data = pebblo_resp.json()
prompt_entities["entities"] = pebblo_resp_data.get("entities", {})
prompt_entities["entityCount"] = pebblo_resp_data.get("entityCount", 0)
prompt_entities["entityDetails"] = pebblo_resp_data.get(
"entityDetails", {}
)

if self.block_prompt:
# Check if any entities or entity groups are in the deny lists
for entity, details in prompt_entities["entityDetails"].items():
if entity in entity_deny_list:
is_valid_prompt = False
break
for detail in details:
if detail.get("entity_group").strip() in map(
str.strip, group_deny_list
):
is_valid_prompt = False
break

except requests.exceptions.RequestException:
logger.warning("Unable to reach pebblo server.")
logger.warning("Unable to reach Pebblo server.")
except Exception as e:
logger.warning("An Exception caught in _send_discover: local %s", e)
logger.warning(f"Exception caught in prompt governance: {str(e)}")
logger.debug("Traceback: ", exc_info=True)

return is_valid_prompt, prompt_entities

@classmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,18 @@ class SemanticTopics(BaseModel):
deny: List[str]


class SemanticEntityGroup(BaseModel):
"""Class for a semantic entity group filter."""

deny: List[str]


class SemanticContext(BaseModel):
"""Class for a semantic context."""

pebblo_semantic_entities: Optional[SemanticEntities] = None
pebblo_semantic_topics: Optional[SemanticTopics] = None
pebblo_entity_group: Optional[SemanticEntityGroup] = None

def __init__(self, **data: Any) -> None:
super().__init__(**data)
Expand All @@ -40,10 +47,11 @@ def __init__(self, **data: Any) -> None:
if (
self.pebblo_semantic_entities is None
and self.pebblo_semantic_topics is None
and self.pebblo_entity_group is None
):
raise ValueError(
"semantic_context must contain 'pebblo_semantic_entities' or "
"'pebblo_semantic_topics'"
"'pebblo_semantic_topics' or 'pebblo_entity_group'"
)


Expand Down Expand Up @@ -134,9 +142,11 @@ class Context(BaseModel):

class Prompt(BaseModel):
data: Optional[Union[list, str]]
entityCount: Optional[int]
entities: Optional[dict]
prompt_gov_enabled: Optional[bool]
entityCount: Optional[int]
entityDetails: Optional[dict]
promptGovEnabled: Optional[bool]
promptBlocked: Optional[bool]


class Qa(BaseModel):
Expand Down
Loading