Skip to content

Commit

Permalink
minor update
Browse files Browse the repository at this point in the history
  • Loading branch information
drcege committed Oct 29, 2024
1 parent 8daa6e1 commit 6c68e71
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 98 deletions.
6 changes: 3 additions & 3 deletions data_juicer/ops/mapper/calibrate_qa_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self,
reference_template: Optional[str] = None,
qa_pair_template: Optional[str] = None,
output_pattern: Optional[str] = None,
api_params: Dict = {},
api_params: Optional[Dict] = None,
**kwargs):
"""
Initialization method.
Expand All @@ -51,7 +51,7 @@ def __init__(self,
:param input_template: Template for building the model input.
:param reference_template: Template for formatting the reference text.
:param qa_pair_template: Template for formatting question-answer pairs.
:param output_pattern: Pattern for parsing model output.
:param output_pattern: Regular expression for parsing model output.
:param api_params: Extra API parameters.
:param kwargs: Extra keyword arguments.
"""
Expand All @@ -65,7 +65,7 @@ def __init__(self,
self.DEFAULT_QA_PAIR_TEMPLATE
self.output_pattern = output_pattern or self.DEFAULT_OUTPUT_PATTERN

self.api_params = api_params
self.api_params = api_params or {}
self.model_key = prepare_model(model_type='api',
api_model=api_model,
api_url=api_url,
Expand Down
2 changes: 1 addition & 1 deletion data_juicer/ops/mapper/calibrate_query_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
@OPERATORS.register_module(OP_NAME)
class CalibrateQueryMapper(CalibrateQAMapper):
"""
Mapper to calibrate only query in question-answer pairs.
Mapper to calibrate query in question-answer pairs.
"""

DEFAULT_SYSTEM_PROMPT = '请根据提供的【参考信息】对问答对中的【问题】进行校准,\
Expand Down
2 changes: 1 addition & 1 deletion data_juicer/ops/mapper/calibrate_response_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
@OPERATORS.register_module(OP_NAME)
class CalibrateResponseMapper(CalibrateQAMapper):
"""
Mapper to calibrate only response in question-answer pairs.
Mapper to calibrate response in question-answer pairs.
"""

DEFAULT_SYSTEM_PROMPT = '请根据提供的【参考信息】对问答对中的【回答】进行校准,\
Expand Down
186 changes: 93 additions & 93 deletions data_juicer/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,98 @@ def check_model(model_name, force=False):
return cached_model_path


class APIModel:

def __init__(self,
*,
api_model,
api_url=None,
api_key=None,
response_path=None):
self.api_model = api_model

if api_url is None:
api_url = os.getenv('DJ_API_URL')
if api_url is None:
base_url = os.getenv('OPENAI_BASE_URL',
'https://api.openai.com/v1')
api_url = base_url.rstrip('/') + '/chat/completions'
self.api_url = api_url

if api_key is None:
api_key = os.getenv('DJ_API_KEY') or os.getenv('OPENAI_API_KEY')
self.api_key = api_key

if response_path is None:
response_path = 'choices.0.message.content'
self.response_path = response_path

self.headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {self.api_key}'
}

def __call__(self, *, messages, **kwargs):
"""Sends messages to the configured API model and returns the parsed response.
:param messages: The messages to send to the API.
:param model: The model to be used for generating responses.
:param kwargs: Additional parameters for the API request.
:return: The parsed response from the API, or None if an error occurs.
"""
payload = {
'model': self.model,
'messages': messages,
**kwargs,
}
try:
response = requests.post(self.api_url,
json=payload,
headers=self.headers)
response.raise_for_status()
result = response.json()
return self.nested_access(result, self.response_path)
except Exception as e:
logger.exception(e)
return None

@staticmethod
def nested_access(data, path):
"""Access nested data using a dot-separated path.
:param data: The data structure to access.
:param path: A dot-separated string representing the path to access.
:return: The value at the specified path, if it exists.
"""
keys = path.split('.')
for key in keys:
# Convert string keys to integers if they are numeric
key = int(key) if key.isdigit() else key
data = data[key]
return data


def prepare_api_model(*, api_url=None, api_key=None, response_path=None):
"""Creates a callable API model for interacting with OpenAI-compatible API.
This callable object supports custom result parsing and is suitable for use
with incompatible proxy servers.
:param api_url: The URL of the API. If not provided, it will fallback to
the environment variables DJ_API_URL or OPENAI_BASE_URL.
:param api_key: The API key for authorization. If not provided, it will
fallback to the environment variables DJ_API_KEY or OPENAI_API_KEY.
:param response_path: The path to extract content from the API response.
Defaults to 'choices.0.message.content'.
:return: A callable API model object that can be used to send messages and
receive responses.
"""
return APIModel(api_url=api_url,
api_key=api_key,
response_path=response_path)


def prepare_fasttext_model(model_name='lid.176.bin'):
"""
Prepare and load a fasttext model.
Expand Down Expand Up @@ -590,99 +682,8 @@ def prepare_opencv_classifier(model_path):
return model


class APIModel:

def __init__(self,
*,
api_model,
api_url=None,
api_key=None,
response_path=None):
self.api_model = api_model

if api_url is None:
api_url = os.getenv('DJ_API_URL')
if api_url is None:
base_url = os.getenv('OPENAI_BASE_URL',
'https://api.openai.com/v1')
api_url = base_url.rstrip('/') + '/chat/completions'
self.api_url = api_url

if api_key is None:
api_key = os.getenv('DJ_API_KEY') or os.getenv('OPENAI_API_KEY')
self.api_key = api_key

if response_path is None:
response_path = 'choices.0.message.content'
self.response_path = response_path

self.headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {self.api_key}'
}

def __call__(self, *, messages, **kwargs):
"""Sends messages to the configured API model and returns the parsed response.
:param messages: The messages to send to the API.
:param model: The model to be used for generating responses.
:param kwargs: Additional parameters for the API request.
:return: The parsed response from the API, or None if an error occurs.
"""
payload = {
'model': self.model,
'messages': messages,
**kwargs,
}
try:
response = requests.post(self.api_url,
json=payload,
headers=self.headers)
response.raise_for_status()
result = response.json()
return self.nested_access(result, self.response_path)
except Exception as e:
logger.exception(e)
return None

@staticmethod
def nested_access(data, path):
"""Access nested data using a dot-separated path.
:param data: The data structure to access.
:param path: A dot-separated string representing the path to access.
:return: The value at the specified path, if it exists.
"""
keys = path.split('.')
for key in keys:
# Convert string keys to integers if they are numeric
key = int(key) if key.isdigit() else key
data = data[key]
return data


def prepare_api_model(*, api_url=None, api_key=None, response_path=None):
"""Creates a callable API model for interacting with the OpenAI-compatible API.
This callable object supports custom result parsing and is suitable for use
with incompatible proxy servers.
:param api_url: The URL of the API. If not provided, it will fallback
to the environment variable or a default OpenAI URL.
:param api_key: The API key for authorization. If not provided, it will
fallback to the environment variable.
:param response_path: The path to extract content from the API response.
Defaults to 'choices.0.message.content'.
:return: A callable API model object that can be used to send messages
and receive responses.
"""
return APIModel(api_url=api_url,
api_key=api_key,
response_path=response_path)


MODEL_FUNCTION_MAPPING = {
'api': prepare_api_model,
'fasttext': prepare_fasttext_model,
'sentencepiece': prepare_sentencepiece_for_lang,
'kenlm': prepare_kenlm_model,
Expand All @@ -695,7 +696,6 @@ def prepare_api_model(*, api_url=None, api_key=None, response_path=None):
'recognizeAnything': prepare_recognizeAnything_model,
'vllm': prepare_vllm_model,
'opencv_classifier': prepare_opencv_classifier,
'api': prepare_api_model,
}


Expand Down

0 comments on commit 6c68e71

Please sign in to comment.