Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The inference result is all like a mess #63

Open
lucasjinreal opened this issue Jun 7, 2024 · 2 comments
Open

The inference result is all like a mess #63

lucasjinreal opened this issue Jun 7, 2024 · 2 comments

Comments

@lucasjinreal
Copy link

Hi, I try using llava to inference the pllava model, the result is really hard to debug,

the output:

USER: bfxs
/data/miniconda3/envs/env-3.9.2/lib/python3.9/site-packages/transformers/generation/configuration_utils.py:520: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.9` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`.
  warnings.warn(
ASSISTANT: hidden_states: torch.Size([16, 576, 5120])
input torch.Size([16, 576, 5120]) num_videos 1 frame_shape [24, 24] 
Carefully watch the video and pay attention to the cause and sequence of events, the detail and movement of objects, and the action and pose of persons. Based on your observations, select the best option that accurately addresses the question.
 USER: 
bfxs ASSISTANT:
AI:   Carefully watch the video and pay attention to the cause and sequence of events, the detail and movement of objects, and the action and pose of persons. Based on your observations, select the best option that accurately addresses the question.
 USER: 
bfxs ASSISTANT:
:
USER: 

Is simply repeat the input texts.


def load_pllava(
    repo_id,
    num_frames,
    use_lora=False,
    weight_dir=None,
    lora_alpha=32,
    use_multi_gpus=False,
    pooling_shape=(16, 12, 12),
):
    kwargs = {
        "num_frames": num_frames,
    }
    # print("===============>pooling_shape", pooling_shape)
    if num_frames == 0:
        kwargs.update(
            pooling_shape=(0, 12, 12)
        )  # produce a bug if ever usen the pooling projector
    config = PllavaConfig.from_pretrained(
        repo_id if not use_lora else weight_dir,
        pooling_shape=pooling_shape,
        **kwargs,
    )

    with torch.no_grad():
        model = PllavaForConditionalGeneration.from_pretrained(
            repo_id, config=config, torch_dtype=dtype
        )

    try:
        processor = PllavaProcessor.from_pretrained(repo_id)
    except Exception as e:
        processor = PllavaProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")

    # load weights
    if weight_dir is not None:
        state_dict = {}
        save_fnames = os.listdir(weight_dir)
        if "model.safetensors" in save_fnames:
            use_full = False
            for fn in save_fnames:
                if fn.startswith("model-0"):
                    use_full = True
                    break
        else:
            use_full = True

        if not use_full:
            print("Loading weight from", weight_dir, "model.safetensors")
            with safe_open(
                f"{weight_dir}/model.safetensors", framework="pt", device="cpu"
            ) as f:
                for k in f.keys():
                    state_dict[k] = f.get_tensor(k)
        else:
            print("Loading weight from", weight_dir)
            for fn in save_fnames:
                if fn.startswith("model-0"):
                    with safe_open(
                        f"{weight_dir}/{fn}", framework="pt", device="cpu"
                    ) as f:
                        for k in f.keys():
                            state_dict[k] = f.get_tensor(k)

        if "model" in state_dict.keys():
            msg = model.load_state_dict(state_dict["model"], strict=False)
        else:
            msg = model.load_state_dict(state_dict, strict=False)
        print(msg)
    model.to("cuda")
    model.eval()
    return model, processor


def pllava_answer(
    conv,
    model,
    processor,
    img_list,
    do_sample=True,
    max_new_tokens=200,
    num_beams=1,
    min_length=1,
    top_p=0.9,
    repetition_penalty=1.0,
    length_penalty=1,
    temperature=1.0,
    stop_criteria_keywords=None,
    print_res=False,
):
    # torch.cuda.empty_cache()
    prompt = conv.get_prompt()
    inputs = processor(text=prompt, images=img_list, return_tensors="pt")
    if inputs["pixel_values"] is None:
        inputs.pop("pixel_values")
    inputs = inputs.to(model.device)

    # set up stopping criteria
    if stop_criteria_keywords is not None:
        stopping_criteria = [
            KeywordsStoppingCriteria(
                stop_criteria_keywords, processor.tokenizer, inputs["input_ids"]
            )
        ]
    else:
        stopping_criteria = None

    with torch.no_grad():
        output_token = model.generate(
            **inputs,
            media_type="video",
            # media_type="image",
            do_sample=do_sample,
            max_new_tokens=max_new_tokens,
            num_beams=num_beams,
            min_length=min_length,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
            length_penalty=length_penalty,
            temperature=temperature,
            stopping_criteria=stopping_criteria,
        )
        output_text = processor.batch_decode(
            output_token, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )[0]
        print(output_text)

    if print_res:  # debug usage
        # print("### PROMPTING LM WITH: ", prompt)
        print("AI:  ", output_text)
    if conv.roles[-1] == "<|im_start|>assistant\n":
        split_tag = "<|im_start|> assistant\n"
    else:
        split_tag = conv.roles[-1]
    output_text = output_text.split(split_tag)[-1]
    ending = conv.sep if isinstance(conv.sep, str) else conv.sep[1]
    output_text = output_text.removesuffix(ending).strip()
    conv.messages[-1][1] = output_text
    return output_text, conv


def main(args):
    disable_torch_init()
    model_name = get_model_name_from_path(args.model_path)
    tokenizer, model, processor = load_model_auto(args.model_path, dtype=dtype)
    model = model.to(dtype)
    print(f"using dtype: {dtype}")

    image = load_image(args.image_file)
    image_tensor = (
        processor(images=image, return_tensors="pt")["pixel_values"]
        .to(model.device)
        .to(dtype)
    )

    while True:
        conv = conv_templates["vicuna_v1"].copy()
        conv.system = "Carefully watch the video and pay attention to the cause and sequence of events, the detail and movement of objects, and the action and pose of persons. Based on your observations, select the best option that accurately addresses the question.\n"

        try:
            inp = input(f"{conv.roles[0]}: ")
        except EOFError:
            inp = ""
        if not inp:
            print("exit...")
            break

        if is_image(inp):
            image = load_image(inp)
            image_tensor = (
                image_processor(images=image, return_tensors="pt")["pixel_values"]
                .to(model.device)
                .to(dtype)
            )
            # print('updated new image')
            # clear conv history
            conv.messages = []
            print("Updated image, start new chat session.")
            continue

        print(f"{conv.roles[1]}: ", end="")

        # conv.user_query("Describe the video in details.", is_mm=True)

        if image is not None:
            # first message
            inp = DEFAULT_IMAGE_TOKEN + "\n" + inp
            conv.append_message(conv.roles[0], inp)
            # image = None
        else:
            # later messages
            conv.append_message(conv.roles[0], inp)
        conv.append_message(conv.roles[1], None)
        # prompt = conv.get_prompt()
        img_list = [image] * 16
        llm_response, conv = pllava_answer(
            conv=conv,
            model=model,
            processor=processor,
            do_sample=False,
            img_list=img_list,
            max_new_tokens=256,
            print_res=True,
        )
        print(llm_response)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model-path", type=str, default="checkpoints/llava-qwen-4b-finetune/"
    )
    parser.add_argument("--model-base", type=str, default=None)
    parser.add_argument("--image-file", type=str, default="images/kobe.jpg")
    parser.add_argument("--num-gpus", type=int, default=1)
    parser.add_argument("--conv-mode", type=str, default=None)
    parser.add_argument("--temperature", type=float, default=0.2)
    parser.add_argument("--max-new-tokens", type=int, default=512)
    parser.add_argument("--load-8bit", action="store_true")
    parser.add_argument("--load-4bit", action="store_true")
    parser.add_argument("--debug", action="store_true")
    args = parser.parse_args()
    main(args)

Above is the simlest inference code which borrows from pllava, but the result is always not right.

I using a single image repeat to 16 frames feed into as a single video.

Any help?

@ermu2001
Copy link
Collaborator

ermu2001 commented Jun 7, 2024

Can you share the running terminal log for this code, and also the execution script

@lucasjinreal
Copy link
Author

the log is just when print out the response, it just repeat the prompt only.

the args is. python run.py --model-path checkpoints/pllava_7b/

I don't know why it just repeat the prompt only

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants