diff --git a/mars/core/__init__.py b/mars/core/__init__.py index abc0385114..3a62197466 100644 --- a/mars/core/__init__.py +++ b/mars/core/__init__.py @@ -20,20 +20,15 @@ EntityData, ENTITY_TYPE, Chunk, - ChunkData, - CHUNK_TYPE, Tileable, TileableData, TILEABLE_TYPE, Object, ObjectData, ObjectChunk, - ObjectChunkData, OBJECT_TYPE, OBJECT_CHUNK_TYPE, FuseChunk, - FuseChunkData, - FUSE_CHUNK_TYPE, OutputType, register_output_types, get_output_types, diff --git a/mars/core/entity/__init__.py b/mars/core/entity/__init__.py index bb6edabaf0..809e8856aa 100644 --- a/mars/core/entity/__init__.py +++ b/mars/core/entity/__init__.py @@ -12,13 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .chunks import Chunk, ChunkData, CHUNK_TYPE +from .chunks import Chunk from .core import Entity, EntityData, ENTITY_TYPE from .executable import ExecutableTuple, _ExecuteAndFetchMixin -from .fuse import FuseChunk, FuseChunkData, FUSE_CHUNK_TYPE +from .fuse import FuseChunk from .objects import ( ObjectChunk, - ObjectChunkData, Object, ObjectData, OBJECT_CHUNK_TYPE, diff --git a/mars/core/entity/chunks.py b/mars/core/entity/chunks.py index 1b1a769d58..ae93ec756b 100644 --- a/mars/core/entity/chunks.py +++ b/mars/core/entity/chunks.py @@ -14,10 +14,10 @@ from ...serialization.serializables import BoolField, FieldTypes, TupleField from ...utils import tokenize -from .core import EntityData, Entity +from .core import EntityData -class ChunkData(EntityData): +class Chunk(EntityData): __slots__ = () is_broadcaster = BoolField("is_broadcaster", default=False) @@ -56,13 +56,3 @@ def _update_key(self): *(getattr(self, k, None) for k in self._keys_ if k != "_index"), ), ) - - -class Chunk(Entity): - _allow_data_type_ = (ChunkData,) - - def __repr__(self): - return f"{type(self).__name__}({self._data.__repr__()})" - - -CHUNK_TYPE = (Chunk, ChunkData) diff --git a/mars/core/entity/fuse.py b/mars/core/entity/fuse.py index 892c5ea07f..72ba8c5ea0 100644 --- a/mars/core/entity/fuse.py +++ b/mars/core/entity/fuse.py @@ -15,15 +15,13 @@ import numpy as np from ...serialization.serializables import ReferenceField -from .chunks import ChunkData, Chunk, CHUNK_TYPE +from .chunks import Chunk -class FuseChunkData(ChunkData): +class FuseChunk(Chunk): __slots__ = ("_inited",) - _chunk = ReferenceField( - "chunk", CHUNK_TYPE, on_serialize=lambda x: x.data if hasattr(x, "data") else x - ) + _chunk = ReferenceField("chunk", Chunk) def __init__(self, *args, **kwargs): self._inited = False @@ -63,11 +61,3 @@ def __setattr__(self, attr, value): @property def nbytes(self): return np.prod(self.shape) * self.dtype.itemsize - - -class FuseChunk(Chunk): - __slots__ = () - _allow_data_type_ = (FuseChunkData,) - - -FUSE_CHUNK_TYPE = (FuseChunkData, FuseChunk) diff --git a/mars/core/entity/objects.py b/mars/core/entity/objects.py index d9dd9ed612..1c52293165 100644 --- a/mars/core/entity/objects.py +++ b/mars/core/entity/objects.py @@ -15,13 +15,13 @@ from typing import Any, Dict from ...serialization.serializables import FieldTypes, ListField -from .chunks import ChunkData, Chunk +from .chunks import Chunk from .core import Entity from .executable import _ToObjectMixin from .tileables import TileableData -class ObjectChunkData(ChunkData): +class ObjectChunk(Chunk): # chunk whose data could be any serializable __slots__ = () type_name = "Object" @@ -48,23 +48,12 @@ def get_params_from_data(cls, data: Any) -> Dict[str, Any]: return dict() -class ObjectChunk(Chunk): - __slots__ = () - _allow_data_type_ = (ObjectChunkData,) - type_name = "Object" - - class ObjectData(TileableData, _ToObjectMixin): __slots__ = () type_name = "Object" # optional fields - _chunks = ListField( - "chunks", - FieldTypes.reference(ObjectChunkData), - on_serialize=lambda x: [it.data for it in x] if x is not None else x, - on_deserialize=lambda x: [ObjectChunk(it) for it in x] if x is not None else x, - ) + _chunks = ListField("chunks", FieldTypes.reference(ObjectChunk)) def __init__(self, op=None, nsplits=None, chunks=None, **kw): super().__init__(_op=op, _nsplits=nsplits, _chunks=chunks, **kw) @@ -96,4 +85,4 @@ class Object(Entity, _ToObjectMixin): OBJECT_TYPE = (Object, ObjectData) -OBJECT_CHUNK_TYPE = (ObjectChunk, ObjectChunkData) +OBJECT_CHUNK_TYPE = (ObjectChunk,) diff --git a/mars/core/entity/output_types.py b/mars/core/entity/output_types.py index 85fc45a984..3967fa9c5d 100644 --- a/mars/core/entity/output_types.py +++ b/mars/core/entity/output_types.py @@ -15,7 +15,7 @@ import functools from enum import Enum -from .fuse import FUSE_CHUNK_TYPE +from .fuse import FuseChunk from .objects import OBJECT_TYPE, OBJECT_CHUNK_TYPE @@ -83,7 +83,7 @@ def get_output_types(*objs, unknown_as=None): for obj in objs: if obj is None: continue - elif isinstance(obj, FUSE_CHUNK_TYPE): + elif isinstance(obj, FuseChunk): obj = obj.chunk try: diff --git a/mars/core/graph/builder/chunk.py b/mars/core/graph/builder/chunk.py index d6c81900a0..d18cc452e6 100644 --- a/mars/core/graph/builder/chunk.py +++ b/mars/core/graph/builder/chunk.py @@ -26,7 +26,7 @@ Union, ) -from ....core import FUSE_CHUNK_TYPE, CHUNK_TYPE, TILEABLE_TYPE +from ....core import FuseChunk, TILEABLE_TYPE, Chunk from ....typing import EntityType, TileableType, ChunkType from ....utils import copy_tileables, build_fetch from ...entity.tileables import handler @@ -223,7 +223,7 @@ def _tile( chunks = [] if need_process is not None: for t in need_process: - if isinstance(t, CHUNK_TYPE): + if isinstance(t, Chunk): chunks.append(self._get_data(t)) elif isinstance(t, TILEABLE_TYPE): to_update_tileables.append(self._get_data(t)) @@ -304,7 +304,7 @@ def _iter(self): # so that fetch chunk can be generated. # Use chunk key as the key to make sure the copied chunk can be build to a fetch. processed_chunks = ( - c.chunk.key if isinstance(c, FUSE_CHUNK_TYPE) else c.key + c.chunk.key if isinstance(c, FuseChunk) else c.key for c in chunk_graph.result_chunks ) self._processed_chunks.update(processed_chunks) @@ -406,7 +406,7 @@ def _process_node(self, entity: EntityType): if entity.key in self._processed_chunks: if entity not in self._chunk_to_fetch: # gen fetch - fetch_chunk = build_fetch(entity).data + fetch_chunk = build_fetch(entity) self._chunk_to_fetch[entity] = fetch_chunk return self._chunk_to_fetch[entity] return entity @@ -417,7 +417,7 @@ def _select_inputs(self, inputs: List[ChunkType]): if inp.key in self._processed_chunks: # gen fetch if inp not in self._chunk_to_fetch: - fetch_chunk = build_fetch(inp).data + fetch_chunk = build_fetch(inp) self._chunk_to_fetch[inp] = fetch_chunk new_inputs.append(self._chunk_to_fetch[inp]) else: diff --git a/mars/core/operand/core.py b/mars/core/operand/core.py index b2862cb2d4..20643592fe 100644 --- a/mars/core/operand/core.py +++ b/mars/core/operand/core.py @@ -28,8 +28,9 @@ from ..context import Context from ..mode import is_eager_mode from ..entity import ( - OutputType, + Chunk, ExecutableTuple, + OutputType, get_chunk_types, get_tileable_types, get_output_types, @@ -78,14 +79,14 @@ def _check_if_gpu(cls, inputs: List[TileableType]): def _tokenize_output(self, output_idx: int, **kw): return f"{self._key}_{output_idx}" - def _create_chunk(self, output_idx: int, index: Tuple[int], **kw) -> ChunkType: + def _create_chunk(self, output_idx: int, index: Tuple[int], **kw) -> Chunk: output_type = kw.pop("output_type", None) or self._get_output_type(output_idx) if not output_type: raise ValueError("output_type should be specified") if isinstance(output_type, (list, tuple)): output_type = output_type[output_idx] - chunk_type, chunk_data_type = get_chunk_types(output_type) + (chunk_data_type,) = get_chunk_types(output_type) kw["_i"] = output_idx kw["op"] = self kw["index"] = index @@ -97,8 +98,7 @@ def _create_chunk(self, output_idx: int, index: Tuple[int], **kw) -> ChunkType: if "_key" not in kw: kw["_key"] = self._tokenize_output(output_idx, **kw) - data = chunk_data_type(**kw) - return chunk_type(data) + return chunk_data_type(**kw) def _new_chunks( self, inputs: List[ChunkType], kws: List[Dict] = None, **kw @@ -130,7 +130,7 @@ def _new_chunks( # for each output chunk, hold the reference to the other outputs # so that either no one or everyone are gc collected for j, t in enumerate(chunks): - t.data._siblings = [c.data for c in chunks[:j] + chunks[j + 1 :]] + t._siblings = [c for c in chunks[:j] + chunks[j + 1 :]] return chunks def new_chunks( diff --git a/mars/core/operand/fuse.py b/mars/core/operand/fuse.py index 1f1d64c244..e2ad449a29 100644 --- a/mars/core/operand/fuse.py +++ b/mars/core/operand/fuse.py @@ -14,7 +14,7 @@ from ... import opcodes from ...serialization.serializables import ReferenceField -from ..entity import FuseChunk, FuseChunkData, NotSupportTile +from ..entity import FuseChunk, NotSupportTile from ..graph import ChunkGraph from .base import Operand @@ -30,8 +30,7 @@ class FuseChunkMixin: __slots__ = () def _create_chunk(self, output_idx, index, **kw): - data = FuseChunkData(_index=index, _op=self, **kw) - return FuseChunk(data) + return FuseChunk(_index=index, _op=self, **kw) @classmethod def tile(cls, op): diff --git a/mars/dataframe/arithmetic/core.py b/mars/dataframe/arithmetic/core.py index 1073811857..536412aef1 100644 --- a/mars/dataframe/arithmetic/core.py +++ b/mars/dataframe/arithmetic/core.py @@ -19,9 +19,9 @@ import numpy as np import pandas as pd -from ...core import ENTITY_TYPE, CHUNK_TYPE, recursive_tile +from ...core import ENTITY_TYPE, recursive_tile from ...serialization.serializables import AnyField -from ...tensor.core import TENSOR_TYPE, TENSOR_CHUNK_TYPE, ChunkData, Chunk +from ...tensor.core import TENSOR_TYPE, TENSOR_CHUNK_TYPE, Chunk from ...utils import classproperty, get_dtype from ..align import ( align_series_series, @@ -421,7 +421,7 @@ def _operator(self): @classmethod def _calc_properties(cls, x1, x2=None, axis="columns"): - is_chunk = isinstance(x1, CHUNK_TYPE) + is_chunk = isinstance(x1, Chunk) if isinstance(x1, (DATAFRAME_TYPE, DATAFRAME_CHUNK_TYPE)) and ( x2 is None @@ -625,7 +625,7 @@ def _new_chunks(self, inputs, kws=None, **kw): property_inputs = reversed(property_inputs) properties = self._calc_properties(*property_inputs) - inputs = [inp for inp in inputs if isinstance(inp, (Chunk, ChunkData))] + inputs = [inp for inp in inputs if isinstance(inp, Chunk)] shape = properties.pop("shape") if "shape" in kw: diff --git a/mars/dataframe/arithmetic/tests/test_arithmetic.py b/mars/dataframe/arithmetic/tests/test_arithmetic.py index e6a79e610d..7fef83ae3d 100644 --- a/mars/dataframe/arithmetic/tests/test_arithmetic.py +++ b/mars/dataframe/arithmetic/tests/test_arithmetic.py @@ -1190,9 +1190,9 @@ def test_both_one_chunk(func_name, func_opts): assert isinstance(c.op, func_opts.op) assert len(c.inputs) == 2 # test the left side - assert c.inputs[0] is df1.chunks[0].data + assert c.inputs[0] is df1.chunks[0] # test the right side - assert c.inputs[1] is df2.chunks[0].data + assert c.inputs[1] is df2.chunks[0] @pytest.mark.parametrize("func_name, func_opts", binary_functions.items()) @@ -1522,7 +1522,7 @@ def test_arithmetic_lazy_chunk_meta(): df2 = df + 1 df2 = tile(df2) - chunk = df2.chunks[0].data + chunk = df2.chunks[0] assert chunk._FIELDS["_dtypes"].get(chunk) is None pd.testing.assert_series_equal(chunk.dtypes, df.dtypes) assert chunk._FIELDS["_dtypes"].get(chunk) is not None diff --git a/mars/dataframe/base/drop.py b/mars/dataframe/base/drop.py index 19f44b09a1..45ed5e4d1f 100644 --- a/mars/dataframe/base/drop.py +++ b/mars/dataframe/base/drop.py @@ -18,7 +18,7 @@ import numpy as np from ... import opcodes -from ...core import Entity, Chunk, CHUNK_TYPE, OutputType, recursive_tile +from ...core import Entity, Chunk, OutputType, recursive_tile from ...serialization.serializables import AnyField, StringField from ..core import IndexValue, DATAFRAME_TYPE, SERIES_TYPE, INDEX_CHUNK_TYPE from ..operands import DataFrameOperand, DataFrameOperandMixin @@ -172,7 +172,7 @@ def tile(cls, op: "DataFrameDrop"): @classmethod def execute(cls, ctx, op: "DataFrameDrop"): inp = op.inputs[0] - if isinstance(op.index, CHUNK_TYPE): + if isinstance(op.index, Chunk): index_val = ctx[op.index.key] else: index_val = op.index diff --git a/mars/dataframe/core.py b/mars/dataframe/core.py index 33289d38ba..51b0c54428 100644 --- a/mars/dataframe/core.py +++ b/mars/dataframe/core.py @@ -25,7 +25,6 @@ import pandas as pd from ..core import ( - ChunkData, Chunk, Tileable, HasShapeTileableData, @@ -452,7 +451,7 @@ def refresh_dtypes(tileable: ENTITY_TYPE): ) -class LazyMetaChunkData(ChunkData): +class LazyMetaChunk(Chunk): __slots__ = _lazy_chunk_meta_properties def _set_tileable_meta( @@ -470,11 +469,9 @@ def _set_tileable_meta( setattr(self, _tileable_dtypes_property, dtypes) -def is_chunk_meta_lazy(chunk: ChunkData) -> bool: +def is_chunk_meta_lazy(chunk: Chunk) -> bool: chunk = chunk.data if hasattr(chunk, "data") else chunk - return isinstance(chunk, LazyMetaChunkData) and hasattr( - chunk, _tileable_key_property - ) + return isinstance(chunk, LazyMetaChunk) and hasattr(chunk, _tileable_key_property) @functools.lru_cache(maxsize=128) @@ -513,7 +510,7 @@ def _gen_chunk_dtypes(instance: Chunk, index: int) -> Optional[pd.Series]: return dtypes def __get__(self, instance, owner=None): - if not issubclass(owner, LazyMetaChunkData): # pragma: no cover + if not issubclass(owner, LazyMetaChunk): # pragma: no cover return super().__get__(instance, owner) try: @@ -566,7 +563,7 @@ def _gen_chunk_index_value(instance: Chunk, index: int) -> Optional[IndexValue]: return index_value def __get__(self, instance, owner=None): - if not issubclass(owner, LazyMetaChunkData): # pragma: no cover + if not issubclass(owner, LazyMetaChunk): # pragma: no cover return super().__get__(instance, owner) try: @@ -615,7 +612,7 @@ def _gen_chunk_columns_value(instance: Chunk, index: int) -> Optional[IndexValue return columns_value def __get__(self, instance, owner=None): - if not issubclass(owner, LazyMetaChunkData): # pragma: no cover + if not issubclass(owner, LazyMetaChunk): # pragma: no cover return super().__get__(instance, owner) try: @@ -636,7 +633,7 @@ def __get__(self, instance, owner=None): return columns_value -class IndexChunkData(ChunkData): +class IndexChunk(Chunk): __slots__ = () type_name = "Index" @@ -732,12 +729,6 @@ def index_value(self): return self._index_value -class IndexChunk(Chunk): - __slots__ = () - _allow_data_type_ = (IndexChunkData,) - type_name = "Index" - - def _on_deserialize_index_value(index_value): if index_value is None: return @@ -827,12 +818,7 @@ class IndexData(HasShapeTileableData, _ToPandasMixin): _index_value = ReferenceField( "index_value", IndexValue, on_deserialize=_on_deserialize_index_value ) - _chunks = ListField( - "chunks", - FieldTypes.reference(IndexChunkData), - on_serialize=lambda x: [it.data for it in x] if x is not None else x, - on_deserialize=lambda x: [IndexChunk(it) for it in x] if x is not None else x, - ) + _chunks = ListField("chunks", FieldTypes.reference(IndexChunk)) def __init__( self, @@ -1156,7 +1142,7 @@ class MultiIndex(Index): __slots__ = () -class BaseSeriesChunkData(LazyMetaChunkData): +class BaseSeriesChunk(LazyMetaChunk): __slots__ = () # required fields @@ -1253,13 +1239,7 @@ def index_value(self): return self._index_value -class SeriesChunkData(BaseSeriesChunkData): - type_name = "Series" - - -class SeriesChunk(Chunk): - __slots__ = () - _allow_data_type_ = (SeriesChunkData,) +class SeriesChunk(BaseSeriesChunk): type_name = "Series" @@ -1272,12 +1252,7 @@ class BaseSeriesData(HasShapeTileableData, _ToPandasMixin): _index_value = ReferenceField( "index_value", IndexValue, on_deserialize=_on_deserialize_index_value ) - _chunks = ListField( - "chunks", - FieldTypes.reference(SeriesChunkData), - on_serialize=lambda x: [it.data for it in x] if x is not None else x, - on_deserialize=lambda x: [SeriesChunk(it) for it in x] if x is not None else x, - ) + _chunks = ListField("chunks", FieldTypes.reference(SeriesChunk)) def __init__( self, @@ -1812,9 +1787,9 @@ def median( ) -class BaseDataFrameChunkData(LazyMetaChunkData): +class BaseDataFrameChunk(LazyMetaChunk): __slots__ = ("_dtypes_value",) - _no_copy_attrs_ = ChunkData._no_copy_attrs_ | {"_dtypes", "_columns_value"} + _no_copy_attrs_ = Chunk._no_copy_attrs_ | {"_dtypes", "_columns_value"} # required fields _shape = TupleField( @@ -1852,7 +1827,7 @@ def __init__( self._dtypes_value = None def __on_deserialize__(self): - super(BaseDataFrameChunkData, self).__on_deserialize__() + super(BaseDataFrameChunk, self).__on_deserialize__() self._dtypes_value = None def __len__(self): @@ -1940,19 +1915,10 @@ def columns_value(self): return self._columns_value -class DataFrameChunkData(BaseDataFrameChunkData): +class DataFrameChunk(BaseDataFrameChunk): type_name = "DataFrame" -class DataFrameChunk(Chunk): - __slots__ = () - _allow_data_type_ = (DataFrameChunkData,) - type_name = "DataFrame" - - def __len__(self): - return len(self._data) - - class BaseDataFrameData(HasShapeTileableData, _ToPandasMixin): __slots__ = "_accessors", "_dtypes_value", "_dtypes_dict" @@ -1962,14 +1928,7 @@ class BaseDataFrameData(HasShapeTileableData, _ToPandasMixin): "index_value", IndexValue, on_deserialize=_on_deserialize_index_value ) _columns_value = ReferenceField("columns_value", IndexValue) - _chunks = ListField( - "chunks", - FieldTypes.reference(DataFrameChunkData), - on_serialize=lambda x: [it.data for it in x] if x is not None else x, - on_deserialize=lambda x: [DataFrameChunk(it) for it in x] - if x is not None - else x, - ) + _chunks = ListField("chunks", FieldTypes.reference(DataFrameChunk)) def __init__( self, @@ -2553,7 +2512,7 @@ def apply_if_callable(maybe_callable, obj, **kwargs): return data -class DataFrameGroupByChunkData(BaseDataFrameChunkData): +class DataFrameGroupByChunk(BaseDataFrameChunk): type_name = "DataFrameGroupBy" _key_dtypes = SeriesField("key_dtypes") @@ -2597,16 +2556,7 @@ def __init__(self, key_dtypes=None, selection=None, **kw): super().__init__(_key_dtypes=key_dtypes, _selection=selection, **kw) -class DataFrameGroupByChunk(Chunk): - __slots__ = () - _allow_data_type_ = (DataFrameGroupByChunkData,) - type_name = "DataFrameGroupBy" - - def __len__(self): - return len(self._data) - - -class SeriesGroupByChunkData(BaseSeriesChunkData): +class SeriesGroupByChunk(BaseSeriesChunk): type_name = "SeriesGroupBy" _key_dtypes = AnyField("key_dtypes") @@ -2648,28 +2598,12 @@ def __init__(self, key_dtypes=None, **kw): super().__init__(_key_dtypes=key_dtypes, **kw) -class SeriesGroupByChunk(Chunk): - __slots__ = () - _allow_data_type_ = (SeriesGroupByChunkData,) - type_name = "SeriesGroupBy" - - def __len__(self): - return len(self._data) - - class DataFrameGroupByData(BaseDataFrameData): type_name = "DataFrameGroupBy" _key_dtypes = SeriesField("key_dtypes") _selection = AnyField("selection") - _chunks = ListField( - "chunks", - FieldTypes.reference(DataFrameGroupByChunkData), - on_serialize=lambda x: [it.data for it in x] if x is not None else x, - on_deserialize=lambda x: [DataFrameGroupByChunk(it) for it in x] - if x is not None - else x, - ) + _chunks = ListField("chunks", FieldTypes.reference(DataFrameGroupByChunk)) @property def key_dtypes(self): @@ -2711,14 +2645,7 @@ class SeriesGroupByData(BaseSeriesData): type_name = "SeriesGroupBy" _key_dtypes = AnyField("key_dtypes") - _chunks = ListField( - "chunks", - FieldTypes.reference(SeriesGroupByChunkData), - on_serialize=lambda x: [it.data for it in x] if x is not None else x, - on_deserialize=lambda x: [SeriesGroupByChunk(it) for it in x] - if x is not None - else x, - ) + _chunks = ListField("chunks", FieldTypes.reference(SeriesGroupByChunk)) @property def key_dtypes(self): @@ -2795,7 +2722,7 @@ def __hash__(self): return super().__hash__() -class CategoricalChunkData(ChunkData): +class CategoricalChunk(Chunk): __slots__ = () type_name = "Categorical" @@ -2875,12 +2802,6 @@ def categories_value(self): return self._categories_value -class CategoricalChunk(Chunk): - __slots__ = () - _allow_data_type_ = (CategoricalChunkData,) - type_name = "Categorical" - - class CategoricalData(HasShapeTileableData, _ToPandasMixin): __slots__ = ("_cache",) type_name = "Categorical" @@ -2890,14 +2811,7 @@ class CategoricalData(HasShapeTileableData, _ToPandasMixin): _categories_value = ReferenceField( "categories_value", IndexValue, on_deserialize=_on_deserialize_index_value ) - _chunks = ListField( - "chunks", - FieldTypes.reference(CategoricalChunkData), - on_serialize=lambda x: [it.data for it in x] if x is not None else x, - on_deserialize=lambda x: [CategoricalChunk(it) for it in x] - if x is not None - else x, - ) + _chunks = ListField("chunks", FieldTypes.reference(CategoricalChunk)) def __init__( self, @@ -3013,19 +2927,19 @@ def __hash__(self): INDEX_TYPE = (Index, IndexData) -INDEX_CHUNK_TYPE = (IndexChunk, IndexChunkData) +INDEX_CHUNK_TYPE = (IndexChunk,) SERIES_TYPE = (Series, SeriesData) -SERIES_CHUNK_TYPE = (SeriesChunk, SeriesChunkData) +SERIES_CHUNK_TYPE = (SeriesChunk,) DATAFRAME_TYPE = (DataFrame, DataFrameData) -DATAFRAME_CHUNK_TYPE = (DataFrameChunk, DataFrameChunkData) +DATAFRAME_CHUNK_TYPE = (DataFrameChunk,) DATAFRAME_GROUPBY_TYPE = (DataFrameGroupBy, DataFrameGroupByData) -DATAFRAME_GROUPBY_CHUNK_TYPE = (DataFrameGroupByChunk, DataFrameGroupByChunkData) +DATAFRAME_GROUPBY_CHUNK_TYPE = (DataFrameGroupByChunk,) SERIES_GROUPBY_TYPE = (SeriesGroupBy, SeriesGroupByData) -SERIES_GROUPBY_CHUNK_TYPE = (SeriesGroupByChunk, SeriesGroupByChunkData) +SERIES_GROUPBY_CHUNK_TYPE = (SeriesGroupByChunk,) GROUPBY_TYPE = (GroupBy,) + DATAFRAME_GROUPBY_TYPE + SERIES_GROUPBY_TYPE GROUPBY_CHUNK_TYPE = DATAFRAME_GROUPBY_CHUNK_TYPE + SERIES_GROUPBY_CHUNK_TYPE CATEGORICAL_TYPE = (Categorical, CategoricalData) -CATEGORICAL_CHUNK_TYPE = (CategoricalChunk, CategoricalChunkData) +CATEGORICAL_CHUNK_TYPE = (CategoricalChunk,) TILEABLE_TYPE = ( INDEX_TYPE + SERIES_TYPE + DATAFRAME_TYPE + GROUPBY_TYPE + CATEGORICAL_TYPE ) diff --git a/mars/dataframe/datastore/tests/test_datastore.py b/mars/dataframe/datastore/tests/test_datastore.py index 099a02399f..d95f01ebc7 100644 --- a/mars/dataframe/datastore/tests/test_datastore.py +++ b/mars/dataframe/datastore/tests/test_datastore.py @@ -29,7 +29,7 @@ def test_to_csv(): assert r.chunk_shape[1] == 1 for i, c in enumerate(r.chunks): assert type(c.op).__name__ == "DataFrameToCSV" - assert c.inputs[0] is r.inputs[0].chunks[i].data + assert c.inputs[0] is r.inputs[0].chunks[i] # test one file r = df.to_csv("out.csv") @@ -38,5 +38,5 @@ def test_to_csv(): assert r.chunk_shape[1] == 1 for i, c in enumerate(r.chunks): assert len(c.inputs) == 2 - assert c.inputs[0].inputs[0] is r.inputs[0].chunks[i].data + assert c.inputs[0].inputs[0] is r.inputs[0].chunks[i] assert type(c.inputs[1].op).__name__ == "DataFrameToCSVStat" diff --git a/mars/dataframe/indexing/tests/test_indexing.py b/mars/dataframe/indexing/tests/test_indexing.py index d1e037b999..56e4847c97 100644 --- a/mars/dataframe/indexing/tests/test_indexing.py +++ b/mars/dataframe/indexing/tests/test_indexing.py @@ -939,7 +939,7 @@ def test_getitem_lazy_chunk_meta(): df2 = df[[0, 2]] df2 = tile(df2) - chunk = df2.chunks[0].data + chunk = df2.chunks[0] assert chunk._FIELDS["_dtypes"].get(chunk) is None pd.testing.assert_series_equal(chunk.dtypes, df.dtypes[[0, 2]]) assert chunk._FIELDS["_dtypes"].get(chunk) is not None @@ -953,7 +953,7 @@ def test_getitem_lazy_chunk_meta(): df2 = df[2] df2 = tile(df2) - chunk = df2.chunks[0].data + chunk = df2.chunks[0] assert chunk._FIELDS["_index_value"].get(chunk) is None pd.testing.assert_index_equal(chunk.index_value.to_pandas(), pd.RangeIndex(3)) assert chunk._FIELDS["_index_value"].get(chunk) is not None diff --git a/mars/dataframe/operands.py b/mars/dataframe/operands.py index 4addcf41b4..ccf66d55c5 100644 --- a/mars/dataframe/operands.py +++ b/mars/dataframe/operands.py @@ -18,7 +18,7 @@ import numpy as np import pandas as pd -from ..core import FuseChunkData, FuseChunk, ENTITY_TYPE, OutputType +from ..core import FuseChunk, ENTITY_TYPE, OutputType from ..core.operand import ( Operand, TileableOperandMixin, @@ -467,9 +467,7 @@ class DataFrameFuseChunkMixin(FuseChunkMixin, DataFrameOperandMixin): __slots__ = () def _create_chunk(self, output_idx, index, **kw): - data = FuseChunkData(_index=index, _shape=kw.pop("shape", None), _op=self, **kw) - - return FuseChunk(data) + return FuseChunk(_index=index, _shape=kw.pop("shape", None), _op=self, **kw) class DataFrameFuseChunk(Fuse, DataFrameFuseChunkMixin): diff --git a/mars/optimization/logical/chunk/tests/test_column_pruning.py b/mars/optimization/logical/chunk/tests/test_column_pruning.py index 953c3b2f0c..7c39a3f9b6 100644 --- a/mars/optimization/logical/chunk/tests/test_column_pruning.py +++ b/mars/optimization/logical/chunk/tests/test_column_pruning.py @@ -58,8 +58,8 @@ def test_groupby_read_csv(gen_data1): graph, fuse_enabled=False, tile_context=context ) chunk_graph = next(chunk_graph_builder.build()) - chunk1 = context[df1.data].chunks[0].data - chunk2 = context[df2.data].chunks[0].data + chunk1 = context[df1.data].chunks[0] + chunk2 = context[df2.data].chunks[0] records = optimize(chunk_graph) opt_chunk1 = records.get_optimization_result(chunk1) assert opt_chunk1 is None diff --git a/mars/optimization/logical/chunk/tests/test_head.py b/mars/optimization/logical/chunk/tests/test_head.py index 42db361ced..62a60c036c 100644 --- a/mars/optimization/logical/chunk/tests/test_head.py +++ b/mars/optimization/logical/chunk/tests/test_head.py @@ -58,8 +58,8 @@ def test_read_csv_head(gen_data1): graph, fuse_enabled=False, tile_context=context ) chunk_graph = next(chunk_graph_builder.build()) - chunk1 = context[df1.data].chunks[0].data - chunk2 = context[df2.data].chunks[0].data + chunk1 = context[df1.data].chunks[0] + chunk2 = context[df2.data].chunks[0] records = optimize(chunk_graph) assert records.get_optimization_result(chunk1) is None opt_chunk2 = records.get_optimization_result(chunk2) diff --git a/mars/optimization/logical/common/column_pruning.py b/mars/optimization/logical/common/column_pruning.py index 72a4dfac6b..91b39589d4 100644 --- a/mars/optimization/logical/common/column_pruning.py +++ b/mars/optimization/logical/common/column_pruning.py @@ -15,7 +15,7 @@ from abc import ABCMeta, abstractmethod from typing import Any, List -from ....core import OperandType, TileableType, CHUNK_TYPE +from ....core import OperandType, TileableType, Chunk from ....dataframe.datasource.core import ColumnPruneSupportedDataSourceMixin from ....dataframe.utils import parse_index from ....utils import implements @@ -103,7 +103,7 @@ def apply(self, op: OperandType): data_source_node = self._graph.predecessors(node)[0] if ( - isinstance(node, CHUNK_TYPE) + isinstance(node, Chunk) and self._graph.count_successors(data_source_node) == 1 ): # merge into data source only for chunk @@ -120,7 +120,7 @@ def apply(self, op: OperandType): ) new_entity = ( data_source_op.new_tileable - if not isinstance(node, CHUNK_TYPE) + if not isinstance(node, Chunk) else data_source_op.new_chunk ) new_data_source_node = new_entity( @@ -192,7 +192,7 @@ def apply(self, op: OperandType): kws.append(params) new_entity = ( new_op.new_tileables - if not isinstance(node, CHUNK_TYPE) + if not isinstance(node, Chunk) else new_op.new_chunks ) new_outputs = [t.data for t in new_entity([new_data_source_node], kws=kws)] @@ -227,7 +227,7 @@ def _need_prune(self, op: OperandType) -> bool: and op.col_names is not None ): selected_columns = self._get_selected_columns(op) - if not isinstance(op.outputs[0], CHUNK_TYPE) and not selected_columns: + if not isinstance(op.outputs[0], Chunk) and not selected_columns: # no columns selected, skip return False return True diff --git a/mars/optimization/logical/common/head.py b/mars/optimization/logical/common/head.py index 7daeb8a392..c8824f815f 100644 --- a/mars/optimization/logical/common/head.py +++ b/mars/optimization/logical/common/head.py @@ -14,7 +14,7 @@ from typing import List -from ....core import OperandType, TileableType, CHUNK_TYPE +from ....core import OperandType, TileableType, Chunk from ....dataframe.base.value_counts import DataFrameValueCounts from ....dataframe.datasource.core import HeadOptimizedDataSource from ....dataframe.sort.core import DataFrameSortOperand @@ -83,7 +83,7 @@ def apply(self, op: OperandType): new_input_params.update(input_node.extra_params) new_entity = ( new_input_op.new_tileable - if not isinstance(node, CHUNK_TYPE) + if not isinstance(node, Chunk) else new_input_op.new_chunk ) new_input_node = new_entity(input_node.inputs, kws=[new_input_params]).data @@ -119,9 +119,7 @@ def apply(self, op: OperandType): params = node.params.copy() params.update(node.extra_params) new_entity = ( - new_op.new_tileable - if not isinstance(node, CHUNK_TYPE) - else new_op.new_chunk + new_op.new_tileable if not isinstance(node, Chunk) else new_op.new_chunk ) new_node = new_entity([new_input_node], kws=[params]).data self._replace_node(node, new_node) diff --git a/mars/optimization/physical/numexpr.py b/mars/optimization/physical/numexpr.py index 8294b9da55..644018aa1f 100644 --- a/mars/optimization/physical/numexpr.py +++ b/mars/optimization/physical/numexpr.py @@ -231,7 +231,7 @@ def _fuse_nodes(self, fuses: List[_Fuse], fuse_cls): kws=[tail_chunk.params], _key=tail_chunk.key, _chunk=tail_chunk, - ).data + ) graph.add_node(fused_chunk) for node in graph.iter_successors(tail_chunk): diff --git a/mars/remote/core.py b/mars/remote/core.py index 29076a66db..47437ca51a 100644 --- a/mars/remote/core.py +++ b/mars/remote/core.py @@ -16,7 +16,7 @@ from functools import partial from .. import opcodes -from ..core import ENTITY_TYPE, ChunkData, Tileable +from ..core import ENTITY_TYPE, Chunk, Tileable from ..core.custom_log import redirect_custom_log from ..core.operand import ObjectOperand from ..dataframe.core import DATAFRAME_TYPE, SERIES_TYPE, INDEX_TYPE @@ -73,7 +73,7 @@ def _set_inputs(self, inputs): if raw_inputs is not None: for raw_inp in raw_inputs: if self._no_prepare(raw_inp): - if not isinstance(self._inputs[0], ChunkData): + if not isinstance(self._inputs[0], Chunk): # not in tile, set_inputs from tileable mapping[raw_inp] = next(function_inputs) else: diff --git a/mars/remote/operands.py b/mars/remote/operands.py index 1433e05874..76b0c2f674 100644 --- a/mars/remote/operands.py +++ b/mars/remote/operands.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ..core import FuseChunkData, FuseChunk +from ..core import FuseChunk from ..core.operand import Fuse, FuseChunkMixin, ObjectOperandMixin @@ -20,9 +20,7 @@ class RemoteFuseChunkMixin(ObjectOperandMixin, FuseChunkMixin): __slots__ = () def _create_chunk(self, output_idx, index, **kw): - data = FuseChunkData(_index=index, _op=self, **kw) - - return FuseChunk(data) + return FuseChunk(_index=index, _op=self, **kw) class RemoteFuseChunk(RemoteFuseChunkMixin, Fuse): diff --git a/mars/services/subtask/core.py b/mars/services/subtask/core.py index ce8bfdc81c..17cc253053 100644 --- a/mars/services/subtask/core.py +++ b/mars/services/subtask/core.py @@ -15,7 +15,7 @@ from enum import Enum from typing import Iterable, List, Optional, Set, Tuple -from ...core import ChunkGraph, DAG, ChunkData +from ...core import ChunkGraph, DAG, Chunk from ...resource import Resource from ...serialization.serializables.field_type import TupleType from ...serialization.serializables import ( @@ -69,7 +69,7 @@ class Subtask(Serializable): stage_id: str = StringField("stage_id") # chunks that need meta updated update_meta_chunks: List[ChunkType] = ListField( - "update_meta_chunks", FieldTypes.reference(ChunkData) + "update_meta_chunks", FieldTypes.reference(Chunk) ) # A unique and deterministic key for subtask compute logic. See logic_key in operator.py. logic_key: str = StringField("logic_key") diff --git a/mars/services/task/analyzer/analyzer.py b/mars/services/task/analyzer/analyzer.py index b9ce259283..d34322aec9 100644 --- a/mars/services/task/analyzer/analyzer.py +++ b/mars/services/task/analyzer/analyzer.py @@ -180,11 +180,11 @@ def _gen_input_chunks( inp_chunk, n_reducers=n_reducers, shuffle_fetch_type=self._shuffle_fetch_type, - ).data + ) chunk_to_fetch_chunk[inp_chunk] = fetch_chunk inp_fetch_chunks.append(fetch_chunk) else: - fetch_chunk = build_fetch(inp_chunk).data + fetch_chunk = build_fetch(inp_chunk) chunk_to_fetch_chunk[inp_chunk] = fetch_chunk inp_fetch_chunks.append(fetch_chunk) @@ -276,7 +276,7 @@ def _gen_subtask_info( copied_op = chunk.op.copy() copied_op._key = chunk.op.key out_chunks = [ - c.data + c for c in copied_op.new_chunks( inp_chunks, kws=[c.params.copy() for c in chunk.op.outputs] ) diff --git a/mars/services/task/analyzer/assigner.py b/mars/services/task/analyzer/assigner.py index 2a1cbc938c..87c100eb94 100644 --- a/mars/services/task/analyzer/assigner.py +++ b/mars/services/task/analyzer/assigner.py @@ -19,7 +19,7 @@ import numpy as np -from ....core import ChunkGraph, ChunkData +from ....core import ChunkGraph, Chunk from ....core.operand import Operand from ....lib.ordered_set import OrderedSet from ....resource import Resource @@ -43,7 +43,7 @@ def __init__( self._band_resource = band_resource @abstractmethod - def assign(self, cur_assigns: Dict[str, str] = None) -> Dict[ChunkData, BandType]: + def assign(self, cur_assigns: Dict[str, str] = None) -> Dict[Chunk, BandType]: """ Assign start nodes to bands. @@ -128,7 +128,7 @@ def _calc_band_assign_limits( def _assign_by_bfs( cls, undirected_chunk_graph: ChunkGraph, - start: ChunkData, + start: Chunk, band: BandType, initial_sizes: Dict[BandType, int], spread_limits: Dict[BandType, float], @@ -159,9 +159,7 @@ def _assign_by_bfs( break initial_sizes[band] -= assigned - def _build_undirected_chunk_graph( - self, chunk_to_assign: List[ChunkData] - ) -> ChunkGraph: + def _build_undirected_chunk_graph(self, chunk_to_assign: List[Chunk]) -> ChunkGraph: chunk_graph = self._chunk_graph.copy() # remove edges for all chunk_to_assign which may contain chunks # that need be reassigned @@ -172,9 +170,7 @@ def _build_undirected_chunk_graph( return chunk_graph.build_undirected() @implements(AbstractGraphAssigner.assign) - def assign( - self, cur_assigns: Dict[str, BandType] = None - ) -> Dict[ChunkData, BandType]: + def assign(self, cur_assigns: Dict[str, BandType] = None) -> Dict[Chunk, BandType]: graph = self._chunk_graph assign_result = dict() cur_assigns = cur_assigns or dict() diff --git a/mars/services/task/execution/mars/stage.py b/mars/services/task/execution/mars/stage.py index 6cced7300b..7f290971b4 100644 --- a/mars/services/task/execution/mars/stage.py +++ b/mars/services/task/execution/mars/stage.py @@ -278,8 +278,7 @@ async def _update_result_meta( update_meta_chunks = chunk_to_result.keys() - set( itertools.chain.from_iterable( - (c.data for c in tiled_tileable.chunks) - for tiled_tileable in tile_context.values() + tuple(tiled_tileable.chunks) for tiled_tileable in tile_context.values() ) ) @@ -292,7 +291,7 @@ async def _update_result_meta( ) worker_meta_api_to_chunk_delays[meta_api][c] = call for tileable in tile_context.values(): - chunks = [c.data for c in tileable.chunks] + chunks = tileable.chunks for c, params_fields in zip(chunks, self._get_params_fields(tileable)): address = chunk_to_result[c].meta["bands"][0][0] meta_api = await WorkerMetaAPI.create(session_id, address) diff --git a/mars/services/task/supervisor/processor.py b/mars/services/task/supervisor/processor.py index 0ce692b2ce..46c5da67b8 100644 --- a/mars/services/task/supervisor/processor.py +++ b/mars/services/task/supervisor/processor.py @@ -269,8 +269,7 @@ def _get_stage_tile_context(self, result_chunks: Set[Chunk]) -> TileContext: continue tiled_tileable = self._preprocessor.tile_context.get(tileable) if tiled_tileable is not None: - tileable_chunks = [c.data for c in tiled_tileable.chunks] - if any(c not in result_chunks for c in tileable_chunks): + if any(c not in result_chunks for c in tiled_tileable.chunks): continue tile_context[tileable] = tiled_tileable collected.add(tileable) @@ -305,7 +304,7 @@ def _update_result_meta( from ....dataframe.core import DATAFRAME_TYPE, SERIES_TYPE from ....tensor.core import TENSOR_TYPE - chunks = [c.data for c in tileable.chunks] + chunks = tileable.chunks if isinstance(tileable, DATAFRAME_TYPE): for c in chunks: i, j = c.index @@ -315,13 +314,13 @@ def _update_result_meta( shape = shape if not update_shape else [None, None] if i > 0: # update dtypes_value - c0j = chunk_to_result[tileable.cix[0, j].data].meta + c0j = chunk_to_result[tileable.cix[0, j]].meta meta["dtypes_value"] = c0j["dtypes_value"] if update_shape: shape[1] = c0j["shape"][1] if j > 0: # update index_value - ci0 = chunk_to_result[tileable.cix[i, 0].data].meta + ci0 = chunk_to_result[tileable.cix[i, 0]].meta meta["index_value"] = ci0["index_value"] if update_shape: shape[0] = ci0["shape"][0] @@ -344,7 +343,7 @@ def _update_result_meta( for i, ind in enumerate(c.index): ind0 = [0] * ndim ind0[i] = ind - c0 = tileable.cix[tuple(ind0)].data + c0 = tileable.cix[tuple(ind0)] shape.append(chunk_to_result[c0].meta["shape"][i]) meta["shape"] = tuple(shape) if i > 0: diff --git a/mars/tensor/arithmetic/tests/test_arithmetic.py b/mars/tensor/arithmetic/tests/test_arithmetic.py index 18f3b4f987..9583f18f60 100644 --- a/mars/tensor/arithmetic/tests/test_arithmetic.py +++ b/mars/tensor/arithmetic/tests/test_arithmetic.py @@ -56,10 +56,10 @@ def test_add(): assert t3.key != k1 assert t3.shape == (3, 4) assert len(t3.chunks) == 4 - assert t3.chunks[0].inputs == [t1.chunks[0].data, t2.chunks[0].data] - assert t3.chunks[1].inputs == [t1.chunks[1].data, t2.chunks[1].data] - assert t3.chunks[2].inputs == [t1.chunks[2].data, t2.chunks[0].data] - assert t3.chunks[3].inputs == [t1.chunks[3].data, t2.chunks[1].data] + assert t3.chunks[0].inputs == [t1.chunks[0], t2.chunks[0]] + assert t3.chunks[1].inputs == [t1.chunks[1], t2.chunks[1]] + assert t3.chunks[2].inputs == [t1.chunks[2], t2.chunks[0]] + assert t3.chunks[3].inputs == [t1.chunks[3], t2.chunks[1]] assert t3.op.dtype == np.dtype("f8") assert t3.chunks[0].op.dtype == np.dtype("f8") @@ -68,18 +68,18 @@ def test_add(): t1, t4 = tile(t1, t4) assert t4.shape == (3, 4) assert len(t3.chunks) == 4 - assert t4.chunks[0].inputs == [t1.chunks[0].data] + assert t4.chunks[0].inputs == [t1.chunks[0]] assert t4.chunks[0].op.rhs == 1 - assert t4.chunks[1].inputs == [t1.chunks[1].data] + assert t4.chunks[1].inputs == [t1.chunks[1]] assert t4.chunks[1].op.rhs == 1 - assert t4.chunks[2].inputs == [t1.chunks[2].data] + assert t4.chunks[2].inputs == [t1.chunks[2]] assert t4.chunks[2].op.rhs == 1 - assert t4.chunks[3].inputs == [t1.chunks[3].data] + assert t4.chunks[3].inputs == [t1.chunks[3]] assert t4.chunks[3].op.rhs == 1 t5 = add([1, 2, 3, 4], 1) tile(t5) - assert t4.chunks[0].inputs == [t1.chunks[0].data] + assert t4.chunks[0].inputs == [t1.chunks[0]] t2 = ones(4, chunk_size=2) t6 = ones((3, 4), chunk_size=2, gpu=True) @@ -350,10 +350,10 @@ def test_unify_chunk_add(): t1, t2, t3 = tile(t1, t2, t3) assert len(t3.chunks) == 2 - assert t3.chunks[0].inputs[0] == t1.chunks[0].data - assert t3.chunks[0].inputs[1] == t2.chunks[0].data - assert t3.chunks[1].inputs[0] == t1.chunks[1].data - assert t3.chunks[1].inputs[1] == t2.chunks[0].data + assert t3.chunks[0].inputs[0] == t1.chunks[0] + assert t3.chunks[0].inputs[1] == t2.chunks[0] + assert t3.chunks[1].inputs[0] == t1.chunks[1] + assert t3.chunks[1].inputs[1] == t2.chunks[0] def test_frexp(): diff --git a/mars/tensor/base/map_chunk.py b/mars/tensor/base/map_chunk.py index 1857e1c08c..94fe94958f 100644 --- a/mars/tensor/base/map_chunk.py +++ b/mars/tensor/base/map_chunk.py @@ -15,7 +15,7 @@ import numpy as np from ... import opcodes -from ...core import ENTITY_TYPE, CHUNK_TYPE, recursive_tile +from ...core import ENTITY_TYPE, recursive_tile, Chunk from ...core.custom_log import redirect_custom_log from ...serialization.serializables import ( FunctionField, @@ -162,7 +162,7 @@ def execute(cls, ctx, op: "TensorMapChunk"): if op.with_chunk_index: kwargs["chunk_index"] = out_chunk.index - chunks = find_objects(args, CHUNK_TYPE) + find_objects(kwargs, CHUNK_TYPE) + chunks = find_objects(args, Chunk) + find_objects(kwargs, Chunk) mapping = {chunk: ctx[chunk.key] for chunk in chunks} args = replace_objects(args, mapping) kwargs = replace_objects(kwargs, mapping) diff --git a/mars/tensor/base/tests/test_base.py b/mars/tensor/base/tests/test_base.py index 6fd18c4b38..e99ecb04cb 100644 --- a/mars/tensor/base/tests/test_base.py +++ b/mars/tensor/base/tests/test_base.py @@ -407,7 +407,7 @@ def test_isin(): assert len(mask.chunks) == len(element.chunks) assert len(mask.op.inputs[1].chunks) == 1 - assert mask.chunks[0].inputs[0] is element.chunks[0].data + assert mask.chunks[0].inputs[0] is element.chunks[0] element = 2 * arange(4, chunk_size=1).reshape(2, 2) test_elements = tensor([1, 2, 4, 8], chunk_size=2) diff --git a/mars/tensor/core.py b/mars/tensor/core.py index 7344bfe339..66d4cfc48f 100644 --- a/mars/tensor/core.py +++ b/mars/tensor/core.py @@ -24,7 +24,6 @@ from ..core import ( HasShapeTileable, - ChunkData, Chunk, HasShapeTileableData, OutputType, @@ -56,9 +55,9 @@ class TensorOrder(Enum): F_ORDER = "F" -class TensorChunkData(ChunkData): +class TensorChunk(Chunk): __slots__ = () - _no_copy_attrs_ = ChunkData._no_copy_attrs_ | {"dtype"} + _no_copy_attrs_ = Chunk._no_copy_attrs_ | {"dtype"} type_name = "Tensor" # required fields @@ -159,15 +158,6 @@ def nbytes(self): return np.prod(self.shape) * self.dtype.itemsize -class TensorChunk(Chunk): - __slots__ = () - _allow_data_type_ = (TensorChunkData,) - type_name = "Tensor" - - def __len__(self): - return len(self._data) - - class TensorData(HasShapeTileableData, _ExecuteAndFetchMixin): __slots__ = () type_name = "Tensor" @@ -178,12 +168,7 @@ class TensorData(HasShapeTileableData, _ExecuteAndFetchMixin): ) # optional fields _dtype = DataTypeField("dtype") - _chunks = ListField( - "chunks", - FieldTypes.reference(TensorChunkData), - on_serialize=lambda x: [it.data for it in x] if x is not None else x, - on_deserialize=lambda x: [TensorChunk(it) for it in x] if x is not None else x, - ) + _chunks = ListField("chunks", FieldTypes.reference(TensorChunk)) def __init__( self, @@ -723,7 +708,7 @@ class Indexes(Serializable): TENSOR_TYPE = (Tensor, TensorData) -TENSOR_CHUNK_TYPE = (TensorChunk, TensorChunkData) +TENSOR_CHUNK_TYPE = (TensorChunk,) register_output_types(OutputType.tensor, TENSOR_TYPE, TENSOR_CHUNK_TYPE) register_output_types(OutputType.scalar, TENSOR_TYPE, TENSOR_CHUNK_TYPE) diff --git a/mars/tensor/random/tests/test_random.py b/mars/tensor/random/tests/test_random.py index 27d9f02668..650f556df5 100644 --- a/mars/tensor/random/tests/test_random.py +++ b/mars/tensor/random/tests/test_random.py @@ -186,7 +186,7 @@ def test_permutation(): assert len(x.chunks) == 3 assert np.isnan(x.chunks[0].shape[0]) - assert x.chunks[0].inputs[0].inputs[0].inputs[0].key == arr.chunks[0].data.key + assert x.chunks[0].inputs[0].inputs[0].inputs[0].key == arr.chunks[0].key arr = rand(3, 3, chunk_size=2) x = permutation(arr) diff --git a/mars/tensor/statistics/histogram.py b/mars/tensor/statistics/histogram.py index baaf847396..e51362472c 100644 --- a/mars/tensor/statistics/histogram.py +++ b/mars/tensor/statistics/histogram.py @@ -57,7 +57,7 @@ def check(self): if width is None: return self._width = width = yield from recursive_tile(width) - yield [c.data for c in width.chunks] + yield width.chunks def __call__(self): return diff --git a/mars/utils.py b/mars/utils.py index f4a22d313f..e845f5f55d 100644 --- a/mars/utils.py +++ b/mars/utils.py @@ -642,9 +642,9 @@ def build_fetch_tileable(tileable: TileableType) -> TileableType: def build_fetch(entity: EntityType) -> EntityType: - from .core import CHUNK_TYPE, ENTITY_TYPE + from .core import Chunk, ENTITY_TYPE - if isinstance(entity, CHUNK_TYPE): + if isinstance(entity, Chunk): return build_fetch_chunk(entity) elif isinstance(entity, ENTITY_TYPE): return build_fetch_tileable(entity)