Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Basic GPU Support #105

Closed
7 tasks done
nic-barbara opened this issue Jul 17, 2023 · 2 comments · Fixed by #118
Closed
7 tasks done

Basic GPU Support #105

nic-barbara opened this issue Jul 17, 2023 · 2 comments · Fixed by #118
Assignees
Labels
enhancement New feature or request

Comments

@nic-barbara
Copy link
Member

nic-barbara commented Jul 17, 2023

We need to add in GPU support for all REN and LBDN models.

  • Check typing on all structs is compatible with CUDA arrays and allows transferring to GPU
  • Write a test to send all REN models/parameterisations to GPU
  • Debug issues arising from this test
  • Repeat for LBDN
  • Repeat for SandwichFC
  • Test on MNIST for LBDN
  • Test on observer for REN

Need to check whether having GPU in the CI is actually possible too.

@nic-barbara nic-barbara self-assigned this Jul 17, 2023
@nic-barbara nic-barbara added the enhancement New feature or request label Jul 17, 2023
@nic-barbara
Copy link
Member Author

No need to write gpu and cpu methods for all our types anymore. As long as they are all functors (specified with @functor) then this will generalise with Flux.jl, as outlined here.

@nic-barbara
Copy link
Member Author

nic-barbara commented Aug 4, 2023

NOTE: Using @functor for gpu specification:

Summary: whatever you specify in the brackets of @functor MyType (...) is what gets loaded onto the GPU.


Writing

@functor ExplicitRENParams (A, )

means that only explicit.A gets loaded onto the GPU when you type explicit |> gpu. If we instead write

@functor ExplicitRENParams
trainable(m::ExplicitRENParams) = (; )

then all fields of explicit get loaded to the GPU, but Flux.trainable still returns an empty named tuple (so they won't be trained on). This is the behaviour we want.

Note that we use @functor ... for the two base parameter types DirectRENParams and DirectLBDNParams, so this is fine.

@nic-barbara nic-barbara changed the title GPU Support Basic GPU Support Aug 14, 2023
@nic-barbara nic-barbara linked a pull request Aug 14, 2023 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant