Skip to content

Commit

Permalink
Merge branch 'refactor-osiris-deserializer'
Browse files Browse the repository at this point in the history
  • Loading branch information
raphaelDkhn committed Mar 18, 2024
2 parents f5f268a + 2664a69 commit 2125631
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 23 deletions.
71 changes: 54 additions & 17 deletions osiris/cairo/serde/deserialize.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import re

import numpy as np

from osiris.cairo.serde.utils import felt_to_int, from_fp
Expand All @@ -17,6 +19,9 @@ def deserializer(serialized, dtype):
elif dtype.startswith('Tensor<'):
return deserialize_tensor(serialized, dtype)

elif dtype.startswith('MutMatrix<'):
return deserialize_matrix(serialized, dtype)

elif dtype.startswith('('): # Tuple
return deserialize_tuple(serialized, dtype)

Expand All @@ -27,7 +32,7 @@ def deserializer(serialized, dtype):
def deserialize_fp(serialized):
parts = serialized.split()
value = from_fp(int(parts[0]))
if len(parts) > 1 and parts[1] == '1': # Check for negative sign
if len(parts) > 1 and parts[1] == 'true': # Check for negative sign
value = -value
return value

Expand Down Expand Up @@ -60,24 +65,56 @@ def deserialize_tensor(serialized, dtype):

def deserialize_tuple(serialized, dtype):
types = dtype[1:-1].split(', ')
if 'Tensor' in types[0]:
tensor_end = find_nth_occurrence(serialized, ']', 2)
depth = 1
for i in range(tensor_end, len(serialized)):
if serialized[i] == '[':
depth += 1
elif serialized[i] == ']':
depth -= 1
if depth == 0:
tensor_end = i + 1
break
part1 = deserializer(serialized[:tensor_end].strip(), types[0])
part2 = deserializer(serialized[tensor_end:].strip(), types[1])
else:
split_index = serialized.find(']') + 2
# Check if there is no space between span and matrix.
is_no_space = re.search(r']\{', serialized)
if is_no_space:
split_index = is_no_space.start() + 1
part1 = deserializer(serialized[:split_index].strip(), types[0])
part2 = deserializer(serialized[split_index:].strip(), types[1])
return part1, part2
return part1, part2
else:
if 'Tensor' in types[0]:
tensor_end = find_nth_occurrence(serialized, ']', 2)
depth = 1
for i in range(tensor_end, len(serialized)):
if serialized[i] == '[':
depth += 1
elif serialized[i] == ']':
depth -= 1
if depth == 0:
tensor_end = i + 1
break
part1 = deserializer(serialized[:tensor_end].strip(), types[0])
part2 = deserializer(serialized[tensor_end:].strip(), types[1])
else:
split_index = serialized.find(']') + 2
part1 = deserializer(serialized[:split_index].strip(), types[0])
part2 = deserializer(serialized[split_index:].strip(), types[1])
return part1, part2


def deserialize_matrix(serialized, dtype):

# Extract inner dtype
pattern = r"<(.*)>"
inner_dtype = re.search(pattern, dtype).group(1)

# Extract the matrix content and shape from the serialized string
content, shape_str = serialized.split("} ")
# Last two numbers are the shape
shape = tuple(map(int, shape_str.split()[-2:]))

# Use regex to find all occurrences of ': ' followed by any characters until the next ' :' or end of string
pattern = r': (.*?)(?=\s\d+: |$)'
elements = re.findall(pattern, content)

# Deserialize each element using the appropriate deserializer based on dtype
deserialized_elements = [deserializer(
element, inner_dtype) for element in elements]

# Reshape the deserialized elements into a numpy array of the specified shape
matrix = np.array(deserialized_elements).reshape(shape)
return matrix


def find_nth_occurrence(string, sub_string, n):
Expand Down
24 changes: 18 additions & 6 deletions tests/test_deserialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ def test_deserialize_int():


def test_deserialize_fp():
serialized = '2780037 0'
serialized = '2780037 false'
deserialized = deserializer(serialized, 'FP16x16')
assert isclose(deserialized, 42.42, rel_tol=1e-7)

serialized = '2780037 1'
serialized = '2780037 true'
deserialized = deserializer(serialized, 'FP16x16')
assert isclose(deserialized, -42.42, rel_tol=1e-7)

Expand All @@ -36,7 +36,7 @@ def test_deserialize_array_int():


def test_deserialize_arr_fixed_point():
serialized = '[2780037 0 2780037 1]'
serialized = '[2780037 false 2780037 true]'
deserialized = deserializer(serialized, 'Span<FP16x16>')
expected = np.array([42.42, -42.42], dtype=np.float64)
assert np.all(np.isclose(deserialized, expected, atol=1e-7))
Expand All @@ -54,11 +54,16 @@ def test_deserialize_tensor_int():


def test_deserialize_tensor_fixed_point():
serialized = '[2 2] [2780037 0 2780037 0 2780037 1 2780037 1]'
serialized = '[2 2] [2780037 false 2780037 false 2780037 true 2780037 true]'
expected_array = np.array([[42.42, 42.42], [-42.42, -42.42]])
deserialized = deserializer(serialized, 'Tensor<FP16x16>')
assert np.allclose(deserialized, expected_array, atol=1e-7)

def test_deserialize_matrix_fixed_point():
serialized = "{0: 2780037 false 2: 2780037 false 1: 2780037 true 3: 2780037 true} 4 2 2"
expected_array = np.array([[42.42, 42.42], [-42.42, -42.42]])
deserialized = deserializer(serialized, 'MutMatrix<FP16x16>')
assert np.allclose(deserialized, expected_array, atol=1e-7)

def test_deserialize_tuple_int():
serialized = '1 3'
Expand All @@ -75,14 +80,21 @@ def test_deserialize_tuple_span():


def test_deserialize_tuple_span_tensor_fp():
serialized = '[1 2] [2 2] [2780037 0 2780037 0 2780037 1 2780037 1]'
serialized = '[1 2] [2 2] [2780037 false 2780037 false 2780037 true 2780037 true]'
deserialized = deserializer(serialized, '(Span<u32>, Tensor<FP16x16>)')
expected = (np.array([1, 2]), np.array([[42.42, 42.42], [-42.42, -42.42]]))
npt.assert_array_equal(deserialized[0], expected[0])
assert np.allclose(deserialized[1], expected[1], atol=1e-7)

serialized = '[2 2] [2780037 0 2780037 0 2780037 1 2780037 1] [1 2]'
serialized = '[2 2] [2780037 false 2780037 false 2780037 true 2780037 true] [1 2]'
deserialized = deserializer(serialized, '(Tensor<FP16x16>, Span<u32>)')
expected = (np.array([[42.42, 42.42], [-42.42, -42.42]]), np.array([1, 2]))
assert np.allclose(deserialized[0], expected[0], atol=1e-7)
npt.assert_array_equal(deserialized[1], expected[1])

def test_deserialize_tuple_matrix_fp():
serialized = '[1 2]{0: 2780037 false 2: 2780037 false 1: 2780037 true 3: 2780037 true} 4 2 2'
deserialized = deserializer(serialized, '(Span<u32>, MutMatrix<FP16x16>)')
expected = (np.array([1, 2]), np.array([[42.42, 42.42], [-42.42, -42.42]]))
npt.assert_array_equal(deserialized[0], expected[0])
assert np.allclose(deserialized[1], expected[1], atol=1e-7)

0 comments on commit 2125631

Please sign in to comment.