Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rewrite_nodes function #589

Merged
merged 1 commit into from
Jul 24, 2024
Merged

Add rewrite_nodes function #589

merged 1 commit into from
Jul 24, 2024

Conversation

seanmor5
Copy link
Contributor

@seanmor5 seanmor5 commented Jul 24, 2024

Resolves #529. This is a proper node rewriting API that should be much more flexible and easier than the wrap_nodes proposed earlier. Now LoRA can be implemented like this:

lora_dense_rewriter = fn [%Axon{} = x], %Axon{} = wx ->
  x
  |> Axon.dropout(rate: 0.1)
  |> Axon.dense(units_a, use_bias: false, name: "lora_a")
  |> Axon.dense(units_b, use_bias: false, name: "lora_b")
  |> Axon.multiply(scaling)
  |> Axon.add(wx)
end

lora_model = Axon.rewrite_nodes(model, fn
  %Axon.Node{op: :dense} -> lora_dense_rewriter
  _ -> :skip
end)

It also supports more complex rewriters. I will continue expanding this API until we have something closer to torch.fx

@seanmor5 seanmor5 merged commit 9fce600 into main Jul 24, 2024
4 of 5 checks passed
@seanmor5 seanmor5 deleted the sm-rewrite-nodes branch July 24, 2024 13:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Unable to replace existing model layers
1 participant