Skip to content

Commit

Permalink
feature: set default allow_pickle param to False (aws#4557)
Browse files Browse the repository at this point in the history
* breaking: set default allow_pickle param to False

* breaking: fix unit tests and linting

NumpyDeserializer will not allow deserialization
unless allow_pickle flag is set to True explicitly

* fix: black-check

---------

Co-authored-by: Ashwin Krishna <[email protected]>
  • Loading branch information
2 people authored and jiapinw committed Jun 25, 2024
1 parent 11f6f70 commit ab7561f
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
17 changes: 14 additions & 3 deletions src/sagemaker/base_deserializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,14 +196,14 @@ class NumpyDeserializer(SimpleBaseDeserializer):
single array.
"""

def __init__(self, dtype=None, accept="application/x-npy", allow_pickle=True):
def __init__(self, dtype=None, accept="application/x-npy", allow_pickle=False):
"""Initialize a ``NumpyDeserializer`` instance.
Args:
dtype (str): The dtype of the data (default: None).
accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that
is expected from the inference endpoint (default: "application/x-npy").
allow_pickle (bool): Allow loading pickled object arrays (default: True).
allow_pickle (bool): Allow loading pickled object arrays (default: False).
"""
super(NumpyDeserializer, self).__init__(accept=accept)
self.dtype = dtype
Expand All @@ -227,10 +227,21 @@ def deserialize(self, stream, content_type):
if content_type == "application/json":
return np.array(json.load(codecs.getreader("utf-8")(stream)), dtype=self.dtype)
if content_type == "application/x-npy":
return np.load(io.BytesIO(stream.read()), allow_pickle=self.allow_pickle)
try:
return np.load(io.BytesIO(stream.read()), allow_pickle=self.allow_pickle)
except ValueError as ve:
raise ValueError(
"Please set the param allow_pickle=True \
to deserialize pickle objects in NumpyDeserializer"
).with_traceback(ve.__traceback__)
if content_type == "application/x-npz":
try:
return np.load(io.BytesIO(stream.read()), allow_pickle=self.allow_pickle)
except ValueError as ve:
raise ValueError(
"Please set the param allow_pickle=True \
to deserialize pickle objectsin NumpyDeserializer"
).with_traceback(ve.__traceback__)
finally:
stream.close()
finally:
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/sagemaker/deserializers/test_deserializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,8 @@ def test_numpy_deserializer_from_npy(numpy_deserializer):
assert np.array_equal(array, result)


def test_numpy_deserializer_from_npy_object_array(numpy_deserializer):
def test_numpy_deserializer_from_npy_object_array():
numpy_deserializer = NumpyDeserializer(allow_pickle=True)
array = np.array([{"a": "", "b": ""}, {"c": "", "d": ""}])
stream = io.BytesIO()
np.save(stream, array)
Expand Down

0 comments on commit ab7561f

Please sign in to comment.