diff --git a/andes/core/model/modeldata.py b/andes/core/model/modeldata.py index 29b21308f..a66f2cb8d 100644 --- a/andes/core/model/modeldata.py +++ b/andes/core/model/modeldata.py @@ -4,13 +4,13 @@ import logging from collections import OrderedDict -from typing import Iterable, Sized import numpy as np from andes.core.model.modelcache import ModelCache from andes.core.param import (BaseParam, DataParam, IdxParam, NumParam, TimerParam) from andes.shared import pd +from andes.utils.func import validate_keys_values logger = logging.getLogger(__name__) @@ -277,7 +277,7 @@ def find_param(self, prop): return out - def find_idx(self, keys, values, allow_none=False, default=False): + def find_idx(self, keys, values, allow_none=False, default=False, allow_all=False): """ Find `idx` of devices whose values match the given pattern. @@ -288,49 +288,65 @@ def find_idx(self, keys, values, allow_none=False, default=False): values : array, array of arrays, Sized Values for the corresponding key to search for. If keys is a str, values should be an array of elements. If keys is a list, values should be an array of arrays, each corresponds to the key. - allow_none : bool, Sized + allow_none : bool, Sized, optional Allow key, value to be not found. Used by groups. - default : bool + default : bool, optional Default idx to return if not found (missing) + allow_all : bool, optional + If True, returns a list of lists where each nested list contains all the matches for the + corresponding search criteria. Returns ------- list indices of devices - """ - if isinstance(keys, str): - keys = (keys,) - if not isinstance(values, (int, float, str, np.floating)) and not isinstance(values, Iterable): - raise ValueError(f"value must be a string, scalar or an iterable, got {values}") - if len(values) > 0 and not isinstance(values[0], (list, tuple, np.ndarray)): - values = (values,) + Notes + ----- + - Only the first match is returned by default. + - If all matches are needed, set `allow_all` to True. + + Examples + -------- + >>> # Use example case of IEEE 14-bus system with PVD1 + >>> ss = andes.load(andes.get_case('ieee14/ieee14_pvd1.xlsx')) + + >>> # To find the idx of `PVD1` with `name` of 'PVD1_1' and 'PVD1_2' + >>> ss.PVD1.find_idx(keys='name', values=['PVD1_1', 'PVD1_2']) + [1, 2] + + >>> # To find the idx of `PVD1` connected to bus 4 + >>> ss.PVD1.find_idx(keys='bus', values=[4]) + [1] - elif isinstance(keys, Sized): - if not isinstance(values, Iterable): - raise ValueError(f"value must be an iterable, got {values}") + >>> # To find ALL the idx of `PVD1` with `gammap` equals to 0.1 + >>> ss.PVD1.find_idx(keys='gammap', values=[0.1], allow_all=True) + [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]] - if len(values) > 0 and not isinstance(values[0], Iterable): - raise ValueError(f"if keys is an iterable, values must be an iterable of iterables. got {values}") + >>> # To find the idx of `PVD1` with `gammap` equals to 0.1 and `name` of 'PVD1_1' + >>> ss.PVD1.find_idx(keys=['gammap', 'name'], values=[[0.1], ['PVD1_1']]) + [1] + """ - if len(keys) != len(values): - raise ValueError("keys and values must have the same length") + keys, values = validate_keys_values(keys, values) v_attrs = [self.__dict__[key].v for key in keys] idxes = [] for v_search in zip(*values): - v_idx = None + v_idx = [] for pos, v_attr in enumerate(zip(*v_attrs)): if all([i == j for i, j in zip(v_search, v_attr)]): - v_idx = self.idx.v[pos] - break - if v_idx is None: + v_idx.append(self.idx.v[pos]) + if not v_idx: if allow_none is False: raise IndexError(f'{list(keys)}={v_search} not found in {self.class_name}') else: - v_idx = default + v_idx = [default] - idxes.append(v_idx) + if allow_all: + idxes.append(v_idx) + else: + idxes.append(v_idx[0]) return idxes diff --git a/andes/models/group.py b/andes/models/group.py index 44480bac3..0c33a49af 100644 --- a/andes/models/group.py +++ b/andes/models/group.py @@ -5,7 +5,7 @@ import numpy as np from andes.core.service import BackRef -from andes.utils.func import list_flatten +from andes.utils.func import list_flatten, validate_keys_values logger = logging.getLogger(__name__) @@ -243,30 +243,71 @@ def set(self, src: str, idx, attr, value): return True - def find_idx(self, keys, values, allow_none=False, default=None): + def find_idx(self, keys, values, allow_none=False, default=None, allow_all=False): """ Find indices of devices that satisfy the given `key=value` condition. This method iterates over all models in this group. + + Parameters + ---------- + keys : str, array-like, Sized + A string or an array-like of strings containing the names of parameters for the search criteria. + values : array, array of arrays, Sized + Values for the corresponding key to search for. If keys is a str, values should be an array of + elements. If keys is a list, values should be an array of arrays, each corresponding to the key. + allow_none : bool, optional + Allow key, value to be not found. Used by groups. Default is False. + default : bool, optional + Default idx to return if not found (missing). Default is None. + allow_all : bool, optional + Return all matches if set to True. Default is False. + + Returns + ------- + list + Indices of devices. """ + + keys, values = validate_keys_values(keys, values) + + n_mdl, n_pair = len(self.models), len(values[0]) + indices_found = [] # `indices_found` contains found indices returned from all models of this group for model in self.models.values(): - indices_found.append(model.find_idx(keys, values, allow_none=True, default=default)) - - out = [] - for idx, idx_found in enumerate(zip(*indices_found)): - if not allow_none: - if idx_found.count(None) == len(idx_found): - missing_values = [item[idx] for item in values] - raise IndexError(f'{list(keys)} = {missing_values} not found in {self.class_name}') - - real_idx = default - for item in idx_found: - if item is not None: - real_idx = item + indices_found.append(model.find_idx(keys, values, allow_none=True, default=default, allow_all=True)) + + # --- find missing pairs --- + i_val_miss = [] + for i in range(n_pair): + idx_cross_mdls = [indices_found[j][i] for j in range(n_mdl)] + if all(item == [default] for item in idx_cross_mdls): + i_val_miss.append(i) + + if (not allow_none) and i_val_miss: + miss_pairs = [] + for i in i_val_miss: + miss_pairs.append([values[j][i] for j in range(len(keys))]) + raise IndexError(f'{keys} = {miss_pairs} not found in {self.class_name}') + + # --- output --- + out_pre = [] + for i in range(n_pair): + idx_cross_mdls = [indices_found[j][i] for j in range(n_mdl)] + if all(item == [default] for item in idx_cross_mdls): + out_pre.append([default]) + continue + for item in idx_cross_mdls: + if item != [default]: + out_pre.append(item) break - out.append(real_idx) + + if allow_all: + out = out_pre + else: + out = [item[0] for item in out_pre] + return out def _check_src(self, src: str): diff --git a/andes/models/misc/output.py b/andes/models/misc/output.py index 1f01f3c92..8dfb0b4bc 100644 --- a/andes/models/misc/output.py +++ b/andes/models/misc/output.py @@ -50,9 +50,9 @@ def in1d(self, addr, v_code): """ if v_code == 'x': - return np.in1d(self.xidx, addr) + return np.isin(self.xidx, addr) if v_code == 'y': - return np.in1d(self.yidx, addr) + return np.isin(self.yidx, addr) raise NotImplementedError("v_code <%s> not recognized" % v_code) diff --git a/andes/utils/func.py b/andes/utils/func.py index d9e615a36..d0b172456 100644 --- a/andes/utils/func.py +++ b/andes/utils/func.py @@ -1,5 +1,6 @@ import functools import operator +from typing import Iterable, Sized from andes.shared import np @@ -36,3 +37,50 @@ def interp_n2(t, x, y): """ return y[:, 0] + (t - x[0]) * (y[:, 1] - y[:, 0]) / (x[1] - x[0]) + + +def validate_keys_values(keys, values): + """ + Validate the inputs for the func `find_idx`. + + Parameters + ---------- + keys : str, array-like, Sized + A string or an array-like of strings containing the names of parameters for the search criteria. + values : array, array of arrays, Sized + Values for the corresponding key to search for. If keys is a str, values should be an array of + elements. If keys is a list, values should be an array of arrays, each corresponds to the key. + + Returns + ------- + tuple + Sanitized keys and values + + Raises + ------ + ValueError + If the inputs are not valid. + """ + if isinstance(keys, str): + keys = (keys,) + if not isinstance(values, (int, float, str, np.floating)) and not isinstance(values, Iterable): + raise ValueError(f"value must be a string, scalar or an iterable, got {values}") + + if len(values) > 0 and not isinstance(values[0], (list, tuple, np.ndarray)): + values = (values,) + + elif isinstance(keys, Sized): + if not isinstance(values, Iterable): + raise ValueError(f"value must be an iterable, got {values}") + + if len(values) > 0 and not isinstance(values[0], Iterable): + raise ValueError(f"if keys is an iterable, values must be an iterable of iterables. got {values}") + + if len(keys) != len(values): + raise ValueError("keys and values must have the same length") + + if isinstance(values[0], Iterable): + if not all([len(val) == len(values[0]) for val in values]): + raise ValueError("All items in values must have the same length") + + return keys, values diff --git a/docs/source/release-notes.rst b/docs/source/release-notes.rst index 9b053e084..dfef9bd87 100644 --- a/docs/source/release-notes.rst +++ b/docs/source/release-notes.rst @@ -19,6 +19,7 @@ v1.9.3 (2024-04-XX) - Adjust `BusFreq.Tw.default` to 0.1. - Add parameter from_csv=None in TDS.run() to allow loading data from CSV files at TDS begining. - Fix `TDS.init()` and `TDS._csv_step()` to fit loading from CSV when `Output` exists. +- Add parameter `allow_all=False` to `ModelData.find_idx()` `GroupBase.find_idx()` to allow searching all matches. v1.9.2 (2024-03-25) ------------------- diff --git a/tests/test_group.py b/tests/test_group.py index 9f7a21bad..dcb9e97fb 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -71,6 +71,7 @@ def test_group_access(self): [6, 7, 8, 1]) # --- find_idx --- + # same Model self.assertListEqual(ss.DG.find_idx('name', ['PVD1_1', 'PVD1_2']), ss.PVD1.find_idx('name', ['PVD1_1', 'PVD1_2']), ) @@ -82,6 +83,22 @@ def test_group_access(self): [('PVD1_1', 'PVD1_2'), (1.0, 1.0)])) + # cross Model, given results + self.assertListEqual(ss.StaticGen.find_idx(keys='bus', + values=[1, 2, 3, 4]), + [1, 2, 3, 6]) + self.assertListEqual(ss.StaticGen.find_idx(keys='bus', + values=[1, 2, 3, 4], + allow_all=True), + [[1], [2], [3], [6]]) + + self.assertListEqual(ss.StaticGen.find_idx(keys='bus', + values=[1, 2, 3, 4, 2024], + allow_none=True, + default=2011, + allow_all=True), + [[1], [2], [3], [6], [2011]]) + # --- get_field --- ff = ss.DG.get_field('f', list(ss.DG._idx2model.keys()), 'v_code') self.assertTrue(any([item == 'y' for item in ff])) diff --git a/tests/test_model_set.py b/tests/test_model_set.py index e3ef8b5ed..df29b54d6 100644 --- a/tests/test_model_set.py +++ b/tests/test_model_set.py @@ -54,3 +54,56 @@ def test_model_set(self): ss.GENROU.set("M", np.array(["GENROU_4"]), "v", 6.0) np.testing.assert_equal(ss.GENROU.M.v[3], 6.0) self.assertEqual(ss.TDS.Teye[omega_addr[3], omega_addr[3]], 6.0) + + def test_find_idx(self): + ss = andes.load(andes.get_case('ieee14/ieee14_pvd1.xlsx')) + mdl = ss.PVD1 + + # not allow all matches + self.assertListEqual(mdl.find_idx(keys='gammap', values=[0.1], allow_all=False), + [1]) + + # allow all matches + self.assertListEqual(mdl.find_idx(keys='gammap', values=[0.1], allow_all=True), + [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]) + + # multiple values + self.assertListEqual(mdl.find_idx(keys='name', values=['PVD1_1', 'PVD1_2'], + allow_none=False, default=False), + [1, 2]) + # non-existing value + self.assertListEqual(mdl.find_idx(keys='name', values=['PVD1_999'], + allow_none=True, default=False), + [False]) + + # non-existing value is not allowed + with self.assertRaises(IndexError): + mdl.find_idx(keys='name', values=['PVD1_999'], + allow_none=False, default=False) + + # multiple keys + self.assertListEqual(mdl.find_idx(keys=['gammap', 'name'], + values=[[0.1, 0.1], ['PVD1_1', 'PVD1_2']]), + [1, 2]) + + # multiple keys, with non-existing values + self.assertListEqual(mdl.find_idx(keys=['gammap', 'name'], + values=[[0.1, 0.1], ['PVD1_1', 'PVD1_999']], + allow_none=True, default='CURENT'), + [1, 'CURENT']) + + # multiple keys, with non-existing values not allowed + with self.assertRaises(IndexError): + mdl.find_idx(keys=['gammap', 'name'], + values=[[0.1, 0.1], ['PVD1_1', 'PVD1_999']], + allow_none=False, default=999) + + # multiple keys, values are not iterable + with self.assertRaises(ValueError): + mdl.find_idx(keys=['gammap', 'name'], + values=[0.1, 0.1]) + + # multiple keys, items length are inconsistent in values + with self.assertRaises(ValueError): + mdl.find_idx(keys=['gammap', 'name'], + values=[[0.1, 0.1], ['PVD1_1']])