Skip to content

Commit

Permalink
[Onnx] Add Adam (apache#9002)
Browse files Browse the repository at this point in the history
* add adam op

* lint

* remove tests

* lint
  • Loading branch information
AndrewZhaoLuo authored and ylc committed Jan 13, 2022
1 parent 8bbed38 commit 80de207
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 2 deletions.
68 changes: 68 additions & 0 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3631,6 +3631,73 @@ def _impl_v1(cls, inputs, attr, params):
return _expr.TupleWrapper(_expr.Tuple(result), len(result))


class Adam(OnnxOpConverter):
"""Operator converter for Adam op."""

@classmethod
def _impl_v1(cls, inputs, attr, params):
alpha = attr.get("alpha", 0.9)
beta = attr.get("beta", 0.999)

# Note in the docs epsilon default is 0.0 but in the tests it is set to 1e-2:
# https://git.io/Ju5C4
epsilon = attr.get("epsilon", 1e-2)
norm_coefficient = attr.get("norm_coefficient", 0.0)
norm_coefficient_post = attr.get("norm_coefficient_post", 0.0)

R = inputs[0]
T = inputs[1]

assert (
len(inputs) - 2
) % 4 == 0, f"Expect 4-lets for remaining inputs, found {len(inputs) - 2}"

# convert attributes to constants, proper types
dtype_inputs = infer_type(inputs[3]).checked_type.dtype
inverse_alpha = relay.const(1 - alpha, dtype=dtype_inputs)
alpha = relay.const(alpha, dtype=dtype_inputs)
inverse_beta = relay.const(1 - beta, dtype=dtype_inputs)
beta = relay.const(beta, dtype=dtype_inputs)
epsilon = relay.const(epsilon, dtype=dtype_inputs)
norm_coefficient = relay.const(norm_coefficient, dtype=dtype_inputs)
norm_coefficient_post = relay.const(norm_coefficient_post, dtype=dtype_inputs)
one = relay.const(1, dtype=dtype_inputs)
T = relay.cast_like(T, inputs[3])

# Remaining inputs are:
# [x_1, x_2 ..., x_1_grad, x_2_grad, ... x_1_g_accum, x_2_g_accum..., x_1_g_sq_accum, ...]
num_input_tensors = (len(inputs) - 2) // 4
output_tensors = []
output_accumulated_gradients = []
output_accumulated_squared_gradients = []
for i in range(num_input_tensors):
x = inputs[i + 2]
g = inputs[i + 2 + num_input_tensors]
v = inputs[i + 2 + 2 * num_input_tensors]
h = inputs[i + 2 + 3 * num_input_tensors]

g_regularized = norm_coefficient * x + g
v_new = alpha * v + inverse_alpha * g_regularized
h_new = beta * h + inverse_beta * g_regularized * g_regularized
h_sqrt = relay.sqrt(h_new) + epsilon

true_branch = R * relay.sqrt(one - relay.power(beta, T)) / (one - relay.power(alpha, T))
R_adjusted = relay.If(T > relay.const(0, dtype=dtype_inputs), true_branch, R)

x_new = x - R_adjusted * (v_new / h_sqrt)
x_result = (one - norm_coefficient_post) * x_new

output_tensors.append(x_result)
output_accumulated_gradients.append(v_new)
output_accumulated_squared_gradients.append(h_new)

# append lists together to get final result
result = (
output_tensors + output_accumulated_gradients + output_accumulated_squared_gradients
)
return _expr.TupleWrapper(_expr.Tuple(result), len(result))


# compatible operators that do NOT require any conversion.
_identity_list = []

Expand Down Expand Up @@ -3817,6 +3884,7 @@ def _get_convert_map(opset):
# Loss functions / training
"NegativeLogLikelihoodLoss": NegativeLogLikelihoodLoss.get_converter(opset),
"Adagrad": Adagrad.get_converter(opset),
"Adam": Adam.get_converter(opset),
}


Expand Down
2 changes: 0 additions & 2 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -4711,8 +4711,6 @@ def verify_eyelike(indata):
)

unsupported_onnx_tests = [
"test_adam",
"test_adam_multiple",
"test_cast_BFLOAT16_to_FLOAT",
"test_cast_DOUBLE_to_FLOAT16",
"test_cast_FLOAT_to_BFLOAT16",
Expand Down

0 comments on commit 80de207

Please sign in to comment.