Skip to content

Commit

Permalink
add assertion for config (hpcaitech#4947)
Browse files Browse the repository at this point in the history
  • Loading branch information
CjhHa1 committed Oct 20, 2023
1 parent 02ab17e commit 1e3cdfa
Show file tree
Hide file tree
Showing 6 changed files with 13 additions and 9 deletions.
3 changes: 2 additions & 1 deletion colossalai/inference/async_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,12 @@ def _step(self):
"""
request_outputs = self.driver.step()
if request_outputs is not None:
print("request_outputs: ", request_outputs)
for request_output in request_outputs:
self._request_tracker.process_request_output(request_output)
self._request_tracker.add_stop()

def abort(self, request_id: str):
def abort_request(self, request_id: str):
self.driver.abort(request_id)

def _has_requests_in_progress(self):
Expand Down
7 changes: 6 additions & 1 deletion colossalai/inference/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,12 @@ def __init__(
running_max_req_size = self.engine.max_batch_size if self.engine is not None else 2
self.req_queue = ReqQueue(max_total_token_num, batch_max_tokens, running_max_req_size, waiting_req_list)
# all the inputs should be put into req_queue: waiting req list

assert max_total_token_num >= self.engine.max_batch_size * (
self.engine.max_input_len + self.engine.max_output_len
), "max_total_token_num should be greater than max_batch_size * (max_input_len+max_output_len)"
assert (
batch_max_tokens >= self.engine.max_input_len + self.engine.max_output_len
), "batch_max_tokens should be greater than (max_input_len+max_output_len)"
self.running_batch: Batch = running_batch
self.eos_id = eos_id
self.has_wait_tokens = 0
Expand Down
2 changes: 0 additions & 2 deletions colossalai/inference/tensor_parallel/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def __init__(
self.max_input_len = max_input_len
self.max_output_len = max_output_len
self.max_total_token_num = self.max_batch_size * (self.max_input_len + self.max_output_len)

# Constraints relatable with specs of devices and model
# This may change into an optional arg in the future
assert self.max_batch_size <= 64, "Max batch size exceeds the constraint"
Expand Down Expand Up @@ -380,7 +379,6 @@ def forward(self, batch_id, is_prefill):
Forward is used in Dynamic Batching Manager
"""
batch = self.cache.pop(batch_id)

if is_prefill:
input_ = torch.tensor(batch.all_input_ids).cuda()
else:
Expand Down
6 changes: 3 additions & 3 deletions tests/test_infer/test_dynamic_batching/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ engine_config:
max_input_len: 128
max_output_len: 32
# config for app router deployment
# Resources assigned to each model replica. This should correspond to Ray AIR ScalingConfig?
# Resources assigned to each model replica. This should correspond to Ray AIR ScalingConfig.
router_config:
max_total_token_num: 42
batch_max_tokens: 42
max_total_token_num: 640
batch_max_tokens: 640
eos_id: 0
disable_log_stats: False
log_stats_interval: 10
Expand Down
3 changes: 2 additions & 1 deletion tests/test_infer/test_dynamic_batching/test_async_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def run_async_engine(path: str):
if model is None or not os.path.exists(model):
return

prompt = "Introduce some landmarks in Beijing"
prompt = "Introduce some landmarks in London.\nThe Tower of London is a historic castle on the north bank of the River Thames in central London. It was founded towards the end of 10"
sampling_params = SamplingParams()
asyncio.run(asy_for_loop_test(config, prompt, sampling_params))

Expand All @@ -32,6 +32,7 @@ async def get_result(engine, prompt, sampling_params):
request_id = str(uuid.uuid4().hex)
results = engine.generate(request_id, prompt, sampling_params)
async for result in results:
# print(result)
assert result is not None


Expand Down
1 change: 0 additions & 1 deletion tests/test_infer/test_dynamic_batching/test_ray_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@


def run_ray_dist(path: str):
print(f"Using yaml file {path}")
if not os.path.exists(path):
return
config = RayInitConfig.from_yaml_path(path)
Expand Down

0 comments on commit 1e3cdfa

Please sign in to comment.