Skip to content

Commit

Permalink
Plumb skip_lm_head
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair committed Nov 1, 2024
1 parent 1c70ec6 commit 3ebcbea
Show file tree
Hide file tree
Showing 16 changed files with 70 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,7 @@ def forward(
adapter_data: AdapterBatchData,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
skip_lm_head: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
hidden_states = self.model(
input_ids,
Expand All @@ -538,6 +539,11 @@ def forward(
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

if skip_lm_head:
# FIXME: simply running the LM head is not sufficient since we also need to scale the logits
return hidden_states, None

logits, speculative_logits = self.lm_head(hidden_states, adapter_data)
logits *= self.logit_scale
if speculative_logits is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1009,6 +1009,7 @@ def forward(
adapter_data: AdapterBatchData,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
skip_lm_head: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
hidden_states = self.model(
input_ids,
Expand All @@ -1023,5 +1024,9 @@ def forward(
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

if skip_lm_head:
return hidden_states, None

logits, speculative_logits = self.lm_head(hidden_states, adapter_data)
return logits, speculative_logits
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,7 @@ def forward(
adapter_data: AdapterBatchData,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
skip_lm_head: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
input_embeds = self.embed_tokens(input_ids)
hidden_states = self.model(
Expand All @@ -554,5 +555,9 @@ def forward(
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

if skip_lm_head:
return hidden_states, None

logits, speculative_logits = self.lm_head(hidden_states, adapter_data)
return logits, speculative_logits
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,7 @@ def forward(
adapter_data: AdapterBatchData,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
skip_lm_head: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
hidden_states = self.model(
input_ids,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ def forward(
adapter_data: AdapterBatchData,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
skip_lm_head: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
hidden_states = self.transformer(
input_ids,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,7 @@ def forward(
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
cross_attention_states: Optional[torch.Tensor] = None,
skip_lm_head: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = self.model(
Expand All @@ -615,5 +616,9 @@ def forward(
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

if skip_lm_head:
return hidden_states, None

logits, speculative_logits = self.lm_head(hidden_states, adapter_data)
return logits, speculative_logits
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,7 @@ def forward(
adapter_data: AdapterBatchData,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
skip_lm_head: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if prefill_cache_indices is not None:
# Slots also need to be sliced as it has the same size as the whole kv tensor
Expand All @@ -635,5 +636,9 @@ def forward(
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

if skip_lm_head:
return hidden_states, None

logits, speculative_logits = self.lm_head(hidden_states, adapter_data)
return logits, speculative_logits
Original file line number Diff line number Diff line change
Expand Up @@ -963,6 +963,7 @@ def forward(
adapter_data: AdapterBatchData,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
skip_lm_head: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if prefill_cache_indices is not None:
# Slots also need to be sliced as it has the same size as the whole kv tensor
Expand All @@ -987,5 +988,9 @@ def forward(
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

if skip_lm_head:
return hidden_states, None

logits, speculative_logits = self.lm_head(hidden_states, adapter_data)
return logits, speculative_logits
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,7 @@ def forward(
max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
skip_lm_head: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
hidden_states = self.gpt_neox(
input_ids,
Expand All @@ -370,5 +371,9 @@ def forward(
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

if skip_lm_head:
return hidden_states, None

logits = self.embed_out(hidden_states)
return logits, None
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,7 @@ def forward(
adapter_data: AdapterBatchData,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
skip_lm_head: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
hidden_states = self.model(
input_ids,
Expand All @@ -520,5 +521,9 @@ def forward(
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

if skip_lm_head:
return hidden_states, None

logits, speculative_logits = self.lm_head(hidden_states, adapter_data)
return logits, speculative_logits
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,7 @@ def forward(
adapter_data: AdapterBatchData,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
skip_lm_head: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
hidden_states = self.model(
input_ids,
Expand All @@ -402,5 +403,9 @@ def forward(
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

if skip_lm_head:
return hidden_states, None

logits, speculative_logits = self.lm_head(hidden_states, adapter_data)
return logits, speculative_logits
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,7 @@ def forward(
adapter_data: AdapterBatchData,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
skip_lm_head: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
hidden_states = self.transformer(
input_ids,
Expand All @@ -521,5 +522,9 @@ def forward(
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

if skip_lm_head:
return hidden_states, None

logits, speculative_logits = self.lm_head(hidden_states, adapter_data)
return logits, speculative_logits
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,7 @@ def forward(
max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
skip_lm_head: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
hidden_states = self.transformer(
input_ids,
Expand All @@ -605,5 +606,9 @@ def forward(
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

if skip_lm_head:
return hidden_states, None

logits = self.lm_head(hidden_states)
return logits, None
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,7 @@ def forward(
max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
skip_lm_head: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
hidden_states = self.transformer(
input_ids,
Expand All @@ -436,5 +437,9 @@ def forward(
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

if skip_lm_head:
return hidden_states, None

logits = self.lm_head(hidden_states)
return logits, None
5 changes: 5 additions & 0 deletions server/lorax_server/models/custom_modeling/llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def forward(
pixel_attention_mask=None,
image_sizes: Optional[torch.LongTensor] = None,
adapter_data: Optional["AdapterBatchData"] = None,
skip_lm_head: bool = False,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if pixel_values is not None and len(pixel_values) > 0:
Expand Down Expand Up @@ -264,5 +265,9 @@ def forward(
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

if skip_lm_head:
return hidden_states, None

logits, speculative_logits = self.text_model.lm_head(hidden_states, adapter_data)
return logits, speculative_logits
2 changes: 2 additions & 0 deletions server/lorax_server/models/custom_modeling/mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,6 +884,7 @@ def forward(
# XXX: Putting these as optional so that the cuda warmup calls can go through.
cross_attention_states: Optional[torch.Tensor] = None,
image_indices=None,
skip_lm_head: bool = False,
):
if cross_attention_states is not None:
seqlen_q = len(image_indices)
Expand Down Expand Up @@ -954,6 +955,7 @@ def forward(
prefill_cache_indices=prefill_cache_indices,
lm_head_indices=lm_head_indices,
cross_attention_states=cross_attention_states,
skip_lm_head=skip_lm_head,
)

return outputs

0 comments on commit 3ebcbea

Please sign in to comment.