Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

781 prompt governance detect pii phi #61

Open
wants to merge 4 commits into
base: daxa_3.1
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 103 additions & 1 deletion libs/community/langchain_community/chains/pebblo_retrieval/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
CLASSIFIER_URL,
PEBBLO_CLOUD_URL,
PLUGIN_VERSION,
PROMPT_GOV_URL,
PROMPT_URL,
get_runtime,
)
Expand Down Expand Up @@ -79,6 +80,8 @@ class PebbloRetrievalQA(Chain):
"""Flag to check if discover payload has been sent."""
_prompt_sent: bool = False #: :meta private:
"""Flag to check if prompt payload has been sent."""
_enable_prompt_gov: bool = True #: :meta private:
"""Flag to check if prompt governance payload has been sent."""

def _call(
self,
Expand All @@ -98,10 +101,27 @@ def _call(
"""
prompt_time = datetime.datetime.now().isoformat()
PebbloRetrievalQA.set_prompt_sent(value=False)
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
question = inputs[self.input_key]

auth_context = inputs.get(self.auth_context_key, {})
semantic_context = inputs.get(self.semantic_context_key, {})

if (
self._enable_prompt_gov
and semantic_context is not None
and semantic_context.pebblo_semantic_entities is not None
and semantic_context.pebblo_semantic_entities.deny is not None
and len(semantic_context.pebblo_semantic_entities.deny) > 0
):
is_valid_prompt = self._check_prompt_validity(question, semantic_context)
logger.info(f"is_valid_prompt {is_valid_prompt}")
if is_valid_prompt is False:
return {
self.output_key: """Your prompt has some sensitive information.
Remove the senstitive information and try again !!!"""
}

_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
accepts_run_manager = (
"run_manager" in inspect.signature(self._get_docs).parameters
)
Expand Down Expand Up @@ -170,6 +190,20 @@ async def _acall(
accepts_run_manager = (
"run_manager" in inspect.signature(self._aget_docs).parameters
)
if (
self._enable_prompt_gov
and semantic_context is not None
and semantic_context.pebblo_semantic_entities is not None
and semantic_context.pebblo_semantic_entities.deny is not None
and len(semantic_context.pebblo_semantic_entities.deny) > 0
):
is_valid_prompt = self._check_prompt_validity(question, semantic_context)
logger.info(f"is_valid_prompt {is_valid_prompt}")
if is_valid_prompt is False:
return {
self.output_key: """Your prompt has some sensitive information.
Remove the senstitive information and try again !!!"""
}
if accepts_run_manager:
docs = await self._aget_docs(
question, auth_context, semantic_context, run_manager=_run_manager
Expand Down Expand Up @@ -443,6 +477,7 @@ def _send_prompt(self, qa_payload: Qa) -> None:
str(pebblo_resp.status_code),
pebblo_resp.json(),
)
logger.info(f"pebblo_resp.json() {pebblo_resp.json()}")
if pebblo_resp.status_code in [HTTPStatus.OK, HTTPStatus.BAD_GATEWAY]:
PebbloRetrievalQA.set_prompt_sent()
else:
Expand Down Expand Up @@ -512,6 +547,73 @@ def _send_prompt(self, qa_payload: Qa) -> None:
logger.warning("API key is missing for sending prompt to Pebblo cloud.")
raise NameError("API key is missing for sending prompt to Pebblo cloud.")

def _check_prompt_validity(
self, question: str, semantic_context: SemanticContext
) -> bool:
"""
Check the validity of the given prompt using a remote classification service.

This method sends a prompt to a remote classifier service to determine if it
contains any entities that are on a deny list,
as specified in the semantic context.

Args:
question (str): The prompt question to be validated.
semantic_context (SemanticContext): The semantic context
containing deny list entities.

Returns:
bool: True if the prompt is valid (does not contain deny list entities),
False otherwise.
"""

headers = {
"Accept": "application/json",
"Content-Type": "application/json",
}
prompt_payload = {"prompt": question}
is_valid_prompt = True
prompt_gov_api_url = f"{self.classifier_url}{PROMPT_GOV_URL}"
pebblo_resp = None
if self.classifier_location == "local":
try:
pebblo_resp = requests.post(
prompt_gov_api_url,
headers=headers,
json=prompt_payload,
timeout=20,
)

logger.debug("prompt-payload: %s", prompt_payload)
logger.debug(
"send_prompt[local]: request url %s, body %s len %s\
response status %s body %s",
pebblo_resp.request.url,
str(pebblo_resp.request.body),
str(
len(
pebblo_resp.request.body if pebblo_resp.request.body else []
)
),
str(pebblo_resp.status_code),
pebblo_resp.json(),
)
logger.info(f"pebblo_resp.json() {pebblo_resp.json()}")

deny_list = semantic_context.pebblo_semantic_entities.deny
prompt_entities = pebblo_resp.json().get("entities")
if prompt_entities is not None:
for value in deny_list:
if value in prompt_entities:
is_valid_prompt = False
break

except requests.exceptions.RequestException:
logger.warning("Unable to reach pebblo server.")
except Exception as e:
logger.warning("An Exception caught in _send_discover: local %s", e)
return is_valid_prompt

@classmethod
def get_chain_details(cls, llm: BaseLanguageModel, **kwargs): # type: ignore
llm_dict = llm.__dict__
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
PEBBLO_CLOUD_URL = os.getenv("PEBBLO_CLOUD_URL", "https://api.daxa.ai")

PROMPT_URL = "/v1/prompt"
PROMPT_GOV_URL = "/v1/promptgov"
APP_DISCOVER_URL = "/v1/app/discover"


Expand Down
Loading