Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor containers #590

Merged
merged 2 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 49 additions & 132 deletions lib/axon.ex
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,8 @@ defmodule Axon do
alias __MODULE__, as: Axon
alias Axon.Parameter

import Axon.Shared

require Logger

@type t :: %__MODULE__{}
Expand Down Expand Up @@ -380,15 +382,6 @@ defmodule Axon do
}
end

defp split_inputs(:container, [inputs]) do
{inputs, cache} =
deep_map_reduce(inputs, %{}, fn %Axon{output: id, nodes: nodes}, cache ->
{id, Map.merge(nodes, cache)}
end)

{[inputs], [], [:layer], cache}
end

defp split_inputs(_op, inputs) do
Enum.reduce(inputs, {[], [], [], %{}}, fn
%Axon{output: layer_input, nodes: nodes}, {layers, params, args, cache} ->
Expand Down Expand Up @@ -704,62 +697,47 @@ defmodule Axon do
@doc type: :special
def container(container, opts \\ []) do
opts = Keyword.validate!(opts, [:name, :meta])

layer(:container, [container], name: opts[:name], meta: opts[:meta], op_name: :container)
{structure_fn, nodes} = destructure(container)
layer(structure_fn, nodes, name: opts[:name], meta: opts[:meta], op_name: :container)
end

# TODO: This should not be duplicated
defp deep_new(%Nx.Tensor{} = x, fun), do: fun.(x)

defp deep_new(x, fun) when is_number(x), do: fun.(x)

defp deep_new(map, fun) do
{cont, :ok} = Nx.Container.traverse(map, :ok, &recur_traverse(&1, &2, fun))
cont
defp destructure(container) do
{structure, {nodes, _}} = recur_destructure(container, {[], 0})
fun = restructure(length(nodes) + 1, structure)
{fun, Enum.reverse(nodes)}
end

defp recur_traverse(item, :ok, fun) do
case item do
%Axon{} = t ->
{fun.(t), :ok}

%{axon: :axon} = t ->
{fun.(t), :ok}
defp recur_destructure(container, acc) do
Nx.Container.traverse(container, acc, fn value, {leaves, idx} ->
case value do
%Axon{} = leaf ->
{idx, {[leaf | leaves], idx + 1}}

container ->
{deep_new(container, fun), :ok}
end
container ->
recur_destructure(container, {leaves, idx})
end
end)
end

defp deep_merge(left, right, fun) do
case Nx.Container.traverse(left, leaves(right), &recur_merge(&1, &2, fun)) do
{merged, []} ->
merged
for i <- 0..128 do
args = Macro.generate_arguments(i, __MODULE__)

{_merged, _leftover} ->
raise ArgumentError,
"unable to merge arguments with incompatible" <>
" structure"
defp restructure(unquote(i), structure) do
fn unquote_splicing(args) ->
args_tuple = {unquote_splicing(args)}
{container, :ok} = recur_restructure(structure, args_tuple)
container
end
end
end

defp leaves(container) do
container
|> Nx.Container.reduce([], fn x, acc -> [x | acc] end)
|> Enum.reverse()
end

defp recur_merge(left, [right | right_leaves], fun) do
case {left, right} do
{%Nx.Tensor{} = left, %Nx.Tensor{} = right} ->
{fun.(left, right), right_leaves}

{%Axon{} = left, %Axon{} = right} ->
{fun.(left, right), right_leaves}

{left, right} ->
{deep_merge(left, right, fun), right_leaves}
end
defp recur_restructure(structure, args_tuple) do
Nx.Container.traverse(structure, :ok, fn value, :ok ->
case value do
idx when is_integer(idx) -> {elem(args_tuple, idx), :ok}
container -> recur_restructure(container, args_tuple)
end
end)
end

@doc """
Expand Down Expand Up @@ -3644,35 +3622,31 @@ defmodule Axon do
end

@doc """
Returns a model's output shape from the given input
Returns a model's output template from the given input
template.

The output template gives you access to the output shape
and type of the given input graph.
"""
@doc type: :graph
def get_output_shape(%Axon{} = axon, inputs, opts \\ []) do
{init_fn, forward_fn} = build(axon, opts ++ [raise_on_none: false])

out =
inputs =
case inputs do
%Nx.Tensor{} = input -> Nx.to_template(input)
inputs when is_map(inputs) -> Map.new(inputs, fn {k, v} -> {k, Nx.to_template(v)} end)
end

fun =
Nx.Defn.jit(
fn inputs ->
forward_fn.(init_fn.(inputs, Axon.ModelState.empty()), inputs)
end,
compiler: Axon.Defn
).(inputs)

safe_shape(out)
end

defp safe_shape(container_or_tensor) do
case container_or_tensor do
%Axon.None{} = none ->
none

%Nx.Tensor{} = tensor ->
Nx.shape(tensor)
)

container ->
deep_new(container, &Nx.shape/1)
end
deep_new(apply(fun, [inputs]), &Nx.to_template/1)
end

@doc """
Expand Down Expand Up @@ -3783,74 +3757,17 @@ defmodule Axon do
if MapSet.member?(visited, id) do
{acc, visited}
else
%{op: op, parent: parents} = parent = nodes[id]
%{parent: parents} = parent = nodes[id]

{acc, visited} =
case op do
:container ->
[container] = parents

deep_reduce(container, {acc, visited}, fn pid, {acc, visited} ->
traverse_nodes(pid, nodes, acc, visited)
end)

_ ->
Enum.reduce(parents, {acc, visited}, fn pid, {acc, visited} ->
traverse_nodes(pid, nodes, acc, visited)
end)
end
Enum.reduce(parents, {acc, visited}, fn pid, {acc, visited} ->
traverse_nodes(pid, nodes, acc, visited)
end)

{[parent | acc], MapSet.put(visited, id)}
end
end

# TODO: Do not duplicate
defp deep_reduce(item, acc, fun) when is_integer(item) do
fun.(item, acc)
end

defp deep_reduce(map, acc, fun) do
Nx.Container.reduce(map, acc, &recur_deep_reduce(&1, &2, fun))
end

defp recur_deep_reduce(value, acc, fun) do
case value do
%Axon{} = val ->
fun.(val, acc)

%Nx.Tensor{} = val ->
fun.(val, acc)

%{axon: :axon} = val ->
fun.(val, acc)

val when is_integer(val) ->
fun.(val, acc)

val ->
deep_reduce(val, acc, fun)
end
end

defp deep_map_reduce(leaf, acc, fun) when is_integer(leaf), do: fun.(leaf, acc)

defp deep_map_reduce(container, acc, fun) do
Nx.Container.traverse(container, acc, &recur_deep_map_reduce(&1, &2, fun))
end

defp recur_deep_map_reduce(leaf, acc, fun) do
case leaf do
%Axon{} = leaf ->
fun.(leaf, acc)

%Nx.Tensor{} = leaf ->
fun.(leaf, acc)

container ->
deep_map_reduce(container, acc, fun)
end
end

@doc """
Pops the top node off of the graph.

Expand Down
67 changes: 0 additions & 67 deletions lib/axon/compiler.ex
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ defmodule Axon.Compiler do
@moduledoc false
require Logger

import Axon.Shared
alias Axon.StatefulOutput

## Init JIT Compilation
Expand Down Expand Up @@ -549,72 +548,6 @@ defmodule Axon.Compiler do
{id, model_funs, cache, op_counts, block_cache, model_state_meta}
end

defp recur_model_funs(
%Axon.Node{id: id, op: :container, parent: [parents]},
nodes,
cache_and_counts,
config
) do
{parent_ids, {cache, op_counts, block_cache, model_state_meta}} =
deep_map_reduce(parents, cache_and_counts, &to_model_funs(&1, nodes, &2, config))

op_counts = Map.update(op_counts, :container, 1, fn x -> x + 1 end)

predict_fun = fn params, inputs, state, cache, result_cache, fn_stacktrace ->
{input, {state, result_cache, none?}} =
deep_map_reduce(
parent_ids,
{state, result_cache, false},
fn parent_id, {state, result_cache, none?} ->
{input, {state, result_cache}} =
call_predict_cache(
parent_id,
params,
inputs,
state,
cache,
result_cache,
fn_stacktrace
)

none? = none? or propagating_none?(input)
{input, {state, result_cache, none?}}
end
)

input = if none?, do: %Axon.None{}, else: input

{input, {state, result_cache}}
end

init_fun = fn template, cache, result_cache, fn_stacktrace, keys ->
{parent_template, {parent_params, result_cache, none?}} =
deep_map_reduce(parent_ids, {%{}, result_cache, false}, fn
parent_id, {params, result_cache, none?} ->
{parent_template, {params, result_cache}} =
call_init_cache(
parent_id,
template,
params,
cache,
result_cache,
fn_stacktrace,
keys
)

none? = none? or propagating_none?(parent_template)
{parent_template, {params, result_cache, none?}}
end)

parent_template = if none?, do: %Axon.None{}, else: parent_template

{parent_template, {parent_params, result_cache}}
end

model_funs = %{predict: predict_fun, init: init_fun}
{id, model_funs, cache, op_counts, block_cache, model_state_meta}
end

defp recur_model_funs(
%Axon.Node{
id: id,
Expand Down
29 changes: 4 additions & 25 deletions lib/axon/display.ex
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ defmodule Axon.Display do
Module for rendering various visual representations of Axon models.
"""

import Axon.Shared
alias Axon.Parameter

@compile {:no_warn_undefined, TableRex.Table}
Expand Down Expand Up @@ -94,7 +93,8 @@ defmodule Axon.Display do
defp do_axon_to_rows(
%Axon.Node{
id: id,
op: :container,
op: structure,
op_name: :container,
parent: [parents],
name: name_fn
},
Expand All @@ -105,7 +105,7 @@ defmodule Axon.Display do
model_info
) do
{input_names, {cache, op_counts, model_info}} =
deep_map_reduce(parents, {cache, op_counts, model_info}, fn
Enum.map_reduce(parents, {cache, op_counts, model_info}, fn
parent_id, {cache, op_counts, model_info} ->
{_, name, _shape, cache, op_counts, model_info} =
axon_to_rows(parent_id, nodes, templates, cache, op_counts, model_info)
Expand All @@ -119,7 +119,7 @@ defmodule Axon.Display do
shape = Axon.get_output_shape(%Axon{output: id, nodes: nodes}, templates)

row = [
"#{name} ( #{op_string} #{inspect(input_names)} )",
"#{name} ( #{op_string} #{inspect(apply(structure, input_names))} )",
"#{inspect({})}",
"#{inspect(shape)}",
render_options([]),
Expand Down Expand Up @@ -311,27 +311,6 @@ defmodule Axon.Display do
end
end

defp recur_axon_to_edges(
%Axon.Node{id: id, op: :container, name: name_fn, parent: [parents]},
nodes,
templates,
cache_counts_edgelist
) do
{node_inputs, {cache, op_counts, edgelist}} =
deep_map_reduce(parents, cache_counts_edgelist, &axon_to_edges(&1, nodes, templates, &2))

name = name_fn.(:container, op_counts)
node_shape = Axon.get_output_shape(%Axon{output: id, nodes: nodes}, templates)
to_node = %{axon: :axon, id: id, op: :container, name: name, shape: node_shape}

new_edgelist =
deep_reduce(node_inputs, edgelist, fn from_node, acc ->
[{from_node, to_node} | acc]
end)

{to_node, {cache, op_counts, new_edgelist}}
end

defp recur_axon_to_edges(
%Axon.Node{id: id, op_name: op, name: name_fn, parent: parents},
nodes,
Expand Down
Loading
Loading