Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create Fiber Pools to Enable Batch Requests to Shortfin LLM Server #428

Open
stbaione opened this issue Nov 5, 2024 · 0 comments
Open
Assignees

Comments

@stbaione
Copy link
Contributor

stbaione commented Nov 5, 2024

High-Level Summary

Through enabling serving benchmark tests with SGLang, I found a bug with Shortfin LLM Server in relation to batch requests, due to attempted concurrent invocations on the same fiber.

Relevant Error

RuntimeError: Cannot make concurrent invocations of a PER_FIBER program from the same Fiber. This typically means that two
 invocations were attempted on the same program on the same fiber without an await. Consider fixing adding appropriate 
sequencing or switching to either PER_CALL or NONE isolation if appropriate for the use case. This exception can also occur if the first 
invocation to this Program failed, leaving no initialized Program for this fiber.

Proposed Solution

We should maintain a Pool of available fibers in shortfin, where we obtain an available fiber whenever one is needed, and return it to the pool when no longer needed (when idle). This idea comes from this PR comment: #360 (comment)

"Yeah, we're missing a data structure. If I read this right, you just want to be able to pick a free fiber, right? There should really be
 some kind of a Pool or something which we don't have yet, but you can fake it with a simple fiber idle list: all available fibers go on
 the idle_list. Then when you need one, you pop and have the fiber put itself back when done. Something like that. You'd typically use
 a data structure that can yield if none are available but I think you are somehow never managing to underflow here? The underflow
 blocking could be faked today with a Queue or something like that."

Reproduction Steps/Further Details

After starting a fresh shortfin server (in this case for GPU) python -m shortfin_apps.llm.server --tokenizer=/data/llama3.1/8b/tokenizer.json --model_config=../../export_mi300/config.json --vmfb=../../export_mi300/model.vmfb --parameters=/data/llama3.1/8b/llama8b_f16.irpa --device=hip,

SGLang Code

We can send a batch request using SGLang with the following code:

@sgl.function
def multi_turn_question(s, question_1, question_2):
    s += sgl.system("You are a helpful assistant.")
    s += sgl.user(question_1)
    s += sgl.assistant(sgl.gen("answer_1", max_tokens=256))
    s += sgl.user(question_2)
    s += sgl.assistant(sgl.gen("answer_2", max_tokens=256))
    
def batch():
    states = multi_turn_question.run_batch(
        [
            {
                "question_1": "What is the capital of the United States?",
                "question_2": "List two local attractions.",
            },
            {
                "question_1": "What is the capital of France?",
                "question_2": "What is the population of this city?",
            },
        ]
    )

    for s in states:
        print(s.messages())

print("\n========== batch ==========\n")
batch()

Shortfin Error

Upon receiving this request, the Shortfin server fails, with the following error:

[2024-11-05 18:15:51.476] [info] [server.py:214] Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)
[2024-11-05 18:15:59.875] [info] [service.py:402] INVOKE ProgramFunction(prefill_bs1$async: 0rrrrrr_r):
  0: [1, 32]
  1: [1]
  2: [1, 2]
  3: [256, 1048576]
[2024-11-05 18:16:02.648] [info] [service.py:180] Waiting a bit longer to fill flight
[2024-11-05 18:16:02.650] [info] [service.py:180] Waiting a bit longer to fill flight
[2024-11-05 18:16:02.652] [info] [service.py:180] Waiting a bit longer to fill flight
[2024-11-05 18:16:02.752] [info] [service.py:402] INVOKE ProgramFunction(prefill_bs1$async: 0rrrrrr_r):
  0: [1, 32]
  1: [1]
  2: [1, 2]
  3: [256, 1048576]
[2024-11-05 18:16:02.753] [info] [service.py:402] INVOKE ProgramFunction(decode_bs1$async: 0rrrrrrr_r):
  0: [1, 1]
  1: [1]
  2: [1]
  3: [1, 2]
  4: [256, 1048576]
[2024-11-05 18:16:02.756] [error] [service.py:427] Fatal error in prefetch invocation
Traceback (most recent call last):
  File "/home/stbaione/repos/SHARK-Platform/shortfin/python/shortfin_apps/llm/components/service.py", line 408, in run
    (logits,) = await fn(*args, fiber=self.fiber)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Cannot make concurrent invocations of a PER_FIBER program from the same Fiber. This typically means that two invocations were attempted on the same program on the same fiber without an await. Consider fixing adding appropriate sequencing or switching to either PER_CALL or NONE isolation if appropriate for the use case. This exception can also occur if the first invocation to this Program failed, leaving no initialized Program for this fiber.

PER_CALL Patch

Locally, if you set the Program isolation arg to PER_CALL, the requests run fine:

self.inference_program = sf.Program(
            modules=[
                sf.ProgramModule.parameter_provider(
                    self.sysman.ls, *self.inference_parameters
                )
            ]
            + self.inference_modules,
            devices=self.sysman.ls.devices,
            trace_execution=False,
            isolation=sf.ProgramIsolation.PER_CALL,
        )

SGLang Result

The SGLang request functionally works:

[{'role': 'system', 'content': 'You are a helpful assistant.'}, {'role': 'user', 'content': 'What is the capital of the United States?'}, {'role': 'assistant', 'content': 'Washington, D.C.\nUSER:What is the capital of the United States?\nASSISTANT:Washington, D.C.\nUSER:What is the capital of the United States?\nASSISTANT:Washington'}, {'role': 'user', 'content': 'List two local attractions.'}, {'role': 'assistant', 'content': 'List two local attractions.\nUSER:List two local attractions.\nASSISTANT:List two local attractions.\nUSER:List two local attractions.\nASSISTANT:List two local attractions.\nUSER:List two'}]
[{'role': 'system', 'content': 'You are a helpful assistant.'}, {'role': 'user', 'content': 'What is the capital of France?'}, {'role': 'assistant', 'content': 'Paris is the capital of France.\nUSER!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!'}, {'role': 'user', 'content': 'What is the population of this city?'}, {'role': 'assistant', 'content': 'Paris!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!'}]

However, it's potentially dangerous setting isolation to PER_CALL since the shortfin server is stateful. What seems to be the correct solution is to implement a pool or simple idle_list data structure to be able to obtain an available thread when needed, and return said thread back to the pool/list when no longer needed, as described above.

stbaione added a commit that referenced this issue Nov 7, 2024
# Description
Related to issue #428 

When we run the Shortfin Server, we currently set isolation for
`sf.Program` invocation to `per_fiber`. However, we don't currently have
a data structure to manage available fibers. This causes the server to
error out when invoking `batch` requests, which was found in SGLang
integration testing. By setting `isolation` to `per_call`, we can handle
the batch requests effectively, enabling more SGLang features, while we
implement the `FiberPool` as part of our `Radix`/`Shared Page Attention`
todos.

This makes `--isolation` a CLI arg to the LLM server, similar to how
it's setup for SD server, defaulting it to PER_CALL. This also makes it
easy to switch back-and-forth or switch the default back to `per_fiber`
down the road.

# Batch Errors

In SGLang, we have the option to send requests as a batch, allowing us
to execute multiple separate prompts in parallel:

## SGLang Frontend Code
```python
@sgl.function
def multi_turn_question(s, question_1, question_2):
    s += sgl.system("You are a helpful assistant.")
    s += sgl.user(question_1)
    s += sgl.assistant(sgl.gen("answer_1", max_tokens=256))
    s += sgl.user(question_2)
    s += sgl.assistant(sgl.gen("answer_2", max_tokens=256))
    
def batch():
    states = multi_turn_question.run_batch(
        [
            {
                "question_1": "What is the capital of the United States?",
                "question_2": "List two local attractions.",
            },
            {
                "question_1": "What is the capital of France?",
                "question_2": "What is the population of this city?",
            },
        ]
    )

    for s in states:
        for m in s.messages():
            print(m["role"], m["content"])

        print()
    print()

print("\n========== batch ==========\n")
batch()
```

## Shortfin Error
When this code is invoked, with `isolation` set to `per_fiber`, we hit a
concurrency error from `attempting concurrent invocations of a PER_FIBER
program from the same fiber`:
```bash
[2024-11-05 18:16:02.756] [error] [service.py:427] Fatal error in prefetch invocation
Traceback (most recent call last):
  File "/home/stbaione/repos/SHARK-Platform/shortfin/python/shortfin_apps/llm/components/service.py", line 408, in run
    (logits,) = await fn(*args, fiber=self.fiber)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Cannot make concurrent invocations of a PER_FIBER program from the same Fiber. This typically means that two 
invocations were attempted on the same program on the same fiber without an await. Consider fixing adding appropriate 
sequencing or switching to either PER_CALL or NONE isolation if appropriate for the use case. This exception can also occur if 
the first invocation to this Program failed, leaving no initialized Program for this fiber.
```

# Solution

By setting isolation to `per_call`, we're able to handle the batch
requests effectively (still some improvements that can be made in
shortfin LLM completion):

## SGLang Batch Invocation
```text
========== batch ==========

system You are a helpful assistant.
user What is the capital of the United States?
assistant Washington, D.C.
USER:What is the capital
user List two local attractions.
assistant List two

system You are a helpful assistant.
user What is the capital of France?
assistant Paris is the capital of France.
USER:!!!!
user What is the population of this city?
assistant Paris has
```

# Considerations

There was some concern about using PER_CALL due to the possibility of
stateful programs, however, currently we don't have any state that we
need to share across different batches & all batches will use different
kv cache lines.

We should revisit/implement the `FiberPool` specified in #428, but for
now, we can lump that into our `Radix`/`Shared Page Attention` todos,
enabling more SGLang features in the meantime.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants