diff --git a/libs/community/langchain_community/chains/pebblo_retrieval/base.py b/libs/community/langchain_community/chains/pebblo_retrieval/base.py index 4fb769231484a..267254eff1416 100644 --- a/libs/community/langchain_community/chains/pebblo_retrieval/base.py +++ b/libs/community/langchain_community/chains/pebblo_retrieval/base.py @@ -37,6 +37,7 @@ CLASSIFIER_URL, PEBBLO_CLOUD_URL, PLUGIN_VERSION, + PROMPT_GOV_URL, PROMPT_URL, get_runtime, ) @@ -79,6 +80,8 @@ class PebbloRetrievalQA(Chain): """Flag to check if discover payload has been sent.""" _prompt_sent: bool = False #: :meta private: """Flag to check if prompt payload has been sent.""" + _enable_prompt_gov: bool = True #: :meta private: + """Flag to check if prompt governance payload has been sent.""" def _call( self, @@ -98,10 +101,27 @@ def _call( """ prompt_time = datetime.datetime.now().isoformat() PebbloRetrievalQA.set_prompt_sent(value=False) - _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() question = inputs[self.input_key] + auth_context = inputs.get(self.auth_context_key, {}) semantic_context = inputs.get(self.semantic_context_key, {}) + + if ( + self._enable_prompt_gov + and semantic_context is not None + and semantic_context.pebblo_semantic_entities is not None + and semantic_context.pebblo_semantic_entities.deny is not None + and len(semantic_context.pebblo_semantic_entities.deny) > 0 + ): + is_valid_prompt = self._check_prompt_validity(question, semantic_context) + logger.info(f"is_valid_prompt {is_valid_prompt}") + if is_valid_prompt is False: + return { + self.output_key: """Your prompt has some sensitive information. + Remove the senstitive information and try again !!!""" + } + + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() accepts_run_manager = ( "run_manager" in inspect.signature(self._get_docs).parameters ) @@ -170,6 +190,20 @@ async def _acall( accepts_run_manager = ( "run_manager" in inspect.signature(self._aget_docs).parameters ) + if ( + self._enable_prompt_gov + and semantic_context is not None + and semantic_context.pebblo_semantic_entities is not None + and semantic_context.pebblo_semantic_entities.deny is not None + and len(semantic_context.pebblo_semantic_entities.deny) > 0 + ): + is_valid_prompt = self._check_prompt_validity(question, semantic_context) + logger.info(f"is_valid_prompt {is_valid_prompt}") + if is_valid_prompt is False: + return { + self.output_key: """Your prompt has some sensitive information. + Remove the senstitive information and try again !!!""" + } if accepts_run_manager: docs = await self._aget_docs( question, auth_context, semantic_context, run_manager=_run_manager @@ -443,6 +477,7 @@ def _send_prompt(self, qa_payload: Qa) -> None: str(pebblo_resp.status_code), pebblo_resp.json(), ) + logger.info(f"pebblo_resp.json() {pebblo_resp.json()}") if pebblo_resp.status_code in [HTTPStatus.OK, HTTPStatus.BAD_GATEWAY]: PebbloRetrievalQA.set_prompt_sent() else: @@ -512,6 +547,73 @@ def _send_prompt(self, qa_payload: Qa) -> None: logger.warning("API key is missing for sending prompt to Pebblo cloud.") raise NameError("API key is missing for sending prompt to Pebblo cloud.") + def _check_prompt_validity( + self, question: str, semantic_context: SemanticContext + ) -> bool: + """ + Check the validity of the given prompt using a remote classification service. + + This method sends a prompt to a remote classifier service to determine if it + contains any entities that are on a deny list, + as specified in the semantic context. + + Args: + question (str): The prompt question to be validated. + semantic_context (SemanticContext): The semantic context + containing deny list entities. + + Returns: + bool: True if the prompt is valid (does not contain deny list entities), + False otherwise. + """ + + headers = { + "Accept": "application/json", + "Content-Type": "application/json", + } + prompt_payload = {"prompt": question} + is_valid_prompt = True + prompt_gov_api_url = f"{self.classifier_url}{PROMPT_GOV_URL}" + pebblo_resp = None + if self.classifier_location == "local": + try: + pebblo_resp = requests.post( + prompt_gov_api_url, + headers=headers, + json=prompt_payload, + timeout=20, + ) + + logger.debug("prompt-payload: %s", prompt_payload) + logger.debug( + "send_prompt[local]: request url %s, body %s len %s\ + response status %s body %s", + pebblo_resp.request.url, + str(pebblo_resp.request.body), + str( + len( + pebblo_resp.request.body if pebblo_resp.request.body else [] + ) + ), + str(pebblo_resp.status_code), + pebblo_resp.json(), + ) + logger.info(f"pebblo_resp.json() {pebblo_resp.json()}") + + deny_list = semantic_context.pebblo_semantic_entities.deny + prompt_entities = pebblo_resp.json().get("entities") + if prompt_entities is not None: + for value in deny_list: + if value in prompt_entities: + is_valid_prompt = False + break + + except requests.exceptions.RequestException: + logger.warning("Unable to reach pebblo server.") + except Exception as e: + logger.warning("An Exception caught in _send_discover: local %s", e) + return is_valid_prompt + @classmethod def get_chain_details(cls, llm: BaseLanguageModel, **kwargs): # type: ignore llm_dict = llm.__dict__ diff --git a/libs/community/langchain_community/chains/pebblo_retrieval/utilities.py b/libs/community/langchain_community/chains/pebblo_retrieval/utilities.py index 3056c8fae7c2c..e22a15265f4ae 100644 --- a/libs/community/langchain_community/chains/pebblo_retrieval/utilities.py +++ b/libs/community/langchain_community/chains/pebblo_retrieval/utilities.py @@ -15,6 +15,7 @@ PEBBLO_CLOUD_URL = os.getenv("PEBBLO_CLOUD_URL", "https://api.daxa.ai") PROMPT_URL = "/v1/prompt" +PROMPT_GOV_URL = "/v1/promptgov" APP_DISCOVER_URL = "/v1/app/discover"