-
Notifications
You must be signed in to change notification settings - Fork 1.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat!: Framework-agnostic device management #6748
Conversation
Pull Request Test Coverage Report for Build 7545294479
💛 - Coveralls |
Hey @shadeMe this is looking really good! I'm still looking through the PR, but I thought I would add a few questions that have already popped into mind:
|
There's no inherent limitation that prevents us from converting that parameter into a
device_map = ComponentDevice.from_multiple(DeviceMap({
"classifier": Device.gpu(1),
"layer_1": Device.cpu(),
"lm_head": Device.disk()
}))
device_map = ComponentDevice.from_multiple(DeviceMap.from_hf({
"classifier": 1,
"layer_1": "cpu",
"lm_head": "disk"
})) It's not super ergonomic and there's admittedly a lot of ceremony, but that's the tradeoff for being more explicit/generic. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a few comments, in general looks good!
Oh one other thing that would be very useful is that when using I can see how this looks and feels when finishing up the |
Ah, good point. I can add a property for that. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good to me!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM 😊
Related Issues
Proposed Changes:
This PR introduces the concept of framework-agnostic device representations. The main impetus behind this change is to move away from stringified representations of devices that are not portable between different frameworks. It also enables support for multi-device inference in a generic manner.
device.py
contains the following new classes:DeviceType
- An enum representing the types of devices we support.Device
- A tuple of aDeviceType
and an integer identifier that represents a single device.DeviceMap
- An arbitrary mapping of model parameters to single devices, similar to HF'sdevice_map
in theiraccelerate
library.ComponentDevice
- Essentially a tagged union ofDevice
andDeviceMap
. This class consumed by downstream components.Going forward, components can expose a single, optional device parameter in their constructor (
Optional[ComponentDevice]
). The component can then decide what to do with it:None
, the component can either do nothing or automatically pick the best device to load the model. In the case of the latter, it is done by simply passing the optional parameter toComponentDevice.resolve_device(device)
.None
, the component can still callComponentDevice.resolve_device(device)
. In this case, the function will immediately return the passed device without any modification.ComponentDevice
instance can optionally be persisted in the component itself. Since the class itself is not trivially serializable to JSON, theComponent.to_dict
andComponent.from_dict
functions must explicitly callComponentDevice.to_dict/from_dict
during serde.ComponentDevice
can be converted to the component's backend representation by calling itsto_xxx
methods.The following components have been updated to support the above workflow:
NamedEntityExtractor
,HuggingFaceLocalGenerator
,TransformersSimilarityRanker
,ExtractiveReader
.How did you test it?
Unit and e2e tests.
Notes for the reviewer
ComponentDevice
name. Suggestions are welcome.Checklist
fix:
,feat:
,build:
,chore:
,ci:
,docs:
,style:
,refactor:
,perf:
,test:
.