Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Causal Classification Problem Type #449

Merged
merged 10 commits into from
Oct 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ Using CLI for fine-tuning LLMs:

## What's New

- [PR 449](https://github.com/h2oai/h2o-llmstudio/pull/449) New problem type for Causal Classification Modeling allows to train binary and multiclass models using LLMs.
- [PR 364](https://github.com/h2oai/h2o-llmstudio/pull/364) User secrets are now handled more securely and flexible. Support for handling secrets using the 'keyring' library was added. User settings are tried to be migrated automatically.
- [PR 328](https://github.com/h2oai/h2o-llmstudio/pull/328) RLHF is now a separate problem type. Note that starting a new RLHF experiment from an old experiment that used RLHF is no longer supported. To continue from a previous experiment, please start a new experiment and enter the settings from the previous experiment manually.
- [PR 308](https://github.com/h2oai/h2o-llmstudio/pull/308) Sequence to sequence models have been added as a new problem type.
Expand Down
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
The column in the dataset containing the expected output.
The column in the dataset containing the expected output.

For classification, this needs to be an integer column containing the class label.
1 change: 1 addition & 0 deletions documentation/docs/tooltips/experiments/_num-classes.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
The number of possible classes for the classification task. For binary classification, a single class should be selected.
6 changes: 5 additions & 1 deletion documentation/docs/tooltips/experiments/_problem-type.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,8 @@ Defines the problem type of the experiment, which also defines the settings H2O

- Causal Language Modeling: Used to fine-tune large language models

- Sequence To Sequence Modeling: Used to fine-tune large sequence to sequence models
- Rlhf Language Modeling: Used to fine-tune RLHF language models

- Sequence To Sequence Modeling: Used to fine-tune large sequence to sequence models

- Causal Classification Modeling: Used to fine-tune causal classification models
1 change: 1 addition & 0 deletions llm_studio/app_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def get_size(x):
"text_causal_language_modeling_config",
"text_rlhf_language_modeling_config",
"text_sequence_to_sequence_modeling_config",
"text_causal_classification_modeling_config",
],
"problem_categories": ["text"],
"dataset_keys": [
Expand Down
13 changes: 12 additions & 1 deletion llm_studio/app_utils/hugging_face_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,19 @@ def publish_model_to_hugging_face(
repo_id=repo_id, repo_type="model", commit_message="Upload model card"
)

# push config to hub
api = huggingface_hub.HfApi()

# push classification head to hub
if cfg.problem_type == "text_causal_classification_modeling":
api.upload_file(
path_or_fileobj=f"{path_to_experiment}/classification_head.pth",
path_in_repo="classification_head.pth",
repo_id=repo_id,
repo_type="model",
commit_message="Upload classification_head.pth",
)

# push config to hub
api.upload_file(
path_or_fileobj=os.path.join(path_to_experiment, "cfg.yaml"),
path_in_repo="cfg.yaml",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
import os
from dataclasses import dataclass, field
from typing import Any, Tuple

from llm_studio.python_configs.base import DefaultConfigProblemBase
from llm_studio.python_configs.text_causal_language_modeling_config import (
ConfigNLPAugmentation,
ConfigNLPCausalLMArchitecture,
ConfigNLPCausalLMDataset,
ConfigNLPCausalLMEnvironment,
ConfigNLPCausalLMLogging,
ConfigNLPCausalLMPrediction,
ConfigNLPCausalLMTokenizer,
ConfigNLPCausalLMTraining,
)
from llm_studio.src import possible_values
from llm_studio.src.losses import text_causal_classification_modeling_losses
from llm_studio.src.metrics import text_causal_classification_modeling_metrics
from llm_studio.src.models import text_causal_classification_modeling_model
from llm_studio.src.utils.modeling_utils import generate_experiment_name


@dataclass
class ConfigNLPCausalClassificationDataset(ConfigNLPCausalLMDataset):
system_column: str = "None"
prompt_column: Tuple[str, ...] = ("instruction", "input")
answer_column: str = "label"
num_classes: int = 1
parent_id_column: str = "None"

text_system_start: str = ""
text_prompt_start: str = ""
text_answer_separator: str = ""

add_eos_token_to_system: bool = False
add_eos_token_to_prompt: bool = False
add_eos_token_to_answer: bool = False

_allowed_file_extensions: Tuple[str, ...] = ("csv", "pq", "parquet")

def __post_init__(self):
self.prompt_column = (
tuple(
self.prompt_column,
)
if isinstance(self.prompt_column, str)
else tuple(self.prompt_column)
)
super().__post_init__()

self._possible_values["num_classes"] = (1, 100, 1)

self._visibility["personalize"] = -1
self._visibility["chatbot_name"] = -1
self._visibility["chatbot_author"] = -1
self._visibility["mask_prompt_labels"] = -1
self._visibility["add_eos_token_to_answer"] = -1


@dataclass
class ConfigNLPCausalClassificationTraining(ConfigNLPCausalLMTraining):
loss_class: Any = text_causal_classification_modeling_losses.Losses
loss_function: str = "BinaryCrossEntropyLoss"

learning_rate: float = 0.0001
differential_learning_rate_layers: Tuple[str, ...] = ("classification_head",)
differential_learning_rate: float = 0.00001

def __post_init__(self):
super().__post_init__()
self._possible_values["loss_function"] = self.loss_class.names()

self._possible_values[
"differential_learning_rate_layers"
] = possible_values.String(
values=("backbone", "embed", "classification_head"),
allow_custom=False,
placeholder="Select optional layers...",
)


@dataclass
class ConfigNLPCausalClassificationTokenizer(ConfigNLPCausalLMTokenizer):
max_length_prompt: int = 512
max_length: int = 512

def __post_init__(self):
super().__post_init__()

self._visibility["max_length_answer"] = -1


@dataclass
class ConfigNLPCausalClassificationArchitecture(ConfigNLPCausalLMArchitecture):
model_class: Any = text_causal_classification_modeling_model.Model

def __post_init__(self):
super().__post_init__()


@dataclass
class ConfigNLPCausalClassificationPrediction(ConfigNLPCausalLMPrediction):
metric_class: Any = text_causal_classification_modeling_metrics.Metrics
metric: str = "AUC"

def __post_init__(self):
super().__post_init__()
self._possible_values["metric"] = self.metric_class.names()

for k in [
psinger marked this conversation as resolved.
Show resolved Hide resolved
"min_length_inference",
"max_length_inference",
"do_sample",
"num_beams",
"temperature",
"repetition_penalty",
"stop_tokens",
"top_k",
"top_p",
]:
self._visibility[k] = -1


@dataclass
class ConfigNLPCausalClassificationEnvironment(ConfigNLPCausalLMEnvironment):
_model_card_template: str = "text_causal_classification_model_card_template.md"
_summary_card_template: str = (
"text_causal_classification_experiment_summary_card_template.md"
)

def __post_init__(self):
super().__post_init__()


@dataclass
class ConfigProblemBase(DefaultConfigProblemBase):
output_directory: str = f"output/{os.path.basename(__file__).split('.')[0]}"
experiment_name: str = field(default_factory=generate_experiment_name)
_parent_experiment: str = ""
llm_backbone: str = "h2oai/h2ogpt-4096-llama2-7b"

dataset: ConfigNLPCausalClassificationDataset = field(
default_factory=ConfigNLPCausalClassificationDataset
)
tokenizer: ConfigNLPCausalLMTokenizer = field(
default_factory=ConfigNLPCausalLMTokenizer
)
architecture: ConfigNLPCausalClassificationArchitecture = field(
default_factory=ConfigNLPCausalClassificationArchitecture
)
training: ConfigNLPCausalClassificationTraining = field(
default_factory=ConfigNLPCausalClassificationTraining
)
augmentation: ConfigNLPAugmentation = field(default_factory=ConfigNLPAugmentation)
prediction: ConfigNLPCausalClassificationPrediction = field(
default_factory=ConfigNLPCausalClassificationPrediction
)
environment: ConfigNLPCausalClassificationEnvironment = field(
default_factory=ConfigNLPCausalClassificationEnvironment
)
logging: ConfigNLPCausalLMLogging = field(default_factory=ConfigNLPCausalLMLogging)

def __post_init__(self):
super().__post_init__()

self._visibility["output_directory"] = -1

self._possible_values["llm_backbone"] = possible_values.String(
values=(
"h2oai/h2ogpt-4096-llama2-70b",
"h2oai/h2ogpt-4096-llama2-70b-chat",
"h2oai/h2ogpt-4096-llama2-13b",
"h2oai/h2ogpt-4096-llama2-13b-chat",
"h2oai/h2ogpt-4096-llama2-7b",
"h2oai/h2ogpt-4096-llama2-7b-chat",
"tiiuae/falcon-40b",
"tiiuae/falcon-7b",
"openlm-research/open_llama_13b",
"openlm-research/open_llama_7b",
"openlm-research/open_llama_3b",
"EleutherAI/gpt-j-6B",
"facebook/opt-125m",
),
allow_custom=True,
)
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ def __post_init__(self):

self._visibility["limit_chained_samples"] = -1
self._visibility["mask_prompt_labels"] = -1
self._visibility["dataset_class"] = -1


@dataclass
Expand Down
34 changes: 33 additions & 1 deletion llm_studio/src/datasets/text_causal_language_modeling_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from llm_studio.src.datasets.conversation_chain_handler import ConversationChainHandler
from llm_studio.src.datasets.text_utils import TEXT_SEPARATOR, get_tokenizer
from llm_studio.src.utils.exceptions import LLMDataException

logger = logging.getLogger(__name__)

Expand All @@ -31,6 +32,22 @@ def __init__(self, df: pd.DataFrame, cfg: Any, mode: str = "train"):
self.tokenizer = get_tokenizer(self.cfg)
self.conversation_chain_handler = ConversationChainHandler(self.df, cfg)

if cfg.problem_type == "text_causal_classification_modeling":
self.answers_int = df[cfg.dataset.answer_column].astype(int).values.tolist()
if (
cfg.dataset.num_classes > 1
and max(self.answers_int) >= cfg.dataset.num_classes
):
raise LLMDataException(
"Number of classes is smaller than max label "
f"{max(self.answers_int)}. Please increase the setting accordingly."
)
elif cfg.dataset.num_classes == 1 and max(self.answers_int) > 1:
raise LLMDataException(
"For binary classification, max label should be 1 but is "
f"{max(self.answers_int)}."
)

def __len__(self) -> int:
return len(self.conversation_chain_handler)

Expand Down Expand Up @@ -107,6 +124,10 @@ def __getitem__(self, idx: int) -> Dict:
sample["labels"][: len(system_encoding)] = -100
if sample["prompt_input_ids"][0] != self.tokenizer.pad_token_id:
sample["prompt_input_ids"][: len(system_encoding)] = system_encoding

if self.cfg.problem_type == "text_causal_classification_modeling":
sample["class_label"] = self.answers_int[idx]

return sample

@staticmethod
Expand Down Expand Up @@ -254,7 +275,15 @@ def clean_output(
return output

def postprocess_output(self, cfg, df: pd.DataFrame, output: Dict) -> Dict:
if not cfg.prediction.metric == "Perplexity":
if cfg.problem_type == "text_causal_classification_modeling":
if cfg.dataset.num_classes == 1:
preds = output["logits"]
preds = np.array((preds > 0.0)).astype(int).astype(str).reshape(-1)
else:
preds = output["logits"]
preds = np.array(torch.argmax(preds, dim=1)).astype(str).reshape(-1)
output["predicted_text"] = preds
elif not cfg.prediction.metric == "Perplexity":
output = self.clean_output(output, cfg)

output["target_text"] = self.conversation_chain_handler.answers
Expand Down Expand Up @@ -297,6 +326,9 @@ def format_output(
if "predicted_text" in output.keys():
output["predicted_text"] = np.array(output["predicted_text"])

if "logits" in output.keys():
output["logits"] = np.array(output["logits"].float())

if isinstance(cfg.dataset.prompt_column, tuple):
for col in cfg.dataset.prompt_column:
output[col] = df.loc[end_conversation_ids, col].values
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import logging
from typing import Any, KeysView

from torch import nn

__all__ = ["Losses"]


logger = logging.getLogger(__name__)


class CrossEntropyLoss(nn.Module):
def __init__(self, cfg: Any):
super().__init__()
self.cfg = cfg
self.loss_fn = nn.CrossEntropyLoss()

def forward(self, logits, labels):
maxjeblick marked this conversation as resolved.
Show resolved Hide resolved
return self.loss_fn(logits, labels.reshape(-1).long())


class BinaryCrossEntropyLoss(nn.Module):
def __init__(self, cfg: Any):
super().__init__()
self.cfg = cfg
self.loss_fn = nn.BCEWithLogitsLoss()

def forward(self, logits, labels):
return self.loss_fn(logits, labels)


class Losses:
"""Losses factory."""

_losses = {
"CrossEntropyLoss": CrossEntropyLoss,
"BinaryCrossEntropyLoss": BinaryCrossEntropyLoss,
}

@classmethod
def names(cls) -> KeysView:
return cls._losses.keys()

@classmethod
def get(cls, name: str) -> Any:
"""Access to Losses.

Args:
name: losses name
Returns:
A class to build the Losses
"""
return cls._losses.get(name, CrossEntropyLoss)
Loading