Skip to content

Commit

Permalink
implement llava & fix bugs for fi
Browse files Browse the repository at this point in the history
  • Loading branch information
J1shen committed Oct 10, 2024
1 parent 1145b16 commit 0e66a95
Show file tree
Hide file tree
Showing 7 changed files with 912 additions and 321 deletions.
120 changes: 120 additions & 0 deletions clients/python/client_llava.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
#Edited by: Junyi Shen

import requests
from typing import Dict, Optional, List
from text_generation.types import Parameters, Grammar, Response
from text_generation.errors import parse_error
import base64
import argparse

class Client_Llava:
def __init__(
self,
base_url: str,
headers: Optional[Dict[str, str]] = None,
cookies: Optional[Dict[str, str]] = None,
timeout: int = 10,
):
self.base_url = base_url
self.headers = headers
self.cookies = cookies
self.timeout = timeout

def generate(
self,
prompt: str,
input_image: Optional[str] = None,
do_sample: bool = False,
max_new_tokens: int = 20,
best_of: Optional[int] = None,
repetition_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None,
return_full_text: bool = False,
seed: Optional[int] = None,
stop_sequences: Optional[List[str]] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
truncate: Optional[int] = None,
typical_p: Optional[float] = None,
watermark: bool = False,
decoder_input_details: bool = False,
top_n_tokens: Optional[int] = None,
grammar: Optional[Grammar] = None,
lora_id: Optional[str] = None,
) -> Response:

parameters = {
"best_of": best_of if best_of is not None else 1,
"details": True,
"do_sample": do_sample,
"max_new_tokens": max_new_tokens,
"repetition_penalty": repetition_penalty,
"frequency_penalty": frequency_penalty,
"return_full_text": return_full_text,
"seed": seed,
"stop": stop_sequences if stop_sequences is not None else [],
"temperature": temperature,
"top_k": top_k,
"top_p": top_p,
"truncate": truncate,
"typical_p": typical_p,
"watermark": watermark,
"decoder_input_details": decoder_input_details,
"top_n_tokens": top_n_tokens,
"grammar": grammar,
"lora_id": lora_id,
}

with open(input_image, "rb") as f:
image = base64.b64encode(f.read()).decode("utf-8")
image = f"data:image/png;base64,{image}"

prompt = f"![]({image}){prompt}\n\n"

request = {
"inputs": prompt,
"stream": False,
"parameters": parameters,
}

request = {
"inputs": prompt,
"parameters": parameters,
}

response = requests.post(
f"{self.base_url}/generate",
json=request,
headers=self.headers,
cookies=self.cookies,
timeout=self.timeout,
)

if response.status_code == 404:
raise parse_error(
response.status_code,
{"error": "Service not found.", "errory_type": "generation"},
)

payload = response.json()
if response.status_code != 200:
raise parse_error(response.status_code, payload)
print(payload)
return Response(**payload)

if __name__ == "__main__":
argparser = argparse.ArgumentParser()
argparser.add_argument("--url", type=str, default="http://127.0.0.1:3000")
argparser.add_argument("--prompt", type=str, default="What is in the picture?")
argparser.add_argument("--input_image", type=str, default="server/examples/images/4.png")
args = argparser.parse_args()

client = Client_Llava(args.url)

response = client.generate(
prompt=args.prompt,
input_image=args.input_image,
)

print(response.generated_text)
Binary file added server/examples/images/4.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
14 changes: 7 additions & 7 deletions server/examples/test_llava.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from text_generation_server.pb import generate_pb2
from text_generation_server.models_flashinfer.flashinfer_llava import LlavaLM, LlavaBatch
import random, torch
from text_generation_server.models_flashinfer.flashinfer_llava import FlashinferLlava
import random
import base64
from collections import defaultdict

service = LlavaLM(model_id="llava-hf/llava-v1.6-vicuna-7b-hf")
tokenizer = service.language_model.tokenizer
service = FlashinferLlava(model_id="llava-hf/llava-v1.6-vicuna-7b-hf")
tokenizer = service.tokenizer

prompts = [
'How many people are in the image?',
Expand Down Expand Up @@ -51,15 +51,15 @@ def make_input(id = 0, prompt=None, image = None):
requests = [make_input(i) for i in range(5)]
batch = generate_pb2.Batch(id = 0, requests = requests, size = len(requests))
display_results = defaultdict(lambda: [])

service.warmup(batch)
# Iterative generation: each step generates a token for each input in the batch
isPrefill = True
while True:
if isPrefill:
generations, next_batch, _ = service.prefill_batch(batch)
generations, next_batch, info = service.prefill_batch(batch)
isPrefill = False
else:
generations, next_batch, _, _ = service.decode_batch([next_batch.to_pb()])
generations, next_batch, info = service.decode_batch([next_batch.to_pb()])

for gen in generations:
if gen.prefill_tokens:
Expand Down
4 changes: 2 additions & 2 deletions server/text_generation_server/models_flashinfer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from text_generation_server.models_flashinfer.flashinfer_chatglm import (
FlashinferChatGLM,
)
from text_generation_server.models_flashinfer.flashinfer_llava import LlavaLM
from text_generation_server.models_flashinfer.flashinfer_llava import FlashinferLlava

# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
Expand Down Expand Up @@ -188,7 +188,7 @@ def get_model(
trust_remote_code=trust_remote_code,
)
elif model_type == LLAVA_NEXT:
return LlavaLM(
return FlashinferLlava(
model_id,
revision=revision,
quantize=quantize,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ def _decode_text(
output_text = self.tokenizer.decode(
request_context.output_ids[request_context.prompt_len :],
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
skip_special_tokens=True,
)
generated_text = GeneratedText(
output_text,
Expand Down Expand Up @@ -492,7 +492,7 @@ def _decode_text(
None,
)
generations.append(generation)
return generations, all_stop
return generations, all_stop

@tracer.start_as_current_span("generate_token")
@torch.no_grad()
Expand Down
Loading

0 comments on commit 0e66a95

Please sign in to comment.