Skip to content

Commit

Permalink
update class to use __init__ to validate environment instead of `ro…
Browse files Browse the repository at this point in the history
…ot_validator`
  • Loading branch information
Sheepsta300 committed Sep 11, 2024
1 parent f3e6cfa commit c18d000
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,11 @@
from langchain_community.tools.azure_cognitive_services.text_analytics_health import (
AzureCogsTextAnalyticsHealthTool,
)
from langchain_community.tools.azure_cognitive_services.content_safety import (
AzureContentSafetyTextTool,
)

__all__ = [
"AzureCogsImageAnalysisTool",
"AzureCogsFormRecognizerTool",
"AzureCogsSpeech2TextTool",
"AzureCogsText2SpeechTool",
"AzureCogsTextAnalyticsHealthTool",
"AzureContentSafetyTextTool",
]
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from __future__ import annotations

import logging
import os
from typing import Any, Dict, Optional

from langchain_core.callbacks import CallbackManagerForToolRun
from langchain_core.pydantic_v1 import root_validator
from langchain_core.tools import BaseTool
from langchain_core.utils import get_from_dict_or_env

logger = logging.getLogger(__name__)

Expand All @@ -30,10 +29,6 @@ class AzureContentSafetyTextTool(BaseTool):
An instance of the Azure Content Safety Client used for making API requests.
Methods:
validate_environment(values: Dict) -> Dict:
Validates the presence of API key and endpoint in the environment
and initializes the Content Safety Client.
_sentiment_analysis(text: str) -> Dict:
Analyzes the provided text to assess its sentiment and safety,
returning the analysis results.
Expand All @@ -56,22 +51,44 @@ class AzureContentSafetyTextTool(BaseTool):
"Input should be text."
)

@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and endpoint exists in environment."""
content_safety_key = get_from_dict_or_env(
values, "content_safety_key", "CONTENT_SAFETY_API_KEY"
)
def __init__(
self,
*,
content_safety_key: Optional[str] = None,
content_safety_endpoint: Optional[str] = None,
) -> None:
"""
Initialize the AzureContentSafetyTextTool with the given API key and endpoint.
This constructor sets up the API key and endpoint, and initializes
the Azure Content Safety Client. If API key or endpoint is not provided,
they are fetched from environment variables.
content_safety_endpoint = get_from_dict_or_env(
values, "content_safety_endpoint", "CONTENT_SAFETY_ENDPOINT"
)
Args:
content_safety_key (Optional[str]):
The API key for Azure Content Safety API. If not provided,
it will be fetched from the environment
variable 'CONTENT_SAFETY_API_KEY'.
content_safety_endpoint (Optional[str]):
The endpoint URL for Azure Content Safety API. If not provided,
it will be fetched from the environment
variable 'CONTENT_SAFETY_ENDPOINT'.
Raises:
ImportError: If the 'azure-ai-contentsafety' package is not installed.
ValueError:
If API key or endpoint is not provided
and environment variables are missing.
"""
content_safety_key = (content_safety_key or
os.environ['CONTENT_SAFETY_API_KEY'])
content_safety_endpoint = (content_safety_endpoint or
os.environ['CONTENT_SAFETY_ENDPOINT'])
try:
import azure.ai.contentsafety as sdk
from azure.core.credentials import AzureKeyCredential

values["content_safety_client"] = sdk.ContentSafetyClient(
content_safety_client = sdk.ContentSafetyClient(
endpoint=content_safety_endpoint,
credential=AzureKeyCredential(content_safety_key),
)
Expand All @@ -81,8 +98,9 @@ def validate_environment(cls, values: Dict) -> Dict:
"azure-ai-contentsafety is not installed. "
"Run `pip install azure-ai-contentsafety` to install."
)

return values
super().__init__(content_safety_key=content_safety_key,
content_safety_endpoint=content_safety_endpoint,
content_safety_client=content_safety_client)

def _sentiment_analysis(self, text: str) -> Dict:
"""
Expand Down

0 comments on commit c18d000

Please sign in to comment.