Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deserialize matrix #10

Merged
merged 4 commits into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 54 additions & 17 deletions osiris/cairo/serde/deserialize.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import re

import numpy as np

from osiris.cairo.serde.utils import felt_to_int, from_fp
Expand All @@ -17,6 +19,9 @@ def deserializer(serialized, dtype):
elif dtype.startswith('Tensor<'):
return deserialize_tensor(serialized, dtype)

elif dtype.startswith('MutMatrix<'):
return deserialize_matrix(serialized, dtype)

elif dtype.startswith('('): # Tuple
return deserialize_tuple(serialized, dtype)

Expand All @@ -27,7 +32,7 @@ def deserializer(serialized, dtype):
def deserialize_fp(serialized):
parts = serialized.split()
value = from_fp(int(parts[0]))
if len(parts) > 1 and parts[1] == '1': # Check for negative sign
if len(parts) > 1 and parts[1] == 'true': # Check for negative sign
value = -value
return value

Expand Down Expand Up @@ -60,24 +65,56 @@ def deserialize_tensor(serialized, dtype):

def deserialize_tuple(serialized, dtype):
types = dtype[1:-1].split(', ')
if 'Tensor' in types[0]:
tensor_end = find_nth_occurrence(serialized, ']', 2)
depth = 1
for i in range(tensor_end, len(serialized)):
if serialized[i] == '[':
depth += 1
elif serialized[i] == ']':
depth -= 1
if depth == 0:
tensor_end = i + 1
break
part1 = deserializer(serialized[:tensor_end].strip(), types[0])
part2 = deserializer(serialized[tensor_end:].strip(), types[1])
else:
split_index = serialized.find(']') + 2
# Check if there is no space between span and matrix.
is_no_space = re.search(r']\{', serialized)
if is_no_space:
split_index = is_no_space.start() + 1
part1 = deserializer(serialized[:split_index].strip(), types[0])
part2 = deserializer(serialized[split_index:].strip(), types[1])
return part1, part2
return part1, part2
else:
if 'Tensor' in types[0]:
tensor_end = find_nth_occurrence(serialized, ']', 2)
depth = 1
for i in range(tensor_end, len(serialized)):
if serialized[i] == '[':
depth += 1
elif serialized[i] == ']':
depth -= 1
if depth == 0:
tensor_end = i + 1
break
part1 = deserializer(serialized[:tensor_end].strip(), types[0])
part2 = deserializer(serialized[tensor_end:].strip(), types[1])
else:
split_index = serialized.find(']') + 2
part1 = deserializer(serialized[:split_index].strip(), types[0])
part2 = deserializer(serialized[split_index:].strip(), types[1])
return part1, part2


def deserialize_matrix(serialized, dtype):

# Extract inner dtype
pattern = r"<(.*)>"
inner_dtype = re.search(pattern, dtype).group(1)

# Extract the matrix content and shape from the serialized string
content, shape_str = serialized.split("} ")
# Last two numbers are the shape
shape = tuple(map(int, shape_str.split()[-2:]))

# Use regex to find all occurrences of ': ' followed by any characters until the next ' :' or end of string
pattern = r': (.*?)(?=\s\d+: |$)'
elements = re.findall(pattern, content)

# Deserialize each element using the appropriate deserializer based on dtype
deserialized_elements = [deserializer(
element, inner_dtype) for element in elements]

# Reshape the deserialized elements into a numpy array of the specified shape
matrix = np.array(deserialized_elements).reshape(shape)
return matrix


def find_nth_occurrence(string, sub_string, n):
Expand Down
24 changes: 18 additions & 6 deletions tests/test_deserialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ def test_deserialize_int():


def test_deserialize_fp():
serialized = '2780037 0'
serialized = '2780037 false'
deserialized = deserializer(serialized, 'FP16x16')
assert isclose(deserialized, 42.42, rel_tol=1e-7)

serialized = '2780037 1'
serialized = '2780037 true'
deserialized = deserializer(serialized, 'FP16x16')
assert isclose(deserialized, -42.42, rel_tol=1e-7)

Expand All @@ -36,7 +36,7 @@ def test_deserialize_array_int():


def test_deserialize_arr_fixed_point():
serialized = '[2780037 0 2780037 1]'
serialized = '[2780037 false 2780037 true]'
deserialized = deserializer(serialized, 'Span<FP16x16>')
expected = np.array([42.42, -42.42], dtype=np.float64)
assert np.all(np.isclose(deserialized, expected, atol=1e-7))
Expand All @@ -54,11 +54,16 @@ def test_deserialize_tensor_int():


def test_deserialize_tensor_fixed_point():
serialized = '[2 2] [2780037 0 2780037 0 2780037 1 2780037 1]'
serialized = '[2 2] [2780037 false 2780037 false 2780037 true 2780037 true]'
expected_array = np.array([[42.42, 42.42], [-42.42, -42.42]])
deserialized = deserializer(serialized, 'Tensor<FP16x16>')
assert np.allclose(deserialized, expected_array, atol=1e-7)

def test_deserialize_matrix_fixed_point():
serialized = "{0: 2780037 false 2: 2780037 false 1: 2780037 true 3: 2780037 true} 4 2 2"
expected_array = np.array([[42.42, 42.42], [-42.42, -42.42]])
deserialized = deserializer(serialized, 'MutMatrix<FP16x16>')
assert np.allclose(deserialized, expected_array, atol=1e-7)

def test_deserialize_tuple_int():
serialized = '1 3'
Expand All @@ -75,14 +80,21 @@ def test_deserialize_tuple_span():


def test_deserialize_tuple_span_tensor_fp():
serialized = '[1 2] [2 2] [2780037 0 2780037 0 2780037 1 2780037 1]'
serialized = '[1 2] [2 2] [2780037 false 2780037 false 2780037 true 2780037 true]'
deserialized = deserializer(serialized, '(Span<u32>, Tensor<FP16x16>)')
expected = (np.array([1, 2]), np.array([[42.42, 42.42], [-42.42, -42.42]]))
npt.assert_array_equal(deserialized[0], expected[0])
assert np.allclose(deserialized[1], expected[1], atol=1e-7)

serialized = '[2 2] [2780037 0 2780037 0 2780037 1 2780037 1] [1 2]'
serialized = '[2 2] [2780037 false 2780037 false 2780037 true 2780037 true] [1 2]'
deserialized = deserializer(serialized, '(Tensor<FP16x16>, Span<u32>)')
expected = (np.array([[42.42, 42.42], [-42.42, -42.42]]), np.array([1, 2]))
assert np.allclose(deserialized[0], expected[0], atol=1e-7)
npt.assert_array_equal(deserialized[1], expected[1])

def test_deserialize_tuple_matrix_fp():
serialized = '[1 2]{0: 2780037 false 2: 2780037 false 1: 2780037 true 3: 2780037 true} 4 2 2'
deserialized = deserializer(serialized, '(Span<u32>, MutMatrix<FP16x16>)')
expected = (np.array([1, 2]), np.array([[42.42, 42.42], [-42.42, -42.42]]))
npt.assert_array_equal(deserialized[0], expected[0])
assert np.allclose(deserialized[1], expected[1], atol=1e-7)
Loading