Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Deserializer + Support Tuple #8

Merged
merged 4 commits into from
Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 61 additions & 127 deletions osiris/cairo/serde/deserialize.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,73 @@
import json
import re

import numpy as np

from .utils import felt_to_int, from_fp


def deserializer(serialized: str, data_type: str, fp_impl='FP16x16'):
"""
Main deserialization function that handles various data types.

:param serialized: The serialized list of data.
:param data_type: The type of data to deserialize ('uint', 'signed_int', 'fixed_point', etc.).
:param fp_impl: The implementation detail, used for fixed-point deserialization.
:return: The deserialized data.
"""
def deserializer(serialized: str, dtype: str):
# Check if the serialized data is a string and needs conversion
if isinstance(serialized, str):
serialized = convert_data(serialized)

# Function to deserialize individual elements within a tuple
def deserialize_element(element, element_type):
if element_type in ("u8", "u16", "u32", "u64", "u128", "i8", "i16", "i32", "i64", "i128"):
return deserialize_int(element)
elif element_type.startswith("FP"):
return deserialize_fixed_point(element, element_type)
elif element_type.startswith("Span<") and element_type.endswith(">"):
inner_type = element_type[5:-1]
if inner_type.startswith("FP"):
return deserialize_arr_fixed_point(element, inner_type)
else:
return deserialize_arr_int(element)
elif element_type.startswith("Tensor<") and element_type.endswith(">"):
inner_type = element_type[7:-1]
if inner_type.startswith("FP"):
return deserialize_tensor_fixed_point(element, inner_type)
else:
return deserialize_tensor_int(element)
elif element_type.startswith("(") and element_type.endswith(")"):
# Recursive call for nested tuples
return deserializer(element, element_type)
else:
raise ValueError(f"Unsupported data type: {element_type}")

# Handle tuple data type
if dtype.startswith("(") and dtype.endswith(")"):
types = dtype[1:-1].split(", ")
deserialized_elements = []
i = 0 # Initialize loop counter

while i < len(serialized):
ele_type = types[len(deserialized_elements)]

if ele_type.startswith("Tensor<"):
# For Tensors, take two elements from serialized (shape and data)
ele = serialized[i:i+2]
i += 2
else:
# For other types, take one element
ele = serialized[i]
i += 1

if ele_type.startswith("Tensor<"):
deserialized_elements.append(
deserialize_element(ele, ele_type))
else:
deserialized_elements.append(
deserialize_element([ele], ele_type))

if len(deserialized_elements) != len(types):
raise ValueError(
"Serialized data length does not match tuple length")

return tuple(deserialized_elements)

serialized = convert_data(serialized)

if data_type == 'int':
return deserialize_int(serialized)
elif data_type == 'fixed_point':
return deserialize_fixed_point(serialized, fp_impl)
elif data_type == 'arr_int':
return deserialize_arr_int(serialized)
elif data_type == 'arr_fixed_point':
return deserialize_arr_fixed_point(serialized, fp_impl)
elif data_type == 'tensor_int':
return deserialize_tensor_int(serialized)
elif data_type == 'tensor_fixed_point':
return deserialize_tensor_fixed_point(serialized)
# TODO: Support Tuples
# elif data_type == 'tensor_fixed_point':
# return deserialize_tensor_fixed_point(serialized, fp_impl)
# elif data_type == 'tuple_uint':
# return deserialize_tuple_uint(serialized)
# elif data_type == 'tuple_signed_int':
# return deserialize_tuple_signed_int(serialized)
# elif data_type == 'tuple_fixed_point':
# return deserialize_tuple_fixed_point(serialized, fp_impl)
# elif data_type == 'tuple_tensor_uint':
# return deserialize_tuple_tensor_uint(serialized)
# elif data_type == 'tuple_tensor_signed_int':
# return deserialize_tuple_tensor_signed_int(serialized)
# elif data_type == 'tuple_tensor_fixed_point':
# return deserialize_tuple_tensor_fixed_point(serialized, fp_impl)
else:
raise ValueError(f"Unknown data type: {data_type}")
return deserialize_element(serialized, dtype)


def parse_return_value(return_value):
Expand Down Expand Up @@ -149,91 +171,3 @@ def deserialize_tensor_fixed_point(serialized: list, impl='FP16x16') -> np.array
data = deserialize_arr_fixed_point([serialized[1]], impl)

return np.array(data, dtype=np.float64).reshape(shape)


# ================= TUPLE UINT =================


# def deserialize_tuple_uint(serialized: list):
# return np.array(serialized[0], dtype=np.int64)


# # ================= TUPLE SIGNED INT =================


# def deserialize_tuple_signed_int(serialized: list):
# num_ele = (len(serialized)) // 2

# deserialized_array = np.empty(num_ele, dtype=np.int64)

# for i in range(num_ele):
# deserialized_array[i] = deserialize_signed_int(
# serialized[i*2: 3 + i*2])

# return deserialized_array

# # ================= TUPLE FIXED POINT =================


# def deserialize_tuple_fixed_point(serialized: list, impl='FP16x16'):
# num_ele = (len(serialized)) // 2

# deserialized_array = np.empty(num_ele, dtype=np.float64)

# for i in range(num_ele):
# deserialized_array[i] = deserialize_fixed_point(
# serialized[i*2: 3 + i*2], impl)

# return deserialized_array


# # ================= TUPLE TENSOR UINT =================

# def deserialize_tuple_tensor_uint(serialized: list):
# return deserialize_tuple_tensor(serialized, deserialize_arr_uint)

# # ================= TUPLE TENSOR SIGNED INT =================


# def deserialize_tuple_tensor_signed_int(serialized: list):
# return deserialize_tuple_tensor(serialized, deserialize_arr_signed_int)

# # ================= TUPLE TENSOR FIXED POINT =================


# def deserialize_tuple_tensor_fixed_point(serialized: list, impl='FP16x16'):
# return deserialize_tuple_tensor(serialized, deserialize_arr_fixed_point, impl)


# # ================= HELPERS =================


# def extract_shape(serialized, start_index):
# """ Extracts the shape part of a tensor from a serialized list. """
# num_shape_elements = serialized[start_index]
# shape = serialized[start_index + 1: start_index + 1 + num_shape_elements]
# return shape, start_index + 1 + num_shape_elements


# def extract_data(serialized, start_index, deserialization_func, impl=None):
# """ Extracts and deserializes the data part of a tensor from a serialized list. """
# num_data_elements = serialized[start_index]
# end_index = start_index + 1 + num_data_elements
# data_serialized = serialized[start_index: end_index]
# if impl:
# data = deserialization_func(data_serialized, impl)
# else:
# data = deserialization_func(data_serialized)
# return data, end_index


# def deserialize_tuple_tensor(serialized, deserialization_func, impl=None):
# """ Generic deserialization function for a tuple of tensors. """
# deserialized_tensors = []
# i = 0
# while i < len(serialized):
# shape, i = extract_shape(serialized, i)
# data, i = extract_data(serialized, i, deserialization_func, impl)
# tensor = data.reshape(shape)
# deserialized_tensors.append(tensor)
# return tuple(deserialized_tensors)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "giza-osiris"
version = "0.2.3"
version = "0.2.4"
description = "Osiris is a Python library designed for efficient data conversion and management, primarily transforming data into Cairo programs"
authors = ["Fran Algaba <[email protected]>"]
readme = "README.md"
Expand Down
86 changes: 32 additions & 54 deletions tests/test_deserialize.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import numpy.testing as npt
import pytest
from math import isclose

Expand All @@ -7,105 +8,82 @@

def test_deserialize_int():
serialized = '[{"Int":"2A"}]'
deserialized = deserializer(serialized, 'int')
deserialized = deserializer(serialized, 'u32')
assert deserialized == 42

serialized = '[{"Int":"800000000000010FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFD7"}]'
deserialized = deserializer(serialized, 'int')
deserialized = deserializer(serialized, 'i32')
assert deserialized == -42


def test_deserialize_fp():
serialized = '[{"Int":"2A6B85"}, {"Int":"0"}]'
deserialized = deserializer(serialized, 'fixed_point', 'FP16x16')
deserialized = deserializer(serialized, 'FP16x16')
assert isclose(deserialized, 42.42, rel_tol=1e-7)

serialized = '[{"Int":"2A6B85"}, {"Int":"1"}]'
deserialized = deserializer(serialized, 'fixed_point', 'FP16x16')
deserialized = deserializer(serialized, 'FP16x16')
assert isclose(deserialized, -42.42, rel_tol=1e-7)


def test_deserialize_array_int():
serialized = '[{"Array": [{"Int": "0x1"}, {"Int": "0x2"}]}]'
deserialized = deserializer(serialized, 'arr_int')
deserialized = deserializer(serialized, 'Span<u32>')
assert np.array_equal(deserialized, np.array([1, 2], dtype=np.int64))

serialized = '[{"Array": [{"Int": "2A"}, {"Int": "800000000000010FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFD7"}]}]'
deserialized = deserializer(serialized, 'arr_int')
deserialized = deserializer(serialized, 'Span<i32>')
assert np.array_equal(deserialized, np.array([42, -42], dtype=np.int64))


def test_deserialize_arr_fixed_point():
serialized = '[{"Array": [{"Int": "2A6B85"}, {"Int": "0"}, {"Int": "2A6B85"}, {"Int": "0x1"}]}]'
deserialized = deserializer(serialized, 'arr_fixed_point')
deserialized = deserializer(serialized, 'Span<FP16x16>')
expected = np.array([42.42, -42.42], dtype=np.float64)
assert np.all(np.isclose(deserialized, expected, atol=1e-7))


def test_deserialize_tensor_int():
serialized = '[{"Array": [{"Int": "0x2"}, {"Int": "0x2"}]}, {"Array": [{"Int": "0x1"}, {"Int": "0x2"}, {"Int": "0x3"}, {"Int": "0x4"}]}]'
deserialized = deserializer(serialized, 'tensor_int')
deserialized = deserializer(serialized, 'Tensor<i32>')
assert np.array_equal(deserialized, np.array(
([1, 2], [3, 4]), dtype=np.int64))

serialized = '[{"Array": [{"Int": "0x2"}, {"Int": "0x2"}]}, {"Array": [{"Int": "2A"}, {"Int": "2A"},{"Int": "800000000000010FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFD7"}, {"Int": "800000000000010FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFD7"}]}]'
deserialized = deserializer(serialized, 'tensor_int')
deserialized = deserializer(serialized, 'Tensor<i32>')
assert np.array_equal(deserialized, np.array([[42, 42], [-42, -42]]))


def test_deserialize_tensor_fixed_point():
serialized = '[{"Array": [{"Int": "0x2"}, {"Int": "0x2"}]}, {"Array": [{"Int": "2A6B85"}, {"Int": "0x0"}, {"Int": "2A6B85"}, {"Int": "0x0"}, {"Int": "2A6B85"}, {"Int": "0x1"}, {"Int": "2A6B85"}, {"Int": "0x1"}]}]'
expected_array = np.array([[42.42, 42.42], [-42.42, -42.42]])
deserialized = deserializer(serialized, 'tensor_fixed_point')
deserialized = deserializer(serialized, 'Tensor<FP16x16>')
assert np.allclose(deserialized, expected_array, atol=1e-7)


# def test_deserialize_tuple_uint():
# serialized = [1, 2]
# deserialized = deserialize_tuple_uint(serialized)
# assert np.array_equal(deserialized, np.array([1, 2], dtype=np.int64))
def test_deserialize_tuple_int():
serialized = '[{"Int":"0x1"},{"Int":"0x3"}]'
deserialized = deserializer(serialized, '(u32, u32)')
assert deserialized == (1, 3)


# def test_deserialize_tuple_signed_int():
# serialized = [42, 0, 42, 1, 42, 0]
# deserialized = deserialize_tuple_signed_int(serialized)
# assert np.array_equal(deserialized, np.array(
# [42, -42, 42], dtype=np.int64))
def test_deserialize_tuple_span():
serialized = '[{"Array":[{"Int":"0x1"},{"Int":"0x2"}]},{"Int":"0x3"}]'
deserialized = deserializer(serialized, '(Span<u32>, u32)')
expected = (np.array([1, 2]), 3)
npt.assert_array_equal(deserialized[0], expected[0])
assert deserialized[1] == expected[1]


# def test_deserialize_tuple_fixed_point():
# serialized = [2780037, 0, 2780037, 1, 2780037, 0]
# deserialized = deserialize_tuple_fixed_point(serialized)
# expected = np.array([42.42, -42.42, 42.42], dtype=np.float64)
# assert np.all(np.isclose(deserialized, expected, atol=1e-7))
def test_deserialize_tuple_span_tensor_fp():
serialized = '[{"Array":[{"Int":"0x1"},{"Int":"0x2"}]},{"Array": [{"Int": "0x2"}, {"Int": "0x2"}]}, {"Array": [{"Int": "2A6B85"}, {"Int": "0x0"}, {"Int": "2A6B85"}, {"Int": "0x0"}, {"Int": "2A6B85"}, {"Int": "0x1"}, {"Int": "2A6B85"}, {"Int": "0x1"}]}]'
deserialized = deserializer(serialized, '(Span<u32>, Tensor<FP16x16>)')
expected = (np.array([1, 2]), np.array([[42.42, 42.42], [-42.42, -42.42]]))
npt.assert_array_equal(deserialized[0], expected[0])
assert np.allclose(deserialized[1], expected[1], atol=1e-7)

# def test_deserialize_tensor_tuple_tensor_uint():
# serialized = [2, 2, 2, 4, 1, 2, 3, 4, 2, 2, 2, 4, 5, 6, 7, 8]
# deserialized = deserialize_tuple_tensor_uint(serialized)

# assert np.array_equal(deserialized[0], np.array(
# [[1, 2], [3, 4]], dtype=np.int64))
# assert np.array_equal(deserialized[1], np.array(
# [[5, 6], [7, 8]], dtype=np.int64))


# def test_deserialize_tensor_tuple_tensor_signed_int():
# serialized = [2, 2, 2, 8, 42,
# 0, 42, 0, 42, 1, 42, 1, 2, 2, 2, 8, 42,
# 0, 42, 0, 42, 1, 42, 1]
# deserialized = deserialize_tuple_tensor_signed_int(serialized)

# expected_array = np.array([[42, 42], [-42, -42]])
# assert np.allclose(deserialized[0], expected_array, atol=1e-7)
# assert np.allclose(deserialized[1], expected_array, atol=1e-7)


# def test_deserialize_tensor_tuple_tensor_fixed_point():
# serialized = [2, 2, 2, 8, 2780037,
# 0, 2780037, 0, 2780037, 1, 2780037, 1, 2, 2, 2, 8, 2780037,
# 0, 2780037, 0, 2780037, 1, 2780037, 1]
# deserialized = deserialize_tuple_tensor_fixed_point(serialized)

# expected_array = np.array([[42.42, 42.42], [-42.42, -42.42]])
# assert np.allclose(deserialized[0], expected_array, atol=1e-7)
# assert np.allclose(deserialized[1], expected_array, atol=1e-7)
serialized = '[{"Array": [{"Int": "0x2"}, {"Int": "0x2"}]}, {"Array": [{"Int": "2A6B85"}, {"Int": "0x0"}, {"Int": "2A6B85"}, {"Int": "0x0"}, {"Int": "2A6B85"}, {"Int": "0x1"}, {"Int": "2A6B85"}, {"Int": "0x1"}]}, {"Array":[{"Int":"0x1"},{"Int":"0x2"}]}]'
deserialized = deserializer(serialized, '(Tensor<FP16x16>, Span<u32>)')
expected = (np.array([[42.42, 42.42], [-42.42, -42.42]]), np.array([1, 2]))
assert np.allclose(deserialized[0], expected[0], atol=1e-7)
npt.assert_array_equal(deserialized[1], expected[1])
Loading