Skip to content

Commit

Permalink
Merge pull request #106 from iryna-kondr/cot_classifier
Browse files Browse the repository at this point in the history
Added the chain of thought classifier
  • Loading branch information
iryna-kondr authored Jul 6, 2024
2 parents fd3a4ac + fab8f78 commit d0fe2b2
Show file tree
Hide file tree
Showing 8 changed files with 174 additions and 10 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ dependencies = [
"google-cloud-aiplatform[pipelines]>=1.27.0,<2.0.0"
]
name = "scikit-llm"
version = "1.2.0"
version = "1.3.0"
authors = [
{ name="Oleh Kostromin", email="[email protected]" },
{ name="Iryna Kondrashchenko", email="[email protected]" },
Expand Down
2 changes: 1 addition & 1 deletion skllm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
__version__ = '1.2.0'
__version__ = '1.3.0'
__author__ = 'Iryna Kondrashchenko, Oleh Kostromin'
24 changes: 24 additions & 0 deletions skllm/classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
## GPT

from skllm.models.gpt.classification.zero_shot import (
ZeroShotGPTClassifier,
MultiLabelZeroShotGPTClassifier,
CoTGPTClassifier,
)
from skllm.models.gpt.classification.few_shot import (
FewShotGPTClassifier,
DynamicFewShotGPTClassifier,
MultiLabelFewShotGPTClassifier,
)
from skllm.models.gpt.classification.tunable import (
GPTClassifier as TunableGPTClassifier,
)

## Vertex
from skllm.models.vertex.classification.zero_shot import (
ZeroShotVertexClassifier,
MultiLabelZeroShotVertexClassifier,
)
from skllm.models.vertex.classification.tunable import (
VertexClassifier as TunableVertexClassifier,
)
74 changes: 67 additions & 7 deletions skllm/models/_base/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
FEW_SHOT_MLCLF_PROMPT_TEMPLATE,
ZERO_SHOT_CLF_SHORT_PROMPT_TEMPLATE,
ZERO_SHOT_MLCLF_SHORT_PROMPT_TEMPLATE,
COT_CLF_PROMPT_TEMPLATE,
COT_MLCLF_PROMPT_TEMPLATE,
)
from skllm.prompts.builders import (
build_zero_shot_prompt_slc,
Expand All @@ -33,7 +35,8 @@
from skllm.memory.base import IndexConstructor
from skllm.memory._sklearn_nn import SklearnMemoryIndex
from skllm.models._base.vectorizer import BaseVectorizer as _BaseVectorizer
import ast
from skllm.utils import re_naive_json_extractor
import json

_TRAINING_SAMPLE_PROMPT_TEMPLATE = """
Sample input:
Expand Down Expand Up @@ -221,7 +224,7 @@ def predict(self, X: Union[np.ndarray, pd.Series, List[str]], num_workers: int =
----------
X : Union[np.ndarray, pd.Series, List[str]]
The input data to predict the class of.
num_workers : int
number of workers to use for multithreaded prediction, default 1
Expand All @@ -231,12 +234,16 @@ def predict(self, X: Union[np.ndarray, pd.Series, List[str]], num_workers: int =
The predicted classes as a numpy array.
"""
X = _to_numpy(X)

if num_workers > 1:
warnings.warn("Passing num_workers to predict is temporary and will be removed in the future.")
warnings.warn(
"Passing num_workers to predict is temporary and will be removed in the future."
)
with ThreadPoolExecutor(max_workers=num_workers) as executor:
predictions = list(tqdm(executor.map(self._predict_single, X), total=len(X)))

predictions = list(
tqdm(executor.map(self._predict_single, X), total=len(X))
)

return np.array(predictions)

def _get_unique_targets(self, y: Any):
Expand Down Expand Up @@ -286,6 +293,47 @@ def _get_prompt(self, x: str) -> dict:
return {"messages": prompt, "system_message": self.system_msg}


class BaseCoTClassifier(BaseClassifier):
def _get_prompt_template(self) -> str:
"""Returns the prompt template to use for a single input."""
if self.prompt_template is not None:
return self.prompt_template
elif isinstance(self, SingleLabelMixin):
return COT_CLF_PROMPT_TEMPLATE
return COT_MLCLF_PROMPT_TEMPLATE

def _get_prompt(self, x: str) -> dict:
"""Returns the prompt to use for a single input."""
if isinstance(self, SingleLabelMixin):
prompt = build_zero_shot_prompt_slc(
x, repr(self.classes_), template=self._get_prompt_template()
)
else:
prompt = build_zero_shot_prompt_mlc(
x,
repr(self.classes_),
self.max_labels,
template=self._get_prompt_template(),
)
return {"messages": prompt, "system_message": self.system_msg}

def _predict_single(self, x: Any) -> Any:
prompt_dict = self._get_prompt(x)
# this will be inherited from the LLM
completion = self._get_chat_completion(model=self.model, **prompt_dict)
completion = self._convert_completion_to_str(completion)
try:
as_dict = json.loads(re_naive_json_extractor(completion))
label = as_dict["label"]
explanation = str(as_dict["explanation"])
except Exception as e:
label = "None"
explanation = "Explanation is not available."
# this will be inherited from the sl/ml mixin
prediction = self.validate_prediction(label)
return [prediction, explanation]


class BaseFewShotClassifier(BaseClassifier):
def _get_prompt_template(self) -> str:
"""Returns the prompt template to use for a single input."""
Expand Down Expand Up @@ -427,6 +475,18 @@ def _get_prompt_template(self) -> str:
return self.prompt_template
return FEW_SHOT_CLF_PROMPT_TEMPLATE

def _reorder_examples(self, examples):
n_classes = len(self.classes_)
n_examples = self.n_examples

shuffled_list = []

for i in range(n_examples):
for cls in range(n_classes):
shuffled_list.append(cls * n_examples + i)

return [examples[i] for i in shuffled_list]

def _get_prompt(self, x: str) -> dict:
"""
Generates the prompt for the given input.
Expand Down Expand Up @@ -455,7 +515,7 @@ def _get_prompt(self, x: str) -> dict:
]
)

training_data_str = "\n".join(training_data)
training_data_str = "\n".join(self._reorder_examples(training_data))

msg = build_few_shot_prompt_slc(
x=x,
Expand Down
36 changes: 36 additions & 0 deletions skllm/models/gpt/classification/zero_shot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
SingleLabelMixin as _SingleLabelMixin,
MultiLabelMixin as _MultiLabelMixin,
BaseZeroShotClassifier as _BaseZeroShotClassifier,
BaseCoTClassifier as _BaseCoTClassifier,
)
from skllm.llm.gpt.mixin import GPTClassifierMixin as _GPTClassifierMixin
from typing import Optional
Expand Down Expand Up @@ -44,6 +45,41 @@ def __init__(
self._set_keys(key, org)


class CoTGPTClassifier(_BaseCoTClassifier, _GPTClassifierMixin, _SingleLabelMixin):
def __init__(
self,
model: str = "gpt-3.5-turbo",
default_label: str = "Random",
prompt_template: Optional[str] = None,
key: Optional[str] = None,
org: Optional[str] = None,
**kwargs,
):
"""
Chain-of-thought text classifier using OpenAI/GPT API-compatible models.
Parameters
----------
model : str, optional
model to use, by default "gpt-3.5-turbo"
default_label : str, optional
default label for failed prediction; if "Random" -> selects randomly based on class frequencies, by default "Random"
prompt_template : Optional[str], optional
custom prompt template to use, by default None
key : Optional[str], optional
estimator-specific API key; if None, retrieved from the global config, by default None
org : Optional[str], optional
estimator-specific ORG key; if None, retrieved from the global config, by default None
"""
super().__init__(
model=model,
default_label=default_label,
prompt_template=prompt_template,
**kwargs,
)
self._set_keys(key, org)


class MultiLabelZeroShotGPTClassifier(
_BaseZeroShotClassifier, _GPTClassifierMixin, _MultiLabelMixin
):
Expand Down
37 changes: 36 additions & 1 deletion skllm/prompts/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,24 @@
Your JSON response:
"""

COT_CLF_PROMPT_TEMPLATE = """
You are tasked with classifying a given text sample based on a list of potential categories. Please adhere to the following guidelines:
1. The text intended for classification is presented between triple backticks.
2. The possible categories are enumerated in square brackets, with each category enclosed in single quotes and separated by commas.
Tasks:
1. Examine the text and provide detailed justifications for the possibility of the text belonging or not belonging to each category listed.
2. Determine and select the most appropriate category for the text based on your comprehensive justifications.
3. Format your decision into a JSON object containing two keys: `explanation` and `label`. The `explanation` should concisely capture the rationale for each category before concluding with the chosen category.
Category List: {labels}
Text Sample: ```{x}```
Provide your JSON response below, ensuring that justifications for all categories are clearly detailed:
"""

ZERO_SHOT_CLF_SHORT_PROMPT_TEMPLATE = """
Classify the following text into one of the following classes: {labels}. Provide your response in a JSON format containing a single key `label`.
Text: ```{x}```
Expand Down Expand Up @@ -84,6 +102,24 @@
Your JSON response:
"""

COT_MLCLF_PROMPT_TEMPLATE = """
You are tasked with classifying a given text sample based on a list of potential categories. Please adhere to the following guidelines:
1. The text intended for classification is presented between triple backticks.
2. The possible categories are enumerated in square brackets, with each category enclosed in quotes and separated by commas.
Tasks:
1. Examine the text and provide detailed justifications for the possibility of the text belonging or not belonging to each category listed.
2. Determine and select at most {max_cats} most appropriate categories for the text based on your comprehensive justifications.
3. Format your decision into a JSON object containing two keys: `explanation` and `label`. The `explanation` should concisely capture the rationale for each category before concluding with the chosen category. The `label` should contain an array of the chosen categories.
Category List: {labels}
Text Sample: ```{x}```
Provide your JSON response below, ensuring that justifications for all categories are clearly detailed:
"""

SUMMARY_PROMPT_TEMPLATE = """
Your task is to generate a summary of the text sample.
Summarize the text sample provided below, delimited by triple backticks, in at most {max_words} words.
Expand Down Expand Up @@ -204,4 +240,3 @@
Output json:
"""

8 changes: 8 additions & 0 deletions skllm/text2text.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
## GPT
from skllm.models.gpt.text2text.summarization import GPTSummarizer
from skllm.models.gpt.text2text.translation import GPTTranslator
from skllm.models.gpt.text2text.tunable import TunableGPTText2Text

## Vertex

from skllm.models.vertex.text2text.tunable import TunableVertexText2Text
1 change: 1 addition & 0 deletions skllm/vectorization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from skllm.models.gpt.vectorization import GPTVectorizer

0 comments on commit d0fe2b2

Please sign in to comment.