Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC]: Initial support for Pipeline Paralleism #4461

Closed
andoorve opened this issue Apr 29, 2024 · 1 comment · Fixed by #4412
Closed

[RFC]: Initial support for Pipeline Paralleism #4461

andoorve opened this issue Apr 29, 2024 · 1 comment · Fixed by #4412
Labels

Comments

@andoorve
Copy link
Collaborator

Motivation.

This RFC describes the initial approach for supporting pipeline parallelism as part of vLLM.

Pipeline parallelism is a technique that allows splitting up model layers across multiple devices, i.e. a 12 layer model may be partitioned across 4 devices, each taking care of 3 layers. As each device finishes the execution of its portion of layers, it sends its finished data to the next device allowing it to then move onto the next microbatch in a pipelined fashion, hence the name. This is shown in the below image:

ppworkers drawio

Here, the input stage handles embedding (E) completes the first 3 layers and sends the result (S) to the next worker where it is received (R). The middle two workers execute their work after a recv and then send. The last worker recvs the output from the previous layer and then computes the relevant output (logits).

Compared to tensor parallelism, this technique can have lower communication overhead. However tensor parallelism may allow increased batching allowing lower latency for a batch. These techniques are not exclusive however and can be combined (For example, tensor parallelism within a node and pipeline parallelism between nodes where communication is high)

This PR will be important for very large models that require multiple boxes, as well as machines that have high communication overhead (no NVLink for example).

Proposed Change.

Scope

We are aiming for the most simple implementation at this stage. Simplicity is preferred over optimization at this point to allow for the functionality to be merged as seamlessly as possible.

We will also limit our scope to AsyncLLMEngine and RayGPUExecutor for now (reasons discussed below).

Key Ideas

The main change here conceptually requires the understanding of data dependencies that exist in pipeline parallelism. This is shown in the below diagram.

pp drawio

We note here that there are 4 streams, each denoted by a different colour. We can imagine requests in each of these streams as being data dependent. Within one step (within which requests move from the first stage to the last stage, top to bottom) note the data dependencies present in each stream. Previous stages (layers) require execution to complete on the previous device before they can start executing. Once this is complete on the last device and output is received, a new step can be scheduled on the first worker.

Note that if we simply had one stream (for example just the red stream) workers would be idle for 3/4 of the time while they wait for work - this is the case without any pipelining. By adding 3 more streams (For a total of 4 streams) which is the same as the number of pipeline stages, note that we now saturate each worker with work. This is also why we restrict the implementation to AsyncLLMEngine - we need to be able to have multiple steps in flight concurrently.

We call the implementation of these streams in vLLM "virtual engines" since all requests on a stream are data independent of any other streams, and therefore steps of each can be separate and concurrent.

This lends itself to an elegant implementation in the code (Note a virtual engine is not its own class right now, it can be done eventually but we might not want to for optimization purposes. All of the virtual engines "live" within an AsyncLLMEngine right now). Each virtual engine currently consists of its own scheduler (And therefore block manager) as well as a corresponding cache engine. This allows completely separate scheduling without interference. New requests can be added to schedulers based on a cost function. By having equal cache engines that are separate we create the conditions for the pipeline stages to be even (especially if we enable chunked prefill) provided the cost function for adding requests to schedulers is defined appropriately. Note in the future we can optimize this for example for prefix sharing but for now this is the simplest implementation IMHO.

Implementation Details

Models

Within each model, we must modify 2 things.

  • Restrict the number of layers to $num\textunderscore layers / pp\textunderscore world\textunderscore size$
  • Ensure that if the stage is not the input or output stage the appropriate send and receives are added.

AsyncLLMEngine and LLMEngine

We modify step_async and engine_step to include a virtual_engine argument which allows us to select the correct scheduler.

We also modify the running loop to allow 1 async step at the same time on each virtual engine by using a fixed-size list with a slot for a running task. Once any task finishes, we schedule another step in the same slot.

We modify LLMEngine to include multiple schedulers and allocate to the scheduler with the least number of sequences.

Scheduler and CacheEngine

We modify each scheduler's block manager and cache engine for each virtual engine to use $num\textunderscore blocks/pp\textunderscore stages$ number of blocks to evenly partition the available cache space among virtual engines.

DistributedGPUExecutorAsync and RayGPUExecutorAsync

We adjust the output code to not automatically use driver worker (Rank 0) output. Sampling is now done on the last PP stage's first TP worker (In the case of PP = 1 this is the driver worker anyways though). Since we have multiple TP groups, we make the first worker of each TP group the driver for its group. This greatly simplifies the design although there may be some optimization opportunity later. We also pass through the virtual engine to Worker here.

Note that due to the use of Ray, requests are implicitly queued on each worker in a FIFO fashion, which is why multiple steps can be launched at the same time as well.

General

In general anywhere we do broadcasts we make src the first TP worker of a group instead of worker 0.

The following image shows the main idea of the implementation in vLLM at a high level. Four steps are concurrently being executed from each virtual engine. 4 steps are queued on the last worker, while only 1 is queued on the first worker.

vllm_pp

Feedback Period.

Feedback can be provided directly on PR. Based on comments here can update the RFC to elaborate.

CC List.

cc: @zhuohan123 @WoosukKwon @simon-mo @youkaichao

Any Other Things.

Initial PR here with tracking items: #4412

@youkaichao
Copy link
Member

Great job! Will take a look this week.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants