Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Code Improvements #1

Closed
wants to merge 16 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 24 additions & 17 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import itertools
from ivy_tests import helpers


FW_STRS = ['numpy', 'jax', 'tensorflow', 'tensorflow_graph', 'torch', 'mxnd']


Expand All @@ -28,18 +27,23 @@ def get_test_devices() -> Dict[ivy.Framework, List[str]]:

# setup the global containers to test the source code
TEST_DEV_STRS: Dict[ivy.Framework, List[str]] = get_test_devices()
TEST_FRAMEWORKS: Dict[str, ivy.Framework] = {'numpy': ivy.numpy,
'jax': ivy.jax,
'tensorflow': ivy.tensorflow,
'tensorflow_graph': ivy.tensorflow,
'torch': ivy.torch,
'mxnd': ivy.mxnd}
TEST_CALL_METHODS: Dict[str, callable] = {'numpy': helpers.np_call,
'jax': helpers.jnp_call,
'tensorflow': helpers.tf_call,
'tensorflow_graph': helpers.tf_graph_call,
'torch': helpers.torch_call,
'mxnd': helpers.mx_call}
TEST_FRAMEWORKS: Dict[str, ivy.Framework] = {
'numpy': ivy.numpy,
'jax': ivy.jax,
'tensorflow': ivy.tensorflow,
'tensorflow_graph': ivy.tensorflow,
'torch': ivy.torch,
'mxnd': ivy.mxnd
}

TEST_CALL_METHODS: Dict[str, callable] = {
'numpy': helpers.np_call,
'jax': helpers.jnp_call,
'tensorflow': helpers.tf_call,
'tensorflow_graph': helpers.tf_graph_call,
'torch': helpers.torch_call,
'mxnd': helpers.mx_call
}


@pytest.fixture(autouse=True)
Expand All @@ -49,7 +53,6 @@ def run_around_tests(f):


def pytest_generate_tests(metafunc):

dev_strs = None
f_strs = None

Expand All @@ -68,9 +71,13 @@ def pytest_generate_tests(metafunc):
f_strs = raw_value.split(',')

if dev_strs is not None and f_strs is not None:
params = list(itertools.chain.from_iterable(
[[(item, TEST_FRAMEWORKS[f_str], TEST_CALL_METHODS[f_str])
for item in TEST_DEV_STRS[f_str] if item in dev_strs] for f_str in f_strs]))
params = list(
itertools.chain.from_iterable([
[(item, TEST_FRAMEWORKS[f_str], TEST_CALL_METHODS[f_str])
for item in TEST_DEV_STRS[f_str] if item in dev_strs
] for f_str in f_strs
])
)
metafunc.parametrize('dev_str,f,call', params)

# ToDo: add full support for partial arguments later
Expand Down
4 changes: 3 additions & 1 deletion ivy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from . import neural_net_stateful
from .neural_net_stateful import *
from . import verbosity
from .framework_handler import get_framework, get_framework_str, set_framework, unset_framework, framework_stack
from .framework_handler import (
get_framework, get_framework_str, set_framework, unset_framework, framework_stack
)


class Array:
Expand Down
Loading