Skip to content

Commit

Permalink
[format] applied code formatting on changed files in pull request 4926 (
Browse files Browse the repository at this point in the history
hpcaitech#5007)

Co-authored-by: github-actions <[email protected]>
  • Loading branch information
2 people authored and flybird11111 committed Nov 10, 2023
1 parent dba3e04 commit 781a83a
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 4 deletions.
4 changes: 2 additions & 2 deletions colossalai/inference/tensor_parallel/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def _shard_model_by(self, shardformer: ShardFormer, model: nn.Module) -> None:
), "Discrepancy between the tp size of TPInferEngine and the tp size of shard config"
model_name = model.__class__.__name__
assert model_name in self.supported_models, f"Unsupported model cls {model_name} for TP inference."

model = model.model if self.shard_config.inference_gptq else model
policy = get_autopolicy(model, shard_config=self.shard_config)

Expand Down Expand Up @@ -311,7 +311,7 @@ def prepare_batch_state(self, inputs) -> BatchInferState:
seq_start_indexes[i] = start_index
start_index += curr_seq_len
max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch

block_loc = torch.empty((batch_size, self.max_input_len + self.max_output_len), dtype=torch.long, device="cuda")
batch_infer_state = BatchInferState(batch_size, max_len_in_batch)
batch_infer_state.seq_len = seq_lengths.to("cuda")
Expand Down
3 changes: 1 addition & 2 deletions tests/test_infer/test_pipeline_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def data_gen():
new_shape[0] = 16
inputs[k] = v.to("cuda").repeat(*new_shape)


def pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
model = transformers.LlamaForCausalLM(
transformers.LlamaConfig(
Expand Down Expand Up @@ -58,7 +59,6 @@ def run_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_si
@parameterize("pp_size", [2])
@parameterize("max_output_len", [4])
@parameterize("micro_batch_size", [1])

@clear_cache_before_run()
def run_tp_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
Expand All @@ -76,7 +76,6 @@ def check_tp_pipeline_inference(rank, world_size, port):


@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")

@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
Expand Down

0 comments on commit 781a83a

Please sign in to comment.