From c18d000387885a97aa300c048875d188087dde67 Mon Sep 17 00:00:00 2001 From: Sheepsta300 <128811766+Sheepsta300@users.noreply.github.com> Date: Wed, 11 Sep 2024 18:01:17 +1200 Subject: [PATCH] update class to use `__init__` to validate environment instead of `root_validator` --- .../azure_cognitive_services/__init__.py | 4 -- .../content_safety.py | 54 ++++++++++++------- 2 files changed, 36 insertions(+), 22 deletions(-) diff --git a/libs/community/langchain_community/tools/azure_cognitive_services/__init__.py b/libs/community/langchain_community/tools/azure_cognitive_services/__init__.py index b495cd3ecd96c..1121e4e89d1f4 100644 --- a/libs/community/langchain_community/tools/azure_cognitive_services/__init__.py +++ b/libs/community/langchain_community/tools/azure_cognitive_services/__init__.py @@ -15,9 +15,6 @@ from langchain_community.tools.azure_cognitive_services.text_analytics_health import ( AzureCogsTextAnalyticsHealthTool, ) -from langchain_community.tools.azure_cognitive_services.content_safety import ( - AzureContentSafetyTextTool, -) __all__ = [ "AzureCogsImageAnalysisTool", @@ -25,5 +22,4 @@ "AzureCogsSpeech2TextTool", "AzureCogsText2SpeechTool", "AzureCogsTextAnalyticsHealthTool", - "AzureContentSafetyTextTool", ] diff --git a/libs/community/langchain_community/tools/azure_cognitive_services/content_safety.py b/libs/community/langchain_community/tools/azure_cognitive_services/content_safety.py index afc2690960ad1..9be95c1a8acc0 100644 --- a/libs/community/langchain_community/tools/azure_cognitive_services/content_safety.py +++ b/libs/community/langchain_community/tools/azure_cognitive_services/content_safety.py @@ -1,12 +1,11 @@ from __future__ import annotations import logging +import os from typing import Any, Dict, Optional from langchain_core.callbacks import CallbackManagerForToolRun -from langchain_core.pydantic_v1 import root_validator from langchain_core.tools import BaseTool -from langchain_core.utils import get_from_dict_or_env logger = logging.getLogger(__name__) @@ -30,10 +29,6 @@ class AzureContentSafetyTextTool(BaseTool): An instance of the Azure Content Safety Client used for making API requests. Methods: - validate_environment(values: Dict) -> Dict: - Validates the presence of API key and endpoint in the environment - and initializes the Content Safety Client. - _sentiment_analysis(text: str) -> Dict: Analyzes the provided text to assess its sentiment and safety, returning the analysis results. @@ -56,22 +51,44 @@ class AzureContentSafetyTextTool(BaseTool): "Input should be text." ) - @root_validator(pre=True) - def validate_environment(cls, values: Dict) -> Dict: - """Validate that api key and endpoint exists in environment.""" - content_safety_key = get_from_dict_or_env( - values, "content_safety_key", "CONTENT_SAFETY_API_KEY" - ) + def __init__( + self, + *, + content_safety_key: Optional[str] = None, + content_safety_endpoint: Optional[str] = None, + ) -> None: + """ + Initialize the AzureContentSafetyTextTool with the given API key and endpoint. + + This constructor sets up the API key and endpoint, and initializes + the Azure Content Safety Client. If API key or endpoint is not provided, + they are fetched from environment variables. - content_safety_endpoint = get_from_dict_or_env( - values, "content_safety_endpoint", "CONTENT_SAFETY_ENDPOINT" - ) + Args: + content_safety_key (Optional[str]): + The API key for Azure Content Safety API. If not provided, + it will be fetched from the environment + variable 'CONTENT_SAFETY_API_KEY'. + content_safety_endpoint (Optional[str]): + The endpoint URL for Azure Content Safety API. If not provided, + it will be fetched from the environment + variable 'CONTENT_SAFETY_ENDPOINT'. + Raises: + ImportError: If the 'azure-ai-contentsafety' package is not installed. + ValueError: + If API key or endpoint is not provided + and environment variables are missing. + """ + content_safety_key = (content_safety_key or + os.environ['CONTENT_SAFETY_API_KEY']) + content_safety_endpoint = (content_safety_endpoint or + os.environ['CONTENT_SAFETY_ENDPOINT']) try: import azure.ai.contentsafety as sdk from azure.core.credentials import AzureKeyCredential - values["content_safety_client"] = sdk.ContentSafetyClient( + content_safety_client = sdk.ContentSafetyClient( endpoint=content_safety_endpoint, credential=AzureKeyCredential(content_safety_key), ) @@ -81,8 +98,9 @@ def validate_environment(cls, values: Dict) -> Dict: "azure-ai-contentsafety is not installed. " "Run `pip install azure-ai-contentsafety` to install." ) - - return values + super().__init__(content_safety_key=content_safety_key, + content_safety_endpoint=content_safety_endpoint, + content_safety_client=content_safety_client) def _sentiment_analysis(self, text: str) -> Dict: """