From bce91cb0b6b821e1b1a579c40f19311e847577b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ole=20Petter=20L=C3=B8d=C3=B8en?= Date: Wed, 14 Jun 2023 09:11:34 +0200 Subject: [PATCH] fix: make apply_condition work for 2D numpy arrays also (#78) --- .../consumer_function/utils.py | 8 ++------ .../core/consumers/test_consumer_utils.py | 20 +++++++++++++++++++ 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/src/ecalc/libraries/libecalc/common/libecalc/core/consumers/legacy_consumer/consumer_function/utils.py b/src/ecalc/libraries/libecalc/common/libecalc/core/consumers/legacy_consumer/consumer_function/utils.py index dfae76a07..d50c64787 100644 --- a/src/ecalc/libraries/libecalc/common/libecalc/core/consumers/legacy_consumer/consumer_function/utils.py +++ b/src/ecalc/libraries/libecalc/common/libecalc/core/consumers/legacy_consumer/consumer_function/utils.py @@ -93,7 +93,7 @@ def apply_condition(input_array: NDArray[np.float64], condition: Optional[NDArra Args: input_array: Array with input values - condition: Array of 1 or 0 describing wether or not conditions are met + condition: Array of 1 or 0 describing whether conditions are met Returns: Returns the input_array where conditions are applied (values set to 0 where condition is 0) @@ -101,11 +101,7 @@ def apply_condition(input_array: NDArray[np.float64], condition: Optional[NDArra if condition is None: return deepcopy(input_array) else: - return ( - np.where(np.any(condition, axis=0), input_array, 0) - if np.ndim(input_array) == 2 - else np.where(condition, input_array, 0) - ) + return np.where(condition, input_array, 0) def get_power_loss_factor_from_expression( diff --git a/src/ecalc/libraries/libecalc/common/tests/core/consumers/test_consumer_utils.py b/src/ecalc/libraries/libecalc/common/tests/core/consumers/test_consumer_utils.py index 6541c2a7a..41baa8b1a 100644 --- a/src/ecalc/libraries/libecalc/common/tests/core/consumers/test_consumer_utils.py +++ b/src/ecalc/libraries/libecalc/common/tests/core/consumers/test_consumer_utils.py @@ -1,4 +1,8 @@ +import numpy as np from libecalc.core.consumers.consumer_system import ConsumerSystem +from libecalc.core.consumers.legacy_consumer.consumer_function.utils import ( + apply_condition, +) def test_topologically_sort_consumers_by_crossover(): @@ -18,3 +22,19 @@ def test_topologically_sort_consumers_by_crossover(): ConsumerSystem._topologically_sort_consumers_by_crossover(crossover=[0, 3, 1], consumers=unsorted_consumers) == sorted_consumers ) + + +def test_apply_condition() -> None: + """Test that apply_condition sets elements in 1D array and columns in 2D array to zero""" + condition = np.asarray([1, 0, 1, 0]) + input_array_1D = np.asarray([10, 10, 10, 10]) + input_array_2D = np.asarray([[10, 10, 10, 10], [10, 10, 10, 10]]) + + input_array_1D_after_condition = apply_condition(input_array=input_array_1D, condition=condition) + input_array_2D_after_condition = apply_condition(input_array=input_array_2D, condition=condition) + + expected_input_array_1D_after_condition = np.asarray([10, 0, 10, 0]) + expected_input_array_2D_after_condition = np.asarray([[10, 0, 10, 0], [10, 0, 10, 0]]) + + assert np.array_equal(input_array_1D_after_condition, expected_input_array_1D_after_condition) + assert np.array_equal(input_array_2D_after_condition, expected_input_array_2D_after_condition)