Skip to content

Commit

Permalink
Update model.py
Browse files Browse the repository at this point in the history
  • Loading branch information
teowu authored Dec 28, 2023
1 parent 1b1113c commit 8eace7d
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions boost_qa/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,17 @@
class QInstructScorer(nn.Module):
def __init__(self, boost=True, device="cuda:0"):
super().__init__()
tokenizer, model, image_processor, _ = load_pretrained_model("teowu/mplug_owl2_7b_448_qinstruct_preview_v0.2", None, "mplug_owl2", device=device)
tokenizer, model, image_processor, _ = load_pretrained_model("q-future/q-instruct-mplug-owl2", None, "mplug_owl2", device=device)
prompt = "USER: <|image|>Rate the quality of the image.\nASSISTANT: "

if not boost:
self.boost = False
self.preferential_ids_ = [id_[1] for id_ in tokenizer(["good", "average", "poor"])["input_ids"]]
self.weight_tensor = torch.Tensor([1, 0.5, 0]).half().to(model.device)
else:
self.boost = True
self.preferential_ids_ = [id_[1] for id_ in tokenizer(["good", "average", "poor", "high", "medium", "low", "fine", "acceptable", "bad"])["input_ids"]]
self.weight_tensor = torch.Tensor([1, 0.5, 0, 1, 0.5, 0, 1, 0.5, 0]).half().to(model.device) / 3.
self.weight_tensor = torch.Tensor([1, 0.5, 0]).half().to(model.device)

self.tokenizer = tokenizer
self.model = model
Expand All @@ -35,10 +37,11 @@ def forward(self, image: List[Image.Image]):
image_tensor = self.image_processor.preprocess(image, return_tensors="pt")["pixel_values"].half().to(self.model.device)
output_logits = self.model(self.input_ids.repeat(image_tensor.shape[0], 1),
images=image_tensor)["logits"][:,-1, self.preferential_ids_]

if self.boost:
output_logits = output_logits.reshape(-1, 3, 3).mean(1)
return torch.softmax(output_logits, -1) @ self.weight_tensor


if __name__ == "__main__":
scorer = QInstructScorer(boost=False)
print(scorer([Image.open("fig/examples_211.jpg"),Image.open("fig/sausage.jpg")]).tolist())
print(scorer([Image.open("fig/examples_211.jpg"),Image.open("fig/sausage.jpg")]).tolist())

0 comments on commit 8eace7d

Please sign in to comment.