Skip to content

Commit

Permalink
Determine trace set from preloaded adapter set
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair committed Nov 1, 2024
1 parent 43c129b commit 1c70ec6
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 2 deletions.
2 changes: 1 addition & 1 deletion server/lorax_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1389,7 +1389,7 @@ def warmup(self, batch: FlashCausalLMBatch, max_new_tokens: int, embedding_model
self.device,
self.kv_cache,
self.adapter_layers,
self.default_traced_adapter_layers,
self.traced_adapter_layers,
self._forward_context,
max_total_tokens,
self.num_heads,
Expand Down
8 changes: 7 additions & 1 deletion server/lorax_server/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,12 @@ def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
@property
def adapter_layers(self) -> List[str]:
return []

@property
def traced_adapter_layers(self) -> List[str]:
if self.layer_to_adapter_weights:
return list(self.layer_to_adapter_weights.keys())
return self.default_traced_adapter_layers

@property
def default_traced_adapter_layers(self) -> List[str]:
Expand Down Expand Up @@ -279,7 +285,7 @@ def register_preloaded_adapters(
for layer_name, layer_adapter_weights in self.layer_to_adapter_weights.items():
layer_id_to_lora_a_weights = defaultdict(list)
layer_id_to_lora_b_weights = defaultdict(list)
for i, adapter in enumerate(preloaded_adapters):
for adapter in preloaded_adapters:
adapter_index = adapter.adapter_index
adapter_weights = layer_adapter_weights.adapter_weights.get(adapter_index)
if not isinstance(adapter_weights, LoraWeights):
Expand Down
1 change: 1 addition & 0 deletions server/lorax_server/utils/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,7 @@ def get_estimated_cache_memory(self) -> int:
def warmup(self):
ngraphs = len(CACHED_BATCH_SIZES) * len(CACHED_MAX_RANKS)
pool = None
logger.info("Tracing CUDA graphs with initial adapter layers: {}", self.default_traced_adapter_layers)
with tqdm(total=ngraphs, desc="Trace CUDA graphs") as pbar:
for batch_size in reversed(CACHED_BATCH_SIZES):
pbar.set_postfix({"batch_size": batch_size})
Expand Down

0 comments on commit 1c70ec6

Please sign in to comment.