Skip to content

Commit

Permalink
abstract model class
Browse files Browse the repository at this point in the history
  • Loading branch information
pufanyi committed Sep 11, 2023
1 parent 4bd8064 commit 50150c5
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 8 deletions.
8 changes: 5 additions & 3 deletions pipeline/evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@
sys.path.append("../..")


from pipeline.evaluation.models.model import load_model
from pipeline.evaluation.evaluator.mmbench import MMBenchEvaluator
from pipeline.evaluation.models.idefics import Idefics


if __name__ == "__main__":
# model = Otter("/data/pufanyi/training_data/checkpoints/OTTER-Image-MPT7B")
model_info = {
"model_path": "/data/pufanyi/training_data/checkpoints/idefics-9b-instruct",
}
evaluator = MMBenchEvaluator("/data/pufanyi/training_data/MMBench/mmbench_test_20230712.tsv")
model = Idefics("/data/pufanyi/training_data/checkpoints/idefics-9b-instruct")
model = load_model("idefics", model_info)
evaluator.evaluate(model)

# pip install otter_ai
Expand Down
8 changes: 5 additions & 3 deletions pipeline/evaluation/models/idefics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import List
from transformers import IdeficsForVisionText2Text, AutoProcessor
from PIL import Image
from .model import Model


def get_formatted_prompt(prompt: str, image: Image.Image) -> List[str]:
Expand All @@ -13,10 +14,11 @@ def get_formatted_prompt(prompt: str, image: Image.Image) -> List[str]:
]


class Idefics(object):
def __init__(self, model_name_or_path: str = "HuggingFaceM4/idefics-9b-instruct"):
class Idefics(Model):
def __init__(self, model_path: str = "HuggingFaceM4/idefics-9b-instruct"):
super().__init__("idefics", model_path)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint = model_name_or_path
checkpoint = model_path
self.model = IdeficsForVisionText2Text.from_pretrained(checkpoint, torch_dtype=torch.bfloat16).to(self.device)
self.processor = AutoProcessor.from_pretrained(checkpoint)

Expand Down
3 changes: 2 additions & 1 deletion pipeline/evaluation/models/idefics_otter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
from pipeline.train.train_utils import get_image_attention_mask


class OtterIdefics:
class OtterIdefics(Model):
def __init__(
self,
checkpoint: str = "/data/pufanyi/training_data/checkpoints/otter_idefics9b_0830",
processor: str = "/data/pufanyi/training_data/checkpoints/idefics-80b-instruct",
) -> None:
super().__init__("idefics_otter", checkpoint)
kwargs = {"device_map": "auto", "torch_dtype": torch.bfloat16}
self.model = IdeficsForVisionText2Text.from_pretrained(checkpoint, **kwargs)
self.processor = AutoProcessor.from_pretrained(processor)
Expand Down
31 changes: 31 additions & 0 deletions pipeline/evaluation/models/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from abc import ABC, abstractmethod
from PIL import Image
from typing import Dict

import importlib

AVAILABLE_MODELS: Dict[str, str] = {
"otter": "Otter",
"idefics": "Idefics",
"idefics_otter": "IdeficsOtter",
}


class Model(ABC):
def __init__(self, name: str, model_path: str):
self.name = name
self.model_path = model_path

@abstractmethod
def generate(self, question: str, raw_image_data: Image.Image):
pass


def load_model(model_name: str, dataset_args: Dict[str, str]) -> Model:
assert model_name in AVAILABLE_MODELS, f"{model_name} is not an available model."
module_path = "pipeline.evaluation.models." + model_name
dataset_name = AVAILABLE_MODELS[model_name]
imported_module = importlib.import_module(module_path)
dataset_class = getattr(imported_module, dataset_name)
print(f"Imported class: {dataset_class}")
return dataset_class(**dataset_args)
4 changes: 3 additions & 1 deletion pipeline/evaluation/models/otter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from PIL import Image

from otter_ai import OtterForConditionalGeneration
from .model import Model


# Disable warnings
Expand Down Expand Up @@ -55,8 +56,9 @@ def get_response(image: Image.Image, prompt: str, model=None, image_processor=No
return parsed_output


class Otter(object):
class Otter(Model):
def __init__(self, model_name_or_path="luodian/OTTER-Image-MPT7B", load_bit="bf16"):
super().__init__("otter", model_name_or_path)
precision = {}
if load_bit == "bf16":
precision["torch_dtype"] = torch.bfloat16
Expand Down

0 comments on commit 50150c5

Please sign in to comment.