From ee8f855ab007a5ed1daa94d68e8e5ff39eb3f2ad Mon Sep 17 00:00:00 2001 From: Sean Moriarity Date: Thu, 25 Jul 2024 15:09:01 -0400 Subject: [PATCH] Add simple quantization API (#586) * Quantization draft * Finish initial quantization API * Docs --- lib/axon.ex | 37 ++++- lib/axon/model_state.ex | 11 ++ lib/axon/quantization.ex | 161 +++++++++++++++++++++ lib/axon/quantization/layers.ex | 43 ++++++ lib/axon/quantization/q_tensor.ex | 233 ++++++++++++++++++++++++++++++ lib/axon/shared.ex | 6 - test/axon/quantization_test.exs | 45 ++++++ 7 files changed, 528 insertions(+), 8 deletions(-) create mode 100644 lib/axon/quantization.ex create mode 100644 lib/axon/quantization/layers.ex create mode 100644 lib/axon/quantization/q_tensor.ex create mode 100644 test/axon/quantization_test.exs diff --git a/lib/axon.ex b/lib/axon.ex index b2e03523..28b57621 100644 --- a/lib/axon.ex +++ b/lib/axon.ex @@ -855,6 +855,12 @@ defmodule Axon do use_bias: true ]) + meta = + opts[:meta] || + %{} + |> Map.put(:units, units) + |> Map.put(:use_bias, opts[:use_bias]) + kernel_shape = &Axon.Shape.dense_kernel(&1, units) bias_shape = &Axon.Shape.dense_bias(&1, units) @@ -868,7 +874,7 @@ defmodule Axon do {[x, kernel], :dense} end - node = layer(op, inputs, name: opts[:name], meta: opts[:meta], op_name: :dense) + node = layer(op, inputs, name: opts[:name], meta: meta, op_name: :dense) if activation = opts[:activation] do activation(node, activation) @@ -3666,7 +3672,7 @@ defmodule Axon do """ @doc type: :graph def get_op_counts(%Axon{} = axon) do - reduce_nodes(axon, %{}, fn %Axon.Node{op: op}, op_counts -> + reduce_nodes(axon, %{}, fn %Axon.Node{op_name: op}, op_counts -> Map.update(op_counts, op, 1, fn x -> x + 1 end) end) end @@ -4096,6 +4102,33 @@ defmodule Axon do end end + @doc """ + Returns a mapping of layer names to layer properties. + """ + def properties(%Axon{output: id, nodes: nodes}) do + {_, _, properties} = node_properties(id, nodes, {%{}, %{}, %{}}) + properties + end + + defp node_properties(id, nodes, {cache, op_counts, properties} = acc) do + case cache do + %{^id => _} -> + {cache, op_counts, properties} + + %{} -> + %Axon.Node{parent: parents, name: name_fn, op_name: op_name} = nodes[id] + + {cache, op_counts, properties} = + Enum.reduce(parents, acc, &node_properties(&1, nodes, &2)) + + name = name_fn.(op_name, op_counts) + op_counts = Map.update(op_counts, op_name, 1, fn x -> x + 1 end) + properties = Map.put(properties, name, op_name) + + {Map.put(cache, id, name), op_counts, properties} + end + end + ## Helpers @valid_initializers [:zeros, :ones, :uniform, :normal, :identity] ++ diff --git a/lib/axon/model_state.ex b/lib/axon/model_state.ex index 8eede9a6..44a94291 100644 --- a/lib/axon/model_state.ex +++ b/lib/axon/model_state.ex @@ -206,6 +206,8 @@ defmodule Axon.ModelState do defp traverse(%Nx.Tensor{}, acc), do: [Enum.reverse(acc)] + defp traverse(%Axon.Quantization.QTensor{}, acc), do: [Enum.reverse(acc)] + defp traverse(map, acc) do Enum.flat_map(map, fn {k, value} -> traverse(value, [k | acc]) @@ -273,6 +275,10 @@ defmodule Axon.ModelState do new_val = fun.(key, val_lhs, val_rhs) Map.put(acc, key, new_val) + %Axon.Quantization.QTensor{} = val_rhs -> + new_val = fun.(key, val_lhs, val_rhs) + Map.put(acc, key, new_val) + val_rhs when is_map(val_lhs) and is_map(val_rhs) -> updated_val = tree_merge(val_lhs, val_rhs, fun) Map.put(acc, key, updated_val) @@ -321,6 +327,11 @@ defmodule Axon.ModelState do {_, %Nx.Tensor{} = tensor}, {count, size} -> {count + Nx.size(tensor), size + Nx.byte_size(tensor)} + {_, %Axon.Quantization.QTensor{value: value, scale: scale, zero_point: zero}}, + {count, size} -> + {count + Nx.size(value) + Nx.size(scale) + Nx.size(zero), + size + Nx.byte_size(value) + Nx.byte_size(scale) + Nx.byte_size(zero)} + {_, map}, {count, size} -> {inner_count, inner_size} = get_param_info(map) {count + inner_count, size + inner_size} diff --git a/lib/axon/quantization.ex b/lib/axon/quantization.ex new file mode 100644 index 00000000..ef11dee3 --- /dev/null +++ b/lib/axon/quantization.ex @@ -0,0 +1,161 @@ +defmodule Axon.Quantization do + @moduledoc """ + Model quantization. + + Model quantization is a technique for reducing the memory footprint of + a model by converting portions of a model to use quantized representations. + Typically, these quantized representations are low-precision integers. + + This is an **experimental** API which implements weight-only quantization. + The implementation in this module will convert dense layers in a large + model to quantized-variants. The only supported quantization type is + `{:s, 8}`. Axon quantization is inference-only. Training is not currently + supported. + """ + alias Axon.Quantization.Layers + alias Axon.Quantization.QTensor + + @doc """ + Quantizes a model and a model state. + + Given a model and model state, this method will rewrite all + of the dense layers in the model to perform weight-only 8-bit + integer versions of the same operation. It will also replace values + for all dense kernels in the given model state with quantized + tensors. + """ + def quantize(%Axon{} = model, %Axon.ModelState{} = model_state) do + quantized_model = quantize_model(model) + quantized_model_state = quantize_model_state(model, model_state) + {quantized_model, quantized_model_state} + end + + @doc """ + Replaces standard operations with quantized variants. + + The only supported conversion is to convert regular dense layers + to a weight-only 8-bit integer variant. Note that this only replaces + the properties of the model. If you have a pre-trained model state + that you wish to quantize, refer to `Axon.Quantization.quantize_model_state/1`. + + All `:dense` layers in the model are replaced with `Axon.Quantization.weight_only_quantized_dense/3`. + """ + def quantize_model(%Axon{} = model) do + quantized_dense_rewriter = fn [%Axon{} = x], _output, units, use_bias -> + weight_only_quantized_dense(x, units, use_bias: use_bias) + end + + Axon.rewrite_nodes(model, fn + %Axon.Node{op: :dense, meta: meta} -> + &quantized_dense_rewriter.(&1, &2, meta[:units], meta[:use_bias]) + + _ -> + :skip + end) + end + + @doc """ + Returns a quantized model state. + + Given a model and a model state, this function will replace + all dense layer kernels with a quantized version of the weight. + + Training is not currently supported, so all quantized layers are + automatically frozen. + """ + def quantize_model_state(model, model_state) do + dense_layer_names = + model + |> Axon.properties() + |> Enum.filter(fn {_, v} -> v == :dense end) + |> Enum.map(fn {k, _} -> k end) + |> MapSet.new() + + state = + Enum.reduce(dense_layer_names, model_state, fn layer_name, state -> + update_in(state, [Access.key!(:data), layer_name, "kernel"], &QTensor.from_tensor/1) + end) + + Axon.ModelState.freeze(state, fn [name | _] -> + MapSet.member?(dense_layer_names, name) + end) + end + + ## Layers + + @doc """ + Adds a weight-only quantized dense layer to the network. + + This is equivalent to a dense layer, but works on quantized + weights for reducing model memory footprint. + + Compiles to `Axon.Quantization.Layers.weight_only_quantized_dense/4`. + + ## Options + + * `:name` - layer name. + + * `:kernel_initializer` - initializer for `kernel` weights. + Defaults to `:glorot_uniform`. + + * `:bias_initializer` - initializer for `bias` weights. Defaults + to `:zeros`. + + * `:use_bias` - whether the layer should add bias to the output. + Defaults to `true`. + """ + def weight_only_quantized_dense(x, units, opts \\ []) do + opts = + Keyword.validate!(opts, [ + :name, + :meta, + use_bias: true, + kernel_initializer: :glorot_uniform, + bias_initializer: :zeros + ]) + + meta = + opts[:meta] || + %{} + |> Map.put(:units, units) + |> Map.put(:use_bias, opts[:use_bias]) + + kernel_shape = &Axon.Shape.dense_kernel(&1, units) + bias_shape = &Axon.Shape.dense_bias(&1, units) + + kernel = + Axon.param("kernel", kernel_shape, + initializer: fn shape, type, key -> + fun = + case opts[:kernel_initializer] do + init when is_atom(init) -> + apply(Axon.Initializers, []) + + fun when is_function(fun) -> + fun + end + + tensor = + case fun do + fun when is_function(fun, 2) -> + fun.(shape, type) + + fun when is_function(fun, 3) -> + fun.(shape, type, key) + end + + QTensor.from_tensor(tensor) + end + ) + + {inputs, op} = + if opts[:use_bias] do + bias = Axon.param("bias", bias_shape, initializer: opts[:bias_initializer]) + {[x, kernel, bias], &Layers.weight_only_quantized_dense/4} + else + {[x, kernel], &Layers.weight_only_quantized_dense/3} + end + + Axon.layer(op, inputs, name: opts[:name], meta: meta, op_name: :dense) + end +end diff --git a/lib/axon/quantization/layers.ex b/lib/axon/quantization/layers.ex new file mode 100644 index 00000000..d7613939 --- /dev/null +++ b/lib/axon/quantization/layers.ex @@ -0,0 +1,43 @@ +defmodule Axon.Quantization.Layers do + @moduledoc """ + Quantized Layer Implementations. + """ + alias Axon.Quantization.QTensor + import Nx.Defn + + @doc """ + Weight-only quantized version of a dense layer. + + It expects the input kernel to be an `Axon.Quantization.QTensor`. + """ + deftransform weight_only_quantized_dense(input, kernel, bias \\ 0, opts \\ []) do + {bias, opts} = + case bias do + %Nx.Tensor{} = bias -> + {bias, opts} + + bias when is_number(bias) -> + {bias, opts} + + opts when is_list(opts) -> + {Nx.tensor(0), opts} + + other -> + raise ArgumentError, "invalid bias, expected a tensor, got #{inspect(other)}" + end + + weight_only_quantized_dense_impl(input, kernel, bias, opts) + end + + defnp weight_only_quantized_dense_impl( + input, + %QTensor{value: kernel, scale: scale}, + bias, + _opts + ) do + input + |> Nx.dot([Nx.rank(input) - 1], Nx.as_type(kernel, Nx.type(input)), [0]) + |> Nx.multiply(scale) + |> Nx.add(bias) + end +end diff --git a/lib/axon/quantization/q_tensor.ex b/lib/axon/quantization/q_tensor.ex new file mode 100644 index 00000000..a7f7749d --- /dev/null +++ b/lib/axon/quantization/q_tensor.ex @@ -0,0 +1,233 @@ +defmodule Axon.Quantization.QTensor do + @moduledoc """ + Representation of a quantized tensor. + + A quantized tensor stores information about the quantized + value, scale, and zero-point. This module contains lower-level + functions for converting to and from quantized tensors. + + In most cases, you should prefer to use the public APIs in + `Axon.Quantization`. + """ + import Nx.Defn + + @derive {Nx.Container, containers: [:value, :scale, :zero_point]} + defstruct [:value, :scale, :zero_point] + + @doc """ + Converts a regular float tensor into a quantized tensor. + """ + deftransform from_tensor(x, opts \\ []) do + opts = Keyword.validate!(opts, type: {:s, 8}) + + case opts[:type] do + {:s, 8} -> + dynamically_quantize_per_channel(x, min: -128, max: 127, type: {:s, 8}) + + other -> + raise "unsupported quantization type #{inspect(other)}" + end + end + + deftransformp dynamically_quantize_per_channel(input, opts \\ []) do + opts = Keyword.validate!(opts, [:min, :max, :type]) + + unless Nx.type(input) == {:f, 32}, do: raise(ArgumentError, "expected a float tensor") + unless Nx.rank(input) == 2, do: raise(ArgumentError, "expected a 2d tensor") + + target_dtype = opts[:type] + eps = Nx.Constants.epsilon(:f32) + block_size = {1, Nx.axis_size(input, 1)} + zero_point_type = {:s, 64} + + {scale, zero_point} = + choose_quantization_params_affine(input, + mapping_type: :symmetric, + block_size: block_size, + type: opts[:type], + min: opts[:min], + max: opts[:max], + eps: eps, + zero_point_type: zero_point_type + ) + + quantized_value = + quantize_affine(input, scale, zero_point, + block_size: block_size, + type: target_dtype, + min: opts[:min], + max: opts[:max] + ) + + struct(__MODULE__, value: quantized_value, scale: scale, zero_point: zero_point) + end + + deftransformp quantize_affine( + input, + scale, + zero_point, + opts \\ [] + ) do + opts = Keyword.validate!(opts, [:block_size, :type, :min, :max, zero_point_domain: :int]) + + target_dtype = opts[:type] + quant_min = opts[:min] + quant_max = opts[:max] + block_size = opts[:block_size] + zero_point_domain = opts[:zero_point_domain] + + {shape_for_reduction, reduction_dims} = get_reduction_params(block_size, Nx.shape(input)) + + original_shape = Nx.shape(input) + input = Nx.reshape(input, shape_for_reduction) + + scale_shape = + Enum.reduce(reduction_dims, shape_for_reduction, fn i, shape -> + put_elem(shape, i, 1) + end) + + scale = Nx.reshape(scale, scale_shape) + zero_point = Nx.reshape(zero_point, scale_shape) + + quant = + case zero_point_domain do + :int -> + Nx.clip( + Nx.add(Nx.round(Nx.multiply(input, Nx.divide(1, scale))), zero_point), + quant_min, + quant_max + ) + + other -> + raise "unsupported zero point domain #{other}" + end + + Nx.as_type(Nx.reshape(quant, original_shape), target_dtype) + end + + deftransformp choose_quantization_params_affine(input, opts \\ []) do + opts = + Keyword.validate!(opts, [ + :mapping_type, + :block_size, + :type, + :min, + :max, + :eps, + :scale_type, + :zero_point_type, + :zero_point_domain, + preserve_zero: true + ]) + + mapping_type = opts[:mapping_type] + block_size = opts[:block_size] + target_dtype = opts[:type] + preserve_zero = opts[:preserve_zero] + + {quant_min, quant_max} = + get_and_check_qmin_qmax(target_dtype, opts[:min], opts[:max]) + + scale_dtype = opts[:scale_type] || Nx.type(input) + zero_point_dtype = opts[:zero_point_type] || Nx.type(input) + eps = opts[:eps] || Nx.Constants.epsilon(Nx.type(input)) + + {shape_for_reduction, reduction_dims} = get_reduction_params(block_size, Nx.shape(input)) + input = Nx.reshape(input, shape_for_reduction) + + min_val = Nx.reduce_min(input, axes: reduction_dims, keep_axes: false) + max_val = Nx.reduce_max(input, axes: reduction_dims, keep_axes: false) + + {min_val_neg, max_val_pos} = + if preserve_zero do + {Nx.min(min_val, Nx.broadcast(0, min_val)), Nx.max(max_val, Nx.broadcast(0, max_val))} + else + {min_val, max_val} + end + + {scale, zero_point} = + case mapping_type do + :symmetric -> + max_val_pos = Nx.max(Nx.negate(min_val_neg), max_val_pos) + scale = Nx.divide(max_val_pos, Nx.divide(Nx.subtract(quant_max, quant_min), 2)) + zero_point = Nx.broadcast(trunc((quant_max + quant_min + 1) / 2), scale) + {scale, zero_point} + + other -> + raise "unsupported mapping #{other}" + end + + scale = Nx.clip(scale, eps, Nx.reduce_max(scale)) + + {Nx.as_type(scale, scale_dtype), Nx.as_type(zero_point, zero_point_dtype)} + end + + deftransformp get_and_check_qmin_qmax(target_dtype, quant_min, quant_max) do + {lower_bound, upper_bound} = + case target_dtype do + {:u, 8} -> {0, 255} + {:s, 8} -> {-128, 127} + {:s, 16} -> {-(2 ** 15), 2 ** 15 - 1} + {:s, 32} -> {-(2 ** 31), 2 ** 31 - 1} + end + + quant_min = + cond do + quant_min == nil -> + lower_bound + + quant_min < lower_bound -> + raise "quant_min out of bounds for target_dtype" + + true -> + quant_min + end + + quant_max = + cond do + quant_max == nil -> + upper_bound + + quant_max > upper_bound -> + raise "quant_max out of bounds for target_dtype" + + true -> + quant_max + end + + {quant_min, quant_max} + end + + deftransformp get_reduction_params(block_size, input_size) do + if tuple_size(block_size) != tuple_size(input_size) do + raise "block_size and input_size must have the same length" + end + + {shape_for_reduction, reduction_dims, _} = + block_size + |> Tuple.to_list() + |> Enum.zip(Tuple.to_list(input_size)) + |> Enum.with_index() + |> Enum.reduce({[], [], 0}, fn {{block, input}, i}, {shape, dims, cur_dim} -> + if block != input and block > 1 do + unless rem(input, block) == 0 do + raise "Expecting input size at #{i} dimension: #{input} to be divisible by block_size at #{i} dimension: #{block}" + end + + shape = [block, div(input, block) | shape] + dims = [cur_dim + 1 | dims] + cur_dim = cur_dim + 2 + + {shape, dims, cur_dim} + else + shape = [input | shape] + dims = if block != 1, do: [cur_dim | dims], else: dims + cur_dim = cur_dim + 1 + + {shape, dims, cur_dim} + end + end) + + {List.to_tuple(Enum.reverse(shape_for_reduction)), Enum.reverse(reduction_dims)} + end +end diff --git a/lib/axon/shared.ex b/lib/axon/shared.ex index 6279488a..87eff5ae 100644 --- a/lib/axon/shared.ex +++ b/lib/axon/shared.ex @@ -192,9 +192,6 @@ defmodule Axon.Shared do defp recur_deep_reduce(value, acc, fun) do case value do - %Axon{} = val -> - fun.(val, acc) - %Nx.Tensor{} = val -> fun.(val, acc) @@ -217,9 +214,6 @@ defmodule Axon.Shared do defp recur_deep_map_reduce(leaf, acc, fun) do case leaf do - %Axon{} = leaf -> - fun.(leaf, acc) - %Nx.Tensor{} = leaf -> fun.(leaf, acc) diff --git a/test/axon/quantization_test.exs b/test/axon/quantization_test.exs new file mode 100644 index 00000000..4a289ce0 --- /dev/null +++ b/test/axon/quantization_test.exs @@ -0,0 +1,45 @@ +defmodule Axon.QuantizationTest do + use Axon.Case, async: true + + alias Axon.ModelState + alias Axon.Quantization.QTensor + + describe "quantize_model_state" do + test "replaces dense kernels with quantized versions" do + model = + Axon.input("input") + |> Axon.dense(10, activation: :relu) + + assert {init_fn, _} = Axon.build(model) + assert %ModelState{} = model_state = init_fn.(Nx.template({1, 1}, :f32), ModelState.empty()) + + assert %{data: %{"dense_0" => %{"kernel" => %QTensor{}}}} = + Axon.Quantization.quantize_model_state(model, model_state) + end + end + + describe "quantize" do + test "returns model and state that execute properly" do + model = + Axon.input("input") + |> Axon.dense(10, activation: :relu) + + assert {init_fn, _} = Axon.build(model) + assert %ModelState{} = model_state = init_fn.(Nx.template({1, 1}, :f32), ModelState.empty()) + + assert {quantized_model, quantized_model_state} = + Axon.Quantization.quantize(model, model_state) + + assert {_, predict_fn} = Axon.build(quantized_model) + + real_fn = fn %{data: %{"dense_0" => %{"kernel" => k, "bias" => b}}}, input -> + input + |> Axon.Quantization.Layers.weight_only_quantized_dense(k, b) + |> Axon.Activations.relu() + end + + inp = Nx.broadcast(1.0, {1, 1}) + assert_equal(predict_fn.(quantized_model_state, inp), real_fn.(quantized_model_state, inp)) + end + end +end