diff --git a/osiris/cairo/serde/deserialize.py b/osiris/cairo/serde/deserialize.py index f7158e7..ec7cd41 100644 --- a/osiris/cairo/serde/deserialize.py +++ b/osiris/cairo/serde/deserialize.py @@ -1,3 +1,5 @@ +import re + import numpy as np from osiris.cairo.serde.utils import felt_to_int, from_fp @@ -17,6 +19,9 @@ def deserializer(serialized, dtype): elif dtype.startswith('Tensor<'): return deserialize_tensor(serialized, dtype) + elif dtype.startswith('MutMatrix<'): + return deserialize_matrix(serialized, dtype) + elif dtype.startswith('('): # Tuple return deserialize_tuple(serialized, dtype) @@ -27,7 +32,7 @@ def deserializer(serialized, dtype): def deserialize_fp(serialized): parts = serialized.split() value = from_fp(int(parts[0])) - if len(parts) > 1 and parts[1] == '1': # Check for negative sign + if len(parts) > 1 and parts[1] == 'true': # Check for negative sign value = -value return value @@ -60,24 +65,56 @@ def deserialize_tensor(serialized, dtype): def deserialize_tuple(serialized, dtype): types = dtype[1:-1].split(', ') - if 'Tensor' in types[0]: - tensor_end = find_nth_occurrence(serialized, ']', 2) - depth = 1 - for i in range(tensor_end, len(serialized)): - if serialized[i] == '[': - depth += 1 - elif serialized[i] == ']': - depth -= 1 - if depth == 0: - tensor_end = i + 1 - break - part1 = deserializer(serialized[:tensor_end].strip(), types[0]) - part2 = deserializer(serialized[tensor_end:].strip(), types[1]) - else: - split_index = serialized.find(']') + 2 + # Check if there is no space between span and matrix. + is_no_space = re.search(r']\{', serialized) + if is_no_space: + split_index = is_no_space.start() + 1 part1 = deserializer(serialized[:split_index].strip(), types[0]) part2 = deserializer(serialized[split_index:].strip(), types[1]) - return part1, part2 + return part1, part2 + else: + if 'Tensor' in types[0]: + tensor_end = find_nth_occurrence(serialized, ']', 2) + depth = 1 + for i in range(tensor_end, len(serialized)): + if serialized[i] == '[': + depth += 1 + elif serialized[i] == ']': + depth -= 1 + if depth == 0: + tensor_end = i + 1 + break + part1 = deserializer(serialized[:tensor_end].strip(), types[0]) + part2 = deserializer(serialized[tensor_end:].strip(), types[1]) + else: + split_index = serialized.find(']') + 2 + part1 = deserializer(serialized[:split_index].strip(), types[0]) + part2 = deserializer(serialized[split_index:].strip(), types[1]) + return part1, part2 + + +def deserialize_matrix(serialized, dtype): + + # Extract inner dtype + pattern = r"<(.*)>" + inner_dtype = re.search(pattern, dtype).group(1) + + # Extract the matrix content and shape from the serialized string + content, shape_str = serialized.split("} ") + # Last two numbers are the shape + shape = tuple(map(int, shape_str.split()[-2:])) + + # Use regex to find all occurrences of ': ' followed by any characters until the next ' :' or end of string + pattern = r': (.*?)(?=\s\d+: |$)' + elements = re.findall(pattern, content) + + # Deserialize each element using the appropriate deserializer based on dtype + deserialized_elements = [deserializer( + element, inner_dtype) for element in elements] + + # Reshape the deserialized elements into a numpy array of the specified shape + matrix = np.array(deserialized_elements).reshape(shape) + return matrix def find_nth_occurrence(string, sub_string, n): diff --git a/tests/test_deserialize.py b/tests/test_deserialize.py index e78502e..ae53c45 100644 --- a/tests/test_deserialize.py +++ b/tests/test_deserialize.py @@ -16,11 +16,11 @@ def test_deserialize_int(): def test_deserialize_fp(): - serialized = '2780037 0' + serialized = '2780037 false' deserialized = deserializer(serialized, 'FP16x16') assert isclose(deserialized, 42.42, rel_tol=1e-7) - serialized = '2780037 1' + serialized = '2780037 true' deserialized = deserializer(serialized, 'FP16x16') assert isclose(deserialized, -42.42, rel_tol=1e-7) @@ -36,7 +36,7 @@ def test_deserialize_array_int(): def test_deserialize_arr_fixed_point(): - serialized = '[2780037 0 2780037 1]' + serialized = '[2780037 false 2780037 true]' deserialized = deserializer(serialized, 'Span') expected = np.array([42.42, -42.42], dtype=np.float64) assert np.all(np.isclose(deserialized, expected, atol=1e-7)) @@ -54,11 +54,16 @@ def test_deserialize_tensor_int(): def test_deserialize_tensor_fixed_point(): - serialized = '[2 2] [2780037 0 2780037 0 2780037 1 2780037 1]' + serialized = '[2 2] [2780037 false 2780037 false 2780037 true 2780037 true]' expected_array = np.array([[42.42, 42.42], [-42.42, -42.42]]) deserialized = deserializer(serialized, 'Tensor') assert np.allclose(deserialized, expected_array, atol=1e-7) +def test_deserialize_matrix_fixed_point(): + serialized = "{0: 2780037 false 2: 2780037 false 1: 2780037 true 3: 2780037 true} 4 2 2" + expected_array = np.array([[42.42, 42.42], [-42.42, -42.42]]) + deserialized = deserializer(serialized, 'MutMatrix') + assert np.allclose(deserialized, expected_array, atol=1e-7) def test_deserialize_tuple_int(): serialized = '1 3' @@ -75,14 +80,21 @@ def test_deserialize_tuple_span(): def test_deserialize_tuple_span_tensor_fp(): - serialized = '[1 2] [2 2] [2780037 0 2780037 0 2780037 1 2780037 1]' + serialized = '[1 2] [2 2] [2780037 false 2780037 false 2780037 true 2780037 true]' deserialized = deserializer(serialized, '(Span, Tensor)') expected = (np.array([1, 2]), np.array([[42.42, 42.42], [-42.42, -42.42]])) npt.assert_array_equal(deserialized[0], expected[0]) assert np.allclose(deserialized[1], expected[1], atol=1e-7) - serialized = '[2 2] [2780037 0 2780037 0 2780037 1 2780037 1] [1 2]' + serialized = '[2 2] [2780037 false 2780037 false 2780037 true 2780037 true] [1 2]' deserialized = deserializer(serialized, '(Tensor, Span)') expected = (np.array([[42.42, 42.42], [-42.42, -42.42]]), np.array([1, 2])) assert np.allclose(deserialized[0], expected[0], atol=1e-7) npt.assert_array_equal(deserialized[1], expected[1]) + +def test_deserialize_tuple_matrix_fp(): + serialized = '[1 2]{0: 2780037 false 2: 2780037 false 1: 2780037 true 3: 2780037 true} 4 2 2' + deserialized = deserializer(serialized, '(Span, MutMatrix)') + expected = (np.array([1, 2]), np.array([[42.42, 42.42], [-42.42, -42.42]])) + npt.assert_array_equal(deserialized[0], expected[0]) + assert np.allclose(deserialized[1], expected[1], atol=1e-7) \ No newline at end of file