Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Neva template #8358

Merged
merged 7 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions nemo/collections/multimodal/data/neva/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,17 @@ def dict(self):
sep2=f"{DEFAULT_SYSTEM_TOKEN}System\n",
)

conv_nv_dpo = Conversation(
system="\n",
roles=("User", "Assistant"),
version="nv_dpo",
messages=(),
offset=0,
sep_style=SeparatorStyle.NVGPT,
sep=DEFAULT_SEPARATOR_TOKEN,
sep2=f"{DEFAULT_SYSTEM_TOKEN}System\n",
)

conv_vicuna_v0 = Conversation(
system="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
Expand Down Expand Up @@ -400,6 +411,8 @@ def dict(self):
"v1_mmtag": conv_llava_v1_mmtag,
"llava_llama_2": conv_llava_llama_2,
"nvgpt": conv_nvgpt,
"nv_steerlm": conv_nvgpt,
"nv_dpo": conv_nv_dpo,
}


Expand Down
105 changes: 104 additions & 1 deletion nemo/collections/multimodal/data/neva/neva_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,8 @@ def preprocess_nvgpt(sources: dict, tokenizer, cfg,) -> Dict:
- The function asserts that each message in a conversation alternates between the defined roles and skips messages not starting with the 'human' role.
"""

"""<extra_id_0>System\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n\n<extra_id_1>User\n{user input}\n<extra_id_1>Assistant\n<extra_id_2>quality:4,toxicity:0,humor:0,creativity:0,helpfulness:4,correctness:4,coherence:4,complexity:4,verbosity:4\n"""

conv = conversation_lib.conv_nvgpt.copy()

# Apply prompt templates
Expand Down Expand Up @@ -462,6 +464,105 @@ def preprocess_nvgpt(sources: dict, tokenizer, cfg,) -> Dict:
return dict(tokens=tokens, labels=labels,)


def preprocess_nv_dpo(sources: dict, tokenizer, cfg,) -> Dict:
"""
Preprocess a given set of conversational sources using nvgpt conversation template
This function processes conversations by first ensuring the conversation starts with a 'human' role, then tokenizes the conversations, applies specific token replacements, and finally masks labels for training purposes.
Parameters:
- sources: A dictionary containing conversational data. Expected format is a dict of conversations, where each conversation is a list of messages, and each message is a dict with 'from' (role) and 'value' (message text).
- tokenizer: A tokenizer from the Hugging Face Transformers library used for tokenizing the conversations.
- cfg: Configuration settings which include 'add_extra_token' (bool) to determine if an extra token should be added to the tokenized output, and 'context_length' for specifying the tokenization context length.
Returns:
- Dict: A dictionary containing two keys:
- 'tokens': A tensor of tokenized conversation data.
- 'labels': A tensor of labels for the conversation data, used for training models. Labels are masked based on the conversation structure.
Note:
- The function includes specific token replacements (e.g., DEFAULT_IMAGE_PATCH_TOKEN, <s>, </s>) and masking techniques for labels.
- It is designed to work with conversational data where messages alternate between a 'human' and a 'gpt' role.
- The function asserts that each message in a conversation alternates between the defined roles and skips messages not starting with the 'human' role.
"""

"""<extra_id_0>System\n\n<extra_id_1>User\n{user input}\n<extra_id_1>Assistant\n"""
HuiyingLi marked this conversation as resolved.
Show resolved Hide resolved

conv = conversation_lib.conv_nv_dpo.copy()

# Apply prompt templates
conversations = []
for source in sources:
conv.messages = []
conv.system = source.get('system', conv.system)

strip_end_for_inference = False
for i, turn in enumerate(source['conversations']):

if i % 2 == 1:
turn['from'] = conv.roles[1]
conv.append_message(turn['from'], turn['value'])
if not turn["value"]:
strip_end_for_inference = (
True # in inference, current turn is empty, thus end tokens need to striped.
)
else:
turn['from'] = conv.roles[0]
conv.append_message(turn['from'], turn['value'])
context = conv.get_prompt()
if strip_end_for_inference:
if context.endswith("\n<extra_id_1>"):
context = context[: -len("\n<extra_id_1>")] + "\n"
conversations.append(context)

add_extra_token = cfg.get("add_extra_token")
# Tokenize conversations
tokens = tokenize(
texts=conversations,
tokenizer=tokenizer,
context_length=cfg.get("context_length"),
add_extra_token=add_extra_token,
)

labels = tokens.clone().detach()

# Mask targets
sep = conv.sep + conv.roles[1] + "\n"
for conversation, target in zip(conversations, labels):
rounds = conversation.split(conv.sep)
re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt

for conv_idx in range(3, len(rounds), 2):
re_rounds.append(conv.sep.join(rounds[conv_idx : conv_idx + 2])) # user + gpt

cur_len = 0
for i, rou in enumerate(re_rounds):
if rou == "":
break
parts = rou.split(sep)
if len(parts) != 2:
break

instruction_len = len(tokenizer.text_to_ids(parts[0] + sep))
round_len = len(tokenizer.text_to_ids(rou + conv.sep))
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX

cur_len += round_len
target[cur_len:] = IGNORE_INDEX

# Check if masking working correctly
# print([x for x in zip(tokens[0].numpy().tolist(), labels[0].numpy().tolist())])

if add_extra_token:
tokens = tokens[:, :-1].contiguous()
labels = labels[:, 1:].contiguous()
else:
labels = torch.roll(labels, shifts=-1, dims=-1)
labels[:, -1] = IGNORE_INDEX

return dict(tokens=tokens, labels=labels,)


def preprocess_plain(sources, tokenizer, cfg,) -> Dict:
"""
Preprocesses plain text sources (no template) for tokenization and label generation.
Expand Down Expand Up @@ -604,8 +705,10 @@ def expand2square(pil_img, background_color):
images_tensors = torch.tensor([])
sources = copy.deepcopy(sources)

if self.conv_template == "nvgpt":
if self.conv_template in ["nvgpt", "nv_steerlm"]:
data_dict = preprocess_nvgpt(sources, self.tokenizer, self.multimodal_cfg,)
elif self.conv_template == "nv_dpo":
data_dict = preprocess_nv_dpo(sources, self.tokenizer, self.multimodal_cfg,)
elif self.conv_template == "v1":
data_dict = preprocess_v1(sources, self.tokenizer, self.multimodal_cfg,)
elif self.conv_template == "llama_2":
Expand Down
12 changes: 9 additions & 3 deletions nemo/collections/nlp/modules/common/text_generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,14 +329,17 @@ def neva_process_prompts(prompt, tokenizer, multimodal_cfg, num_media_latents, c
DEFAULT_IMAGE_TOKEN,
preprocess_llama_2,
preprocess_multimodal,
preprocess_nv_dpo,
preprocess_nvgpt,
preprocess_v1,
)

list_data_dict = []
if multimodal_cfg["conv_template"] == "nvgpt":
if multimodal_cfg["conv_template"] in ["nvgpt", "nv_steerlm", "nv_dpo"]:
record = {
'system': 'A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user\'s questions.\n\n',
'system': '\n'
if multimodal_cfg["conv_template"] == 'nv_dpo'
else 'A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user\'s questions.\n\n',
'conversations': [{'from': 'User', 'value': prompt}, {'from': 'Assistant', 'value': '',},],
}

Expand All @@ -348,7 +351,10 @@ def neva_process_prompts(prompt, tokenizer, multimodal_cfg, num_media_latents, c
sources = preprocess_multimodal(
copy.deepcopy(list_data_dict), multimodal_cfg, num_media_latents
) # HARDCODED FOR NOW
data_dict = preprocess_nvgpt(sources, tokenizer, multimodal_cfg)
if multimodal_cfg["conv_template"] in ["nvgpt", "nv_steerlm"]:
data_dict = preprocess_nvgpt(sources, tokenizer, multimodal_cfg)
else:
data_dict = preprocess_nv_dpo(sources, tokenizer, multimodal_cfg)

elif multimodal_cfg["conv_template"] == "llama_2":
record = {
Expand Down
4 changes: 3 additions & 1 deletion nemo/collections/nlp/modules/common/text_generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,14 +181,16 @@ def megatron_neva_generate(model, prompt_dict_list, length_params, sampling_para

clean_response = clean_text

if conv_template == "nvgpt":
if conv_template in ["nvgpt", "nv_steerlm"]:
labels_str_regexp = re.compile(f"<extra_id_2>quality:.*\n")
last_match_end_position = None
for match in re.finditer(labels_str_regexp, clean_response):
last_match_end_position = match.end()
if last_match_end_position is not None:
clean_response = clean_response[last_match_end_position:]
clean_response = clean_response.strip("<extra_id_1>")
elif conv_template == 'nv_dpo':
clean_response = clean_response.split("<extra_id_1>")[-2][10:] # [10:] for removing "Assistant\n"
elif conv_template == "llama_2":
clean_response = clean_response.rsplit("[/INST] ", 1)[-1]
elif conv_template == "v1":
Expand Down
Loading