diff --git a/docs/reference/models/transformers_vision.md b/docs/reference/models/transformers_vision.md new file mode 100644 index 000000000..5d95b456c --- /dev/null +++ b/docs/reference/models/transformers_vision.md @@ -0,0 +1,114 @@ +# Transformers Vision + +Outlines allows seamless use of [vision models](https://huggingface.co/learn/computer-vision-course/en/unit4/multimodal-models/tasks-models-part1). + +`outlines.models.transformers_vision` has shares interfaces with, and is based on [`outlines.models.transformers`](./transformers.md). + +Tasks supported include +- image + text -> text +- video + text -> text + + + +## Example: Using [Llava-Next](https://huggingface.co/docs/transformers/en/model_doc/llava_next) Vision Models + +Install dependencies +`pip install torchvision pillow flash-attn` + +Create the model +```python +import outlines + +model = outlines.models.transformers_vision( + "llava-hf/llava-v1.6-mistral-7b-hf", + device="cuda", +) +``` + +Create convenience function to load a `PIL.Image` from URL +``` +from PIL import Image +from io import BytesIO +from urllib.request import urlopen + +def img_from_url(url): + img_byte_stream = BytesIO(urlopen(url).read()) + return Image.open(img_byte_stream).convert("RGB") +``` + +### Describing an image + +```python +description_generator = outlines.generate.text(model) +description_generator( + " detailed description:", + [img_from_url("https://upload.wikimedia.org/wikipedia/commons/2/25/Siam_lilacpoint.jpg")] +) +``` + +> This is a color photograph featuring a Siamese cat with striking blue eyes. The cat has a creamy coat and a light eye color, which is typical for the Siamese breed. Its features include elongated ears, a long, thin tail, and a striking coat pattern. The cat is sitting in an indoor setting, possibly on a cat tower or a similar raised platform, which is covered with a beige fabric, providing a comfortable and soft surface for the cat to rest or perch. The surface of the wall behind the cat appears to be a light-colored stucco or plaster. + +#### Multiple Images + +To include multiple images in your prompt you simply add more `` tokens to the prompt + +```python +image_urls = [ + "https://cdn1.byjus.com/wp-content/uploads/2020/08/ShapeArtboard-1-copy-3.png", # triangle + "https://cdn1.byjus.com/wp-content/uploads/2020/08/ShapeArtboard-1-copy-11.png", # hexagon +] +description_generator = outlines.generate.text(model) +description_generator( + "What shapes are present?", + list(map(img_from_url, image_urls)), +) +``` + +> There are two shapes present. One shape is a hexagon and the other shape is an triangle. ' + + +### Classifying an Image + +```python +pattern = "Mercury|Venus|Earth|Mars|Saturn|Jupiter|Neptune|Uranus|Pluto" +planet_generator = outlines.generate.regex(model, pattern) + +planet_generator( + "What planet is this: ", + [img_from_url("https://upload.wikimedia.org/wikipedia/commons/e/e3/Saturn_from_Cassini_Orbiter_%282004-10-06%29.jpg")] +) +``` + +> Saturn + + +### Extracting Structured Image data + +```python +from pydantic import BaseModel +from typing import List, Optional + +def img_from_url(url) + +class ImageData(BaseModel): + caption: str + tags_list: List[str] + object_list: List[str] + is_photo: bool + +image_data_generator = outlines.generate.json(model, ImageData) + +image_data_generator( + " detailed JSON metadata:", + [img_from_url("https://upload.wikimedia.org/wikipedia/commons/9/98/Aldrin_Apollo_11_original.jpg")] +) +``` + +> `ImageData(caption='An astronaut on the moon', tags_list=['moon', 'space', 'nasa', 'americanflag'], object_list=['moon', 'moon_surface', 'space_suit', 'americanflag'], is_photo=True)` + + +## Resources + +### Chosing a model +- https://mmbench.opencompass.org.cn/leaderboard +- https://huggingface.co/spaces/WildVision/vision-arena diff --git a/outlines/generate/api.py b/outlines/generate/api.py index 4104e3080..ad01377c0 100644 --- a/outlines/generate/api.py +++ b/outlines/generate/api.py @@ -1,6 +1,6 @@ import datetime from dataclasses import dataclass -from typing import TYPE_CHECKING, Iterator, List, Optional, Union +from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Union from outlines.generate.generator import sequence_generator from outlines.samplers import BeamSearchSampler, GreedySampler, MultinomialSampler @@ -479,6 +479,13 @@ def format_sequence(self, sequence: str) -> FormattedOutput: """ return sequence + def _format(self, sequences): + """Apply formatting to every string in a completion.""" + if isinstance(sequences, list): + return [self._format(sequence) for sequence in sequences] + else: + return self.format_sequence(sequences) + def __call__( self, prompts: Union[str, List[str]], @@ -489,13 +496,6 @@ def __call__( ): """Generate text from a prompt of list of prompts.""" - def format(sequences): - """Apply formatting to every string in a completion.""" - if isinstance(sequences, list): - return [format(sequence) for sequence in sequences] - else: - return self.format_sequence(sequences) - generation_params = self.prepare_generation_parameters( max_tokens, stop_at, seed ) @@ -508,7 +508,7 @@ def format(sequences): **model_specific_params, ) - return format(completions) + return self._format(completions) def stream( self, @@ -529,3 +529,94 @@ def stream( self.sampling_params, **model_specific_params, ) + + +class VisionSequenceGeneratorAdapter(SequenceGeneratorAdapter): + def __call__( # type: ignore + self, + prompts: Union[str, List[str]], + media: Union[str, Any], + max_tokens: Optional[int] = None, + stop_at: Optional[Union[str, List[str]]] = None, + seed: Optional[int] = None, + **model_specific_params, + ): + """ + Generate text from a prompt of list of prompts. + + Media: A URI to construct media or media object itself. Used as AutoProcessor argument. + """ + prompts, media = self._validate_prompt_media_types(prompts, media) + + generation_params = self.prepare_generation_parameters( + max_tokens, stop_at, seed + ) + + completions = self.model.generate( + prompts, + media, + generation_params, + self.logits_processor, + self.sampling_params, + **model_specific_params, + ) + + return self._format(completions) + + def stream( # type: ignore + self, + prompts: Union[str, List[str]], + media: List[Union[str, Any, List[Union[str, Any]]]], + max_tokens: Optional[int] = None, + stop_at: Optional[Union[str, List[str]]] = None, + seed: Optional[int] = None, + **model_specific_params, + ): + """Return a text generator from a prompt or a list of prompts.""" + prompts, media = self._validate_prompt_media_types(prompts, media) + generation_params = self.prepare_generation_parameters( + max_tokens, stop_at, seed + ) + return self.model.stream( + prompts, + media, + generation_params, + self.logits_processor, + self.sampling_params, + **model_specific_params, + ) + + @classmethod + def _validate_prompt_media_types( + cls, + prompts: Union[str, List[str]], + media: Union[str, Any, List[Union[str, Any]]], + ) -> Union[Any, List[Any]]: + """ + Prepare media as PIL.Image and ensure for every prompt str there is one List[PIL.Image] + """ + + def valid_types(prompts, media): + from PIL import Image # type: ignore + + if isinstance(prompts, list): + if not isinstance(media, list) or len(prompts) != len(media): + return False + for subprompt, submedia in zip(prompts, media): + if not isinstance(subprompt, str) or not all( + isinstance(m, Image.Image) for m in submedia + ): + return False + elif isinstance(prompts, str): + if not all(isinstance(m, Image.Image) for m in media): + return False + return True + + if not valid_types(prompts, media): + raise TypeError( + "Expected (prompts, media) to be of type " + "(str, List[Image])), or (List[str], List[List[Image]]) " + f"instead got prompts={prompts}, media={media}" + ) + + return prompts, media diff --git a/outlines/generate/fsm.py b/outlines/generate/fsm.py index 832a154bd..47661b47f 100644 --- a/outlines/generate/fsm.py +++ b/outlines/generate/fsm.py @@ -3,8 +3,12 @@ import interegular from outlines.fsm.guide import RegexGuide -from outlines.generate.api import SequenceGenerator, SequenceGeneratorAdapter -from outlines.models import MLXLM, LlamaCpp, Transformers +from outlines.generate.api import ( + SequenceGenerator, + SequenceGeneratorAdapter, + VisionSequenceGeneratorAdapter, +) +from outlines.models import MLXLM, LlamaCpp, Transformers, TransformersVision from outlines.samplers import Sampler, multinomial @@ -29,3 +33,12 @@ def fsm_unified( fsm = RegexGuide.from_interegular_fsm(fsm, model.tokenizer) logits_processor = FSMLogitsProcessor(tokenizer=model.tokenizer, fsm=fsm) return SequenceGeneratorAdapter(model, logits_processor, sampler) + + +@fsm.register(TransformersVision) +def fsm_vision(model, fsm: interegular.fsm.FSM, sampler: Sampler = multinomial()): + from outlines.processors import FSMLogitsProcessor + + fsm = RegexGuide.from_interegular_fsm(fsm, model.tokenizer) + logits_processor = FSMLogitsProcessor(tokenizer=model.tokenizer, fsm=fsm) + return VisionSequenceGeneratorAdapter(model, logits_processor, sampler) diff --git a/outlines/generate/regex.py b/outlines/generate/regex.py index 52b8c7dad..3aebcd429 100644 --- a/outlines/generate/regex.py +++ b/outlines/generate/regex.py @@ -1,12 +1,19 @@ from functools import singledispatch from outlines.fsm.guide import RegexGuide -from outlines.generate.api import SequenceGenerator, SequenceGeneratorAdapter -from outlines.models import OpenAI -from outlines.models.llamacpp import LlamaCpp -from outlines.models.mlxlm import MLXLM -from outlines.models.transformers import Transformers -from outlines.models.vllm import VLLM +from outlines.generate.api import ( + SequenceGenerator, + SequenceGeneratorAdapter, + VisionSequenceGeneratorAdapter, +) +from outlines.models import ( + MLXLM, + VLLM, + LlamaCpp, + OpenAI, + Transformers, + TransformersVision, +) from outlines.samplers import Sampler, multinomial @@ -53,6 +60,18 @@ def regex_unified( return SequenceGeneratorAdapter(model, logits_processor, sampler) +@regex.register(TransformersVision) +def regex_vision( + model, + regex_str: str, + sampler: Sampler = multinomial(), +): + from outlines.processors import RegexLogitsProcessor + + logits_processor = RegexLogitsProcessor(regex_str, tokenizer=model.tokenizer) + return VisionSequenceGeneratorAdapter(model, logits_processor, sampler) + + @regex.register(VLLM) def regex_vllm( model: VLLM, diff --git a/outlines/generate/text.py b/outlines/generate/text.py index 6da187e0b..b0b4e10c7 100644 --- a/outlines/generate/text.py +++ b/outlines/generate/text.py @@ -1,8 +1,19 @@ from functools import singledispatch from outlines.fsm.guide import StopAtEOSGuide -from outlines.generate.api import SequenceGenerator, SequenceGeneratorAdapter -from outlines.models import MLXLM, VLLM, LlamaCpp, OpenAI, Transformers +from outlines.generate.api import ( + SequenceGenerator, + SequenceGeneratorAdapter, + VisionSequenceGeneratorAdapter, +) +from outlines.models import ( + MLXLM, + VLLM, + LlamaCpp, + OpenAI, + Transformers, + TransformersVision, +) from outlines.samplers import Sampler, multinomial @@ -43,6 +54,11 @@ def text_unified(model, sampler: Sampler = multinomial()): return SequenceGeneratorAdapter(model, None, sampler) +@text.register(TransformersVision) +def text_vision(model, sampler: Sampler = multinomial()): + return VisionSequenceGeneratorAdapter(model, None, sampler) + + @text.register(VLLM) def text_vllm(model: VLLM, sampler: Sampler = multinomial()): return SequenceGeneratorAdapter(model, None, sampler) diff --git a/outlines/models/__init__.py b/outlines/models/__init__.py index c161215d1..d28fcb2d7 100644 --- a/outlines/models/__init__.py +++ b/outlines/models/__init__.py @@ -5,6 +5,7 @@ codebase. """ + from typing import Union from .exllamav2 import ExLlamaV2Model, exl2 @@ -12,6 +13,7 @@ from .mlxlm import MLXLM, mlxlm from .openai import OpenAI, azure_openai, openai from .transformers import Transformers, TransformerTokenizer, mamba, transformers +from .transformers_vision import TransformersVision, transformers_vision from .vllm import VLLM, vllm LogitsGenerator = Union[Transformers, LlamaCpp, ExLlamaV2Model, MLXLM, VLLM] diff --git a/outlines/models/transformers_vision.py b/outlines/models/transformers_vision.py new file mode 100644 index 000000000..876f9bff5 --- /dev/null +++ b/outlines/models/transformers_vision.py @@ -0,0 +1,139 @@ +from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Union + +from outlines.generate.api import GenerationParameters, SamplingParameters +from outlines.models import Transformers + +if TYPE_CHECKING: + from outlines.processors import OutlinesLogitsProcessor + + +class TransformersVision(Transformers): + def __init__(self, model, tokenizer, processor): + super().__init__(model, tokenizer) + self.processor = processor + + def generate( # type: ignore + self, + prompts: Union[str, List[str]], + media: Union[List[Any], List[List[Any]]], + generation_parameters: GenerationParameters, + logits_processor: Optional["OutlinesLogitsProcessor"], + sampling_parameters: SamplingParameters, + ) -> Union[str, List[str], List[List[str]]]: + """Generate text using `transformers`. + + Arguments + --------- + prompts + A prompt or list of prompts. + media + A List[PIL.Image] or List[List[PIL.Image]] + generation_parameters + An instance of `GenerationParameters` that contains the prompt, + the maximum number of tokens, stop sequences and seed. All the + arguments to `SequenceGeneratorAdapter`'s `__cal__` method. + logits_processor + The logits processor to use when generating text. + sampling_parameters + An instance of `SamplingParameters`, a dataclass that contains + the name of the sampler to use and related parameters as available + in Outlines. + + Returns + ------- + The generated text + """ + inputs = self.processor(prompts, media, padding=True, return_tensors="pt").to( + self.model.device + ) + + generation_kwargs = self._get_generation_kwargs( + prompts, + generation_parameters, + logits_processor, + sampling_parameters, + ) + generated_ids = self._generate_output_seq(prompts, inputs, **generation_kwargs) + + # if single str input and single sample per input, convert to a 1D output + if isinstance(prompts, str): + # Should always be true until NotImplementedError above is fixed + generated_ids = generated_ids.squeeze(0) + + return self._decode_generation(generated_ids) + + def stream( # type: ignore + self, + prompts: Union[str, List[str]], + media: Union[Any, List[Any]], # TODO: docstring + generation_parameters: GenerationParameters, + logits_processor: Optional["OutlinesLogitsProcessor"], + sampling_parameters: SamplingParameters, + ) -> Iterator[Union[str, List[str]]]: + raise NotImplementedError + + +def transformers_vision( + model_name: str, + device: Optional[str] = None, + model_kwargs: dict = {}, + processor_kwargs: dict = {}, + model_class=None, + tokenizer_class=None, + processor_class=None, +): + """Instantiate a model from the `transformers` library and its tokenizer. + + Parameters + ---------- + model_name + The name of the model as listed on Hugging Face's model page. + device + The device(s) on which the model should be loaded. This overrides + the `device_map` entry in `model_kwargs` when provided. + model_kwargs + A dictionary that contains the keyword arguments to pass to the + `from_pretrained` method when loading the model. + processor_kwargs + A dictionary that contains the keyword arguments to pass to the + `from_pretrained` method when loading the processor. + + Returns + ------- + A `TransformersModel` model instance. + + """ + if model_class is None or tokenizer_class is None: + try: + from transformers import ( + AutoTokenizer, + LlavaNextForConditionalGeneration, + LlavaNextProcessor, + ) + except ImportError: + raise ImportError( + "The `transformers` library needs to be installed in order to use `transformers` models." + ) + if model_class is None: + model_class = LlavaNextForConditionalGeneration + if processor_class is None: + processor_class = LlavaNextProcessor + + if device is not None: + model_kwargs["device_map"] = device + + model = model_class.from_pretrained(model_name, **model_kwargs) + + processor_kwargs.setdefault("padding_side", "left") + processor_kwargs.setdefault("pad_token", "[PAD]") + processor = processor_class.from_pretrained(model_name, **processor_kwargs) + + if tokenizer_class is None: + if getattr(processor, "tokenizer", None): + tokenizer = processor.tokenizer + else: + tokenizer = AutoTokenizer.from_pretrained(model_name, **processor_kwargs) + else: + tokenizer = tokenizer_class.from_pretrained(model_name, **processor_kwargs) + + return TransformersVision(model, tokenizer, processor) diff --git a/pyproject.toml b/pyproject.toml index aa88fcbbc..f94b3c84d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,6 +62,7 @@ test = [ "vllm; sys_platform != 'darwin'", "torch", "transformers", + "pillow", ] serve = [ "vllm>=0.3.0", diff --git a/tests/generate/test_api.py b/tests/generate/test_api.py new file mode 100644 index 000000000..7188022f5 --- /dev/null +++ b/tests/generate/test_api.py @@ -0,0 +1,33 @@ +from io import BytesIO +from urllib.request import urlopen + +import pytest +from PIL import Image # type: ignore + +from outlines.generate.api import VisionSequenceGeneratorAdapter + +IMG_URI = "https://upload.wikimedia.org/wikipedia/en/a/a9/Example.jpg" +PIL_IMG = Image.open(BytesIO(urlopen(IMG_URI).read())).convert("RGB") + + +@pytest.mark.parametrize( + "prompts,media,type_error", + [ + ("single prompt", [PIL_IMG], False), + (["prompt0", "prompt1"], [[PIL_IMG], [PIL_IMG]], False), + ("single prompt", [PIL_IMG, PIL_IMG], False), + (["prompt0", "prompt1"], [[PIL_IMG, PIL_IMG], [PIL_IMG]], False), + ("single prompt", "this isn't an image, it's a string", True), + ("single prompt", PIL_IMG, True), + (["prompt0", "prompt1"], [PIL_IMG], True), + (["prompt0", "prompt1"], [[PIL_IMG]], True), + (["prompt0", "prompt1"], [[[PIL_IMG]], [[PIL_IMG]]], True), + ], +) +def test_vision_sequence_generator_validate_types(prompts, media, type_error): + """Ensure inputs are validated correctly""" + if type_error: + with pytest.raises(TypeError): + VisionSequenceGeneratorAdapter._validate_prompt_media_types(prompts, media) + else: + VisionSequenceGeneratorAdapter._validate_prompt_media_types(prompts, media) diff --git a/tests/generate/test_generate.py b/tests/generate/test_generate.py index 06d311f5b..b40e8b002 100644 --- a/tests/generate/test_generate.py +++ b/tests/generate/test_generate.py @@ -7,6 +7,10 @@ import outlines.models as models import outlines.samplers as samplers +########################################## +# Model Fixtures +########################################## + @pytest.fixture(scope="session") def model_llamacpp(tmp_path_factory): @@ -50,6 +54,17 @@ def model_bart(tmp_path_factory): ) +@pytest.fixture(scope="session") +def model_transformers_vision(tmp_path_factory): + import torch + + return models.transformers_vision( + "llava-hf/llava-v1.6-mistral-7b-hf", + device="cuda", + model_kwargs=dict(torch_dtype=torch.bfloat16), + ) + + # TODO: exllamav2 failing in main, address in https://github.com/outlines-dev/outlines/issues/808 # TODO: t5 tokenizer doesn't work with streaming """ @@ -78,15 +93,42 @@ def model_t5(tmp_path_factory): "model_transformers_opt125m", "model_mamba", "model_bart", + # "model_transformers_vision", # tests pass, but awaiting a tiny model for CI ) -NOT_IMPLEMENTED = { - "stream": [], - "batch": ["model_llamacpp", "model_mlxlm", "model_mlxlm_phi3"], - "beam_search": ["model_llamacpp", "model_mlxlm", "model_mlxlm_phi3"], - "multiple_samples": ["model_llamacpp", "model_mlxlm", "model_mlxlm_phi3"], -} +########################################## +# Stuctured Generation Inputs +########################################## + + +@pytest.fixture() +def sample_schema(): + from pydantic import BaseModel, conint, conlist, constr + + class SampleSchema(BaseModel): + title: constr(max_length=10) + numbers: conlist(conint(strict=True), min_length=3, max_length=3) + labels: conlist(constr(min_length=1, max_length=5), min_length=3, max_length=3) + + return SampleSchema + + +@pytest.fixture() +def sample_choices(): + return ["foo", "bar", "baz"] + + +REGEX_PATTERNS = [ + "a b c d e", # ensure proper tokenizer whitespace prefix handling + "(123456789)|(abcdefghijklmnop)", # ensure consistent correct sequence handling during batch + r"([a-z]{10})@([a-z]{5})\.([a-z]{3})", # email example +] + + +########################################### +# Model/Generator Pair Behavior Definitions +########################################### def enforce_not_implemented(model_fixture, *task_names): @@ -94,6 +136,12 @@ def enforce_not_implemented(model_fixture, *task_names): Per `NOT_IMPLEMENTED`, mapping, if a model hasn't implemented a task, assert an NotImplementedError is raised. Otherwise, run normally """ + NOT_IMPLEMENTED = { + "stream": ["model_transformers_vision"], + "batch": ["model_llamacpp", "model_mlxlm", "model_mlxlm_phi3"], + "beam_search": ["model_llamacpp", "model_mlxlm", "model_mlxlm_phi3"], + "multiple_samples": ["model_llamacpp", "model_mlxlm", "model_mlxlm_phi3"], + } for task_name in task_names: if model_fixture in NOT_IMPLEMENTED.get(task_name, []): return pytest.raises(NotImplementedError) @@ -101,11 +149,35 @@ def enforce_not_implemented(model_fixture, *task_names): return contextlib.nullcontext() -REGEX_PATTERNS = [ - "a b c d e", # ensure proper tokenizer whitespace prefix handling - "(123456789)|(abcdefghijklmnop)", # ensure consistent correct sequence handling during batch - r"([a-z]{10})@([a-z]{5})\.([a-z]{3})", # email example -] +def get_inputs(fixture_name, batch_size=None): + """Get generator kwargs, just the prompt by default, but include images for transformers_visian""" + from io import BytesIO + from urllib.request import urlopen + + from PIL import Image # type: ignore + + prompts = ["abcd", "efgh", "1234", "5678", "foo", "bar", "baz", "bif"] + prompts = prompts[0] if batch_size is None else prompts[:batch_size] + + if fixture_name.endswith("_vision"): + img_url = "https://python-pillow.org/pillow-perf/static/space_pil_lanczos.png" + img = Image.open(BytesIO(urlopen(img_url).read())).convert("RGB") + + if batch_size is None: + return {"prompts": f" {prompts}", "media": [img]} + else: + return { + "prompts": [f" {p}" for p in prompts], + "media": [[img] for _ in range(batch_size)], + } + + else: + return {"prompts": prompts} + + +########################################### +# Tests +########################################### @pytest.mark.parametrize("sampler_name", ("greedy", "multinomial", "beam_search")) @@ -114,27 +186,17 @@ def test_generate_text(request, model_fixture, sampler_name): model = request.getfixturevalue(model_fixture) generator = generate.text(model, getattr(samplers, sampler_name)()) with enforce_not_implemented(model_fixture, sampler_name): - res = generator("test", max_tokens=10) + res = generator(**get_inputs(model_fixture), max_tokens=10) assert isinstance(res, str) +@pytest.mark.parametrize("pattern", REGEX_PATTERNS) @pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES) -def test_generate_batch_text(request, model_fixture): - model = request.getfixturevalue(model_fixture) - generator = generate.text(model) - with enforce_not_implemented(model_fixture, "batch"): - res = generator(["test", "test2"], max_tokens=10) - assert isinstance(res, list) - assert isinstance(res[0], str) - - -@pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES) -def test_generate_text_stream(request, model_fixture): +def test_generate_regex(request, model_fixture, pattern): model = request.getfixturevalue(model_fixture) - generator = generate.text(model) - with enforce_not_implemented(model_fixture, "stream"): - for token in generator.stream("a b c ", max_tokens=10): - assert isinstance(token, str) + generator = generate.regex(model, pattern) + res = generator(**get_inputs(model_fixture), max_tokens=20) + assert re.fullmatch(pattern, res) is not None, res @pytest.mark.parametrize("pattern", REGEX_PATTERNS) @@ -144,17 +206,41 @@ def test_generate_fsm(request, model_fixture, pattern): model = request.getfixturevalue(model_fixture) generator = generate.fsm(model, interegular.parse_pattern(pattern).to_fsm()) - res = generator("test") + res = generator(**get_inputs(model_fixture)) assert re.fullmatch(pattern, res) is not None, res -@pytest.mark.parametrize("pattern", REGEX_PATTERNS) @pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES) -def test_generate_regex(request, model_fixture, pattern): +def test_generate_json(request, model_fixture, sample_schema): model = request.getfixturevalue(model_fixture) - generator = generate.regex(model, pattern) - res = generator("foobarbaz", max_tokens=20) - assert re.fullmatch(pattern, res) is not None, res + generator = generate.json(model, sample_schema) + # asserts valid within call + generator(**get_inputs(model_fixture), max_tokens=100) + + +@pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES) +def test_generate_choice(request, model_fixture, sample_choices): + model = request.getfixturevalue(model_fixture) + generator = generate.choice(model, sample_choices) + res = generator(**get_inputs(model_fixture)) + assert res in sample_choices + + +@pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES) +def test_generate_format_bool(request, model_fixture): + model = request.getfixturevalue(model_fixture) + generator = generate.format(model, bool) + res = generator(**get_inputs(model_fixture)) + assert isinstance(res, bool) + + +@pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES) +def test_generate_text_stream(request, model_fixture): + model = request.getfixturevalue(model_fixture) + generator = generate.text(model) + with enforce_not_implemented(model_fixture, "stream"): + for token in generator.stream(**get_inputs(model_fixture), max_tokens=10): + assert isinstance(token, str) @pytest.mark.parametrize("pattern", REGEX_PATTERNS) @@ -164,23 +250,19 @@ def test_generate_regex_stream(request, model_fixture, pattern): generator = generate.regex(model, pattern) with enforce_not_implemented(model_fixture, "stream"): output = "" - for token in generator.stream("output:", max_tokens=20): + for token in generator.stream(**get_inputs(model_fixture), max_tokens=20): output += token assert re.fullmatch(pattern, output) is not None, output -@pytest.mark.parametrize("pattern", REGEX_PATTERNS) @pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES) -def test_generate_regex_batch_stream(request, model_fixture, pattern): +def test_generate_batch_text(request, model_fixture): model = request.getfixturevalue(model_fixture) - generator = generate.regex(model, pattern) - with enforce_not_implemented(model_fixture, "batch", "stream"): - outputs = ["", ""] - for tokens in generator.stream(["input 0", "input 1"], max_tokens=20): - outputs[0] += tokens[0] - outputs[1] += tokens[1] - for output in outputs: - assert re.fullmatch(pattern, output) is not None, output + generator = generate.text(model) + with enforce_not_implemented(model_fixture, "batch"): + res = generator(**get_inputs(model_fixture, 2), max_tokens=10) + assert isinstance(res, list) + assert isinstance(res[0], str) @pytest.mark.parametrize("pattern", REGEX_PATTERNS) @@ -190,44 +272,50 @@ def test_generate_regex_batch(request, model_fixture, pattern): model = request.getfixturevalue(model_fixture) generator = generate.regex(model, pattern) with enforce_not_implemented(model_fixture, "batch"): - outputs = generator(["abc", "123", "123bce", "33aa"], max_tokens=20) + outputs = generator(**get_inputs(model_fixture, 4), max_tokens=20) for output in outputs: assert re.fullmatch(pattern, output) is not None, output @pytest.mark.parametrize("pattern", REGEX_PATTERNS) @pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES) -def test_generate_regex_single_multinomial(request, model_fixture, pattern): - """Ensure batch requests work and fsm order is maintained""" +def test_generate_regex_batch_stream(request, model_fixture, pattern): model = request.getfixturevalue(model_fixture) - generator = generate.regex(model, pattern, sampler=samplers.multinomial(4)) - with enforce_not_implemented(model_fixture, "multiple_samples"): - output_sample_groups = generator("single input", max_tokens=40) - for output in output_sample_groups: + generator = generate.regex(model, pattern) + with enforce_not_implemented(model_fixture, "batch", "stream"): + outputs = ["", ""] + for tokens in generator.stream(**get_inputs(model_fixture, 2), max_tokens=20): + outputs[0] += tokens[0] + outputs[1] += tokens[1] + for output in outputs: assert re.fullmatch(pattern, output) is not None, output @pytest.mark.parametrize("pattern", REGEX_PATTERNS) @pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES) -def test_generate_regex_batch_multinomial(request, model_fixture, pattern): +def test_generate_regex_single_multinomial(request, model_fixture, pattern): """Ensure batch requests work and fsm order is maintained""" model = request.getfixturevalue(model_fixture) generator = generate.regex(model, pattern, sampler=samplers.multinomial(4)) - with enforce_not_implemented(model_fixture, "batch", "multiple_samples"): - output_batch_groups = generator(["abc", "123", "123bce", "33aa"], max_tokens=40) - for output_sample_groups in output_batch_groups: - for output in output_sample_groups: - assert re.fullmatch(pattern, output) is not None, output + with enforce_not_implemented(model_fixture, "multiple_samples"): + output_sample_groups = generator(**get_inputs(model_fixture), max_tokens=40) + for output in output_sample_groups: + assert re.fullmatch(pattern, output) is not None, output @pytest.mark.parametrize("pattern", REGEX_PATTERNS) @pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES) -def test_generate_regex_batch_beam_search(request, model_fixture, pattern): +@pytest.mark.parametrize("sampler_name", ("multinomial", "beam_search")) +def test_generate_regex_batch_multi_sample( + request, model_fixture, pattern, sampler_name +): """Ensure batch requests work and fsm order is maintained""" model = request.getfixturevalue(model_fixture) - generator = generate.regex(model, pattern, sampler=samplers.beam_search(4)) + generator = generate.regex( + model, pattern, sampler=getattr(samplers, sampler_name)(4) + ) with enforce_not_implemented(model_fixture, "batch", "multiple_samples"): - output_batch_groups = generator(["abc", "123", "123bce", "33aa"], max_tokens=40) + output_batch_groups = generator(**get_inputs(model_fixture, 4), max_tokens=40) for output_sample_groups in output_batch_groups: for output in output_sample_groups: assert re.fullmatch(pattern, output) is not None, output