diff --git a/shublang/shublang.py b/shublang/shublang.py index 330670c..481db71 100644 --- a/shublang/shublang.py +++ b/shublang/shublang.py @@ -130,7 +130,7 @@ def decode(iterable, encoding): @Pipe def find(iterable, sub, start=None, end=None): """Returns the lowest index in the string where the sub is found. - If specified, the start and end params serve to slice the string + If specified, the start and end params serve to slice the string where sub should be searched. :param iterable: collection of data to transform @@ -332,6 +332,11 @@ def extract_currency(iterable): def urljoin(iterable, base): return (parse.urljoin(base, url) for url in iterable) +@Pipe +def identity(iterable, element): + """ Return the same element is passed as parameter.""" + return (element) + filter = where map = select diff --git a/tests/test_functions.py b/tests/test_functions.py index d18252a..6b2795e 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -400,3 +400,27 @@ def test_join(test_input, expected): ) def test_urljoin(test_input, expected): assert evaluate(*test_input) == expected + +@pytest.mark.parametrize( + "test_input,expected", + [ + ( + ['identity(True)', + [10, 'far', ['boo', 3]]], + True, + ), + ( + ['identity("InStock")', + ["In Stock.", "Only 3 in Stock", "Stock Ok"]], + "InStock", + ), + ( + ['identity((1,2,3,4,5))', + "foo"], + (1,2,3,4,5), + ), + ] +) + +def test_identity(test_input, expected): + assert evaluate(*test_input) == expected