Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rewrite_nodes function #589

Merged
merged 1 commit into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 72 additions & 5 deletions lib/axon.ex
Original file line number Diff line number Diff line change
Expand Up @@ -3594,16 +3594,15 @@ defmodule Axon do
relu layers with tanh layers:

new_model = Axon.map_nodes(model, fn
%Axon{op: :relu} = graph ->
# Get nodes immediate parent
parent = Axon.get_parent(graph)
# Replace node with a tanh
Axon.tanh(parent)
%Axon.Node{op: :relu} = axon_node ->
%{axon_node | op: :tanh}

graph ->
graph
end)

For more complex graph rewriting and manipulation cases, see
`Axon.rewrite_nodes/2`.
"""
@doc type: :graph
def map_nodes(%Axon{output: id, nodes: nodes} = axon, fun) when is_function(fun, 1) do
Expand Down Expand Up @@ -3642,6 +3641,74 @@ defmodule Axon do
Enum.reduce(inorder_nodes, acc, fun)
end

@doc """
Rewrite and manipulate nodes in the Axon execution graph.

Axon models are represented as a graph of nodes. Working on these nodes
directly can be difficult and lead to disconnected and invalid graphs.
In some cases, you simply want to rewrite patterns. This function takes
an Axon model and traverses the nodes, applying the rewrite `fun` on each
node to rewrite some or all of the nodes in the Axon model.

The rewrite function is an arity-1 function which takes the current Axon node
as input and returns a function that replaces or rewrites the given node.
For example, you can define a simple rewriter which replaces the `:relu`
layers with `:tanh` layers:

tanh_rewriter = fn [%Axon{} = x], _output ->
Axon.relu(x)
end

Axon.rewrite_nodes(model, fn
%Axon.Node{op: :relu} -> tanh_rewriter
_ -> :skip
end)

Notice that the rewriter receives all of the original graph inputs *as well as*
the original graph outputs. This makes certain transformations which may rely
on both the input and output, such as LoRA, much easier to perform.
"""
@doc type: :graph
def rewrite_nodes(%Axon{output: id, nodes: nodes}, fun) when is_function(fun, 1) do
{inorder_nodes, _} = traverse_nodes(id, nodes, [], MapSet.new())

updated_nodes =
Enum.reduce(inorder_nodes, nodes, fn
%{id: original_id, parent: parents} = current_node, nodes ->
rewriter = fun.(current_node)

case rewriter do
:skip ->
nodes

rewriter when is_function(rewriter, 2) ->
input_axons = Enum.map(parents, &%Axon{output: &1, nodes: nodes})
%Axon{output: swapped_id} = placeholder_output = Axon.input("placeholder_output")

%Axon{output: new_node_id, nodes: updated_nodes} =
rewriter.(input_axons, placeholder_output)

# now we have to swap the IDs for the rewritten model so that
# anything that references this node takes the new, rewritten form
# as an input properly
original_node = %{updated_nodes[original_id] | id: swapped_id}
updated_node = %{updated_nodes[new_node_id] | id: original_id}

updated_nodes
|> Map.replace(swapped_id, original_node)
|> Map.replace(original_id, updated_node)
end
end)

# if we removed any nodes (like by just using the input instead)
# then technically we will have extra nodes in the graph, so we
# can prune them by traversing once again
{pruned_nodes, _} = traverse_nodes(id, updated_nodes, [], MapSet.new())
pruned_nodes = Map.new(pruned_nodes, fn %{id: id} = axon_node -> {id, axon_node} end)

%Axon{output: id, nodes: pruned_nodes}
end

defp traverse_nodes(id, nodes, acc, visited) do
if MapSet.member?(visited, id) do
{acc, visited}
Expand Down
118 changes: 118 additions & 0 deletions test/axon/compiler_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -5654,7 +5654,7 @@

x = random({1, 1})

assert {init_fn, predict_fn} = Axon.build(model)

Check failure on line 5657 in test/axon/compiler_test.exs

View workflow job for this annotation

GitHub Actions / main (25.3.2.6, 1.14.5, USE_EXLA=true)

test inspect values prints intermediate layer values to the screen (CompilerTest)

assert %ModelState{
data: %{"custom" => %{"composite" => %{"inner_composite" => %{"a" => a}}}}
Expand Down Expand Up @@ -5739,4 +5739,122 @@
assert out =~ "bar:"
end
end

describe "graph manipulation" do
test "rewrite_nodes does nothing if all rewrites are skip" do
model =
Axon.input("x")
|> Axon.dense(10, activation: :relu)

model = Axon.rewrite_nodes(model, fn _ -> :skip end)

{init_fn, predict_fn} = Axon.build(model)
input = Nx.broadcast(1, {1, 10})

%ModelState{data: %{"dense_0" => %{"kernel" => k, "bias" => b}}} =
model_state = init_fn.(input, ModelState.empty())

assert_equal(
predict_fn.(model_state, input),
Axon.Activations.relu(Axon.Layers.dense(input, k, b))
)
end

test "rewrite_nodes applies simple rewriters" do
relu_rewriter = fn [%Axon{} = x], _ ->
Axon.tanh(x)
end

model =
Axon.input("x")
|> Axon.dense(10, activation: :relu)

model =
Axon.rewrite_nodes(model, fn
%Axon.Node{op: :relu} -> relu_rewriter
_ -> :skip
end)

{init_fn, predict_fn} = Axon.build(model)
input = Nx.broadcast(1, {1, 10})

%ModelState{data: %{"dense_0" => %{"kernel" => k, "bias" => b}}} =
model_state = init_fn.(input, ModelState.empty())

assert_equal(
predict_fn.(model_state, input),
Axon.Activations.tanh(Axon.Layers.dense(input, k, b))
)
end

test "rewrite_nodes applies residual rewriter" do
residual_rewriter = fn [%Axon{} = x], %Axon{} = out ->
Axon.add(x, out)
end

model =
Axon.input("x")
|> Axon.dense(10, activation: :relu)

model =
Axon.rewrite_nodes(model, fn
%Axon.Node{op: :dense} -> residual_rewriter
_ -> :skip
end)

{init_fn, predict_fn} = Axon.build(model)
input = Nx.broadcast(1, {1, 10})

%ModelState{data: %{"dense_0" => %{"kernel" => k, "bias" => b}}} =
model_state = init_fn.(input, ModelState.empty())

real_fn = fn input, k, b ->
out = Nx.add(Axon.Layers.dense(input, k, b), input)
Axon.Activations.relu(out)
end

assert_equal(predict_fn.(model_state, input), real_fn.(input, k, b))
end

test "rewrite_nodes properly removes layers" do
remove_relu_rewriter = fn [%Axon{} = x], _out ->
x
end

input = Axon.input("x")
relu_tanh_input = Axon.tanh(Axon.relu(input))

model =
input
|> Axon.relu()
|> Axon.tanh()
|> Axon.relu()
|> Axon.tanh()
|> Axon.tanh()
|> Axon.relu()
|> Axon.relu()
|> Axon.add(relu_tanh_input)

model =
Axon.rewrite_nodes(model, fn
%Axon.Node{op: :relu} -> remove_relu_rewriter
_ -> :skip
end)

{_, predict_fn} = Axon.build(model)
input = Nx.broadcast(1, {1, 10})

real_fn = fn input ->
tanh_input = Axon.Activations.tanh(input)

input
|> Axon.Activations.tanh()
|> Axon.Activations.tanh()
|> Axon.Activations.tanh()
|> Nx.add(tanh_input)
end

assert_equal(predict_fn.(ModelState.empty(), input), real_fn.(input))
end
end
end
Loading