Skip to content

Commit

Permalink
Add simple quantization API (#586)
Browse files Browse the repository at this point in the history
* Quantization draft

* Finish initial quantization API

* Docs
  • Loading branch information
seanmor5 committed Jul 25, 2024
1 parent 216fafe commit ee8f855
Show file tree
Hide file tree
Showing 7 changed files with 528 additions and 8 deletions.
37 changes: 35 additions & 2 deletions lib/axon.ex
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,12 @@ defmodule Axon do
use_bias: true
])

meta =
opts[:meta] ||
%{}
|> Map.put(:units, units)
|> Map.put(:use_bias, opts[:use_bias])

kernel_shape = &Axon.Shape.dense_kernel(&1, units)
bias_shape = &Axon.Shape.dense_bias(&1, units)

Expand All @@ -868,7 +874,7 @@ defmodule Axon do
{[x, kernel], :dense}
end

node = layer(op, inputs, name: opts[:name], meta: opts[:meta], op_name: :dense)
node = layer(op, inputs, name: opts[:name], meta: meta, op_name: :dense)

if activation = opts[:activation] do
activation(node, activation)
Expand Down Expand Up @@ -3666,7 +3672,7 @@ defmodule Axon do
"""
@doc type: :graph
def get_op_counts(%Axon{} = axon) do
reduce_nodes(axon, %{}, fn %Axon.Node{op: op}, op_counts ->
reduce_nodes(axon, %{}, fn %Axon.Node{op_name: op}, op_counts ->
Map.update(op_counts, op, 1, fn x -> x + 1 end)
end)
end
Expand Down Expand Up @@ -4096,6 +4102,33 @@ defmodule Axon do
end
end

@doc """
Returns a mapping of layer names to layer properties.
"""
def properties(%Axon{output: id, nodes: nodes}) do
{_, _, properties} = node_properties(id, nodes, {%{}, %{}, %{}})
properties
end

defp node_properties(id, nodes, {cache, op_counts, properties} = acc) do
case cache do
%{^id => _} ->
{cache, op_counts, properties}

%{} ->
%Axon.Node{parent: parents, name: name_fn, op_name: op_name} = nodes[id]

{cache, op_counts, properties} =
Enum.reduce(parents, acc, &node_properties(&1, nodes, &2))

name = name_fn.(op_name, op_counts)
op_counts = Map.update(op_counts, op_name, 1, fn x -> x + 1 end)
properties = Map.put(properties, name, op_name)

{Map.put(cache, id, name), op_counts, properties}
end
end

## Helpers

@valid_initializers [:zeros, :ones, :uniform, :normal, :identity] ++
Expand Down
11 changes: 11 additions & 0 deletions lib/axon/model_state.ex
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,8 @@ defmodule Axon.ModelState do

defp traverse(%Nx.Tensor{}, acc), do: [Enum.reverse(acc)]

defp traverse(%Axon.Quantization.QTensor{}, acc), do: [Enum.reverse(acc)]

defp traverse(map, acc) do
Enum.flat_map(map, fn {k, value} ->
traverse(value, [k | acc])
Expand Down Expand Up @@ -273,6 +275,10 @@ defmodule Axon.ModelState do
new_val = fun.(key, val_lhs, val_rhs)
Map.put(acc, key, new_val)

%Axon.Quantization.QTensor{} = val_rhs ->
new_val = fun.(key, val_lhs, val_rhs)
Map.put(acc, key, new_val)

val_rhs when is_map(val_lhs) and is_map(val_rhs) ->
updated_val = tree_merge(val_lhs, val_rhs, fun)
Map.put(acc, key, updated_val)
Expand Down Expand Up @@ -321,6 +327,11 @@ defmodule Axon.ModelState do
{_, %Nx.Tensor{} = tensor}, {count, size} ->
{count + Nx.size(tensor), size + Nx.byte_size(tensor)}

{_, %Axon.Quantization.QTensor{value: value, scale: scale, zero_point: zero}},
{count, size} ->
{count + Nx.size(value) + Nx.size(scale) + Nx.size(zero),
size + Nx.byte_size(value) + Nx.byte_size(scale) + Nx.byte_size(zero)}

{_, map}, {count, size} ->
{inner_count, inner_size} = get_param_info(map)
{count + inner_count, size + inner_size}
Expand Down
161 changes: 161 additions & 0 deletions lib/axon/quantization.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
defmodule Axon.Quantization do
@moduledoc """
Model quantization.
Model quantization is a technique for reducing the memory footprint of
a model by converting portions of a model to use quantized representations.
Typically, these quantized representations are low-precision integers.
This is an **experimental** API which implements weight-only quantization.
The implementation in this module will convert dense layers in a large
model to quantized-variants. The only supported quantization type is
`{:s, 8}`. Axon quantization is inference-only. Training is not currently
supported.
"""
alias Axon.Quantization.Layers
alias Axon.Quantization.QTensor

@doc """
Quantizes a model and a model state.
Given a model and model state, this method will rewrite all
of the dense layers in the model to perform weight-only 8-bit
integer versions of the same operation. It will also replace values
for all dense kernels in the given model state with quantized
tensors.
"""
def quantize(%Axon{} = model, %Axon.ModelState{} = model_state) do
quantized_model = quantize_model(model)
quantized_model_state = quantize_model_state(model, model_state)
{quantized_model, quantized_model_state}
end

@doc """
Replaces standard operations with quantized variants.
The only supported conversion is to convert regular dense layers
to a weight-only 8-bit integer variant. Note that this only replaces
the properties of the model. If you have a pre-trained model state
that you wish to quantize, refer to `Axon.Quantization.quantize_model_state/1`.
All `:dense` layers in the model are replaced with `Axon.Quantization.weight_only_quantized_dense/3`.
"""
def quantize_model(%Axon{} = model) do
quantized_dense_rewriter = fn [%Axon{} = x], _output, units, use_bias ->
weight_only_quantized_dense(x, units, use_bias: use_bias)
end

Axon.rewrite_nodes(model, fn
%Axon.Node{op: :dense, meta: meta} ->
&quantized_dense_rewriter.(&1, &2, meta[:units], meta[:use_bias])

_ ->
:skip
end)
end

@doc """
Returns a quantized model state.
Given a model and a model state, this function will replace
all dense layer kernels with a quantized version of the weight.
Training is not currently supported, so all quantized layers are
automatically frozen.
"""
def quantize_model_state(model, model_state) do
dense_layer_names =
model
|> Axon.properties()
|> Enum.filter(fn {_, v} -> v == :dense end)
|> Enum.map(fn {k, _} -> k end)
|> MapSet.new()

state =
Enum.reduce(dense_layer_names, model_state, fn layer_name, state ->
update_in(state, [Access.key!(:data), layer_name, "kernel"], &QTensor.from_tensor/1)
end)

Axon.ModelState.freeze(state, fn [name | _] ->
MapSet.member?(dense_layer_names, name)
end)
end

## Layers

@doc """
Adds a weight-only quantized dense layer to the network.
This is equivalent to a dense layer, but works on quantized
weights for reducing model memory footprint.
Compiles to `Axon.Quantization.Layers.weight_only_quantized_dense/4`.
## Options
* `:name` - layer name.
* `:kernel_initializer` - initializer for `kernel` weights.
Defaults to `:glorot_uniform`.
* `:bias_initializer` - initializer for `bias` weights. Defaults
to `:zeros`.
* `:use_bias` - whether the layer should add bias to the output.
Defaults to `true`.
"""
def weight_only_quantized_dense(x, units, opts \\ []) do
opts =
Keyword.validate!(opts, [
:name,
:meta,
use_bias: true,
kernel_initializer: :glorot_uniform,
bias_initializer: :zeros
])

meta =
opts[:meta] ||
%{}
|> Map.put(:units, units)
|> Map.put(:use_bias, opts[:use_bias])

kernel_shape = &Axon.Shape.dense_kernel(&1, units)
bias_shape = &Axon.Shape.dense_bias(&1, units)

kernel =
Axon.param("kernel", kernel_shape,
initializer: fn shape, type, key ->
fun =
case opts[:kernel_initializer] do
init when is_atom(init) ->
apply(Axon.Initializers, [])

fun when is_function(fun) ->
fun
end

tensor =
case fun do
fun when is_function(fun, 2) ->
fun.(shape, type)

fun when is_function(fun, 3) ->
fun.(shape, type, key)
end

QTensor.from_tensor(tensor)
end
)

{inputs, op} =
if opts[:use_bias] do
bias = Axon.param("bias", bias_shape, initializer: opts[:bias_initializer])
{[x, kernel, bias], &Layers.weight_only_quantized_dense/4}
else
{[x, kernel], &Layers.weight_only_quantized_dense/3}
end

Axon.layer(op, inputs, name: opts[:name], meta: meta, op_name: :dense)
end
end
43 changes: 43 additions & 0 deletions lib/axon/quantization/layers.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
defmodule Axon.Quantization.Layers do
@moduledoc """
Quantized Layer Implementations.
"""
alias Axon.Quantization.QTensor
import Nx.Defn

@doc """
Weight-only quantized version of a dense layer.
It expects the input kernel to be an `Axon.Quantization.QTensor`.
"""
deftransform weight_only_quantized_dense(input, kernel, bias \\ 0, opts \\ []) do
{bias, opts} =
case bias do
%Nx.Tensor{} = bias ->
{bias, opts}

bias when is_number(bias) ->
{bias, opts}

opts when is_list(opts) ->
{Nx.tensor(0), opts}

other ->
raise ArgumentError, "invalid bias, expected a tensor, got #{inspect(other)}"
end

weight_only_quantized_dense_impl(input, kernel, bias, opts)
end

defnp weight_only_quantized_dense_impl(
input,
%QTensor{value: kernel, scale: scale},
bias,
_opts
) do
input
|> Nx.dot([Nx.rank(input) - 1], Nx.as_type(kernel, Nx.type(input)), [0])
|> Nx.multiply(scale)
|> Nx.add(bias)
end
end
Loading

0 comments on commit ee8f855

Please sign in to comment.