Skip to content

Commit

Permalink
Added zero-shot-classification pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Sep 16, 2024
1 parent c5fc777 commit a8e44ad
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
## 1.0.4 (unreleased)

- Added `zero-shot-classification` pipeline
- Added `fill-mask` pipeline

## 1.0.3 (2024-08-29)
Expand Down
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,13 @@ extractor = Informers.pipeline("feature-extraction")
extractor.("We are very happy to show you the 🤗 Transformers library.")
```

Zero-shot classification [unreleased]

```ruby
classifier = Informers.pipeline("zero-shot-classification")
classifier.("text", ["label1", "label2", "label3"])
```

Fill mask [unreleased]

```ruby
Expand Down
3 changes: 2 additions & 1 deletion lib/informers/configs.rb
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
module Informers
class PretrainedConfig
attr_reader :model_type, :problem_type, :id2label
attr_reader :model_type, :problem_type, :id2label, :label2id

def initialize(config_json)
@is_encoder_decoder = false

@model_type = config_json["model_type"]
@problem_type = config_json["problem_type"]
@id2label = config_json["id2label"]
@label2id = config_json["label2id"]
end

def [](key)
Expand Down
88 changes: 88 additions & 0 deletions lib/informers/pipelines.rb
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,85 @@ def call(texts, top_k: 5)
end
end

class ZeroShotClassificationPipeline < Pipeline
def initialize(**options)
super(**options)

@label2id = @model.config.label2id.transform_keys(&:downcase)

@entailment_id = @label2id["entailment"]
if @entailment_id.nil?
warn "Could not find 'entailment' in label2id mapping. Using 2 as entailment_id."
@entailment_id = 2
end

@contradiction_id = @label2id["contradiction"] || @label2id["not_entailment"]
if @contradiction_id.nil?
warn "Could not find 'contradiction' in label2id mapping. Using 0 as contradiction_id."
@contradiction_id = 0
end
end

def call(texts, candidate_labels, hypothesis_template: "This example is {}.", multi_label: false)
is_batched = texts.is_a?(Array)
if !is_batched
texts = [texts]
end
if !candidate_labels.is_a?(Array)
candidate_labels = [candidate_labels]
end

# Insert labels into hypothesis template
hypotheses = candidate_labels.map { |x| hypothesis_template.sub("{}", x) }

# How to perform the softmax over the logits:
# - true: softmax over the entailment vs. contradiction dim for each label independently
# - false: softmax the "entailment" logits over all candidate labels
softmax_each = multi_label || candidate_labels.length == 1

to_return = []
texts.each do |premise|
entails_logits = []

hypotheses.each do |hypothesis|
inputs = @tokenizer.(
premise,
text_pair: hypothesis,
padding: true,
truncation: true
)
outputs = @model.(inputs)

if softmax_each
entails_logits << [
outputs.logits[0][@contradiction_id],
outputs.logits[0][@entailment_id]
]
else
entails_logits << outputs.logits[0][@entailment_id]
end
end

scores =
if softmax_each
entails_logits.map { |x| Utils.softmax(x)[1] }
else
Utils.softmax(entails_logits)
end

# Sort by scores (desc) and return scores with indices
scores_sorted = scores.map.with_index { |x, i| [x, i] }.sort_by { |v| -v[0] }

to_return << {
sequence: premise,
labels: scores_sorted.map { |x| candidate_labels[x[1]] },
scores: scores_sorted.map { |x| x[0] }
}
end
is_batched ? to_return : to_return[0]
end
end

class FeatureExtractionPipeline < Pipeline
def call(
texts,
Expand Down Expand Up @@ -418,6 +497,15 @@ def call(
},
type: "text"
},
"zero-shot-classification" => {
tokenizer: AutoTokenizer,
pipeline: ZeroShotClassificationPipeline,
model: AutoModelForSequenceClassification,
default: {
model: "Xenova/distilbert-base-uncased-mnli"
},
type: "text"
},
"feature-extraction" => {
tokenizer: AutoTokenizer,
pipeline: FeatureExtractionPipeline,
Expand Down
10 changes: 10 additions & 0 deletions test/pipeline_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,16 @@ def test_question_answering
assert_equal 46, result[:end]
end

def test_zero_shot_classification
classifier = Informers.pipeline("zero-shot-classification")
text = "Last week I upgraded my iOS version and ever since then my phone has been overheating whenever I use your app."
labels = ["mobile", "billing", "website", "account access"]
result = classifier.(text, labels)
assert_equal text, result[:sequence]
assert_equal ["mobile", "billing", "account access", "website"], result[:labels]
assert_elements_in_delta [0.516, 0.179, 0.167, 0.138], result[:scores]
end

def test_fill_mask
unmasker = Informers.pipeline("fill-mask")
result = unmasker.("Paris is the [MASK] of France.")
Expand Down

0 comments on commit a8e44ad

Please sign in to comment.