Skip to content

Commit

Permalink
Fix tests (ivy-llc#2)
Browse files Browse the repository at this point in the history
* replace calls with fw

* removed call

* added param

* changed dev_str to device

* replaced test calls to fw

* handle positional and keyword args

* handle positional and keyword args

* changed fourier_encode

* undo change
  • Loading branch information
rush2406 authored Sep 26, 2022
1 parent 087f460 commit 79c7c6d
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 42 deletions.
68 changes: 34 additions & 34 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,30 +7,27 @@
from ivy_tests.test_ivy import helpers


FW_STRS = ['numpy', 'jax', 'tensorflow', 'torch', 'mxnet']
FW_STRS = ["numpy", "jax", "tensorflow", "torch", "mxnet"]


TEST_FRAMEWORKS: Dict[str, callable] = {'numpy': lambda: helpers.get_ivy_numpy(),
'jax': lambda: helpers.get_ivy_jax(),
'tensorflow': lambda: helpers.get_ivy_tensorflow(),
'torch': lambda: helpers.get_ivy_torch(),
'mxnet': lambda: helpers.get_ivy_mxnet()}
TEST_CALL_METHODS: Dict[str, callable] = {'numpy': helpers.np_call,
'jax': helpers.jnp_call,
'tensorflow': helpers.tf_call,
'torch': helpers.torch_call,
'mxnet': helpers.mx_call}
TEST_FRAMEWORKS: Dict[str, callable] = {
"numpy": lambda: helpers.get_ivy_numpy(),
"jax": lambda: helpers.get_ivy_jax(),
"tensorflow": lambda: helpers.get_ivy_tensorflow(),
"torch": lambda: helpers.get_ivy_torch(),
"mxnet": lambda: helpers.get_ivy_mxnet(),
}


@pytest.fixture(autouse=True)
def run_around_tests(device, f, wrapped_mode, compile_graph, call):
if wrapped_mode and call is helpers.tf_graph_call:
def run_around_tests(device, f, wrapped_mode, compile_graph, fw):
if wrapped_mode and fw == "tensorflow_graph":
# ToDo: add support for wrapped_mode and tensorflow compilation
pytest.skip()
if wrapped_mode and call is helpers.jnp_call:
if wrapped_mode and fw == "jax":
# ToDo: add support for wrapped_mode with jax, presumably some errenously wrapped jax methods
pytest.skip()
if 'gpu' in device and call is helpers.np_call:
if "gpu" in device and fw == "numpy":
# Numpy does not support GPU
pytest.skip()
ivy.clear_backend_stack()
Expand All @@ -42,33 +39,33 @@ def run_around_tests(device, f, wrapped_mode, compile_graph, call):
def pytest_generate_tests(metafunc):

# dev_str
raw_value = metafunc.config.getoption('--device')
if raw_value == 'all':
devices = ['cpu', 'gpu:0', 'tpu:0']
raw_value = metafunc.config.getoption("--device")
if raw_value == "all":
devices = ["cpu", "gpu:0", "tpu:0"]
else:
devices = raw_value.split(',')
devices = raw_value.split(",")

# framework
raw_value = metafunc.config.getoption('--framework')
if raw_value == 'all':
raw_value = metafunc.config.getoption("--framework")
if raw_value == "all":
f_strs = TEST_FRAMEWORKS.keys()
else:
f_strs = raw_value.split(',')
f_strs = raw_value.split(",")

# wrapped_mode
raw_value = metafunc.config.getoption('--wrapped_mode')
if raw_value == 'both':
raw_value = metafunc.config.getoption("--wrapped_mode")
if raw_value == "both":
wrapped_modes = [True, False]
elif raw_value == 'true':
elif raw_value == "true":
wrapped_modes = [True]
else:
wrapped_modes = [False]

# compile_graph
raw_value = metafunc.config.getoption('--compile_graph')
if raw_value == 'both':
raw_value = metafunc.config.getoption("--compile_graph")
if raw_value == "both":
compile_modes = [True, False]
elif raw_value == 'true':
elif raw_value == "true":
compile_modes = [True]
else:
compile_modes = [False]
Expand All @@ -80,12 +77,15 @@ def pytest_generate_tests(metafunc):
for wrapped_mode in wrapped_modes:
for compile_graph in compile_modes:
configs.append(
(device, TEST_FRAMEWORKS[f_str](), wrapped_mode, compile_graph, TEST_CALL_METHODS[f_str]))
metafunc.parametrize('device,f,wrapped_mode,compile_graph,call', configs)
(device, TEST_FRAMEWORKS[f_str](), wrapped_mode, compile_graph, f_str)
)
metafunc.parametrize("device,f,wrapped_mode,compile_graph,fw", configs)


def pytest_addoption(parser):
parser.addoption('--device', action="store", default="cpu")
parser.addoption('--framework', action="store", default="numpy,jax,tensorflow,torch")
parser.addoption('--wrapped_mode', action="store", default="false")
parser.addoption('--compile_graph', action="store", default="true")
parser.addoption("--device", action="store", default="cpu")
parser.addoption(
"--framework", action="store", default="numpy,jax,tensorflow,torch"
)
parser.addoption("--wrapped_mode", action="store", default="false")
parser.addoption("--compile_graph", action="store", default="true")
8 changes: 4 additions & 4 deletions ivy_models/transformers/perceiver_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,22 +154,22 @@ def _forward(self, data, mask=None, queries=None):
data, '{} ... -> ({}) ...'.format(batch_shape_str, batch_shape_str), **batch_shape_dict)
else:
flat_batch_size = 1
data = ivy.expand_dims(data, 0)
data = ivy.expand_dims(data, axis=0)

# flatten the data channels
data = ivy.einops_rearrange(data, 'b ... d -> b (...) d')

# maybe add fourier positional encoding
if self._fourier_encode_input:
axis_pos = list(map(lambda size: ivy.linspace(-1., 1., size, device=self._spec.device), data_shape))
pos = ivy.stack(ivy.meshgrid(*axis_pos), -1)
axis_pos = list(map(lambda size: ivy.linspace(-1., 1., num=size, device=self._spec.device), data_shape))
pos = ivy.stack(ivy.meshgrid(*axis_pos), axis=-1)
pos_flat = ivy.reshape(pos, [-1, len(axis_pos)])
if not ivy.exists(self._spec.max_fourier_freq):
self._spec.max_fourier_freq = ivy.array(data_shape, dtype='float32')
enc_pos = ivy.fourier_encode(
pos_flat, self._spec.max_fourier_freq, self._spec.num_fourier_freq_bands, True, flatten=True)
enc_pos = ivy.einops_repeat(enc_pos, '... -> b ...', b=flat_batch_size)
data = ivy.concat([data, enc_pos], -1)
data = ivy.concat([data, enc_pos], axis=-1)

# batchify latents
x = ivy.einops_repeat(self.v.latents, 'n d -> b n d', b=flat_batch_size)
Expand Down
8 changes: 4 additions & 4 deletions ivy_models_tests/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
# Helpers #
# --------#

def test_feedforward(device, f, call):
def test_feedforward(device, f, fw):
ivy.seed(seed_value=0)
feedforward = FeedForward(4, device=device)
x = ivy.random_uniform(shape=(1, 3, 4), device=device)
ret = feedforward(x)
assert list(ret.shape) == [1, 3, 4]


def test_prenorm(device, f, call):
def test_prenorm(device, f, fw):
ivy.seed(seed_value=0)
att = ivy.MultiHeadAttention(4, device=device)
prenorm = PreNorm(4, att, device=device)
Expand All @@ -42,7 +42,7 @@ def test_prenorm(device, f, call):
"learn_query", [True])
@pytest.mark.parametrize(
"load_weights", [True, False])
def test_perceiver_io_img_classification(device, f, call, batch_shape, img_dims, queries_dim, learn_query,
def test_perceiver_io_img_classification(device, f, fw, batch_shape, img_dims, queries_dim, learn_query,
load_weights):

# params
Expand Down Expand Up @@ -127,7 +127,7 @@ def np_softmax(x):
"queries_dim", [32])
@pytest.mark.parametrize(
"learn_query", [True, False])
def test_perceiver_io_flow_prediction(device, f, call, batch_shape, img_dims, queries_dim, learn_query):
def test_perceiver_io_flow_prediction(device, f, fw, batch_shape, img_dims, queries_dim, learn_query):
# params
input_dim = 3
num_input_axes = 3
Expand Down

0 comments on commit 79c7c6d

Please sign in to comment.