From 64a153289f5591cc3d8da88a4e4a34e47f57c238 Mon Sep 17 00:00:00 2001 From: Matthew Ellis Date: Thu, 31 Oct 2024 15:22:01 -0700 Subject: [PATCH] Make style --- src/python/module/smartredis/client.py | 14 +- src/python/module/smartredis/configoptions.py | 1 - src/python/module/smartredis/dataset_utils.py | 6 +- src/python/module/smartredis/util.py | 2 +- tests/python/test_client.py | 75 +++--- tests/python/test_configoptions.py | 14 +- tests/python/test_dataset_aggregation.py | 71 +++--- tests/python/test_dataset_conversion.py | 1 + tests/python/test_dataset_methods.py | 9 +- tests/python/test_dataset_ops.py | 6 +- tests/python/test_errors.py | 215 +++++++++++++----- tests/python/test_logging.py | 50 ++-- tests/python/test_model_methods_torch.py | 174 ++++++++++---- tests/python/test_nonkeyed_cmd.py | 19 +- tests/python/test_prefixing.py | 8 +- tests/python/test_put_get_bytes.py | 11 +- tests/python/test_put_get_tensor.py | 3 +- tests/python/test_script_methods.py | 24 +- tests/python/test_tensor_ops.py | 3 +- 19 files changed, 463 insertions(+), 243 deletions(-) diff --git a/src/python/module/smartredis/client.py b/src/python/module/smartredis/client.py index fcb5c08e..c86628ef 100644 --- a/src/python/module/smartredis/client.py +++ b/src/python/module/smartredis/client.py @@ -26,14 +26,15 @@ # pylint: disable=too-many-lines,too-many-public-methods import inspect +import io import os import os.path as osp import typing as t + import numpy as np -import io -from .dataset import Dataset from .configoptions import ConfigOptions +from .dataset import Dataset from .error import RedisConnectionError from .smartredisPy import PyClient from .smartredisPy import RedisReplyError as PybindRedisReplyError @@ -95,7 +96,7 @@ def __address_construction( self, cluster: bool, address: t.Optional[str] = None, - logger_name: str = "Default" + logger_name: str = "Default", ) -> PyClient: """Initialize a SmartRedis client @@ -129,8 +130,7 @@ def __address_construction( @staticmethod def __standard_construction( - config_options: t.Optional[ConfigOptions] = None, - logger_name: str = "Default" + config_options: t.Optional[ConfigOptions] = None, logger_name: str = "Default" ) -> PyClient: """Initialize a RedisAI client @@ -607,7 +607,7 @@ def run_script( name: str, fn_name: str, inputs: t.Union[str, t.List[str]], - outputs: t.Union[str, t.List[str]] + outputs: t.Union[str, t.List[str]], ) -> None: """Execute TorchScript stored inside the database @@ -1472,7 +1472,7 @@ def use_dataset_ensemble_prefix(self, use_prefix: bool) -> None: """ typecheck(use_prefix, "use_prefix", bool) return self._client.use_dataset_ensemble_prefix(use_prefix) - + @exception_handler def use_bytes_ensemble_prefix(self, use_prefix: bool) -> None: """Control whether byte keys are prefixed (e.g. in an diff --git a/src/python/module/smartredis/configoptions.py b/src/python/module/smartredis/configoptions.py index 8755bdf1..698d75ab 100644 --- a/src/python/module/smartredis/configoptions.py +++ b/src/python/module/smartredis/configoptions.py @@ -32,7 +32,6 @@ from .smartredisPy import PyConfigOptions from .util import exception_handler, typecheck - if t.TYPE_CHECKING: from typing_extensions import ParamSpec diff --git a/src/python/module/smartredis/dataset_utils.py b/src/python/module/smartredis/dataset_utils.py index c60886cb..533490eb 100644 --- a/src/python/module/smartredis/dataset_utils.py +++ b/src/python/module/smartredis/dataset_utils.py @@ -198,15 +198,13 @@ def transform_to_xarray(dataset: Dataset) -> t.Dict: # Extract dimensions in correct form dims_final = [ dataset.get_meta_strings(dim_field_name)[0] - for dim_field_name - in get_data(dataset, variable_name, "dim") + for dim_field_name in get_data(dataset, variable_name, "dim") ] # Extract attributes in correct form attrs_final = { attr_field_name: dataset.get_meta_strings(attr_field_name)[0] - for attr_field_name - in get_data(dataset, variable_name, "attr") + for attr_field_name in get_data(dataset, variable_name, "attr") } # Add coordinates to the correct data name diff --git a/src/python/module/smartredis/util.py b/src/python/module/smartredis/util.py index 56e5bd5e..05e60d80 100644 --- a/src/python/module/smartredis/util.py +++ b/src/python/module/smartredis/util.py @@ -28,8 +28,8 @@ from functools import wraps import numpy as np -from . import error +from . import error from .smartredisPy import RedisReplyError as PybindRedisReplyError from .smartredisPy import c_get_last_error_location diff --git a/tests/python/test_client.py b/tests/python/test_client.py index 14798a2d..21915291 100644 --- a/tests/python/test_client.py +++ b/tests/python/test_client.py @@ -25,10 +25,11 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import os -import pytest +import pytest from smartredis import Client, ConfigOptions + def test_serialization(context): c = Client(None, logger_name=context) assert str(c) != repr(c) @@ -45,6 +46,7 @@ def test_address(context): # check if SSDB was set anyway assert os.environ["SSDB"] == ssdb + # Globals for Client constructor testing ac_original = Client._Client__address_construction sc_original = Client._Client__standard_construction @@ -52,34 +54,45 @@ def test_address(context): target_address = os.environ["SSDB"] co_envt = ConfigOptions.create_from_environment("") + @pytest.mark.parametrize( - "args, kwargs, expected_constructor", [ - # address constructions - [(False,), {}, "address"], - [(False,), {"address": target_address}, "address"], - [(False,), {"address": target_address, "logger_name": "log_name"}, "address"], - [(False,), {"logger_name": "log_name"}, "address"], - [(False, target_address), {}, "address"], - [(False, target_address), {"logger_name": "log_name"}, "address"], - [(False, target_address, "log_name"), {}, "address"], - [(), {"cluster": cluster_mode}, "address"], - [(), {"cluster": cluster_mode, "address": target_address}, "address"], - [(), {"cluster": cluster_mode, "address": target_address, "logger_name": "log_name"}, "address"], - [(), {"cluster": cluster_mode, "logger_name": "log_name"}, "address"], - # standard constructions - [(None,), {}, "standard"], - [(None,), {"logger_name": "log_name"}, "standard"], - [(None, "log_name"), {}, "standard"], - [(co_envt,), {}, "standard"], - [(co_envt,), {"logger_name": "log_name"}, "standard"], - [(co_envt, "log_name"), {}, "standard"], - [(), {}, "standard"], - [(), {"config_options": None}, "standard"], - [(), {"config_options": None, "logger_name": "log_name"}, "standard"], - [(), {"config_options": co_envt}, "standard"], - [(), {"config_options": co_envt, "logger_name": "log_name"}, "standard"], - [(), {"logger_name": "log_name"}, "standard"], -]) + "args, kwargs, expected_constructor", + [ + # address constructions + [(False,), {}, "address"], + [(False,), {"address": target_address}, "address"], + [(False,), {"address": target_address, "logger_name": "log_name"}, "address"], + [(False,), {"logger_name": "log_name"}, "address"], + [(False, target_address), {}, "address"], + [(False, target_address), {"logger_name": "log_name"}, "address"], + [(False, target_address, "log_name"), {}, "address"], + [(), {"cluster": cluster_mode}, "address"], + [(), {"cluster": cluster_mode, "address": target_address}, "address"], + [ + (), + { + "cluster": cluster_mode, + "address": target_address, + "logger_name": "log_name", + }, + "address", + ], + [(), {"cluster": cluster_mode, "logger_name": "log_name"}, "address"], + # standard constructions + [(None,), {}, "standard"], + [(None,), {"logger_name": "log_name"}, "standard"], + [(None, "log_name"), {}, "standard"], + [(co_envt,), {}, "standard"], + [(co_envt,), {"logger_name": "log_name"}, "standard"], + [(co_envt, "log_name"), {}, "standard"], + [(), {}, "standard"], + [(), {"config_options": None}, "standard"], + [(), {"config_options": None, "logger_name": "log_name"}, "standard"], + [(), {"config_options": co_envt}, "standard"], + [(), {"config_options": co_envt, "logger_name": "log_name"}, "standard"], + [(), {"logger_name": "log_name"}, "standard"], + ], +) def test_client_constructor(args, kwargs, expected_constructor, monkeypatch): ac_got_called = False sc_got_called = False @@ -96,9 +109,11 @@ def mock_standard_constructor(*a, **kw): return sc_original(*a, **kw) monkeypatch.setattr( - Client, "_Client__address_construction", mock_address_constructor) + Client, "_Client__address_construction", mock_address_constructor + ) monkeypatch.setattr( - Client, "_Client__standard_construction", mock_standard_constructor) + Client, "_Client__standard_construction", mock_standard_constructor + ) Client(*args, **kwargs) diff --git a/tests/python/test_configoptions.py b/tests/python/test_configoptions.py index a605b532..a73f030d 100644 --- a/tests/python/test_configoptions.py +++ b/tests/python/test_configoptions.py @@ -33,6 +33,7 @@ ##### # Test attempts to use API functions from non-factory object + def test_non_factory_configobject(): co = ConfigOptions() with pytest.raises(RedisRuntimeError): @@ -46,6 +47,7 @@ def test_non_factory_configobject(): with pytest.raises(RedisRuntimeError): _ = co.override_string_option("key", "value") + def test_options(monkeypatch): monkeypatch.setenv("test_integer_key", "42") monkeypatch.setenv("test_string_key", "charizard") @@ -58,8 +60,7 @@ def test_options(monkeypatch): _ = co.get_integer_option("test_integer_key_that_is_not_really_present") co.override_integer_option("test_integer_key_that_is_not_really_present", 42) assert co.is_configured("test_integer_key_that_is_not_really_present") - assert co.get_integer_option( - "test_integer_key_that_is_not_really_present") == 42 + assert co.get_integer_option("test_integer_key_that_is_not_really_present") == 42 # string option tests assert co.get_string_option("test_string_key") == "charizard" @@ -68,8 +69,10 @@ def test_options(monkeypatch): _ = co.get_string_option("test_string_key_that_is_not_really_present") co.override_string_option("test_string_key_that_is_not_really_present", "meowth") assert co.is_configured("test_string_key_that_is_not_really_present") - assert co.get_string_option( - "test_string_key_that_is_not_really_present") == "meowth" + assert ( + co.get_string_option("test_string_key_that_is_not_really_present") == "meowth" + ) + def test_options_with_suffix(monkeypatch): monkeypatch.setenv("integer_key_suffixtest", "42") @@ -92,5 +95,4 @@ def test_options_with_suffix(monkeypatch): _ = co.get_string_option("string_key_that_is_not_really_present") co.override_string_option("string_key_that_is_not_really_present", "meowth") assert co.is_configured("string_key_that_is_not_really_present") - assert co.get_string_option( - "string_key_that_is_not_really_present") == "meowth" + assert co.get_string_option("string_key_that_is_not_really_present") == "meowth" diff --git a/tests/python/test_dataset_aggregation.py b/tests/python/test_dataset_aggregation.py index a156cbe7..ccfd920c 100644 --- a/tests/python/test_dataset_aggregation.py +++ b/tests/python/test_dataset_aggregation.py @@ -25,9 +25,9 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import numpy as np +from smartredis import * from smartredis import Client, Dataset from smartredis.error import * -from smartredis import * def test_aggregation(context): @@ -53,77 +53,87 @@ def test_aggregation(context): # Confirm that poll for list length works correctly actual_length = num_datasets poll_result = client.poll_list_length(list_name, actual_length, 100, 5) - if (poll_result == False): + if poll_result == False: raise RuntimeError( f"Polling for list length of {actual_length} returned " - f"False for known length of {actual_length}.") + f"False for known length of {actual_length}." + ) log_data(context, LLDebug, "Polling 1") poll_result = client.poll_list_length(list_name, actual_length + 1, 100, 5) - if (poll_result == True): + if poll_result == True: raise RuntimeError( f"Polling for list length of {actual_length + 1} returned " - f"True for known length of {actual_length}.") + f"True for known length of {actual_length}." + ) log_data(context, LLDebug, "Polling 2") # Confirm that poll for greater than or equal list length works correctly poll_result = client.poll_list_length_gte(list_name, actual_length - 1, 100, 5) - if (poll_result == False): + if poll_result == False: raise RuntimeError( f"Polling for list length greater than or equal to {actual_length - 1} " - f"returned False for known length of {actual_length}.") + f"returned False for known length of {actual_length}." + ) log_data(context, LLDebug, "Polling 3") poll_result = client.poll_list_length_gte(list_name, actual_length, 100, 5) - if (poll_result == False): + if poll_result == False: raise RuntimeError( f"Polling for list length greater than or equal to {actual_length} " - f"returned False for known length of {actual_length}.") + f"returned False for known length of {actual_length}." + ) log_data(context, LLDebug, "Polling 4") poll_result = client.poll_list_length_gte(list_name, actual_length + 1, 100, 5) - if (poll_result == True): + if poll_result == True: raise RuntimeError( f"Polling for list length greater than or equal to {actual_length + 1} " - f"returned True for known length of {actual_length}.") + f"returned True for known length of {actual_length}." + ) log_data(context, LLDebug, "Polling 5") # Confirm that poll for less than or equal list length works correctly poll_result = client.poll_list_length_lte(list_name, actual_length - 1, 100, 5) - if (poll_result == True): + if poll_result == True: raise RuntimeError( f"Polling for list length less than or equal to {actual_length - 1} " - f"returned True for known length of {actual_length}.") + f"returned True for known length of {actual_length}." + ) log_data(context, LLDebug, "Polling 6") poll_result = client.poll_list_length_lte(list_name, actual_length, 100, 5) - if (poll_result == False): + if poll_result == False: raise RuntimeError( f"Polling for list length less than or equal to {actual_length} " - f"returned False for known length of {actual_length}.") + f"returned False for known length of {actual_length}." + ) log_data(context, LLDebug, "Polling 7") poll_result = client.poll_list_length_lte(list_name, actual_length + 1, 100, 5) - if (poll_result == False): + if poll_result == False: raise RuntimeError( f"Polling for list length less than or equal to {actual_length + 1} " - f"returned False for known length of {actual_length}.") + f"returned False for known length of {actual_length}." + ) log_data(context, LLDebug, "Polling 8") # Check the list length list_length = client.get_list_length(list_name) - if (list_length != actual_length): + if list_length != actual_length: raise RuntimeError( f"The list length of {list_length} does not match expected " - f"value of {actual_length}.") + f"value of {actual_length}." + ) log_data(context, LLDebug, "List length check") - + # Check the return of a range of datasets from the aggregated list num_datasets = client.get_dataset_list_range(list_name, 0, 1) - if (len(num_datasets) != 2): + if len(num_datasets) != 2: raise RuntimeError( f"The length is {len(num_datasets)}, which does not " - f"match expected value of 2.") + f"match expected value of 2." + ) log_data(context, LLDebug, "Retrieve datasets from list checked") # Retrieve datasets via the aggregation list @@ -131,34 +141,36 @@ def test_aggregation(context): if len(datasets) != list_length: raise RuntimeError( f"The number of datasets received {len(datasets)} " - f"does not match expected value of {list_length}.") + f"does not match expected value of {list_length}." + ) for ds in datasets: check_dataset(ds) log_data(context, LLDebug, "DataSet list retrieval") - + # Rename a list of datasets client.rename_list(list_name, "new_list_name") renamed_list_datasets = client.get_datasets_from_list("new_list_name") if len(renamed_list_datasets) != list_length: raise RuntimeError( f"The number of datasets received {len(datasets)} " - f"does not match expected value of {list_length}.") + f"does not match expected value of {list_length}." + ) for ds in renamed_list_datasets: check_dataset(ds) log_data(context, LLDebug, "DataSet list rename complete") - + # Copy a list of datasets client.copy_list("new_list_name", "copied_list_name") copied_list_datasets = client.get_datasets_from_list("copied_list_name") if len(copied_list_datasets) != list_length: raise RuntimeError( f"The number of datasets received {len(datasets)} " - f"does not match expected value of {list_length}.") + f"does not match expected value of {list_length}." + ) for ds in copied_list_datasets: check_dataset(ds) log_data(context, LLDebug, "DataSet list copied") - - + # ------------ helper functions --------------------------------- @@ -174,6 +186,7 @@ def create_dataset(name): dataset.add_meta_scalar("test_scalar", scalar) return dataset + def check_dataset(ds): comp_array = np.array([1, 2, 3, 4]) tensor_name = "test_array" diff --git a/tests/python/test_dataset_conversion.py b/tests/python/test_dataset_conversion.py index e4e4740c..a095dcff 100644 --- a/tests/python/test_dataset_conversion.py +++ b/tests/python/test_dataset_conversion.py @@ -618,6 +618,7 @@ def test_raise_exception_if_xarray_not_found(monkeypatch): and provide instruction on how to fix it """ import sys + import smartredis.dataset_utils as _dsu monkeypatch.setattr(sys, "path", []) diff --git a/tests/python/test_dataset_methods.py b/tests/python/test_dataset_methods.py index 3be54720..c900cd9c 100644 --- a/tests/python/test_dataset_methods.py +++ b/tests/python/test_dataset_methods.py @@ -27,13 +27,13 @@ import numpy as np from smartredis import Dataset + def test_serialize_dataset(): - """Test serializing a dataset - """ + """Test serializing a dataset""" dataset = Dataset("test-dataset") - data = np.uint8([2,4,8]) + data = np.uint8([2, 4, 8]) dataset.add_tensor("u8_tensor", data) - data = np.double([2.0,4.1,8.3, 5.6]) + data = np.double([2.0, 4.1, 8.3, 5.6]) dataset.add_tensor("double_tensor", data) dataset.add_meta_scalar("float2_scalar", float(3.1415926535)) dataset.add_meta_scalar("float_scalar", np.double(3.1415926535)) @@ -190,6 +190,7 @@ def test_dataset_inspection(context): assert str == d.get_metadata_field_type("metastring") assert np.uint32 == d.get_metadata_field_type("u32_scalar") + # ------- Helper Functions ----------------------------------------------- diff --git a/tests/python/test_dataset_ops.py b/tests/python/test_dataset_ops.py index 0136ab08..5b0558d6 100644 --- a/tests/python/test_dataset_ops.py +++ b/tests/python/test_dataset_ops.py @@ -25,6 +25,7 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import os + import numpy as np import pytest from smartredis import Client, Dataset @@ -124,22 +125,19 @@ def test_delete_dataset(context): def test_rename_nonexisting_dataset(context): - client = Client(None, logger_name=context) with pytest.raises(RedisReplyError): client.rename_dataset("not-a-tensor", "still-not-a-tensor") def test_copy_nonexistant_dataset(context): - client = Client(None, logger_name=context) with pytest.raises(RedisReplyError): client.copy_dataset("not-a-tensor", "still-not-a-tensor") def test_dataset_get_name(): - """Test getting a dataset name - """ + """Test getting a dataset name""" dataset = Dataset("test-dataset") name = dataset.get_name() assert name == "test-dataset" diff --git a/tests/python/test_errors.py b/tests/python/test_errors.py index 16c81a7e..402439ca 100644 --- a/tests/python/test_errors.py +++ b/tests/python/test_errors.py @@ -25,16 +25,16 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import os +from os import environ import numpy as np import pytest -from os import environ from smartredis import * from smartredis.error import * from smartredis.util import Dtypes +test_gpu = environ.get("SR_TEST_DEVICE", "cpu").lower() == "gpu" -test_gpu = environ.get("SR_TEST_DEVICE","cpu").lower() == "gpu" @pytest.fixture def cfg_opts() -> ConfigOptions: @@ -102,14 +102,16 @@ def test_missing_script_function(context): c = Client(None, logger_name=context) c.set_function("bad-function", bad_function) with pytest.raises(RedisReplyError): - c.run_script("bad-function", "not-a-function-in-script", ["bad-func-tensor"], ["output"]) + c.run_script( + "bad-function", "not-a-function-in-script", ["bad-func-tensor"], ["output"] + ) with pytest.raises(RedisReplyError): - c.run_script("bad-function", "not-a-function-in-script", "bad-func-tensor", "output") + c.run_script( + "bad-function", "not-a-function-in-script", "bad-func-tensor", "output" + ) -@pytest.mark.skipif( - not test_gpu, - reason="SR_TEST_DEVICE does not specify 'gpu'" -) + +@pytest.mark.skipif(not test_gpu, reason="SR_TEST_DEVICE does not specify 'gpu'") def test_bad_function_execution_multigpu(use_cluster, context): """Error raised inside function""" @@ -118,24 +120,41 @@ def test_bad_function_execution_multigpu(use_cluster, context): data = np.array([1, 2, 3, 4]) c.put_tensor("bad-func-tensor", data) with pytest.raises(RedisReplyError): - c.run_script_multigpu("bad-function", "bad_function", ["bad-func-tensor"], ["output"], 0, 0, 2) + c.run_script_multigpu( + "bad-function", "bad_function", ["bad-func-tensor"], ["output"], 0, 0, 2 + ) with pytest.raises(RedisReplyError): - c.run_script_multigpu("bad-function", "bad_function", "bad-func-tensor", "output", 0, 0, 2) + c.run_script_multigpu( + "bad-function", "bad_function", "bad-func-tensor", "output", 0, 0, 2 + ) -@pytest.mark.skipif( - not test_gpu, - reason="SR_TEST_DEVICE does not specify 'gpu'" -) +@pytest.mark.skipif(not test_gpu, reason="SR_TEST_DEVICE does not specify 'gpu'") def test_missing_script_function_multigpu(context): """User requests to run a function not in the script""" c = Client(None, logger_name=context) c.set_function_multigpu("bad-function", bad_function, 0, 1) with pytest.raises(RedisReplyError): - c.run_script_multigpu("bad-function", "not-a-function-in-script", ["bad-func-tensor"], ["output"], 0, 0, 2) + c.run_script_multigpu( + "bad-function", + "not-a-function-in-script", + ["bad-func-tensor"], + ["output"], + 0, + 0, + 2, + ) with pytest.raises(RedisReplyError): - c.run_script_multigpu("bad-function", "not-a-function-in-script", "bad-func-tensor", "output", 0, 0, 2) + c.run_script_multigpu( + "bad-function", + "not-a-function-in-script", + "bad-func-tensor", + "output", + 0, + 0, + 2, + ) def test_wrong_model_name(mock_data, mock_model, context): @@ -172,9 +191,11 @@ def test_bad_device(context): with pytest.raises(TypeError): c.set_script("key", "some_script", device="not-a-gpu") + ##### # Test type errors from bad parameter types to Client API calls + def test_bad_type_put_tensor(context): c = Client(None, logger_name=context) array = np.array([1, 2, 3, 4]) @@ -264,6 +285,7 @@ def test_bad_type_set_function(context): with pytest.raises(TypeError): c.set_function("key", bad_function, 42) + def test_bad_type_set_function_multigpu(context): c = Client(None, logger_name=context) with pytest.raises(TypeError): @@ -275,9 +297,10 @@ def test_bad_type_set_function_multigpu(context): with pytest.raises(TypeError): c.set_function_multigpu("key", bad_function, 0, "not an integer") with pytest.raises(ValueError): - c.set_function_multigpu("key", bad_function, -1, 1) # invalid first GPU + c.set_function_multigpu("key", bad_function, -1, 1) # invalid first GPU with pytest.raises(ValueError): - c.set_function_multigpu("key", bad_function, 0, 0) # invalid num GPUs + c.set_function_multigpu("key", bad_function, 0, 0) # invalid num GPUs + def test_bad_type_set_script(context): c = Client(None, logger_name=context) @@ -291,6 +314,7 @@ def test_bad_type_set_script(context): with pytest.raises(TypeError): c.set_script(key, script, 42) + def test_bad_type_set_script_multigpu(context): c = Client(None, logger_name=context) key = "key_for_script" @@ -310,6 +334,7 @@ def test_bad_type_set_script_multigpu(context): with pytest.raises(ValueError): c.set_script_multigpu(key, script, first_gpu, 0) + def test_bad_type_set_script_from_file(context): c = Client(None, logger_name=context) key = "key_for_script" @@ -322,6 +347,7 @@ def test_bad_type_set_script_from_file(context): with pytest.raises(TypeError): c.set_script_from_file(key, scriptfile, 42) + def test_bad_type_set_script_from_file_multigpu(context): c = Client(None, logger_name=context) key = "key_for_script" @@ -337,6 +363,7 @@ def test_bad_type_set_script_from_file_multigpu(context): with pytest.raises(TypeError): c.set_script_from_file_multigpu(key, scriptfile, first_gpu, "not an integer") + def test_bad_type_get_script(context): c = Client(None, logger_name=context) with pytest.raises(TypeError): @@ -393,11 +420,17 @@ def test_bad_type_run_script_multigpu_str(context): with pytest.raises(TypeError): c.run_script_multigpu(key, fn_name, inputs, 42, offset, first_gpu, num_gpus) with pytest.raises(TypeError): - c.run_script_multigpu(key, fn_name, inputs, outputs, "not an integer", first_gpu, num_gpus) + c.run_script_multigpu( + key, fn_name, inputs, outputs, "not an integer", first_gpu, num_gpus + ) with pytest.raises(TypeError): - c.run_script_multigpu(key, fn_name, inputs, outputs, offset, "not an integer", num_gpus) + c.run_script_multigpu( + key, fn_name, inputs, outputs, offset, "not an integer", num_gpus + ) with pytest.raises(TypeError): - c.run_script_multigpu(key, fn_name, inputs, outputs, offset, first_gpu, "not an integer") + c.run_script_multigpu( + key, fn_name, inputs, outputs, offset, first_gpu, "not an integer" + ) with pytest.raises(ValueError): c.run_script_multigpu(key, fn_name, inputs, outputs, offset, -1, num_gpus) with pytest.raises(ValueError): @@ -422,11 +455,17 @@ def test_bad_type_run_script_multigpu_list(context): with pytest.raises(TypeError): c.run_script_multigpu(key, fn_name, inputs, 42, offset, first_gpu, num_gpus) with pytest.raises(TypeError): - c.run_script_multigpu(key, fn_name, inputs, outputs, "not an integer", first_gpu, num_gpus) + c.run_script_multigpu( + key, fn_name, inputs, outputs, "not an integer", first_gpu, num_gpus + ) with pytest.raises(TypeError): - c.run_script_multigpu(key, fn_name, inputs, outputs, offset, "not an integer", num_gpus) + c.run_script_multigpu( + key, fn_name, inputs, outputs, offset, "not an integer", num_gpus + ) with pytest.raises(TypeError): - c.run_script_multigpu(key, fn_name, inputs, outputs, offset, first_gpu, "not an integer") + c.run_script_multigpu( + key, fn_name, inputs, outputs, offset, first_gpu, "not an integer" + ) with pytest.raises(ValueError): c.run_script_multigpu(key, fn_name, inputs, outputs, offset, -1, num_gpus) with pytest.raises(ValueError): @@ -455,10 +494,13 @@ def test_bad_type_set_model(mock_model, context): with pytest.raises(TypeError): c.set_model("simple_cnn", model, "TORCH", "CPU", batch_size="not_an_integer") with pytest.raises(TypeError): - c.set_model("simple_cnn", model, "TORCH", "CPU", min_batch_size="not_an_integer") + c.set_model( + "simple_cnn", model, "TORCH", "CPU", min_batch_size="not_an_integer" + ) with pytest.raises(TypeError): c.set_model("simple_cnn", model, "TORCH", "CPU", tag=42) + def test_bad_type_set_model_multigpu(mock_model, context): c = Client(None, logger_name=context) model = mock_model.create_torch_cnn() @@ -478,9 +520,13 @@ def test_bad_type_set_model_multigpu(mock_model, context): with pytest.raises(ValueError): c.set_model_multigpu("simple_cnn", model, "TORCH", 0, 0) with pytest.raises(TypeError): - c.set_model_multigpu("simple_cnn", model, "TORCH", 0, 1, batch_size="not_an_integer") + c.set_model_multigpu( + "simple_cnn", model, "TORCH", 0, 1, batch_size="not_an_integer" + ) with pytest.raises(TypeError): - c.set_model_multigpu("simple_cnn", model, "TORCH", 0, 1, min_batch_size="not_an_integer") + c.set_model_multigpu( + "simple_cnn", model, "TORCH", 0, 1, min_batch_size="not_an_integer" + ) with pytest.raises(TypeError): c.set_model_multigpu("simple_cnn", model, "TORCH", 0, 1, tag=42) @@ -501,12 +547,17 @@ def test_bad_type_set_model_from_file(context): with pytest.raises(TypeError): c.set_model_from_file("simple_cnn", modelfile, "TORCH", "BAD_DEVICE") with pytest.raises(TypeError): - c.set_model_from_file("simple_cnn", modelfile, "TORCH", "CPU", batch_size="not_an_integer") + c.set_model_from_file( + "simple_cnn", modelfile, "TORCH", "CPU", batch_size="not_an_integer" + ) with pytest.raises(TypeError): - c.set_model_from_file("simple_cnn", modelfile, "TORCH", "CPU", min_batch_size="not_an_integer") + c.set_model_from_file( + "simple_cnn", modelfile, "TORCH", "CPU", min_batch_size="not_an_integer" + ) with pytest.raises(TypeError): c.set_model_from_file("simple_cnn", modelfile, "TORCH", "CPU", tag=42) + def test_bad_type_set_model_from_file_multigpu(context): modelfile = "bad filename but right parameter type" c = Client(None, logger_name=context) @@ -517,20 +568,33 @@ def test_bad_type_set_model_from_file_multigpu(context): with pytest.raises(TypeError): c.set_model_from_file_multigpu("simple_cnn", modelfile, 42, 0, 1) with pytest.raises(TypeError): - c.set_model_from_file_multigpu("simple_cnn", modelfile, "UNSUPPORTED_ENGINE", 0, 1) + c.set_model_from_file_multigpu( + "simple_cnn", modelfile, "UNSUPPORTED_ENGINE", 0, 1 + ) with pytest.raises(TypeError): - c.set_model_from_file_multigpu("simple_cnn", modelfile, "TORCH", "not an integer", 1) + c.set_model_from_file_multigpu( + "simple_cnn", modelfile, "TORCH", "not an integer", 1 + ) with pytest.raises(TypeError): - c.set_model_from_file_multigpu("simple_cnn", modelfile, "TORCH", 0, "not an integer") + c.set_model_from_file_multigpu( + "simple_cnn", modelfile, "TORCH", 0, "not an integer" + ) with pytest.raises(TypeError): - c.set_model_from_file_multigpu("simple_cnn", modelfile, "TORCH", 0, 1, batch_size="not_an_integer") + c.set_model_from_file_multigpu( + "simple_cnn", modelfile, "TORCH", 0, 1, batch_size="not_an_integer" + ) with pytest.raises(TypeError): - c.set_model_from_file_multigpu("simple_cnn", modelfile, "TORCH", 0, 1, min_batch_size="not_an_integer") + c.set_model_from_file_multigpu( + "simple_cnn", modelfile, "TORCH", 0, 1, min_batch_size="not_an_integer" + ) with pytest.raises(TypeError): - c.set_model_from_file_multigpu("simple_cnn", modelfile, "TORCH", 0, 1, min_batch_timeout="not_an_integer") + c.set_model_from_file_multigpu( + "simple_cnn", modelfile, "TORCH", 0, 1, min_batch_timeout="not_an_integer" + ) with pytest.raises(TypeError): c.set_model_from_file_multigpu("simple_cnn", modelfile, "TORCH", 0, 1, tag=42) + def test_bad_type_run_model(context): c = Client(None, logger_name=context) with pytest.raises(TypeError): @@ -552,12 +616,13 @@ def test_bad_type_run_model_multigpu(context): with pytest.raises(ValueError): c.run_model_multigpu("simple_cnn", 0, 0, 0) + def test_bad_type_delete_model_multigpu(context): c = Client(None, logger_name=context) with pytest.raises(TypeError): c.delete_model_multigpu(42, 0, 1) with pytest.raises(TypeError): - c.delete_model_multigpu("simple_cnn", "not an integer", 1) + c.delete_model_multigpu("simple_cnn", "not an integer", 1) with pytest.raises(TypeError): c.delete_model_multigpu("simple_cnn", 0, "not an integer") with pytest.raises(ValueError): @@ -565,13 +630,14 @@ def test_bad_type_delete_model_multigpu(context): with pytest.raises(ValueError): c.delete_model_multigpu("simple_cnn", 0, 0) + def test_bad_type_delete_script_multigpu(context): c = Client(None, logger_name=context) script_name = "my_script" with pytest.raises(TypeError): c.delete_script_multigpu(42, 0, 1) with pytest.raises(TypeError): - c.delete_script_multigpu(script_name, "not an integer", 1) + c.delete_script_multigpu(script_name, "not an integer", 1) with pytest.raises(TypeError): c.delete_script_multigpu(script_name, 0, "not an integer") with pytest.raises(ValueError): @@ -579,6 +645,7 @@ def test_bad_type_delete_script_multigpu(context): with pytest.raises(ValueError): c.delete_script_multigpu(script_name, 0, 0) + def test_bad_type_tensor_exists(context): c = Client(None, logger_name=context) with pytest.raises(TypeError): @@ -745,16 +812,19 @@ def test_bad_type_save(context): with pytest.raises(TypeError): c.save("not a list") + def test_bad_type_append_to_list(context): c = Client(None, logger_name=context) with pytest.raises(TypeError): c.append_to_list(42, 42) + def test_bad_type_delete_list(context): c = Client(None, logger_name=context) with pytest.raises(TypeError): c.delete_list(42) + def test_bad_type_copy_list(context): c = Client(None, logger_name=context) with pytest.raises(TypeError): @@ -762,6 +832,7 @@ def test_bad_type_copy_list(context): with pytest.raises(TypeError): c.copy_list("src", 42) + def test_bad_type_rename_list(context): c = Client(None, logger_name=context) with pytest.raises(TypeError): @@ -769,11 +840,13 @@ def test_bad_type_rename_list(context): with pytest.raises(TypeError): c.rename_list("src", 42) + def test_bad_type_get_list_length(context): c = Client(None, logger_name=context) with pytest.raises(TypeError): c.get_list_length(42) + def test_bad_type_poll_list_length(context): c = Client(None, logger_name=context) name = "mylist" @@ -789,6 +862,7 @@ def test_bad_type_poll_list_length(context): with pytest.raises(TypeError): c.poll_list_length(name, len, pollfreq, "not an integer") + def test_bad_type_poll_list_length_gte(context): c = Client(None, logger_name=context) name = "mylist" @@ -804,6 +878,7 @@ def test_bad_type_poll_list_length_gte(context): with pytest.raises(TypeError): c.poll_list_length_gte(name, len, pollfreq, "not an integer") + def test_bad_type_poll_list_length_lte(context): c = Client(None, logger_name=context) name = "mylist" @@ -819,11 +894,13 @@ def test_bad_type_poll_list_length_lte(context): with pytest.raises(TypeError): c.poll_list_length_lte(name, len, pollfreq, "not an integer") + def test_bad_type_get_datasets_from_list(context): c = Client(None, logger_name=context) with pytest.raises(TypeError): c.get_datasets_from_list(42) + def test_bad_type_get_dataset_list_range(context): c = Client(None, logger_name=context) listname = "my_list" @@ -836,17 +913,18 @@ def test_bad_type_get_dataset_list_range(context): with pytest.raises(TypeError): c.get_dataset_list_range(listname, start_index, "not an integer") + def test_bad_type_set_model_chunk_size(context): c = Client(None, logger_name=context) with pytest.raises(TypeError): c.set_model_chunk_size("not an integer") + ##### # Test type errors from bad parameter types to logging calls -@pytest.mark.parametrize("log_fn", [ - (log_data,), (log_warning,), (log_error,) -]) + +@pytest.mark.parametrize("log_fn", [(log_data,), (log_warning,), (log_error,)]) def test_bad_type_log_function(context, log_fn): c = Client(None, logger_name=context) with pytest.raises(TypeError): @@ -856,6 +934,7 @@ def test_bad_type_log_function(context, log_fn): with pytest.raises(TypeError): log_fn("test_bad_type_log_function", LLInfo, 42) + def test_bad_type_client_log(context): c = Client(None, logger_name=context) with pytest.raises(TypeError): @@ -871,6 +950,7 @@ def test_bad_type_client_log(context): with pytest.raises(TypeError): c.log_error(LLInfo, 42) + def test_bad_type_dataset_log(context): d = Dataset(context) with pytest.raises(TypeError): @@ -886,6 +966,7 @@ def test_bad_type_dataset_log(context): with pytest.raises(TypeError): d.log_error(LLInfo, 42) + def test_bad_type_logcontext_log(context): lc = LogContext(context) with pytest.raises(TypeError): @@ -901,13 +982,16 @@ def test_bad_type_logcontext_log(context): with pytest.raises(TypeError): lc.log_error(LLInfo, 42) + ##### # Test type errors from bad parameter types to Dataset API calls + def test_bad_type_dataset(): with pytest.raises(TypeError): d = Dataset(42) + def test_bad_type_add_tensor(): d = Dataset("test-dataset") with pytest.raises(TypeError): @@ -934,8 +1018,7 @@ def test_set_data_wrong_type(): def test_add_tensor_wrong_type(): - """A call to Dataset.add_tensor is made with the wrong type - """ + """A call to Dataset.add_tensor is made with the wrong type""" d = Dataset("test_dataset") data = np.array([1, 2, 3, 4]) with pytest.raises(TypeError): @@ -943,27 +1026,26 @@ def test_add_tensor_wrong_type(): with pytest.raises(TypeError): d.add_tensor("tensorname", 42) + def test_get_tensor_wrong_type(): - """A call to Dataset.get_tensor is made with the wrong type - """ + """A call to Dataset.get_tensor is made with the wrong type""" d = Dataset("test_dataset") with pytest.raises(TypeError): d.get_tensor(42) def test_add_meta_scalar_wrong_type(): - """A call to Dataset.add_meta_scalar is made with the wrong type - """ + """A call to Dataset.add_meta_scalar is made with the wrong type""" d = Dataset("test_dataset") data = np.array([1, 2, 3, 4]) with pytest.raises(TypeError): d.add_meta_scalar(42, 42) with pytest.raises(TypeError): - d.add_meta_scalar("scalarname", data) # array, not scalar + d.add_meta_scalar("scalarname", data) # array, not scalar + def test_add_meta_string_wrong_type(): - """A call to Dataset.add_meta_string is made with the wrong type - """ + """A call to Dataset.add_meta_string is made with the wrong type""" d = Dataset("test_dataset") with pytest.raises(TypeError): d.add_meta_string(42, "metastring") @@ -972,78 +1054,82 @@ def test_add_meta_string_wrong_type(): def test_get_meta_scalars_wrong_type(): - """A call to Dataset.get_meta_scalars is made with the wrong type - """ + """A call to Dataset.get_meta_scalars is made with the wrong type""" d = Dataset("test_dataset") with pytest.raises(TypeError): d.get_meta_scalars(42) def test_get_meta_strings_wrong_type(): - """A call to Dataset.get_meta_strings is made with the wrong type - """ + """A call to Dataset.get_meta_strings is made with the wrong type""" d = Dataset("test_dataset") with pytest.raises(TypeError): d.get_meta_strings(42) + def test_get_tensor_type_wrong_type(): - """A call to Dataset.get_tensor_type is made with the wrong type - """ + """A call to Dataset.get_tensor_type is made with the wrong type""" d = Dataset("test_dataset") with pytest.raises(TypeError): d.get_tensor_type(42) + def test_get_metadata_field_type_wrong_type(): - """A call to Dataset.get_metadata_field_type is made with the wrong type - """ + """A call to Dataset.get_metadata_field_type is made with the wrong type""" d = Dataset("test_dataset") with pytest.raises(TypeError): d.get_metadata_field_type(42) + def test_from_string_wrong_type(): - """A call to Dataset.get_metadata_field_type is made with the wrong type - """ + """A call to Dataset.get_metadata_field_type is made with the wrong type""" with pytest.raises(TypeError): Dtypes.from_string("Incorrect input") + def test_metadata_from_numpy_wrong_type(): - """A call to Dataset.add_meta_scalar is made with the wrong type - """ + """A call to Dataset.add_meta_scalar is made with the wrong type""" array = np.array(["Incorrect Input"]) with pytest.raises(TypeError): Dtypes.metadata_from_numpy(array) + def test_get_tensor_names_wrong_type(): - """A call to Dataset.get_tensor_names is made with the wrong type - """ + """A call to Dataset.get_tensor_names is made with the wrong type""" d = Dataset("test_dataset") with pytest.raises(TypeError): d.get_tensor_names(42) + ##### # Test type errors from bad parameter types to ConfigOptions API calls + def test_create_from_environment_wrong_type(): """Ensure create_from_environment doesn't accept an invalid db_prefix param""" with pytest.raises(TypeError): _ = ConfigOptions.create_from_environment(42) + def test_get_integer_option_wrong_type(cfg_opts: ConfigOptions): """Ensure get_integer_option raises an exception on an invalid key type""" with pytest.raises(TypeError): _ = cfg_opts.get_integer_option(42) + def test_get_string_option_wrong_type(cfg_opts: ConfigOptions): """Ensure get_string_option raises an exception on an invalid key type""" with pytest.raises(TypeError): _ = cfg_opts.get_string_option(42) + def test_is_configured_wrong_type(cfg_opts: ConfigOptions): """Ensure is_configured raises an exception on an invalid key type""" with pytest.raises(TypeError): _ = cfg_opts.is_configured(42) + def test_override_integer_option_wrong_type(cfg_opts: ConfigOptions): """Ensure override_integer_option raises an exception on an invalid key type and when an invalid value for the target storage type is encountered""" @@ -1057,6 +1143,7 @@ def test_override_integer_option_wrong_type(cfg_opts: ConfigOptions): with pytest.raises(TypeError): _ = cfg_opts.override_integer_option(key, value) + def test_override_string_option_wrong_type(cfg_opts: ConfigOptions): """Ensure override_string_option raises an exception on an invalid key type and when an invalid value for the target storage type is encountered""" @@ -1071,9 +1158,11 @@ def test_override_string_option_wrong_type(cfg_opts: ConfigOptions): with pytest.raises(TypeError): _ = cfg_opts.override_string_option(key, value) + #### # Utility functions + def bad_function(data): """Bad function which only raises an exception""" return False diff --git a/tests/python/test_logging.py b/tests/python/test_logging.py index a1584055..0406a4d8 100644 --- a/tests/python/test_logging.py +++ b/tests/python/test_logging.py @@ -24,41 +24,51 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import pytest from smartredis import * from smartredis.error import * -import pytest -@pytest.mark.parametrize("log_level", [ - LLQuiet, LLInfo, LLDebug, LLDeveloper -]) + +@pytest.mark.parametrize("log_level", [LLQuiet, LLInfo, LLDebug, LLDeveloper]) def test_logging_string(context, log_level): - log_data(context, log_level, f"This is data logged from a string ({log_level.name})") - log_warning(context, log_level, f"This is a warning logged from a string ({log_level.name})") - log_error(context, log_level, f"This is an error logged from a string ({log_level.name})") + log_data( + context, log_level, f"This is data logged from a string ({log_level.name})" + ) + log_warning( + context, log_level, f"This is a warning logged from a string ({log_level.name})" + ) + log_error( + context, log_level, f"This is an error logged from a string ({log_level.name})" + ) -@pytest.mark.parametrize("log_level", [ - LLQuiet, LLInfo, LLDebug, LLDeveloper -]) + +@pytest.mark.parametrize("log_level", [LLQuiet, LLInfo, LLDebug, LLDeveloper]) def test_logging_client(context, log_level): c = Client(None, logger_name=context) c.log_data(log_level, f"This is data logged from a client ({log_level.name})") - c.log_warning(log_level, f"This is a warning logged from a client ({log_level.name})") + c.log_warning( + log_level, f"This is a warning logged from a client ({log_level.name})" + ) c.log_error(log_level, f"This is an error logged from a client ({log_level.name})") -@pytest.mark.parametrize("log_level", [ - LLQuiet, LLInfo, LLDebug, LLDeveloper -]) + +@pytest.mark.parametrize("log_level", [LLQuiet, LLInfo, LLDebug, LLDeveloper]) def test_logging_dataset(context, log_level): d = Dataset(context) d.log_data(log_level, f"This is data logged from a dataset ({log_level.name})") - d.log_warning(log_level, f"This is a warning logged from a dataset ({log_level.name})") + d.log_warning( + log_level, f"This is a warning logged from a dataset ({log_level.name})" + ) d.log_error(log_level, f"This is an error logged from a dataset ({log_level.name})") -@pytest.mark.parametrize("log_level", [ - LLQuiet, LLInfo, LLDebug, LLDeveloper -]) + +@pytest.mark.parametrize("log_level", [LLQuiet, LLInfo, LLDebug, LLDeveloper]) def test_logging_logcontext(context, log_level): lc = LogContext(context) lc.log_data(log_level, f"This is data logged from a logcontext ({log_level.name})") - lc.log_warning(log_level, f"This is a warning logged from a logcontext ({log_level.name})") - lc.log_error(log_level, f"This is an error logged from a logcontext ({log_level.name})") + lc.log_warning( + log_level, f"This is a warning logged from a logcontext ({log_level.name})" + ) + lc.log_error( + log_level, f"This is an error logged from a logcontext ({log_level.name})" + ) diff --git a/tests/python/test_model_methods_torch.py b/tests/python/test_model_methods_torch.py index 6a3d6be1..41dea800 100644 --- a/tests/python/test_model_methods_torch.py +++ b/tests/python/test_model_methods_torch.py @@ -25,14 +25,15 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import os +from os import environ -import torch import pytest -from os import environ +import torch from smartredis import Client from smartredis.error import * -test_gpu = environ.get("SR_TEST_DEVICE","cpu").lower() == "gpu" +test_gpu = environ.get("SR_TEST_DEVICE", "cpu").lower() == "gpu" + def test_set_model(mock_model, context): model = mock_model.create_torch_cnn() @@ -74,6 +75,7 @@ def test_torch_inference(mock_model, context): out_data = c.get_tensor("torch_cnn_output") assert out_data.shape == (1, 1, 1, 1) + def test_batch_exceptions(mock_model, context): # get model and set into database mock_model.create_torch_cnn(filepath="./torch_cnn.pt") @@ -84,113 +86,199 @@ def test_batch_exceptions(mock_model, context): min_batch_timeout = 1 with pytest.raises(RedisRuntimeError): c.set_model_from_file( - "file_cnn", "./torch_cnn.pt", "TORCH", "CPU", - batch_size=0, min_batch_size=0, min_batch_timeout=min_batch_timeout + "file_cnn", + "./torch_cnn.pt", + "TORCH", + "CPU", + batch_size=0, + min_batch_size=0, + min_batch_timeout=min_batch_timeout, ) with pytest.raises(RedisRuntimeError): c.set_model_from_file( - "file_cnn", "./torch_cnn.pt", "TORCH", "CPU", - batch_size=0, min_batch_size=min_batch_size, min_batch_timeout=0 + "file_cnn", + "./torch_cnn.pt", + "TORCH", + "CPU", + batch_size=0, + min_batch_size=min_batch_size, + min_batch_timeout=0, ) with pytest.raises(RedisRuntimeError): c.set_model_from_file( - "file_cnn", "./torch_cnn.pt", "TORCH", "CPU", - batch_size=batch_size, min_batch_size=0, min_batch_timeout=min_batch_timeout + "file_cnn", + "./torch_cnn.pt", + "TORCH", + "CPU", + batch_size=batch_size, + min_batch_size=0, + min_batch_timeout=min_batch_timeout, ) with pytest.raises(RedisRuntimeError): c.set_model_from_file_multigpu( - "file_cnn", "./torch_cnn.pt", "TORCH", 1, 1, - batch_size=0, min_batch_size=0, min_batch_timeout=min_batch_timeout + "file_cnn", + "./torch_cnn.pt", + "TORCH", + 1, + 1, + batch_size=0, + min_batch_size=0, + min_batch_timeout=min_batch_timeout, ) with pytest.raises(RedisRuntimeError): c.set_model_from_file_multigpu( - "file_cnn", "./torch_cnn.pt", "TORCH", 1, 1, - batch_size=0, min_batch_size=min_batch_size, min_batch_timeout=0 + "file_cnn", + "./torch_cnn.pt", + "TORCH", + 1, + 1, + batch_size=0, + min_batch_size=min_batch_size, + min_batch_timeout=0, ) with pytest.raises(RedisRuntimeError): c.set_model_from_file_multigpu( - "file_cnn", "./torch_cnn.pt", "TORCH", 1, 1, - batch_size=batch_size, min_batch_size=0, min_batch_timeout=min_batch_timeout + "file_cnn", + "./torch_cnn.pt", + "TORCH", + 1, + 1, + batch_size=batch_size, + min_batch_size=0, + min_batch_timeout=min_batch_timeout, ) with pytest.raises(RedisRuntimeError): c.set_model( - "file_cnn", model, "TORCH", "CPU", - batch_size=0, min_batch_size=0, min_batch_timeout=min_batch_timeout + "file_cnn", + model, + "TORCH", + "CPU", + batch_size=0, + min_batch_size=0, + min_batch_timeout=min_batch_timeout, ) with pytest.raises(RedisRuntimeError): c.set_model( - "file_cnn", model, "TORCH", "CPU", - batch_size=0, min_batch_size=min_batch_size, min_batch_timeout=0 + "file_cnn", + model, + "TORCH", + "CPU", + batch_size=0, + min_batch_size=min_batch_size, + min_batch_timeout=0, ) with pytest.raises(RedisRuntimeError): c.set_model( - "file_cnn", model, "TORCH", "CPU", - batch_size=batch_size, min_batch_size=0, min_batch_timeout=min_batch_timeout + "file_cnn", + model, + "TORCH", + "CPU", + batch_size=batch_size, + min_batch_size=0, + min_batch_timeout=min_batch_timeout, ) with pytest.raises(RedisRuntimeError): c.set_model_multigpu( - "file_cnn", model, "TORCH", 1, 1, - batch_size=0, min_batch_size=0, min_batch_timeout=min_batch_timeout + "file_cnn", + model, + "TORCH", + 1, + 1, + batch_size=0, + min_batch_size=0, + min_batch_timeout=min_batch_timeout, ) with pytest.raises(RedisRuntimeError): c.set_model_multigpu( - "file_cnn", model, "TORCH", 1, 1, - batch_size=0, min_batch_size=min_batch_size, min_batch_timeout=0 + "file_cnn", + model, + "TORCH", + 1, + 1, + batch_size=0, + min_batch_size=min_batch_size, + min_batch_timeout=0, ) with pytest.raises(RedisRuntimeError): c.set_model_multigpu( - "file_cnn", model, "TORCH", 1, 1, - batch_size=batch_size, min_batch_size=0, min_batch_timeout=min_batch_timeout + "file_cnn", + model, + "TORCH", + 1, + 1, + batch_size=batch_size, + min_batch_size=0, + min_batch_timeout=min_batch_timeout, ) + def test_batch_warning_set_model_from_file(mock_model, context, capfd): # get model and set into database mock_model.create_torch_cnn(filepath="./torch_cnn.pt") c = Client(None, logger_name=context) c.set_model_from_file( - "file_cnn", "./torch_cnn.pt", "TORCH", "CPU", - batch_size=1, min_batch_size=1, min_batch_timeout=0 + "file_cnn", + "./torch_cnn.pt", + "TORCH", + "CPU", + batch_size=1, + min_batch_size=1, + min_batch_timeout=0, ) captured = capfd.readouterr() assert "WARNING" in captured.err -@pytest.mark.skipif( - not test_gpu, - reason="SR_TEST_DEVICE does not specify 'gpu'" -) + +@pytest.mark.skipif(not test_gpu, reason="SR_TEST_DEVICE does not specify 'gpu'") def test_batch_warning_set_model_from_file_multigpu(mock_model, context, capfd): # get model and set into database mock_model.create_torch_cnn(filepath="./torch_cnn.pt") c = Client(None, logger_name=context) c.set_model_from_file_multigpu( - "file_cnn", "./torch_cnn.pt", "TORCH", 1, 1, - batch_size=1, min_batch_size=1, min_batch_timeout=0 + "file_cnn", + "./torch_cnn.pt", + "TORCH", + 1, + 1, + batch_size=1, + min_batch_size=1, + min_batch_timeout=0, ) captured = capfd.readouterr() assert "WARNING" in captured.err + def test_batch_warning_set_model(mock_model, context, capfd): # get model and set into database model = mock_model.create_torch_cnn() c = Client(None, logger_name=context) c.set_model( - "file_cnn", model, "TORCH", "CPU", - batch_size=1, min_batch_size=1, min_batch_timeout=0 + "file_cnn", + model, + "TORCH", + "CPU", + batch_size=1, + min_batch_size=1, + min_batch_timeout=0, ) captured = capfd.readouterr() assert "WARNING" in captured.err -@pytest.mark.skipif( - not test_gpu, - reason="SR_TEST_DEVICE does not specify 'gpu'" -) + +@pytest.mark.skipif(not test_gpu, reason="SR_TEST_DEVICE does not specify 'gpu'") def test_batch_warning_set_model_multigpu(mock_model, context, capfd): # get model and set into database model = mock_model.create_torch_cnn() c = Client(None, logger_name=context) c.set_model_multigpu( - "file_cnn", model, "TORCH", 1, 1, - batch_size=1, min_batch_size=1, min_batch_timeout=0 + "file_cnn", + model, + "TORCH", + 1, + 1, + batch_size=1, + min_batch_size=1, + min_batch_timeout=0, ) captured = capfd.readouterr() assert "WARNING" in captured.err diff --git a/tests/python/test_nonkeyed_cmd.py b/tests/python/test_nonkeyed_cmd.py index ec9d7f05..626b703c 100644 --- a/tests/python/test_nonkeyed_cmd.py +++ b/tests/python/test_nonkeyed_cmd.py @@ -28,21 +28,21 @@ import numpy as np import pytest -from smartredis import Client -from smartredis import ConfigOptions +from smartredis import Client, ConfigOptions from smartredis.error import * def test_dbnode_info_command(context): ssdb = os.environ["SSDB"] - addresses = ssdb.split(',') + addresses = ssdb.split(",") client = Client(None, logger_name=context) info = client.get_db_node_info(addresses) assert len(info) > 0 + def test_dbcluster_info_command(mock_model, context): ssdb = os.environ["SSDB"] - addresses = ssdb.split(',') + addresses = ssdb.split(",") co = ConfigOptions().create_from_environment("") client = Client(co, logger_name=context) @@ -72,12 +72,13 @@ def test_dbcluster_info_command(mock_model, context): with pytest.raises(RedisRuntimeError): client.get_ai_info(addresses, "bad_key") + def test_flushdb_command(context): # from within the testing framework, there is no way # of knowing each db node that is being used, so skip # if on cluster ssdb = os.environ["SSDB"] - addresses = ssdb.split(',') + addresses = ssdb.split(",") if os.environ["SR_DB_TYPE"] == "Clustered": return @@ -130,8 +131,12 @@ def test_save_command(context): # for each address, check that the timestamp of the last SAVE increases after calling Client::save for address in addresses: - save_time_before = client.get_db_node_info([address])[0]["Persistence"]["rdb_last_save_time"] + save_time_before = client.get_db_node_info([address])[0]["Persistence"][ + "rdb_last_save_time" + ] client.save([address]) - save_time_after = client.get_db_node_info([address])[0]["Persistence"]["rdb_last_save_time"] + save_time_after = client.get_db_node_info([address])[0]["Persistence"][ + "rdb_last_save_time" + ] assert save_time_before <= save_time_after diff --git a/tests/python/test_prefixing.py b/tests/python/test_prefixing.py index 4677d5a3..582be11c 100644 --- a/tests/python/test_prefixing.py +++ b/tests/python/test_prefixing.py @@ -24,11 +24,12 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import numpy as np import os +import numpy as np from smartredis import Client, Dataset + def test_prefixing(context, monkeypatch): # configure prefix variables monkeypatch.setenv("SSKEYOUT", "prefix_test") @@ -55,6 +56,7 @@ def test_prefixing(context, monkeypatch): assert c.key_exists("prefix_test.test_tensor") assert not c.key_exists("test_tensor") + def test_model_prefixing(mock_model, context, monkeypatch): # configure prefix variables monkeypatch.setenv("SSKEYOUT", "prefix_test") @@ -101,8 +103,10 @@ def test_list_prefixing(context, monkeypatch): assert c.key_exists("prefix_test.dataset_test_list") assert not c.key_exists("dataset_test_list") + # ------------ helper functions --------------------------------- + def create_dataset(name): array = np.array([1, 2, 3, 4]) string = "test_meta_strings" @@ -112,4 +116,4 @@ def create_dataset(name): dataset.add_tensor("test_array", array) dataset.add_meta_string("test_string", string) dataset.add_meta_scalar("test_scalar", scalar) - return dataset \ No newline at end of file + return dataset diff --git a/tests/python/test_put_get_bytes.py b/tests/python/test_put_get_bytes.py index a56c718a..9291c659 100644 --- a/tests/python/test_put_get_bytes.py +++ b/tests/python/test_put_get_bytes.py @@ -24,10 +24,10 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import numpy as np -import os import io +import os +import numpy as np from smartredis import Client # ----- Tests ----------------------------------------------------------- @@ -38,12 +38,11 @@ def test_put_get_bytes(mock_data, context): client = Client(None, logger_name=context) - data = np.random.rand(2,8,4,2,30) + data = np.random.rand(2, 8, 4, 2, 30) bytes = io.BytesIO(data.tobytes()) - + client.put_bytes("python_bytes", bytes) retrieved_bytes = client.get_bytes("python_bytes") - assert(bytes.getvalue() == retrieved_bytes.getvalue()) - + assert bytes.getvalue() == retrieved_bytes.getvalue() diff --git a/tests/python/test_put_get_tensor.py b/tests/python/test_put_get_tensor.py index f0bfc5c7..c63223ce 100644 --- a/tests/python/test_put_get_tensor.py +++ b/tests/python/test_put_get_tensor.py @@ -24,10 +24,9 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import numpy as np import os - +import numpy as np from smartredis import Client # ----- Tests ----------------------------------------------------------- diff --git a/tests/python/test_script_methods.py b/tests/python/test_script_methods.py index baa43c10..3b8c4512 100644 --- a/tests/python/test_script_methods.py +++ b/tests/python/test_script_methods.py @@ -24,18 +24,20 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import pytest import inspect import os.path as osp from os import environ + import numpy as np +import pytest import torch from smartredis import Client file_path = osp.dirname(osp.abspath(__file__)) -test_gpu = environ.get("SR_TEST_DEVICE","cpu").lower() == "gpu" +test_gpu = environ.get("SR_TEST_DEVICE", "cpu").lower() == "gpu" + def test_set_get_function(context): c = Client(None, logger_name=context) @@ -97,24 +99,22 @@ def test_run_script_list(context): out, expected, "Returned array from script not equal to expected result" ) -@pytest.mark.skipif( - not test_gpu, - reason="SR_TEST_DEVICE does not specify 'gpu'" -) + +@pytest.mark.skipif(not test_gpu, reason="SR_TEST_DEVICE does not specify 'gpu'") def test_run_script_multigpu_str(use_cluster, context): data = np.array([[1, 2, 3, 4, 5]]) c = Client(None, use_cluster, logger_name=context) c.put_tensor("script-test-data", data) c.set_function_multigpu("one-to-one", one_to_one, 0, 2) - c.run_script_multigpu("one-to-one", "one_to_one", "script-test-data", "script-test-out", 0, 0, 2) + c.run_script_multigpu( + "one-to-one", "one_to_one", "script-test-data", "script-test-out", 0, 0, 2 + ) out = c.get_tensor("script-test-out") assert out == 5 -@pytest.mark.skipif( - not test_gpu, - reason="SR_TEST_DEVICE does not specify 'gpu'" -) + +@pytest.mark.skipif(not test_gpu, reason="SR_TEST_DEVICE does not specify 'gpu'") def test_run_script_multigpu_list(use_cluster, context): data = np.array([[1, 2, 3, 4]]) data_2 = np.array([[5, 6, 7, 8]]) @@ -130,7 +130,7 @@ def test_run_script_multigpu_list(use_cluster, context): ["srpt-multi-out-output"], 0, 0, - 2 + 2, ) out = c.get_tensor("srpt-multi-out-output") expected = np.array([4, 8]) diff --git a/tests/python/test_tensor_ops.py b/tests/python/test_tensor_ops.py index 2061de29..a40c0e95 100644 --- a/tests/python/test_tensor_ops.py +++ b/tests/python/test_tensor_ops.py @@ -25,6 +25,7 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import os + import numpy as np import pytest from smartredis import Client @@ -78,14 +79,12 @@ def test_delete_tensor(context): def test_rename_nonexisting_key(context): - client = Client(None, logger_name=context) with pytest.raises(RedisReplyError): client.rename_tensor("not-a-tensor", "still-not-a-tensor") def test_copy_nonexistant_key(context): - client = Client(None, logger_name=context) with pytest.raises(RedisReplyError): client.copy_tensor("not-a-tensor", "still-not-a-tensor")