Skip to content

Commit

Permalink
Format
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair committed Nov 1, 2024
1 parent b2de54f commit 35c7de2
Show file tree
Hide file tree
Showing 36 changed files with 372 additions and 503 deletions.
17 changes: 13 additions & 4 deletions server/lorax_server/adapters/medusa.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Type

from loguru import logger
import torch
import torch.distributed
from loguru import logger

from lorax_server.adapters.config import AdapterConfig, ModuleMap
from lorax_server.adapters.types import MEDUSA
from lorax_server.adapters.weights import AdapterBatchMetadata, AdapterWeights, BatchAdapterWeights
from lorax_server.layers import FastLinear, TensorParallelColumnLinear
from lorax_server.utils.segments import find_segments
from lorax_server.utils.punica import segmented_matmul
from lorax_server.utils.segments import find_segments
from lorax_server.utils.state import LORAX_SPECULATION_MAX_BATCH_SIZE, get_speculative_tokens
from lorax_server.utils.weights import AbstractWeights, InMemoryWeights

Expand All @@ -22,6 +22,7 @@

_MEDUSA_ENABLED = False


@dataclass
class MedusaConfig(AdapterConfig):
medusa_num_heads: int
Expand Down Expand Up @@ -312,11 +313,19 @@ def load(
default_medusa=default_medusa,
segments=MedusaSegments(
w=[
(adapter_weights[idx].model.medusa.linear.linear.weight.data if idx in adapter_weights else EMPTY_TENSOR)
(
adapter_weights[idx].model.medusa.linear.linear.weight.data
if idx in adapter_weights
else EMPTY_TENSOR
)
for idx in segment_indices
],
b=[
(adapter_weights[idx].model.medusa.linear.linear.bias.data if idx in adapter_weights else EMPTY_TENSOR)
(
adapter_weights[idx].model.medusa.linear.linear.bias.data
if idx in adapter_weights
else EMPTY_TENSOR
)
for idx in segment_indices
],
s_start=segments[indices],
Expand Down
8 changes: 4 additions & 4 deletions server/lorax_server/adapters/weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,10 @@ def from_meta(
if layer_weights:
data[k] = layer_weights
return AdapterBatchData(
meta=meta,
data=data,
layer_to_lora_weights=layer_to_lora_weights,
punica_wrapper=punica_wrapper,
meta=meta,
data=data,
layer_to_lora_weights=layer_to_lora_weights,
punica_wrapper=punica_wrapper,
prefill=prefill,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -539,11 +539,11 @@ def forward(
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

if skip_lm_head:
# FIXME: simply running the LM head is not sufficient since we also need to scale the logits
return hidden_states, None

logits, speculative_logits = self.lm_head(hidden_states, adapter_data)
logits *= self.logit_scale
if speculative_logits is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1024,9 +1024,9 @@ def forward(
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

if skip_lm_head:
return hidden_states, None

logits, speculative_logits = self.lm_head(hidden_states, adapter_data)
return logits, speculative_logits
Original file line number Diff line number Diff line change
Expand Up @@ -555,9 +555,9 @@ def forward(
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

if skip_lm_head:
return hidden_states, None

logits, speculative_logits = self.lm_head(hidden_states, adapter_data)
return logits, speculative_logits
Original file line number Diff line number Diff line change
Expand Up @@ -616,9 +616,9 @@ def forward(
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

if skip_lm_head:
return hidden_states, None

logits, speculative_logits = self.lm_head(hidden_states, adapter_data)
return logits, speculative_logits
Original file line number Diff line number Diff line change
Expand Up @@ -636,9 +636,9 @@ def forward(
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

if skip_lm_head:
return hidden_states, None

logits, speculative_logits = self.lm_head(hidden_states, adapter_data)
return logits, speculative_logits
Original file line number Diff line number Diff line change
Expand Up @@ -988,9 +988,9 @@ def forward(
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

if skip_lm_head:
return hidden_states, None

logits, speculative_logits = self.lm_head(hidden_states, adapter_data)
return logits, speculative_logits
Original file line number Diff line number Diff line change
Expand Up @@ -371,9 +371,9 @@ def forward(
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

if skip_lm_head:
return hidden_states, None

logits = self.embed_out(hidden_states)
return logits, None
Original file line number Diff line number Diff line change
Expand Up @@ -521,9 +521,9 @@ def forward(
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

if skip_lm_head:
return hidden_states, None

logits, speculative_logits = self.lm_head(hidden_states, adapter_data)
return logits, speculative_logits
Original file line number Diff line number Diff line change
Expand Up @@ -403,9 +403,9 @@ def forward(
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

if skip_lm_head:
return hidden_states, None

logits, speculative_logits = self.lm_head(hidden_states, adapter_data)
return logits, speculative_logits
Original file line number Diff line number Diff line change
Expand Up @@ -537,10 +537,10 @@ def forward(

if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

if skip_lm_head:
return hidden_states, None

logits, speculative_logits = self.lm_head(hidden_states, adapter_data)
return logits, speculative_logits

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -522,9 +522,9 @@ def forward(
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

if skip_lm_head:
return hidden_states, None

logits, speculative_logits = self.lm_head(hidden_states, adapter_data)
return logits, speculative_logits
Original file line number Diff line number Diff line change
Expand Up @@ -606,9 +606,9 @@ def forward(
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

if skip_lm_head:
return hidden_states, None

logits = self.lm_head(hidden_states)
return logits, None
Original file line number Diff line number Diff line change
Expand Up @@ -437,9 +437,9 @@ def forward(
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

if skip_lm_head:
return hidden_states, None

logits = self.lm_head(hidden_states)
return logits, None
4 changes: 2 additions & 2 deletions server/lorax_server/models/custom_modeling/llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,9 +265,9 @@ def forward(
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

if skip_lm_head:
return hidden_states, None

logits, speculative_logits = self.text_model.lm_head(hidden_states, adapter_data)
return logits, speculative_logits
Loading

0 comments on commit 35c7de2

Please sign in to comment.