Skip to content

Commit

Permalink
docs: add setfit differences, change number of samples
Browse files Browse the repository at this point in the history
  • Loading branch information
sdiazlor committed May 27, 2024
1 parent df3997e commit a6b4b14
Showing 1 changed file with 77 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,28 @@
"Let's get started! 🚀"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Introduction\n",
"\n",
"[**FastFit**](https://github.com/IBM/fastfit) is a library that allows you to train a multi-class classifier with few-shot learning. It is based on the [**transformers**](https://huggingface.co/transformers/) library and uses a pre-trained model to fine-tune it on a small dataset. This is particularly useful when you have a small dataset and you want to train a model quickly. However, [**SetFit**](https://github.com/huggingface/setfit) is another well-know library that also allows few-shot learning with Sentence Transformers.\n",
"\n",
"So, why using one and not the other? Based on this [article](https://medium.com/@meetgandhi586/comparing-setfit-fastfit-and-semantic-router-finding-the-best-nlp-chatbot-intent-detection-d8161a7ad117), where the author compares FastFit, SetFit, and Semantic Router, we can determine some distinctions.\n",
"\n",
"| **Aspect** | **FastFit** | **SetFit** |\n",
"|---------------------------|-----------------------------------------------|--------------------------------------------|\n",
"| **Accuracy** | High, but may sacrifice accuracy for speed | Consistently high |\n",
"| **Training Speed** | Fast | Slow |\n",
"| **Inference Speed** | Slow | Fast |\n",
"| **Deployment** | Easy, minimal expertise needed | Requires knowledge of transformers |\n",
"| **Dataset Handling** | Struggles with highly complex datasets | Can be fine-tuned for various datasets |\n",
"| **Computational Costs** | Lower | Higher |\n",
"\n",
"In this tutorial, we will focus on FastFit, but you can also try SetFit and compare the results. To know how to use SetFit, you can check this [tutorial](../../tutorials/feedback/labelling-feedback-setfit.ipynb)."
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -279,7 +301,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Since this is few-shot learning, we don't need to use all the examples in the training set. Therefore, we will utilize the `sample_dataset` method from FastFit to select only 5 examples per class. Additionally, we will rename the `val` split to `validation` to comply with FastFit requirements."
"Since this is few-shot learning, we don't need to use all the examples in the training set. Therefore, we will utilize the `sample_dataset` method from FastFit to select 10 examples per class (since FastFit is faster to train, we can afford to include more samples without worrying about significantly increased training times). Additionally, we will rename the `val` split to `validation` to comply with FastFit requirements."
]
},
{
Expand All @@ -293,7 +315,7 @@
"DatasetDict({\n",
" train: Dataset({\n",
" features: ['text', 'intent'],\n",
" num_rows: 750\n",
" num_rows: 1500\n",
" })\n",
" test: Dataset({\n",
" features: ['text', 'intent'],\n",
Expand Down Expand Up @@ -339,7 +361,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"In this case, we have chosen to use the [avsolatorio/GIST-small-Embedding-v0](https://huggingface.co/avsolatorio/NoInstruct-small-Embedding-v0) model to train the intent classifier due to its size and performance. However, you can explore other models available on Hugging Face and find the most appropriate one by consulting the [MTEB leaderboard](https://huggingface.co/spaces/mteb/leaderboard). Most of the arguments set in the `FastFitTrainer` are the default values, but you can change them according to your needs."
"In this case, we have chosen to use the [sentence-transformers/paraphrase-mpnet-base-v2](https://huggingface.co/sentence-transformers/paraphrase-mpnet-base-v2) model to train the intent classifier due to its size and performance. However, you can explore other models available on Hugging Face and find the most appropriate one by consulting the [MTEB leaderboard](https://huggingface.co/spaces/mteb/leaderboard). Most of the arguments set in the `FastFitTrainer` are the default values, but you can change them according to your needs."
]
},
{
Expand All @@ -350,7 +372,7 @@
"source": [
"# Initialize the FastFitTrainer\n",
"trainer = FastFitTrainer(\n",
" model_name_or_path=\"avsolatorio/GIST-small-Embedding-v0\",\n",
" model_name_or_path=\"sentence-transformers/paraphrase-mpnet-base-v2\",\n",
" label_column_name=\"intent\",\n",
" text_column_name=\"text\",\n",
" num_train_epochs=25,\n",
Expand All @@ -377,8 +399,8 @@
"\n",
" <div>\n",
" \n",
" <progress value='600' max='600' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" [600/600 01:07, Epoch 25/25]\n",
" <progress value='1175' max='1175' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" [1175/1175 02:58, Epoch 25/25]\n",
" </div>\n",
" <table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
Expand All @@ -390,7 +412,11 @@
" <tbody>\n",
" <tr>\n",
" <td>500</td>\n",
" <td>2.665300</td>\n",
" <td>2.676200</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1000</td>\n",
" <td>2.590300</td>\n",
" </tr>\n",
" </tbody>\n",
"</table><p>"
Expand All @@ -409,11 +435,11 @@
"***** train metrics *****\n",
" epoch = 25.0\n",
" total_flos = 0GF\n",
" train_loss = 2.6493\n",
" train_runtime = 0:01:08.63\n",
" train_samples = 750\n",
" train_samples_per_second = 273.189\n",
" train_steps_per_second = 8.742\n"
" train_loss = 2.6261\n",
" train_runtime = 0:02:59.32\n",
" train_samples = 1500\n",
" train_samples_per_second = 209.121\n",
" train_steps_per_second = 6.552\n"
]
}
],
Expand All @@ -426,26 +452,44 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, we can evaluate the model and save it to disk for the inference step."
"As we can see, the training time took only 3 minutes, which is quite quick. Now, let’s evaluate the model to check its accuracy. After evaluation, we will save the model to disk for the inference step."
]
},
{
"cell_type": "code",
"execution_count": 234,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" <div>\n",
" \n",
" <progress value='47' max='47' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" [47/47 00:03]\n",
" </div>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"***** eval metrics *****\n",
" epoch = 25.0\n",
" eval_accuracy = 0.9063\n",
" eval_loss = 3.1428\n",
" eval_runtime = 0:00:02.76\n",
" eval_accuracy = 0.947\n",
" eval_loss = 2.9615\n",
" eval_runtime = 0:00:04.30\n",
" eval_samples = 3000\n",
" eval_samples_per_second = 1085.348\n",
" eval_steps_per_second = 17.004\n"
" eval_samples_per_second = 697.334\n",
" eval_steps_per_second = 10.925\n"
]
}
],
Expand Down Expand Up @@ -486,50 +530,37 @@
"source": [
"# Load the model and tokenizer\n",
"model = FastFit.from_pretrained(\"intent_fastfit_model\")\n",
"tokenizer = AutoTokenizer.from_pretrained(\"avsolatorio/GIST-small-Embedding-v0\")"
"tokenizer = AutoTokenizer.from_pretrained(\"sentence-transformers/paraphrase-mpnet-base-v2\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, we will prepare our pipeline and create a function to make predictions since our model is based on BERT, including `token_type_ids`, which are not needed in our case as noted [here](https://github.com/IBM/fastfit/issues/6).\n",
"Next, we will prepare our pipeline.\n",
"\n",
"> Note: Due to this fact, a warning is raised, but the classifier will work properly.\n",
"> An error warning is raised when initializing the pipeline: `The model 'FastFit' is not supported for text-classification`. However, the classifier will work properly.\n",
"\n",
"> If you have changed the base model, you may encounter a `token_type_ids` error. You can solve it as noted [here](https://github.com/IBM/fastfit/issues/6).\n",
"\n",
"We will set `top_k=1` to get the most likely class for each prediction. If you want to get all the predicted classes, set it to `None`."
]
},
{
"cell_type": "code",
"execution_count": 53,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"The model 'FastFit' is not supported for text-classification. Supported models are ['AlbertForSequenceClassification', 'BartForSequenceClassification', 'BertForSequenceClassification', 'BigBirdForSequenceClassification', 'BigBirdPegasusForSequenceClassification', 'BioGptForSequenceClassification', 'BloomForSequenceClassification', 'CamembertForSequenceClassification', 'CanineForSequenceClassification', 'LlamaForSequenceClassification', 'ConvBertForSequenceClassification', 'CTRLForSequenceClassification', 'Data2VecTextForSequenceClassification', 'DebertaForSequenceClassification', 'DebertaV2ForSequenceClassification', 'DistilBertForSequenceClassification', 'ElectraForSequenceClassification', 'ErnieForSequenceClassification', 'ErnieMForSequenceClassification', 'EsmForSequenceClassification', 'FalconForSequenceClassification', 'FlaubertForSequenceClassification', 'FNetForSequenceClassification', 'FunnelForSequenceClassification', 'GemmaForSequenceClassification', 'GPT2ForSequenceClassification', 'GPT2ForSequenceClassification', 'GPTBigCodeForSequenceClassification', 'GPTNeoForSequenceClassification', 'GPTNeoXForSequenceClassification', 'GPTJForSequenceClassification', 'IBertForSequenceClassification', 'JambaForSequenceClassification', 'JetMoeForSequenceClassification', 'LayoutLMForSequenceClassification', 'LayoutLMv2ForSequenceClassification', 'LayoutLMv3ForSequenceClassification', 'LEDForSequenceClassification', 'LiltForSequenceClassification', 'LlamaForSequenceClassification', 'LongformerForSequenceClassification', 'LukeForSequenceClassification', 'MarkupLMForSequenceClassification', 'MBartForSequenceClassification', 'MegaForSequenceClassification', 'MegatronBertForSequenceClassification', 'MistralForSequenceClassification', 'MixtralForSequenceClassification', 'MobileBertForSequenceClassification', 'MPNetForSequenceClassification', 'MptForSequenceClassification', 'MraForSequenceClassification', 'MT5ForSequenceClassification', 'MvpForSequenceClassification', 'NezhaForSequenceClassification', 'NystromformerForSequenceClassification', 'OpenLlamaForSequenceClassification', 'OpenAIGPTForSequenceClassification', 'OPTForSequenceClassification', 'PerceiverForSequenceClassification', 'PersimmonForSequenceClassification', 'PhiForSequenceClassification', 'Phi3ForSequenceClassification', 'PLBartForSequenceClassification', 'QDQBertForSequenceClassification', 'Qwen2ForSequenceClassification', 'Qwen2MoeForSequenceClassification', 'ReformerForSequenceClassification', 'RemBertForSequenceClassification', 'RobertaForSequenceClassification', 'RobertaPreLayerNormForSequenceClassification', 'RoCBertForSequenceClassification', 'RoFormerForSequenceClassification', 'SqueezeBertForSequenceClassification', 'StableLmForSequenceClassification', 'Starcoder2ForSequenceClassification', 'T5ForSequenceClassification', 'TapasForSequenceClassification', 'TransfoXLForSequenceClassification', 'UMT5ForSequenceClassification', 'XLMForSequenceClassification', 'XLMRobertaForSequenceClassification', 'XLMRobertaXLForSequenceClassification', 'XLNetForSequenceClassification', 'XmodForSequenceClassification', 'YosoForSequenceClassification'].\n"
]
}
],
"outputs": [],
"source": [
"# Define the pipeline\n",
"classifier = pipeline(\"text-classification\", model=model, tokenizer=tokenizer)\n",
"\n",
"# Define the adapted predict function\n",
"def predict(text):\n",
" preprocessed_ip = classifier.preprocess(text)\n",
" model_op = classifier.forward({'input_ids': preprocessed_ip['input_ids'],'attention_mask': preprocessed_ip['attention_mask']})\n",
" prediction = classifier.postprocess(model_op, top_k=1)\n",
" return prediction"
"classifier = pipeline(\"text-classification\", model=model, tokenizer=tokenizer, top_k=1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"So, let's make some predictions! We will use the `validation` split for the tutorial's purpose and only the first 100 samples."
"So, let's make some predictions! For the purpose of this tutorial, we will use the validation split and only the first 100 samples. As observed during our initial run, making predictions took over a minute. Therefore, as noted in the introduction, if we needed to make more predictions, it would be slow."
]
},
{
Expand All @@ -543,7 +574,7 @@
" {\n",
" \"text\": sample[\"text\"],\n",
" \"true_intent\": sample[\"intent\"],\n",
" \"predicted_intent\": predict(sample[\"text\"])\n",
" \"predicted_intent\": classifier(sample[\"text\"])\n",
" }\n",
" for sample in dataset['validation'].to_list()[:100]\n",
"]"
Expand All @@ -558,11 +589,11 @@
"name": "stdout",
"output_type": "stream",
"text": [
"{'text': \"what's the weather today\", 'true_intent': 'utility:weather', 'predicted_intent': {'label': 'utility:weather', 'score': 0.8008840084075928}}\n",
"{'text': 'what are the steps required for making a vacation request', 'true_intent': 'work:pto_request', 'predicted_intent': {'label': 'work:pto_request', 'score': 0.9565843939781189}}\n",
"{'text': 'help me set a timer please', 'true_intent': 'utility:timer', 'predicted_intent': {'label': 'utility:timer', 'score': 0.7600076794624329}}\n",
"{'text': 'what is the mpg for this car', 'true_intent': 'auto_and_commute:mpg', 'predicted_intent': {'label': 'auto_and_commute:mpg', 'score': 0.9399768114089966}}\n",
"{'text': 'was my last transaction at walmart', 'true_intent': 'banking:transactions', 'predicted_intent': {'label': 'banking:transactions', 'score': 0.32635951042175293}}\n"
"{'text': \"what's the weather today\", 'true_intent': 'utility:weather', 'predicted_intent': [[{'label': 'utility:weather', 'score': 0.8783325552940369}]]}\n",
"{'text': 'what are the steps required for making a vacation request', 'true_intent': 'work:pto_request', 'predicted_intent': [[{'label': 'work:pto_request', 'score': 0.9850018620491028}]]}\n",
"{'text': 'help me set a timer please', 'true_intent': 'utility:timer', 'predicted_intent': [[{'label': 'utility:timer', 'score': 0.8667100071907043}]]}\n",
"{'text': 'what is the mpg for this car', 'true_intent': 'auto_and_commute:mpg', 'predicted_intent': [[{'label': 'auto_and_commute:mpg', 'score': 0.9732610583305359}]]}\n",
"{'text': 'was my last transaction at walmart', 'true_intent': 'banking:transactions', 'predicted_intent': [[{'label': 'banking:transactions', 'score': 0.9001362919807434}]]}\n"
]
}
],
Expand Down

0 comments on commit a6b4b14

Please sign in to comment.