Skip to content

Commit

Permalink
Add Neva Template for NV-DPO Models (#8358)
Browse files Browse the repository at this point in the history
* add/rename from nvgpt to nv_steerlm, add nv_dpo template

Signed-off-by: HuiyingLi <[email protected]>

* add nv_dpo conversation to accomendate empty system message

Signed-off-by: HuiyingLi <[email protected]>

* handle nv_dpo template text generation

Signed-off-by: HuiyingLi <[email protected]>

* add prompt string to nvgpt

Signed-off-by: HuiyingLi <[email protected]>

* bugfix for inference prompt template

Signed-off-by: HuiyingLi <[email protected]>

* bug fix for grabbing clean text

Signed-off-by: Huiying Li <[email protected]>

* fix code format

Signed-off-by: Huiying Li <[email protected]>

---------

Signed-off-by: HuiyingLi <[email protected]>
Signed-off-by: Huiying Li <[email protected]>
  • Loading branch information
HuiyingLi authored and yaoyu-33 committed Feb 26, 2024
1 parent 8f2c5c3 commit 5406b84
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 5 deletions.
13 changes: 13 additions & 0 deletions nemo/collections/multimodal/data/neva/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,17 @@ def dict(self):
sep2=f"{DEFAULT_SYSTEM_TOKEN}System\n",
)

conv_nv_dpo = Conversation(
system="\n",
roles=("User", "Assistant"),
version="nv_dpo",
messages=(),
offset=0,
sep_style=SeparatorStyle.NVGPT,
sep=DEFAULT_SEPARATOR_TOKEN,
sep2=f"{DEFAULT_SYSTEM_TOKEN}System\n",
)

conv_vicuna_v0 = Conversation(
system="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
Expand Down Expand Up @@ -400,6 +411,8 @@ def dict(self):
"v1_mmtag": conv_llava_v1_mmtag,
"llava_llama_2": conv_llava_llama_2,
"nvgpt": conv_nvgpt,
"nv_steerlm": conv_nvgpt,
"nv_dpo": conv_nv_dpo,
}


Expand Down
105 changes: 104 additions & 1 deletion nemo/collections/multimodal/data/neva/neva_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,8 @@ def preprocess_nvgpt(sources: dict, tokenizer, cfg,) -> Dict:
- The function asserts that each message in a conversation alternates between the defined roles and skips messages not starting with the 'human' role.
"""

"""<extra_id_0>System\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n\n<extra_id_1>User\n{user input}\n<extra_id_1>Assistant\n<extra_id_2>quality:4,toxicity:0,humor:0,creativity:0,helpfulness:4,correctness:4,coherence:4,complexity:4,verbosity:4\n"""

conv = conversation_lib.conv_nvgpt.copy()

# Apply prompt templates
Expand Down Expand Up @@ -462,6 +464,105 @@ def preprocess_nvgpt(sources: dict, tokenizer, cfg,) -> Dict:
return dict(tokens=tokens, labels=labels,)


def preprocess_nv_dpo(sources: dict, tokenizer, cfg,) -> Dict:
"""
Preprocess a given set of conversational sources using nvgpt conversation template
This function processes conversations by first ensuring the conversation starts with a 'human' role, then tokenizes the conversations, applies specific token replacements, and finally masks labels for training purposes.
Parameters:
- sources: A dictionary containing conversational data. Expected format is a dict of conversations, where each conversation is a list of messages, and each message is a dict with 'from' (role) and 'value' (message text).
- tokenizer: A tokenizer from the Hugging Face Transformers library used for tokenizing the conversations.
- cfg: Configuration settings which include 'add_extra_token' (bool) to determine if an extra token should be added to the tokenized output, and 'context_length' for specifying the tokenization context length.
Returns:
- Dict: A dictionary containing two keys:
- 'tokens': A tensor of tokenized conversation data.
- 'labels': A tensor of labels for the conversation data, used for training models. Labels are masked based on the conversation structure.
Note:
- The function includes specific token replacements (e.g., DEFAULT_IMAGE_PATCH_TOKEN, <s>, </s>) and masking techniques for labels.
- It is designed to work with conversational data where messages alternate between a 'human' and a 'gpt' role.
- The function asserts that each message in a conversation alternates between the defined roles and skips messages not starting with the 'human' role.
"""

"""<extra_id_0>System\n\n<extra_id_1>User\n{user input}\n<extra_id_1>Assistant\n"""

conv = conversation_lib.conv_nv_dpo.copy()

# Apply prompt templates
conversations = []
for source in sources:
conv.messages = []
conv.system = source.get('system', conv.system)

strip_end_for_inference = False
for i, turn in enumerate(source['conversations']):

if i % 2 == 1:
turn['from'] = conv.roles[1]
conv.append_message(turn['from'], turn['value'])
if not turn["value"]:
strip_end_for_inference = (
True # in inference, current turn is empty, thus end tokens need to striped.
)
else:
turn['from'] = conv.roles[0]
conv.append_message(turn['from'], turn['value'])
context = conv.get_prompt()
if strip_end_for_inference:
if context.endswith("\n<extra_id_1>"):
context = context[: -len("\n<extra_id_1>")] + "\n"
conversations.append(context)

add_extra_token = cfg.get("add_extra_token")
# Tokenize conversations
tokens = tokenize(
texts=conversations,
tokenizer=tokenizer,
context_length=cfg.get("context_length"),
add_extra_token=add_extra_token,
)

labels = tokens.clone().detach()

# Mask targets
sep = conv.sep + conv.roles[1] + "\n"
for conversation, target in zip(conversations, labels):
rounds = conversation.split(conv.sep)
re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt

for conv_idx in range(3, len(rounds), 2):
re_rounds.append(conv.sep.join(rounds[conv_idx : conv_idx + 2])) # user + gpt

cur_len = 0
for i, rou in enumerate(re_rounds):
if rou == "":
break
parts = rou.split(sep)
if len(parts) != 2:
break

instruction_len = len(tokenizer.text_to_ids(parts[0] + sep))
round_len = len(tokenizer.text_to_ids(rou + conv.sep))
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX

cur_len += round_len
target[cur_len:] = IGNORE_INDEX

# Check if masking working correctly
# print([x for x in zip(tokens[0].numpy().tolist(), labels[0].numpy().tolist())])

if add_extra_token:
tokens = tokens[:, :-1].contiguous()
labels = labels[:, 1:].contiguous()
else:
labels = torch.roll(labels, shifts=-1, dims=-1)
labels[:, -1] = IGNORE_INDEX

return dict(tokens=tokens, labels=labels,)


def preprocess_plain(sources, tokenizer, cfg,) -> Dict:
"""
Preprocesses plain text sources (no template) for tokenization and label generation.
Expand Down Expand Up @@ -604,8 +705,10 @@ def expand2square(pil_img, background_color):
images_tensors = torch.tensor([])
sources = copy.deepcopy(sources)

if self.conv_template == "nvgpt":
if self.conv_template in ["nvgpt", "nv_steerlm"]:
data_dict = preprocess_nvgpt(sources, self.tokenizer, self.multimodal_cfg,)
elif self.conv_template == "nv_dpo":
data_dict = preprocess_nv_dpo(sources, self.tokenizer, self.multimodal_cfg,)
elif self.conv_template == "v1":
data_dict = preprocess_v1(sources, self.tokenizer, self.multimodal_cfg,)
elif self.conv_template == "llama_2":
Expand Down
12 changes: 9 additions & 3 deletions nemo/collections/nlp/modules/common/text_generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,14 +329,17 @@ def neva_process_prompts(prompt, tokenizer, multimodal_cfg, num_media_latents, c
DEFAULT_IMAGE_TOKEN,
preprocess_llama_2,
preprocess_multimodal,
preprocess_nv_dpo,
preprocess_nvgpt,
preprocess_v1,
)

list_data_dict = []
if multimodal_cfg["conv_template"] == "nvgpt":
if multimodal_cfg["conv_template"] in ["nvgpt", "nv_steerlm", "nv_dpo"]:
record = {
'system': 'A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user\'s questions.\n\n',
'system': '\n'
if multimodal_cfg["conv_template"] == 'nv_dpo'
else 'A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user\'s questions.\n\n',
'conversations': [{'from': 'User', 'value': prompt}, {'from': 'Assistant', 'value': '',},],
}

Expand All @@ -348,7 +351,10 @@ def neva_process_prompts(prompt, tokenizer, multimodal_cfg, num_media_latents, c
sources = preprocess_multimodal(
copy.deepcopy(list_data_dict), multimodal_cfg, num_media_latents
) # HARDCODED FOR NOW
data_dict = preprocess_nvgpt(sources, tokenizer, multimodal_cfg)
if multimodal_cfg["conv_template"] in ["nvgpt", "nv_steerlm"]:
data_dict = preprocess_nvgpt(sources, tokenizer, multimodal_cfg)
else:
data_dict = preprocess_nv_dpo(sources, tokenizer, multimodal_cfg)

elif multimodal_cfg["conv_template"] == "llama_2":
record = {
Expand Down
4 changes: 3 additions & 1 deletion nemo/collections/nlp/modules/common/text_generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,14 +181,16 @@ def megatron_neva_generate(model, prompt_dict_list, length_params, sampling_para

clean_response = clean_text

if conv_template == "nvgpt":
if conv_template in ["nvgpt", "nv_steerlm"]:
labels_str_regexp = re.compile(f"<extra_id_2>quality:.*\n")
last_match_end_position = None
for match in re.finditer(labels_str_regexp, clean_response):
last_match_end_position = match.end()
if last_match_end_position is not None:
clean_response = clean_response[last_match_end_position:]
clean_response = clean_response.strip("<extra_id_1>")
elif conv_template == 'nv_dpo':
clean_response = clean_response.split("<extra_id_1>")[-2][10:] # [10:] for removing "Assistant\n"
elif conv_template == "llama_2":
clean_response = clean_response.rsplit("[/INST] ", 1)[-1]
elif conv_template == "v1":
Expand Down

0 comments on commit 5406b84

Please sign in to comment.