From 4a264bcdf19023957a376fa0b3ea5aacc5130d3b Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Mon, 4 Nov 2024 12:14:47 -0800 Subject: [PATCH] ruff --- server/lorax_server/models/causal_lm.py | 6 +++--- server/lorax_server/models/flash_qwen2.py | 6 +++--- server/lorax_server/models/flash_roberta.py | 6 +++--- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/server/lorax_server/models/causal_lm.py b/server/lorax_server/models/causal_lm.py index c642ecf0..7eeaa9f3 100644 --- a/server/lorax_server/models/causal_lm.py +++ b/server/lorax_server/models/causal_lm.py @@ -596,11 +596,11 @@ def generate_token(self, batch: CausalLMBatch) -> Tuple[List[Generation], Option # TODO(travis): don't update this if indices haven't changed # Use prefill=True in all cases to force use of SGMV, as the batch is heterogenous adapter_data = AdapterBatchData.from_meta( - meta=batch.adapter_meta, - weights=self.layer_to_adapter_weights, + meta=batch.adapter_meta, + weights=self.layer_to_adapter_weights, layer_to_lora_weights={}, punica_wrapper=None, - prefill=True, + prefill=True, prefill_head_indices=None, ) diff --git a/server/lorax_server/models/flash_qwen2.py b/server/lorax_server/models/flash_qwen2.py index 10e2d369..f2c70687 100644 --- a/server/lorax_server/models/flash_qwen2.py +++ b/server/lorax_server/models/flash_qwen2.py @@ -122,11 +122,11 @@ def embed(self, batch) -> torch.Tensor: adapter_meta = batch.adapter_meta prefill = False adapter_data = AdapterBatchData.from_meta( - meta=adapter_meta, - weights=self.layer_to_adapter_weights, + meta=adapter_meta, + weights=self.layer_to_adapter_weights, layer_to_lora_weights={}, punica_wrapper=None, - prefill=prefill, + prefill=prefill, prefill_head_indices=batch.prefill_head_indices, ) embedding, _ = self.forward(batch, adapter_data=adapter_data) diff --git a/server/lorax_server/models/flash_roberta.py b/server/lorax_server/models/flash_roberta.py index 617c2419..74768336 100644 --- a/server/lorax_server/models/flash_roberta.py +++ b/server/lorax_server/models/flash_roberta.py @@ -210,11 +210,11 @@ def forward(self, batch: FlashEmbeddingClassificationBatch): @tracer.start_as_current_span("embed") def embed(self, batch: FlashEmbeddingClassificationBatch) -> Embedding: adapter_data = AdapterBatchData.from_meta( - meta=batch.adapter_meta, - weights=self.layer_to_adapter_weights, + meta=batch.adapter_meta, + weights=self.layer_to_adapter_weights, layer_to_lora_weights={}, punica_wrapper=None, - prefill=False, + prefill=False, prefill_head_indices=None, )