Skip to content

Commit

Permalink
Make style
Browse files Browse the repository at this point in the history
  • Loading branch information
mellis13 committed Oct 31, 2024
1 parent fec070b commit 64a1532
Show file tree
Hide file tree
Showing 19 changed files with 463 additions and 243 deletions.
14 changes: 7 additions & 7 deletions src/python/module/smartredis/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,15 @@

# pylint: disable=too-many-lines,too-many-public-methods
import inspect
import io
import os
import os.path as osp
import typing as t

import numpy as np
import io

from .dataset import Dataset
from .configoptions import ConfigOptions
from .dataset import Dataset
from .error import RedisConnectionError
from .smartredisPy import PyClient
from .smartredisPy import RedisReplyError as PybindRedisReplyError
Expand Down Expand Up @@ -95,7 +96,7 @@ def __address_construction(
self,
cluster: bool,
address: t.Optional[str] = None,
logger_name: str = "Default"
logger_name: str = "Default",
) -> PyClient:
"""Initialize a SmartRedis client
Expand Down Expand Up @@ -129,8 +130,7 @@ def __address_construction(

@staticmethod
def __standard_construction(
config_options: t.Optional[ConfigOptions] = None,
logger_name: str = "Default"
config_options: t.Optional[ConfigOptions] = None, logger_name: str = "Default"
) -> PyClient:
"""Initialize a RedisAI client
Expand Down Expand Up @@ -607,7 +607,7 @@ def run_script(
name: str,
fn_name: str,
inputs: t.Union[str, t.List[str]],
outputs: t.Union[str, t.List[str]]
outputs: t.Union[str, t.List[str]],
) -> None:
"""Execute TorchScript stored inside the database
Expand Down Expand Up @@ -1472,7 +1472,7 @@ def use_dataset_ensemble_prefix(self, use_prefix: bool) -> None:
"""
typecheck(use_prefix, "use_prefix", bool)
return self._client.use_dataset_ensemble_prefix(use_prefix)

@exception_handler
def use_bytes_ensemble_prefix(self, use_prefix: bool) -> None:
"""Control whether byte keys are prefixed (e.g. in an
Expand Down
1 change: 0 additions & 1 deletion src/python/module/smartredis/configoptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from .smartredisPy import PyConfigOptions
from .util import exception_handler, typecheck


if t.TYPE_CHECKING:
from typing_extensions import ParamSpec

Expand Down
6 changes: 2 additions & 4 deletions src/python/module/smartredis/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,15 +198,13 @@ def transform_to_xarray(dataset: Dataset) -> t.Dict:
# Extract dimensions in correct form
dims_final = [
dataset.get_meta_strings(dim_field_name)[0]
for dim_field_name
in get_data(dataset, variable_name, "dim")
for dim_field_name in get_data(dataset, variable_name, "dim")
]

# Extract attributes in correct form
attrs_final = {
attr_field_name: dataset.get_meta_strings(attr_field_name)[0]
for attr_field_name
in get_data(dataset, variable_name, "attr")
for attr_field_name in get_data(dataset, variable_name, "attr")
}

# Add coordinates to the correct data name
Expand Down
2 changes: 1 addition & 1 deletion src/python/module/smartredis/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
from functools import wraps

import numpy as np
from . import error

from . import error
from .smartredisPy import RedisReplyError as PybindRedisReplyError
from .smartredisPy import c_get_last_error_location

Expand Down
75 changes: 45 additions & 30 deletions tests/python/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import os
import pytest

import pytest
from smartredis import Client, ConfigOptions


def test_serialization(context):
c = Client(None, logger_name=context)
assert str(c) != repr(c)
Expand All @@ -45,41 +46,53 @@ def test_address(context):
# check if SSDB was set anyway
assert os.environ["SSDB"] == ssdb


# Globals for Client constructor testing
ac_original = Client._Client__address_construction
sc_original = Client._Client__standard_construction
cluster_mode = os.environ["SR_DB_TYPE"] == "Clustered"
target_address = os.environ["SSDB"]
co_envt = ConfigOptions.create_from_environment("")


@pytest.mark.parametrize(
"args, kwargs, expected_constructor", [
# address constructions
[(False,), {}, "address"],
[(False,), {"address": target_address}, "address"],
[(False,), {"address": target_address, "logger_name": "log_name"}, "address"],
[(False,), {"logger_name": "log_name"}, "address"],
[(False, target_address), {}, "address"],
[(False, target_address), {"logger_name": "log_name"}, "address"],
[(False, target_address, "log_name"), {}, "address"],
[(), {"cluster": cluster_mode}, "address"],
[(), {"cluster": cluster_mode, "address": target_address}, "address"],
[(), {"cluster": cluster_mode, "address": target_address, "logger_name": "log_name"}, "address"],
[(), {"cluster": cluster_mode, "logger_name": "log_name"}, "address"],
# standard constructions
[(None,), {}, "standard"],
[(None,), {"logger_name": "log_name"}, "standard"],
[(None, "log_name"), {}, "standard"],
[(co_envt,), {}, "standard"],
[(co_envt,), {"logger_name": "log_name"}, "standard"],
[(co_envt, "log_name"), {}, "standard"],
[(), {}, "standard"],
[(), {"config_options": None}, "standard"],
[(), {"config_options": None, "logger_name": "log_name"}, "standard"],
[(), {"config_options": co_envt}, "standard"],
[(), {"config_options": co_envt, "logger_name": "log_name"}, "standard"],
[(), {"logger_name": "log_name"}, "standard"],
])
"args, kwargs, expected_constructor",
[
# address constructions
[(False,), {}, "address"],
[(False,), {"address": target_address}, "address"],
[(False,), {"address": target_address, "logger_name": "log_name"}, "address"],
[(False,), {"logger_name": "log_name"}, "address"],
[(False, target_address), {}, "address"],
[(False, target_address), {"logger_name": "log_name"}, "address"],
[(False, target_address, "log_name"), {}, "address"],
[(), {"cluster": cluster_mode}, "address"],
[(), {"cluster": cluster_mode, "address": target_address}, "address"],
[
(),
{
"cluster": cluster_mode,
"address": target_address,
"logger_name": "log_name",
},
"address",
],
[(), {"cluster": cluster_mode, "logger_name": "log_name"}, "address"],
# standard constructions
[(None,), {}, "standard"],
[(None,), {"logger_name": "log_name"}, "standard"],
[(None, "log_name"), {}, "standard"],
[(co_envt,), {}, "standard"],
[(co_envt,), {"logger_name": "log_name"}, "standard"],
[(co_envt, "log_name"), {}, "standard"],
[(), {}, "standard"],
[(), {"config_options": None}, "standard"],
[(), {"config_options": None, "logger_name": "log_name"}, "standard"],
[(), {"config_options": co_envt}, "standard"],
[(), {"config_options": co_envt, "logger_name": "log_name"}, "standard"],
[(), {"logger_name": "log_name"}, "standard"],
],
)
def test_client_constructor(args, kwargs, expected_constructor, monkeypatch):
ac_got_called = False
sc_got_called = False
Expand All @@ -96,9 +109,11 @@ def mock_standard_constructor(*a, **kw):
return sc_original(*a, **kw)

monkeypatch.setattr(
Client, "_Client__address_construction", mock_address_constructor)
Client, "_Client__address_construction", mock_address_constructor
)
monkeypatch.setattr(
Client, "_Client__standard_construction", mock_standard_constructor)
Client, "_Client__standard_construction", mock_standard_constructor
)

Client(*args, **kwargs)

Expand Down
14 changes: 8 additions & 6 deletions tests/python/test_configoptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#####
# Test attempts to use API functions from non-factory object


def test_non_factory_configobject():
co = ConfigOptions()
with pytest.raises(RedisRuntimeError):
Expand All @@ -46,6 +47,7 @@ def test_non_factory_configobject():
with pytest.raises(RedisRuntimeError):
_ = co.override_string_option("key", "value")


def test_options(monkeypatch):
monkeypatch.setenv("test_integer_key", "42")
monkeypatch.setenv("test_string_key", "charizard")
Expand All @@ -58,8 +60,7 @@ def test_options(monkeypatch):
_ = co.get_integer_option("test_integer_key_that_is_not_really_present")
co.override_integer_option("test_integer_key_that_is_not_really_present", 42)
assert co.is_configured("test_integer_key_that_is_not_really_present")
assert co.get_integer_option(
"test_integer_key_that_is_not_really_present") == 42
assert co.get_integer_option("test_integer_key_that_is_not_really_present") == 42

# string option tests
assert co.get_string_option("test_string_key") == "charizard"
Expand All @@ -68,8 +69,10 @@ def test_options(monkeypatch):
_ = co.get_string_option("test_string_key_that_is_not_really_present")
co.override_string_option("test_string_key_that_is_not_really_present", "meowth")
assert co.is_configured("test_string_key_that_is_not_really_present")
assert co.get_string_option(
"test_string_key_that_is_not_really_present") == "meowth"
assert (
co.get_string_option("test_string_key_that_is_not_really_present") == "meowth"
)


def test_options_with_suffix(monkeypatch):
monkeypatch.setenv("integer_key_suffixtest", "42")
Expand All @@ -92,5 +95,4 @@ def test_options_with_suffix(monkeypatch):
_ = co.get_string_option("string_key_that_is_not_really_present")
co.override_string_option("string_key_that_is_not_really_present", "meowth")
assert co.is_configured("string_key_that_is_not_really_present")
assert co.get_string_option(
"string_key_that_is_not_really_present") == "meowth"
assert co.get_string_option("string_key_that_is_not_really_present") == "meowth"
Loading

0 comments on commit 64a1532

Please sign in to comment.