From 781a83a4ea19c9ce9693d647acf33f7f505dd82a Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Mon, 6 Nov 2023 17:08:12 +0800 Subject: [PATCH] [format] applied code formatting on changed files in pull request 4926 (#5007) Co-authored-by: github-actions --- colossalai/inference/tensor_parallel/engine.py | 4 ++-- tests/test_infer/test_pipeline_infer.py | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py index 283f719e57fc..3be2132748e0 100644 --- a/colossalai/inference/tensor_parallel/engine.py +++ b/colossalai/inference/tensor_parallel/engine.py @@ -218,7 +218,7 @@ def _shard_model_by(self, shardformer: ShardFormer, model: nn.Module) -> None: ), "Discrepancy between the tp size of TPInferEngine and the tp size of shard config" model_name = model.__class__.__name__ assert model_name in self.supported_models, f"Unsupported model cls {model_name} for TP inference." - + model = model.model if self.shard_config.inference_gptq else model policy = get_autopolicy(model, shard_config=self.shard_config) @@ -311,7 +311,7 @@ def prepare_batch_state(self, inputs) -> BatchInferState: seq_start_indexes[i] = start_index start_index += curr_seq_len max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch - + block_loc = torch.empty((batch_size, self.max_input_len + self.max_output_len), dtype=torch.long, device="cuda") batch_infer_state = BatchInferState(batch_size, max_len_in_batch) batch_infer_state.seq_len = seq_lengths.to("cuda") diff --git a/tests/test_infer/test_pipeline_infer.py b/tests/test_infer/test_pipeline_infer.py index 1cf38c1ec19f..3544153da857 100644 --- a/tests/test_infer/test_pipeline_infer.py +++ b/tests/test_infer/test_pipeline_infer.py @@ -24,6 +24,7 @@ def data_gen(): new_shape[0] = 16 inputs[k] = v.to("cuda").repeat(*new_shape) + def pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size): model = transformers.LlamaForCausalLM( transformers.LlamaConfig( @@ -58,7 +59,6 @@ def run_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_si @parameterize("pp_size", [2]) @parameterize("max_output_len", [4]) @parameterize("micro_batch_size", [1]) - @clear_cache_before_run() def run_tp_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size): pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size) @@ -76,7 +76,6 @@ def check_tp_pipeline_inference(rank, world_size, port): @pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") - @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run()