From cf4673197400d03e328d6ff3cf078316bbc8c6aa Mon Sep 17 00:00:00 2001 From: eb8680 Date: Fri, 30 Oct 2020 11:29:26 -0400 Subject: [PATCH] Add TanhOp and AtanhOp (#387) --- funsor/jax/ops.py | 2 ++ funsor/ops/array.py | 4 ++- funsor/ops/builtin.py | 64 ++++++++++++++++++++++++++++++++++++++++++- funsor/terms.py | 16 +++++++++++ funsor/torch/ops.py | 3 ++ test/test_tensor.py | 6 ++-- test/test_terms.py | 4 ++- 7 files changed, 94 insertions(+), 5 deletions(-) diff --git a/funsor/jax/ops.py b/funsor/jax/ops.py index 2c6cc04b4..d709de2ed 100644 --- a/funsor/jax/ops.py +++ b/funsor/jax/ops.py @@ -17,6 +17,7 @@ ################################################################################ array = (onp.generic, onp.ndarray, DeviceArray, Tracer) +ops.atanh.register(array)(np.arctanh) ops.clamp.register(array, object, object)(np.clip) ops.exp.register(array)(np.exp) ops.full_like.register(array, object)(np.full_like) @@ -26,6 +27,7 @@ ops.permute.register(array, (tuple, list))(np.transpose) ops.sigmoid.register(array)(expit) ops.sqrt.register(array)(np.sqrt) +ops.tanh.register(array)(np.tanh) ops.transpose.register(array, int, int)(np.swapaxes) ops.unsqueeze.register(array, int)(np.expand_dims) diff --git a/funsor/ops/array.py b/funsor/ops/array.py index 7e135e8c6..915935553 100644 --- a/funsor/ops/array.py +++ b/funsor/ops/array.py @@ -5,7 +5,7 @@ import numpy as np -from .builtin import AssociativeOp, add, exp, log, log1p, max, min, reciprocal, safediv, safesub, sqrt +from .builtin import AssociativeOp, add, atanh, exp, log, log1p, max, min, reciprocal, safediv, safesub, sqrt, tanh from .op import DISTRIBUTIVE_OPS, Op _builtin_all = all @@ -32,6 +32,8 @@ sqrt.register(array)(np.sqrt) exp.register(array)(np.exp) log1p.register(array)(np.log1p) +tanh.register(array)(np.tanh) +atanh.register(array)(np.arctanh) class LogAddExpOp(AssociativeOp): diff --git a/funsor/ops/builtin.py b/funsor/ops/builtin.py index c6a73228d..b14335c2c 100644 --- a/funsor/ops/builtin.py +++ b/funsor/ops/builtin.py @@ -147,21 +147,78 @@ def log_abs_det_jacobian(x, y): log.set_inv(exp) +class TanhOp(TransformOp): + pass + + +@TanhOp +def tanh(x): + return math.tanh(x) + + +@tanh.set_inv +def tanh_inv(y): + return atanh(y) + + +@tanh.set_log_abs_det_jacobian +def tanh_log_abs_det_jacobian(x, y): + return 2. * (math.log(2.) - x - softplus(-2. * x)) + + +class AtanhOp(TransformOp): + pass + + +@AtanhOp +def atanh(x): + return math.atanh(x) + + +@atanh.set_inv +def atanh_inv(y): + return tanh(y) + + +@atanh.set_log_abs_det_jacobian +def atanh_log_abs_det_jacobian(x, y): + return -tanh.log_abs_det_jacobian(y, x) + + @Op def log1p(x): return math.log1p(x) -@Op +class SigmoidOp(TransformOp): + pass + + +@SigmoidOp def sigmoid(x): return 1 / (1 + exp(-x)) +@sigmoid.set_inv +def sigmoid_inv(y): + return log(y) - log1p(-y) + + +@sigmoid.set_log_abs_det_jacobian +def sigmoid_log_abs_det_jacobian(x, y): + return -softplus(-x) - softplus(x) + + @Op def pow(x, y): return x ** y +@Op +def softplus(x): + return log(1. + exp(x)) + + @AssociativeOp def min(x, y): if hasattr(x, '__min__'): @@ -223,6 +280,7 @@ def lgamma(x): __all__ = [ 'AddOp', 'AssociativeOp', + 'AtanhOp', 'DivOp', 'ExpOp', 'GetitemOp', @@ -232,10 +290,13 @@ def lgamma(x): 'NegOp', 'NullOp', 'ReciprocalOp', + 'SigmoidOp', 'SubOp', + 'TanhOp', 'abs', 'add', 'and_', + 'atanh', 'eq', 'exp', 'ge', @@ -262,6 +323,7 @@ def lgamma(x): 'sigmoid', 'sqrt', 'sub', + 'tanh', 'truediv', 'xor', ] diff --git a/funsor/terms.py b/funsor/terms.py index a54173534..56cf207f9 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -592,6 +592,9 @@ def __neg__(self): def abs(self): return Unary(ops.abs, self) + def atanh(self): + return Unary(ops.atanh, self) + def sqrt(self): return Unary(ops.sqrt, self) @@ -607,6 +610,9 @@ def log1p(self): def sigmoid(self): return Unary(ops.sigmoid, self) + def tanh(self): + return Unary(ops.tanh, self) + def reshape(self, shape): return Unary(ops.ReshapeOp(shape), self) @@ -1634,6 +1640,11 @@ def _abs(x): return Unary(ops.abs, x) +@ops.atanh.register(Funsor) +def _atanh(x): + return Unary(ops.atanh, x) + + @ops.sqrt.register(Funsor) def _sqrt(x): return Unary(ops.sqrt, x) @@ -1669,6 +1680,11 @@ def _sigmoid(x): return Unary(ops.sigmoid, x) +@ops.tanh.register(Funsor) +def _tanh(x): + return Unary(ops.tanh, x) + + __all__ = [ 'Binary', 'Cat', diff --git a/funsor/torch/ops.py b/funsor/torch/ops.py index 4943cbd8d..9a77fe520 100644 --- a/funsor/torch/ops.py +++ b/funsor/torch/ops.py @@ -11,12 +11,15 @@ ################################################################################ ops.abs.register(torch.Tensor)(torch.abs) +ops.atanh.register(torch.Tensor)(torch.atanh) ops.cholesky_solve.register(torch.Tensor, torch.Tensor)(torch.cholesky_solve) ops.clamp.register(torch.Tensor, object, object)(torch.clamp) ops.exp.register(torch.Tensor)(torch.exp) ops.full_like.register(torch.Tensor, object)(torch.full_like) ops.log1p.register(torch.Tensor)(torch.log1p) +ops.sigmoid.register(torch.Tensor)(torch.sigmoid) ops.sqrt.register(torch.Tensor)(torch.sqrt) +ops.tanh.register(torch.Tensor)(torch.tanh) ops.transpose.register(torch.Tensor, int, int)(torch.transpose) ops.unsqueeze.register(torch.Tensor, int)(torch.unsqueeze) diff --git a/test/test_tensor.py b/test/test_tensor.py index 61b6756c1..92a9c4b14 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -305,7 +305,7 @@ def unary_eval(symbol, x): @pytest.mark.parametrize('dims', [(), ('a',), ('a', 'b')]) @pytest.mark.parametrize('symbol', [ - '~', '-', 'abs', 'sqrt', 'exp', 'log', 'log1p', 'sigmoid', + '~', '-', 'abs', 'atanh', 'sqrt', 'exp', 'log', 'log1p', 'sigmoid', 'tanh', ]) def test_unary(symbol, dims): sizes = {'a': 3, 'b': 4} @@ -316,7 +316,9 @@ def test_unary(symbol, dims): if symbol == '~': data = ops.astype(data, 'uint8') dtype = 2 - if get_backend() != "torch" and symbol in ["abs", "sqrt", "exp", "log", "log1p", "sigmoid"]: + if symbol == 'atanh': + data = ops.clamp(data, -0.99, 0.99) + if get_backend() != "torch" and symbol in ["abs", "atanh", "sqrt", "exp", "log", "log1p", "sigmoid", "tanh"]: expected_data = getattr(ops, symbol)(data) else: expected_data = unary_eval(symbol, data) diff --git a/test/test_terms.py b/test/test_terms.py index 4cc18f0ff..da3d25569 100644 --- a/test/test_terms.py +++ b/test/test_terms.py @@ -238,13 +238,15 @@ def unary_eval(symbol, x): @pytest.mark.parametrize('data', [0, 0.5, 1]) @pytest.mark.parametrize('symbol', [ - '~', '-', 'abs', 'sqrt', 'exp', 'log', 'log1p', 'sigmoid', + '~', '-', 'atanh', 'abs', 'sqrt', 'exp', 'log', 'log1p', 'sigmoid', 'tanh', ]) def test_unary(symbol, data): dtype = 'real' if symbol == '~': data = bool(data) dtype = 2 + if symbol == 'atanh': + data = min(data, 0.99) expected_data = unary_eval(symbol, data) x = Number(data, dtype)