From d34cddb78aeb7e829ea3831beea31b085d946af4 Mon Sep 17 00:00:00 2001 From: Philipp Singer Date: Mon, 16 Oct 2023 12:33:20 +0000 Subject: [PATCH 1/7] first implementation --- .../tooltips/experiments/_answer-column.mdx | 4 +- .../tooltips/experiments/_num-classes.mdx | 1 + .../tooltips/experiments/_problem-type.mdx | 6 +- llm_studio/app_utils/config.py | 1 + ...t_causal_classification_modeling_config.py | 175 ++++++++++++++++++ .../text_causal_language_modeling_config.py | 1 + .../text_rlhf_language_modeling_config.py | 1 + ...xt_sequence_to_sequence_modeling_config.py | 2 +- .../text_causal_language_modeling_ds.py | 15 +- ...t_causal_classification_modeling_losses.py | 53 ++++++ ..._causal_classification_modeling_metrics.py | 101 ++++++++++ ...xt_causal_classification_modeling_model.py | 96 ++++++++++ llm_studio/src/utils/modeling_utils.py | 5 +- 13 files changed, 456 insertions(+), 5 deletions(-) create mode 100644 documentation/docs/tooltips/experiments/_num-classes.mdx create mode 100644 llm_studio/python_configs/text_causal_classification_modeling_config.py create mode 100644 llm_studio/src/losses/text_causal_classification_modeling_losses.py create mode 100644 llm_studio/src/metrics/text_causal_classification_modeling_metrics.py create mode 100644 llm_studio/src/models/text_causal_classification_modeling_model.py diff --git a/documentation/docs/tooltips/experiments/_answer-column.mdx b/documentation/docs/tooltips/experiments/_answer-column.mdx index 5cc247c8f..abd997c0c 100644 --- a/documentation/docs/tooltips/experiments/_answer-column.mdx +++ b/documentation/docs/tooltips/experiments/_answer-column.mdx @@ -1 +1,3 @@ -The column in the dataset containing the expected output. \ No newline at end of file +The column in the dataset containing the expected output. + +For classification, this needs to be an integer column containing the class label. \ No newline at end of file diff --git a/documentation/docs/tooltips/experiments/_num-classes.mdx b/documentation/docs/tooltips/experiments/_num-classes.mdx new file mode 100644 index 000000000..544720dcc --- /dev/null +++ b/documentation/docs/tooltips/experiments/_num-classes.mdx @@ -0,0 +1 @@ +The number of possible classes for the classification task. For binary classification, a single class should be selected. \ No newline at end of file diff --git a/documentation/docs/tooltips/experiments/_problem-type.mdx b/documentation/docs/tooltips/experiments/_problem-type.mdx index bd178a0c8..60d12e206 100644 --- a/documentation/docs/tooltips/experiments/_problem-type.mdx +++ b/documentation/docs/tooltips/experiments/_problem-type.mdx @@ -2,4 +2,8 @@ Defines the problem type of the experiment, which also defines the settings H2O - Causal Language Modeling: Used to fine-tune large language models -- Sequence To Sequence Modeling: Used to fine-tune large sequence to sequence models \ No newline at end of file +- Rlhf Language Modeling: Used to fine-tune RLHF language models + +- Sequence To Sequence Modeling: Used to fine-tune large sequence to sequence models + +- Causal Classification Modeling: Used to fine-tune causal classification models \ No newline at end of file diff --git a/llm_studio/app_utils/config.py b/llm_studio/app_utils/config.py index d7944cf9d..3223c1ce9 100644 --- a/llm_studio/app_utils/config.py +++ b/llm_studio/app_utils/config.py @@ -61,6 +61,7 @@ def get_size(x): "text_causal_language_modeling_config", "text_rlhf_language_modeling_config", "text_sequence_to_sequence_modeling_config", + "text_causal_classification_modeling_config", ], "problem_categories": ["text"], "dataset_keys": [ diff --git a/llm_studio/python_configs/text_causal_classification_modeling_config.py b/llm_studio/python_configs/text_causal_classification_modeling_config.py new file mode 100644 index 000000000..f466e3699 --- /dev/null +++ b/llm_studio/python_configs/text_causal_classification_modeling_config.py @@ -0,0 +1,175 @@ +import os +from dataclasses import dataclass, field +from typing import Any, Tuple + +from llm_studio.python_configs.base import DefaultConfigProblemBase +from llm_studio.python_configs.text_causal_language_modeling_config import ( + ConfigNLPAugmentation, + ConfigNLPCausalLMArchitecture, + ConfigNLPCausalLMDataset, + ConfigNLPCausalLMEnvironment, + ConfigNLPCausalLMLogging, + ConfigNLPCausalLMPrediction, + ConfigNLPCausalLMTokenizer, + ConfigNLPCausalLMTraining, +) +from llm_studio.src import possible_values +from llm_studio.src.losses import text_causal_classification_modeling_losses +from llm_studio.src.metrics import text_causal_classification_modeling_metrics +from llm_studio.src.models import text_causal_classification_modeling_model +from llm_studio.src.utils.modeling_utils import generate_experiment_name + + +@dataclass +class ConfigNLPCausalClassificationDataset(ConfigNLPCausalLMDataset): + system_column: str = "None" + prompt_column: Tuple[str, ...] = ("instruction", "input") + answer_column: str = "label" + num_classes: int = 1 + parent_id_column: str = "None" + + text_system_start: str = "" + text_prompt_start: str = "" + text_answer_separator: str = "" + + add_eos_token_to_system: bool = False + add_eos_token_to_prompt: bool = False + add_eos_token_to_answer: bool = False + + _allowed_file_extensions: Tuple[str, ...] = ("csv", "pq", "parquet") + + def __post_init__(self): + self.prompt_column = ( + tuple( + self.prompt_column, + ) + if isinstance(self.prompt_column, str) + else tuple(self.prompt_column) + ) + super().__post_init__() + + self._possible_values["num_classes"] = (1, 100, 1) + + self._visibility["personalize"] = -1 + self._visibility["chatbot_name"] = -1 + self._visibility["chatbot_author"] = -1 + self._visibility["mask_prompt_labels"] = -1 + self._visibility["add_eos_token_to_answer"] = -1 + + +@dataclass +class ConfigNLPCausalClassificationTraining(ConfigNLPCausalLMTraining): + loss_class: Any = text_causal_classification_modeling_losses.Losses + loss_function: str = "CrossEntropyLoss" + + learning_rate: float = 0.0001 + differential_learning_rate_layers: Tuple[str, ...] = ("classification_head",) + differential_learning_rate: float = 0.00001 + + def __post_init__(self): + super().__post_init__() + self._possible_values["loss_function"] = self.loss_class.names() + + self._possible_values[ + "differential_learning_rate_layers" + ] = possible_values.String( + values=("backbone", "embed", "classification_head"), + allow_custom=False, + placeholder="Select optional layers...", + ) + + +@dataclass +class ConfigNLPCausalClassificationTokenizer(ConfigNLPCausalLMTokenizer): + max_length_prompt: int = 512 + max_length: int = 512 + + def __post_init__(self): + super().__post_init__() + + self._visibility["max_length_answer"] = -1 + + +@dataclass +class ConfigNLPCausalClassificationArchitecture(ConfigNLPCausalLMArchitecture): + model_class: Any = text_causal_classification_modeling_model.Model + + def __post_init__(self): + super().__post_init__() + + +@dataclass +class ConfigNLPCausalClassificationPrediction(ConfigNLPCausalLMPrediction): + metric_class: Any = text_causal_classification_modeling_metrics.Metrics + metric: str = "Accuracy" + + def __post_init__(self): + super().__post_init__() + self._possible_values["metric"] = self.metric_class.names() + + for k in [ + "min_length_inference", + "max_length_inference", + "do_sample", + "num_beams", + "temperature", + "repetition_penalty", + "stop_tokens", + "top_k", + "top_p", + ]: + self._visibility[k] = -1 + + +@dataclass +class ConfigProblemBase(DefaultConfigProblemBase): + output_directory: str = f"output/{os.path.basename(__file__).split('.')[0]}" + experiment_name: str = field(default_factory=generate_experiment_name) + _parent_experiment: str = "" + llm_backbone: str = "h2oai/h2ogpt-4096-llama2-7b" + type: str = "causal_classification" + + dataset: ConfigNLPCausalClassificationDataset = field( + default_factory=ConfigNLPCausalClassificationDataset + ) + tokenizer: ConfigNLPCausalLMTokenizer = field( + default_factory=ConfigNLPCausalLMTokenizer + ) + architecture: ConfigNLPCausalClassificationArchitecture = field( + default_factory=ConfigNLPCausalClassificationArchitecture + ) + training: ConfigNLPCausalClassificationTraining = field( + default_factory=ConfigNLPCausalClassificationTraining + ) + augmentation: ConfigNLPAugmentation = field(default_factory=ConfigNLPAugmentation) + prediction: ConfigNLPCausalClassificationPrediction = field( + default_factory=ConfigNLPCausalClassificationPrediction + ) + environment: ConfigNLPCausalLMEnvironment = field( + default_factory=ConfigNLPCausalLMEnvironment + ) + logging: ConfigNLPCausalLMLogging = field(default_factory=ConfigNLPCausalLMLogging) + + def __post_init__(self): + super().__post_init__() + + self._visibility["output_directory"] = -1 + + self._possible_values["llm_backbone"] = possible_values.String( + values=( + "h2oai/h2ogpt-4096-llama2-70b", + "h2oai/h2ogpt-4096-llama2-70b-chat", + "h2oai/h2ogpt-4096-llama2-13b", + "h2oai/h2ogpt-4096-llama2-13b-chat", + "h2oai/h2ogpt-4096-llama2-7b", + "h2oai/h2ogpt-4096-llama2-7b-chat", + "tiiuae/falcon-40b", + "tiiuae/falcon-7b", + "openlm-research/open_llama_13b", + "openlm-research/open_llama_7b", + "openlm-research/open_llama_3b", + "EleutherAI/gpt-j-6B", + "facebook/opt-125m", + ), + allow_custom=True, + ) diff --git a/llm_studio/python_configs/text_causal_language_modeling_config.py b/llm_studio/python_configs/text_causal_language_modeling_config.py index 77f9549e0..359474663 100644 --- a/llm_studio/python_configs/text_causal_language_modeling_config.py +++ b/llm_studio/python_configs/text_causal_language_modeling_config.py @@ -407,6 +407,7 @@ class ConfigProblemBase(DefaultConfigProblemBase): experiment_name: str = field(default_factory=generate_experiment_name) _parent_experiment: str = "" llm_backbone: str = "h2oai/h2ogpt-4096-llama2-7b" + type: str = "causal_lm" dataset: ConfigNLPCausalLMDataset = field(default_factory=ConfigNLPCausalLMDataset) tokenizer: ConfigNLPCausalLMTokenizer = field( diff --git a/llm_studio/python_configs/text_rlhf_language_modeling_config.py b/llm_studio/python_configs/text_rlhf_language_modeling_config.py index bd9a0342f..4c04b3eea 100644 --- a/llm_studio/python_configs/text_rlhf_language_modeling_config.py +++ b/llm_studio/python_configs/text_rlhf_language_modeling_config.py @@ -178,6 +178,7 @@ class ConfigProblemBase(DefaultConfigProblemBase): _parent_experiment: str = "" llm_backbone: str = "h2oai/h2ogpt-4096-llama2-7b-chat" reward_model: str = "OpenAssistant/reward-model-deberta-v3-large-v2" + type: str = "rlhf" dataset: ConfigRLHFLMDataset = field(default_factory=ConfigRLHFLMDataset) tokenizer: ConfigNLPCausalLMTokenizer = field( diff --git a/llm_studio/python_configs/text_sequence_to_sequence_modeling_config.py b/llm_studio/python_configs/text_sequence_to_sequence_modeling_config.py index 7af5645c9..6cd872ee3 100644 --- a/llm_studio/python_configs/text_sequence_to_sequence_modeling_config.py +++ b/llm_studio/python_configs/text_sequence_to_sequence_modeling_config.py @@ -42,7 +42,6 @@ def __post_init__(self): self._visibility["limit_chained_samples"] = -1 self._visibility["mask_prompt_labels"] = -1 - self._visibility["dataset_class"] = -1 @dataclass @@ -75,6 +74,7 @@ class ConfigProblemBase(DefaultConfigProblemBase): experiment_name: str = field(default_factory=generate_experiment_name) _parent_experiment: str = "" llm_backbone: str = "t5-small" + type: str = "seq2seq" dataset: ConfigNLPSeq2SeqDataset = field(default_factory=ConfigNLPSeq2SeqDataset) tokenizer: ConfigNLPCausalLMTokenizer = field( diff --git a/llm_studio/src/datasets/text_causal_language_modeling_ds.py b/llm_studio/src/datasets/text_causal_language_modeling_ds.py index bdc94c7f7..67c58c91b 100644 --- a/llm_studio/src/datasets/text_causal_language_modeling_ds.py +++ b/llm_studio/src/datasets/text_causal_language_modeling_ds.py @@ -31,6 +31,9 @@ def __init__(self, df: pd.DataFrame, cfg: Any, mode: str = "train"): self.tokenizer = get_tokenizer(self.cfg) self.conversation_chain_handler = ConversationChainHandler(self.df, cfg) + if cfg.type == "causal_classification": + self.answers_int = df[cfg.dataset.answer_column].astype(int).values.tolist() + def __len__(self) -> int: return len(self.conversation_chain_handler) @@ -107,6 +110,10 @@ def __getitem__(self, idx: int) -> Dict: sample["labels"][: len(system_encoding)] = -100 if sample["prompt_input_ids"][0] != self.tokenizer.pad_token_id: sample["prompt_input_ids"][: len(system_encoding)] = system_encoding + + if self.cfg.type == "causal_classification": + sample["class_label"] = self.answers_int[idx] + return sample @staticmethod @@ -254,7 +261,10 @@ def clean_output( return output def postprocess_output(self, cfg, df: pd.DataFrame, output: Dict) -> Dict: - if not cfg.prediction.metric == "Perplexity": + if ( + not cfg.prediction.metric == "Perplexity" + and not cfg.type == "causal_classification" + ): output = self.clean_output(output, cfg) output["target_text"] = self.conversation_chain_handler.answers @@ -298,6 +308,9 @@ def format_output( if "predicted_text" in output.keys(): output["predicted_text"] = np.array(output["predicted_text"]) + if "logits" in output.keys(): + output["logits"] = np.array(output["logits"].float()) + if isinstance(cfg.dataset.prompt_column, tuple): for col in cfg.dataset.prompt_column: output[col] = df.loc[end_conversation_ids, col].values diff --git a/llm_studio/src/losses/text_causal_classification_modeling_losses.py b/llm_studio/src/losses/text_causal_classification_modeling_losses.py new file mode 100644 index 000000000..62ab48020 --- /dev/null +++ b/llm_studio/src/losses/text_causal_classification_modeling_losses.py @@ -0,0 +1,53 @@ +import logging +from typing import Any, KeysView + +from torch import nn + +__all__ = ["Losses"] + + +logger = logging.getLogger(__name__) + + +class CrossEntropyLoss(nn.Module): + def __init__(self, cfg: Any): + super().__init__() + self.cfg = cfg + self.loss_fn = nn.CrossEntropyLoss() + + def forward(self, logits, labels): + return self.loss_fn(logits, labels.reshape(-1).long()) + + +class BinaryCrossEntropyLoss(nn.Module): + def __init__(self, cfg: Any): + super().__init__() + self.cfg = cfg + self.loss_fn = nn.BCEWithLogitsLoss() + + def forward(self, logits, labels): + return self.loss_fn(logits, labels) + + +class Losses: + """Losses factory.""" + + _losses = { + "CrossEntropyLoss": CrossEntropyLoss, + "BinaryCrossEntropyLoss": BinaryCrossEntropyLoss, + } + + @classmethod + def names(cls) -> KeysView: + return cls._losses.keys() + + @classmethod + def get(cls, name: str) -> Any: + """Access to Losses. + + Args: + name: losses name + Returns: + A class to build the Losses + """ + return cls._losses.get(name, CrossEntropyLoss) diff --git a/llm_studio/src/metrics/text_causal_classification_modeling_metrics.py b/llm_studio/src/metrics/text_causal_classification_modeling_metrics.py new file mode 100644 index 000000000..08bd9539c --- /dev/null +++ b/llm_studio/src/metrics/text_causal_classification_modeling_metrics.py @@ -0,0 +1,101 @@ +import logging +from typing import Any, Dict, List, Tuple, Union + +import numpy as np +import pandas as pd +import torch +from numpy.typing import NDArray +from sklearn.metrics import roc_auc_score + +logger = logging.getLogger(__name__) + + +def accuracy_score( + cfg: Any, + results: Dict, + val_df: pd.DataFrame, + raw_results: bool = False, +) -> Union[NDArray, Tuple[NDArray, List[str]]]: + if cfg.dataset.num_classes == 1: + logits = results["logits"] + logits = np.array((logits > 0.0)).astype(int).reshape(-1) + target_text = np.array([int(text) for text in results["target_text"]]) + correct_predictions = (logits == target_text).sum() + total_predictions = len(logits) + return correct_predictions / total_predictions + else: + logits = results["logits"] + logits = np.array(torch.argmax(logits, dim=1)).astype(int).reshape(-1) + target_text = np.array([int(text) for text in results["target_text"]]) + correct_predictions = (logits == target_text).sum() + total_predictions = len(logits) + return correct_predictions / total_predictions + + +def auc_score( + cfg: Any, + results: Dict, + val_df: pd.DataFrame, + raw_results: bool = False, +) -> Union[NDArray, Tuple[NDArray, List[str]]]: + logits = results["logits"] + target_text = np.array([int(text) for text in results["target_text"]]) + if cfg.dataset.num_classes > 1: + target_text = np.eye(cfg.dataset.num_classes)[target_text] + return roc_auc_score(target_text, logits, multi_class="ovr") + + +class Metrics: + """ + Metrics factory. Returns: + - metric value + - should it be maximized or minimized + - Reduce function + + Maximized or minimized is needed for early stopping (saving best checkpoint) + Reduce function to generate a single metric value, usually "mean" or "none" + """ + + _metrics = { + "AUC": (auc_score, "max", "mean"), + "Accuracy": (accuracy_score, "max", "mean"), + } + + @classmethod + def names(cls) -> List[str]: + return sorted(cls._metrics.keys()) + + @classmethod + def get(cls, name: str) -> Any: + """Access to Metrics. + + Args: + name: metrics name + Returns: + A class to build the Metrics + """ + return cls._metrics.get(name, "GPT") + + @classmethod + def suitable_metrics(cls, cfg: Any, results: Dict, val_df: pd.DataFrame) -> Dict: + """Access to All Suitable Metrics. For some problem types (e.g. classification) + there might be metrics (e.g. Micro Averaged F1) that are only suitable in + specific cases (multiclass not binary). There might also be additional + metrics returned, which are not possible to select as validation metrics, + e.g. threshold dependant metrics + + Returns: + A dictionary of all suitable metrics for current problem setup + """ + return cls._metrics + + @classmethod + def all_metrics(cls) -> Dict: + """Access to All Metrics. There might also be additional + metrics returned, which are not possible to select as validation metrics, + e.g. threshold dependant metrics + + Returns: + A dictionary of all metrics (including not suitable metrics). + """ + return cls._metrics diff --git a/llm_studio/src/models/text_causal_classification_modeling_model.py b/llm_studio/src/models/text_causal_classification_modeling_model.py new file mode 100644 index 000000000..cba117f26 --- /dev/null +++ b/llm_studio/src/models/text_causal_classification_modeling_model.py @@ -0,0 +1,96 @@ +import logging +from typing import Any, Dict + +from torch import nn +from transformers import AutoModelForCausalLM + +from llm_studio.src.utils.data_utils import batch_padding +from llm_studio.src.utils.modeling_utils import ( + create_nlp_backbone, + generate, + prepare_lora, +) + +logger = logging.getLogger(__name__) + + +class Model(nn.Module): + """ + Model for causal language modeling problem type. + """ + + def __init__(self, cfg: Any): + """ + Args: + cfg: config with all the hyperparameters + """ + + super(Model, self).__init__() + + self.cfg = cfg + self.backbone, self.backbone_config = create_nlp_backbone( + cfg, model_class=AutoModelForCausalLM + ) + + if cfg.training.lora: + self.backbone = prepare_lora(cfg, self.backbone) + + self.head = nn.Linear( + self.backbone_config.vocab_size, cfg.dataset.num_classes, bias=False + ) + + self.loss_fn = self.cfg.training.loss_class.get( + self.cfg.training.loss_function + )(self.cfg) + + def generate(self, batch: Dict, cfg: Any, streamer=None): + return generate(self.backbone, batch, cfg, streamer) + + def forward( + self, + batch: Dict, + padding: bool = True, + ) -> Dict: + # disable cache if gradient checkpointing is enabled + if self.cfg.architecture.gradient_checkpointing: + self.backbone.config.use_cache = False + + outputs: Dict = {} + mask_key = "prompt_attention_mask" + pad_keys = [ + "prompt_input_ids", + "prompt_attention_mask", + "special_tokens_mask", + "labels", + ] + + if padding: + batch = batch_padding( + self.cfg, + batch, + self.training, + mask_key=mask_key, + pad_keys=pad_keys, + padding_side=self.cfg.tokenizer._padding_side, + ) + + output = self.backbone( + input_ids=batch["prompt_input_ids"], + attention_mask=batch["prompt_attention_mask"], + ) + + output.logits = self.head(output[0][:, -1].float()) + + if "labels" in batch: + loss = self.loss_fn( + output.logits, batch["class_label"].unsqueeze(1).float() + ) + outputs["loss"] = loss + + outputs["logits"] = output.logits + + # enable cache again if gradient checkpointing is enabled + if self.cfg.architecture.gradient_checkpointing: + self.backbone.config.use_cache = True + + return outputs diff --git a/llm_studio/src/utils/modeling_utils.py b/llm_studio/src/utils/modeling_utils.py index afadd4871..93c5285a7 100644 --- a/llm_studio/src/utils/modeling_utils.py +++ b/llm_studio/src/utils/modeling_utils.py @@ -419,7 +419,10 @@ def run_inference( with autocast(enabled=cfg.environment.mixed_precision): output = model.forward(batch) - if cfg.prediction.metric != "Perplexity": + if ( + cfg.prediction.metric != "Perplexity" + and cfg.type != "causal_classification" + ): output["predicted_answer_ids"] = ( unwrap_model(model).generate(batch, cfg).detach().cpu() ) From 4d07203153f683a3bc4ac6b225ba43957b4f20a1 Mon Sep 17 00:00:00 2001 From: Philipp Singer Date: Tue, 17 Oct 2023 14:54:19 +0000 Subject: [PATCH 2/7] updates --- llm_studio/app_utils/hugging_face_utils.py | 15 ++- ...t_causal_classification_modeling_config.py | 14 ++- ...xt_causal_classification_modeling_model.py | 4 +- llm_studio/src/utils/modeling_utils.py | 3 + ...cation_experiment_summary_card_template.md | 61 ++++++++++ ...usal_classification_model_card_template.md | 106 ++++++++++++++++++ 6 files changed, 198 insertions(+), 5 deletions(-) create mode 100644 model_cards/text_causal_classification_experiment_summary_card_template.md create mode 100644 model_cards/text_causal_classification_model_card_template.md diff --git a/llm_studio/app_utils/hugging_face_utils.py b/llm_studio/app_utils/hugging_face_utils.py index 81fcfe306..400fcc10a 100644 --- a/llm_studio/app_utils/hugging_face_utils.py +++ b/llm_studio/app_utils/hugging_face_utils.py @@ -116,14 +116,27 @@ def publish_model_to_hugging_face( # push tokenizer to hub tokenizer.push_to_hub(repo_id=repo_id, private=True) + + # push model card to hub card = get_model_card(cfg, model, repo_id) card.push_to_hub( repo_id=repo_id, repo_type="model", commit_message="Upload model card" ) - # push config to hub api = huggingface_hub.HfApi() + + # push classification head to hub + if cfg.type == "causal_classification": + api.upload_file( + path_or_fileobj=f"{path_to_experiment}/classification_head.pth", + path_in_repo="classification_head.pth", + repo_id=repo_id, + repo_type="model", + commit_message="Upload classification_head.pth", + ) + + # push config to hub api.upload_file( path_or_fileobj=os.path.join(path_to_experiment, "cfg.yaml"), path_in_repo="cfg.yaml", diff --git a/llm_studio/python_configs/text_causal_classification_modeling_config.py b/llm_studio/python_configs/text_causal_classification_modeling_config.py index f466e3699..304f21520 100644 --- a/llm_studio/python_configs/text_causal_classification_modeling_config.py +++ b/llm_studio/python_configs/text_causal_classification_modeling_config.py @@ -121,6 +121,16 @@ def __post_init__(self): self._visibility[k] = -1 +@dataclass +class ConfigNLPCausalClassificationEnvironment(ConfigNLPCausalLMEnvironment): + _model_card_template: str = "text_causal_classification_model_card_template.md" + _summary_card_template: str = ( + "text_causal_classification_experiment_summary_card_template.md" + ) + + def __post_init__(self): + super().__post_init__() + @dataclass class ConfigProblemBase(DefaultConfigProblemBase): output_directory: str = f"output/{os.path.basename(__file__).split('.')[0]}" @@ -145,8 +155,8 @@ class ConfigProblemBase(DefaultConfigProblemBase): prediction: ConfigNLPCausalClassificationPrediction = field( default_factory=ConfigNLPCausalClassificationPrediction ) - environment: ConfigNLPCausalLMEnvironment = field( - default_factory=ConfigNLPCausalLMEnvironment + environment: ConfigNLPCausalClassificationEnvironment = field( + default_factory=ConfigNLPCausalClassificationEnvironment ) logging: ConfigNLPCausalLMLogging = field(default_factory=ConfigNLPCausalLMLogging) diff --git a/llm_studio/src/models/text_causal_classification_modeling_model.py b/llm_studio/src/models/text_causal_classification_modeling_model.py index cba117f26..b492fa3b0 100644 --- a/llm_studio/src/models/text_causal_classification_modeling_model.py +++ b/llm_studio/src/models/text_causal_classification_modeling_model.py @@ -35,7 +35,7 @@ def __init__(self, cfg: Any): if cfg.training.lora: self.backbone = prepare_lora(cfg, self.backbone) - self.head = nn.Linear( + self.classification_head = nn.Linear( self.backbone_config.vocab_size, cfg.dataset.num_classes, bias=False ) @@ -79,7 +79,7 @@ def forward( attention_mask=batch["prompt_attention_mask"], ) - output.logits = self.head(output[0][:, -1].float()) + output.logits = self.classification_head(output[0][:, -1].float()) if "labels" in batch: loss = self.loss_fn( diff --git a/llm_studio/src/utils/modeling_utils.py b/llm_studio/src/utils/modeling_utils.py index 93c5285a7..b28ccf143 100644 --- a/llm_studio/src/utils/modeling_utils.py +++ b/llm_studio/src/utils/modeling_utils.py @@ -99,6 +99,9 @@ def save_checkpoint(model: torch.nn.Module, path: str, cfg: Any): if path is not None: torch.save(checkpoint, os.path.join(path, "checkpoint.pth")) + if cfg.type == "causal_classification": + torch.save(checkpoint['model']["classification_head.weight"], os.path.join(path, "classification_head.pth")) + def load_model_weights( model: torch.nn.Module, model_weights: Dict, strict: bool, cfg: Any diff --git a/model_cards/text_causal_classification_experiment_summary_card_template.md b/model_cards/text_causal_classification_experiment_summary_card_template.md new file mode 100644 index 000000000..aede52fe6 --- /dev/null +++ b/model_cards/text_causal_classification_experiment_summary_card_template.md @@ -0,0 +1,61 @@ +### Usage with HF transformers + +To use the model with the `transformers` library on a machine with GPUs: +- First, push the model to a huggingface repo by clicking the Push checkpoint to huggingface button below +- Make sure you have the `transformers` library installed in the machine's environment + +```bash +pip install transformers=={{transformers_version}} +``` + +Also make sure you are providing your huggingface token if the model is lying in a private repo. + - You can login to hugginface_hub by running + ```python + import huggingface_hub + huggingface_hub.login() + ``` + +You will also need to download the classification head, either manually, or by running the following code: + +```python +from huggingface_hub import hf_hub_download + +model_name = "{{repo_id}}" # either local folder or huggingface model name +hf_hub_download(repo_id=model_name, filename="classification_head.pth", local_dir="./") +``` + +You can make classification predictions by following the example below: + +```python +from transformers import AutoModelForCausalLM, AutoTokenizer + +model_name = "{{repo_id}}" # either local folder or huggingface model name +# Important: The prompt needs to be in the same format the model was trained with. +# You can find an example prompt in the experiment logs. +prompt = "{{text_prompt_start}}How are you?{{end_of_sentence}}{{text_answer_separator}}" + +tokenizer = AutoTokenizer.from_pretrained( + model_name, + use_fast={{use_fast}}, + trust_remote_code={{trust_remote_code}}, +) +model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype="auto", + device_map={"": "cuda:0"}, + trust_remote_code={{trust_remote_code}}, +).cuda().eval() + +head_weights = torch.load("classification_head.pth", map_location="cuda") +# settings can be arbitrary here as we overwrite with saved weights +head = torch.nn.Linear(1, 1, bias=False).to("cuda") +head.weight.data = head_weights + +inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to("cuda") + +out = model(**inputs).logits + +logits = head(out[:,-1]) + +print(logits) +``` diff --git a/model_cards/text_causal_classification_model_card_template.md b/model_cards/text_causal_classification_model_card_template.md new file mode 100644 index 000000000..c5422b369 --- /dev/null +++ b/model_cards/text_causal_classification_model_card_template.md @@ -0,0 +1,106 @@ +--- +language: +- en +library_name: transformers +inference: false +thumbnail: https://h2o.ai/etc.clientlibs/h2o/clientlibs/clientlib-site/resources/images/favicon.ico +tags: +- gpt +- llm +- large language model +- h2o-llmstudio +--- +# Model Card +## Summary + +This model was trained using [H2O LLM Studio](https://github.com/h2oai/h2o-llmstudio). +- Base model: [{{base_model}}](https://huggingface.co/{{base_model}}) + + +## Usage + +To use the model with the `transformers` library on a machine with GPUs, first make sure you have the `transformers` library installed. + +```bash +pip install transformers=={{transformers_version}} +``` + +Also make sure you are providing your huggingface token if the model is lying in a private repo. + - You can login to hugginface_hub by running + ```python + import huggingface_hub + huggingface_hub.login() + ``` + +You will also need to download the classification head, either manually, or by running the following code: + +```python +from huggingface_hub import hf_hub_download + +model_name = "{{repo_id}}" # either local folder or huggingface model name +hf_hub_download(repo_id=model_name, filename="classification_head.pth", local_dir="./") +``` + +You can make classification predictions by following the example below: + +```python +from transformers import AutoModelForCausalLM, AutoTokenizer + +model_name = "{{repo_id}}" # either local folder or huggingface model name +# Important: The prompt needs to be in the same format the model was trained with. +# You can find an example prompt in the experiment logs. +prompt = "{{text_prompt_start}}How are you?{{end_of_sentence}}{{text_answer_separator}}" + +tokenizer = AutoTokenizer.from_pretrained( + model_name, + use_fast={{use_fast}}, + trust_remote_code={{trust_remote_code}}, +) +model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype="auto", + device_map={"": "cuda:0"}, + trust_remote_code={{trust_remote_code}}, +).cuda().eval() + +head_weights = torch.load("classification_head.pth", map_location="cuda") +# settings can be arbitrary here as we overwrite with saved weights +head = torch.nn.Linear(1, 1, bias=False).to("cuda") +head.weight.data = head_weights + +inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to("cuda") + +out = model(**inputs).logits + +logits = head(out[:,-1]) + +print(logits) +``` + +## Quantization and sharding + +You can load the models using quantization by specifying ```load_in_8bit=True``` or ```load_in_4bit=True```. Also, sharding on multiple GPUs is possible by setting ```device_map=auto```. + +## Model Architecture + +``` +{{model_architecture}} +``` + +## Model Configuration + +This model was trained using H2O LLM Studio and with the configuration in [cfg.yaml](cfg.yaml). Visit [H2O LLM Studio](https://github.com/h2oai/h2o-llmstudio) to learn how to train your own large language models. + + +## Disclaimer + +Please read this disclaimer carefully before using the large language model provided in this repository. Your use of the model signifies your agreement to the following terms and conditions. + +- Biases and Offensiveness: The large language model is trained on a diverse range of internet text data, which may contain biased, racist, offensive, or otherwise inappropriate content. By using this model, you acknowledge and accept that the generated content may sometimes exhibit biases or produce content that is offensive or inappropriate. The developers of this repository do not endorse, support, or promote any such content or viewpoints. +- Limitations: The large language model is an AI-based tool and not a human. It may produce incorrect, nonsensical, or irrelevant responses. It is the user's responsibility to critically evaluate the generated content and use it at their discretion. +- Use at Your Own Risk: Users of this large language model must assume full responsibility for any consequences that may arise from their use of the tool. The developers and contributors of this repository shall not be held liable for any damages, losses, or harm resulting from the use or misuse of the provided model. +- Ethical Considerations: Users are encouraged to use the large language model responsibly and ethically. By using this model, you agree not to use it for purposes that promote hate speech, discrimination, harassment, or any form of illegal or harmful activities. +- Reporting Issues: If you encounter any biased, offensive, or otherwise inappropriate content generated by the large language model, please report it to the repository maintainers through the provided channels. Your feedback will help improve the model and mitigate potential issues. +- Changes to this Disclaimer: The developers of this repository reserve the right to modify or update this disclaimer at any time without prior notice. It is the user's responsibility to periodically review the disclaimer to stay informed about any changes. + +By using the large language model provided in this repository, you agree to accept and comply with the terms and conditions outlined in this disclaimer. If you do not agree with any part of this disclaimer, you should refrain from using the model and any content generated by it. From c3cbec57071c4f9d1c473b30628dfd7bcfcb298a Mon Sep 17 00:00:00 2001 From: Philipp Singer Date: Tue, 17 Oct 2023 14:54:31 +0000 Subject: [PATCH 3/7] format --- llm_studio/app_utils/hugging_face_utils.py | 2 -- .../text_causal_classification_modeling_config.py | 1 + llm_studio/src/utils/modeling_utils.py | 5 ++++- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/llm_studio/app_utils/hugging_face_utils.py b/llm_studio/app_utils/hugging_face_utils.py index 400fcc10a..c8568a6af 100644 --- a/llm_studio/app_utils/hugging_face_utils.py +++ b/llm_studio/app_utils/hugging_face_utils.py @@ -116,8 +116,6 @@ def publish_model_to_hugging_face( # push tokenizer to hub tokenizer.push_to_hub(repo_id=repo_id, private=True) - - # push model card to hub card = get_model_card(cfg, model, repo_id) card.push_to_hub( diff --git a/llm_studio/python_configs/text_causal_classification_modeling_config.py b/llm_studio/python_configs/text_causal_classification_modeling_config.py index 304f21520..980fcd373 100644 --- a/llm_studio/python_configs/text_causal_classification_modeling_config.py +++ b/llm_studio/python_configs/text_causal_classification_modeling_config.py @@ -131,6 +131,7 @@ class ConfigNLPCausalClassificationEnvironment(ConfigNLPCausalLMEnvironment): def __post_init__(self): super().__post_init__() + @dataclass class ConfigProblemBase(DefaultConfigProblemBase): output_directory: str = f"output/{os.path.basename(__file__).split('.')[0]}" diff --git a/llm_studio/src/utils/modeling_utils.py b/llm_studio/src/utils/modeling_utils.py index b28ccf143..4c9b3bf35 100644 --- a/llm_studio/src/utils/modeling_utils.py +++ b/llm_studio/src/utils/modeling_utils.py @@ -100,7 +100,10 @@ def save_checkpoint(model: torch.nn.Module, path: str, cfg: Any): torch.save(checkpoint, os.path.join(path, "checkpoint.pth")) if cfg.type == "causal_classification": - torch.save(checkpoint['model']["classification_head.weight"], os.path.join(path, "classification_head.pth")) + torch.save( + checkpoint["model"]["classification_head.weight"], + os.path.join(path, "classification_head.pth"), + ) def load_model_weights( From e70f88b34c8a96bdace87a9292c0e4a03ad1d09c Mon Sep 17 00:00:00 2001 From: Philipp Singer Date: Tue, 17 Oct 2023 15:39:12 +0000 Subject: [PATCH 4/7] cfg --- .../text_causal_classification_modeling_config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llm_studio/python_configs/text_causal_classification_modeling_config.py b/llm_studio/python_configs/text_causal_classification_modeling_config.py index 980fcd373..5b740e1a2 100644 --- a/llm_studio/python_configs/text_causal_classification_modeling_config.py +++ b/llm_studio/python_configs/text_causal_classification_modeling_config.py @@ -60,7 +60,7 @@ def __post_init__(self): @dataclass class ConfigNLPCausalClassificationTraining(ConfigNLPCausalLMTraining): loss_class: Any = text_causal_classification_modeling_losses.Losses - loss_function: str = "CrossEntropyLoss" + loss_function: str = "BinaryCrossEntropyLoss" learning_rate: float = 0.0001 differential_learning_rate_layers: Tuple[str, ...] = ("classification_head",) @@ -101,7 +101,7 @@ def __post_init__(self): @dataclass class ConfigNLPCausalClassificationPrediction(ConfigNLPCausalLMPrediction): metric_class: Any = text_causal_classification_modeling_metrics.Metrics - metric: str = "Accuracy" + metric: str = "AUC" def __post_init__(self): super().__post_init__() From 25244e5146293972e1f66a709b355034a814d815 Mon Sep 17 00:00:00 2001 From: Philipp Singer Date: Thu, 19 Oct 2023 12:24:38 +0000 Subject: [PATCH 5/7] feedback --- llm_studio/app_utils/hugging_face_utils.py | 2 +- ...t_causal_classification_modeling_config.py | 1 - .../text_causal_language_modeling_config.py | 1 - .../text_rlhf_language_modeling_config.py | 1 - ...xt_sequence_to_sequence_modeling_config.py | 1 - .../text_causal_language_modeling_ds.py | 31 +++++++++++++++---- ..._causal_classification_modeling_metrics.py | 17 ++-------- llm_studio/src/utils/modeling_utils.py | 4 +-- 8 files changed, 31 insertions(+), 27 deletions(-) diff --git a/llm_studio/app_utils/hugging_face_utils.py b/llm_studio/app_utils/hugging_face_utils.py index c8568a6af..9b2d007ee 100644 --- a/llm_studio/app_utils/hugging_face_utils.py +++ b/llm_studio/app_utils/hugging_face_utils.py @@ -125,7 +125,7 @@ def publish_model_to_hugging_face( api = huggingface_hub.HfApi() # push classification head to hub - if cfg.type == "causal_classification": + if cfg.problem_type == "text_causal_classification_modeling": api.upload_file( path_or_fileobj=f"{path_to_experiment}/classification_head.pth", path_in_repo="classification_head.pth", diff --git a/llm_studio/python_configs/text_causal_classification_modeling_config.py b/llm_studio/python_configs/text_causal_classification_modeling_config.py index 5b740e1a2..73f479eb5 100644 --- a/llm_studio/python_configs/text_causal_classification_modeling_config.py +++ b/llm_studio/python_configs/text_causal_classification_modeling_config.py @@ -138,7 +138,6 @@ class ConfigProblemBase(DefaultConfigProblemBase): experiment_name: str = field(default_factory=generate_experiment_name) _parent_experiment: str = "" llm_backbone: str = "h2oai/h2ogpt-4096-llama2-7b" - type: str = "causal_classification" dataset: ConfigNLPCausalClassificationDataset = field( default_factory=ConfigNLPCausalClassificationDataset diff --git a/llm_studio/python_configs/text_causal_language_modeling_config.py b/llm_studio/python_configs/text_causal_language_modeling_config.py index 359474663..77f9549e0 100644 --- a/llm_studio/python_configs/text_causal_language_modeling_config.py +++ b/llm_studio/python_configs/text_causal_language_modeling_config.py @@ -407,7 +407,6 @@ class ConfigProblemBase(DefaultConfigProblemBase): experiment_name: str = field(default_factory=generate_experiment_name) _parent_experiment: str = "" llm_backbone: str = "h2oai/h2ogpt-4096-llama2-7b" - type: str = "causal_lm" dataset: ConfigNLPCausalLMDataset = field(default_factory=ConfigNLPCausalLMDataset) tokenizer: ConfigNLPCausalLMTokenizer = field( diff --git a/llm_studio/python_configs/text_rlhf_language_modeling_config.py b/llm_studio/python_configs/text_rlhf_language_modeling_config.py index 4c04b3eea..bd9a0342f 100644 --- a/llm_studio/python_configs/text_rlhf_language_modeling_config.py +++ b/llm_studio/python_configs/text_rlhf_language_modeling_config.py @@ -178,7 +178,6 @@ class ConfigProblemBase(DefaultConfigProblemBase): _parent_experiment: str = "" llm_backbone: str = "h2oai/h2ogpt-4096-llama2-7b-chat" reward_model: str = "OpenAssistant/reward-model-deberta-v3-large-v2" - type: str = "rlhf" dataset: ConfigRLHFLMDataset = field(default_factory=ConfigRLHFLMDataset) tokenizer: ConfigNLPCausalLMTokenizer = field( diff --git a/llm_studio/python_configs/text_sequence_to_sequence_modeling_config.py b/llm_studio/python_configs/text_sequence_to_sequence_modeling_config.py index 6cd872ee3..aaa7fe2b0 100644 --- a/llm_studio/python_configs/text_sequence_to_sequence_modeling_config.py +++ b/llm_studio/python_configs/text_sequence_to_sequence_modeling_config.py @@ -74,7 +74,6 @@ class ConfigProblemBase(DefaultConfigProblemBase): experiment_name: str = field(default_factory=generate_experiment_name) _parent_experiment: str = "" llm_backbone: str = "t5-small" - type: str = "seq2seq" dataset: ConfigNLPSeq2SeqDataset = field(default_factory=ConfigNLPSeq2SeqDataset) tokenizer: ConfigNLPCausalLMTokenizer = field( diff --git a/llm_studio/src/datasets/text_causal_language_modeling_ds.py b/llm_studio/src/datasets/text_causal_language_modeling_ds.py index c1509e644..73480efec 100644 --- a/llm_studio/src/datasets/text_causal_language_modeling_ds.py +++ b/llm_studio/src/datasets/text_causal_language_modeling_ds.py @@ -10,6 +10,7 @@ from llm_studio.src.datasets.conversation_chain_handler import ConversationChainHandler from llm_studio.src.datasets.text_utils import TEXT_SEPARATOR, get_tokenizer +from llm_studio.src.utils.exceptions import LLMDataException logger = logging.getLogger(__name__) @@ -31,8 +32,21 @@ def __init__(self, df: pd.DataFrame, cfg: Any, mode: str = "train"): self.tokenizer = get_tokenizer(self.cfg) self.conversation_chain_handler = ConversationChainHandler(self.df, cfg) - if cfg.type == "causal_classification": + if cfg.problem_type == "text_causal_classification_modeling": self.answers_int = df[cfg.dataset.answer_column].astype(int).values.tolist() + if ( + cfg.dataset.num_classes > 1 + and max(self.answers_int) >= cfg.dataset.num_classes + ): + raise LLMDataException( + "Number of classes is smaller than max label " + f"{max(self.answers_int)}. Please increase the setting accordingly." + ) + elif cfg.dataset.num_classes == 1 and max(self.answers_int) > 1: + raise LLMDataException( + "For binary classification, max label should be 1 but is " + f"{max(self.answers_int)}." + ) def __len__(self) -> int: return len(self.conversation_chain_handler) @@ -111,7 +125,7 @@ def __getitem__(self, idx: int) -> Dict: if sample["prompt_input_ids"][0] != self.tokenizer.pad_token_id: sample["prompt_input_ids"][: len(system_encoding)] = system_encoding - if self.cfg.type == "causal_classification": + if self.cfg.problem_type == "text_causal_classification_modeling": sample["class_label"] = self.answers_int[idx] return sample @@ -261,10 +275,15 @@ def clean_output( return output def postprocess_output(self, cfg, df: pd.DataFrame, output: Dict) -> Dict: - if ( - not cfg.prediction.metric == "Perplexity" - and not cfg.type == "causal_classification" - ): + if cfg.problem_type == "text_causal_classification_modeling": + if cfg.dataset.num_classes == 1: + preds = output["logits"] + preds = np.array((preds > 0.0)).astype(int).astype(str).reshape(-1) + else: + preds = output["logits"] + preds = np.array(torch.argmax(preds, dim=1)).astype(str).reshape(-1) + output["predicted_text"] = preds + elif not cfg.prediction.metric == "Perplexity": output = self.clean_output(output, cfg) output["target_text"] = self.conversation_chain_handler.answers diff --git a/llm_studio/src/metrics/text_causal_classification_modeling_metrics.py b/llm_studio/src/metrics/text_causal_classification_modeling_metrics.py index 08bd9539c..5b39142a1 100644 --- a/llm_studio/src/metrics/text_causal_classification_modeling_metrics.py +++ b/llm_studio/src/metrics/text_causal_classification_modeling_metrics.py @@ -16,20 +16,9 @@ def accuracy_score( val_df: pd.DataFrame, raw_results: bool = False, ) -> Union[NDArray, Tuple[NDArray, List[str]]]: - if cfg.dataset.num_classes == 1: - logits = results["logits"] - logits = np.array((logits > 0.0)).astype(int).reshape(-1) - target_text = np.array([int(text) for text in results["target_text"]]) - correct_predictions = (logits == target_text).sum() - total_predictions = len(logits) - return correct_predictions / total_predictions - else: - logits = results["logits"] - logits = np.array(torch.argmax(logits, dim=1)).astype(int).reshape(-1) - target_text = np.array([int(text) for text in results["target_text"]]) - correct_predictions = (logits == target_text).sum() - total_predictions = len(logits) - return correct_predictions / total_predictions + predicted_text = np.array([int(text) for text in results["predicted_text"]]) + target_text = np.array([int(text) for text in results["target_text"]]) + return (predicted_text == target_text).astype("float") def auc_score( diff --git a/llm_studio/src/utils/modeling_utils.py b/llm_studio/src/utils/modeling_utils.py index 4c9b3bf35..34e50a0e9 100644 --- a/llm_studio/src/utils/modeling_utils.py +++ b/llm_studio/src/utils/modeling_utils.py @@ -99,7 +99,7 @@ def save_checkpoint(model: torch.nn.Module, path: str, cfg: Any): if path is not None: torch.save(checkpoint, os.path.join(path, "checkpoint.pth")) - if cfg.type == "causal_classification": + if cfg.problem_type == "text_causal_classification_modeling": torch.save( checkpoint["model"]["classification_head.weight"], os.path.join(path, "classification_head.pth"), @@ -427,7 +427,7 @@ def run_inference( output = model.forward(batch) if ( cfg.prediction.metric != "Perplexity" - and cfg.type != "causal_classification" + and cfg.problem_type != "text_causal_classification_modeling" ): output["predicted_answer_ids"] = ( unwrap_model(model).generate(batch, cfg).detach().cpu() From 88e232ed9be0b23f9647f9b8449448f3a9a6c08c Mon Sep 17 00:00:00 2001 From: Philipp Singer Date: Thu, 19 Oct 2023 12:33:35 +0000 Subject: [PATCH 6/7] import --- .../src/metrics/text_causal_classification_modeling_metrics.py | 1 - 1 file changed, 1 deletion(-) diff --git a/llm_studio/src/metrics/text_causal_classification_modeling_metrics.py b/llm_studio/src/metrics/text_causal_classification_modeling_metrics.py index 5b39142a1..2497a703d 100644 --- a/llm_studio/src/metrics/text_causal_classification_modeling_metrics.py +++ b/llm_studio/src/metrics/text_causal_classification_modeling_metrics.py @@ -3,7 +3,6 @@ import numpy as np import pandas as pd -import torch from numpy.typing import NDArray from sklearn.metrics import roc_auc_score From 8a1108fbc8a11c10402f2fdb2b50dd4a5af5db1c Mon Sep 17 00:00:00 2001 From: Philipp Singer Date: Mon, 23 Oct 2023 12:51:11 +0000 Subject: [PATCH 7/7] readme --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index d8f5c7deb..11cb89423 100644 --- a/README.md +++ b/README.md @@ -54,6 +54,7 @@ Using CLI for fine-tuning LLMs: ## What's New +- [PR 449](https://github.com/h2oai/h2o-llmstudio/pull/449) New problem type for Causal Classification Modeling allows to train binary and multiclass models using LLMs. - [PR 364](https://github.com/h2oai/h2o-llmstudio/pull/364) User secrets are now handled more securely and flexible. Support for handling secrets using the 'keyring' library was added. User settings are tried to be migrated automatically. - [PR 328](https://github.com/h2oai/h2o-llmstudio/pull/328) RLHF is now a separate problem type. Note that starting a new RLHF experiment from an old experiment that used RLHF is no longer supported. To continue from a previous experiment, please start a new experiment and enter the settings from the previous experiment manually. - [PR 308](https://github.com/h2oai/h2o-llmstudio/pull/308) Sequence to sequence models have been added as a new problem type.