Skip to content

Commit

Permalink
Remove chunk entity in DAG
Browse files Browse the repository at this point in the history
  • Loading branch information
wjsi committed Feb 10, 2023
1 parent 37c08a6 commit 67c7717
Show file tree
Hide file tree
Showing 35 changed files with 127 additions and 278 deletions.
5 changes: 0 additions & 5 deletions mars/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,15 @@
EntityData,
ENTITY_TYPE,
Chunk,
ChunkData,
CHUNK_TYPE,
Tileable,
TileableData,
TILEABLE_TYPE,
Object,
ObjectData,
ObjectChunk,
ObjectChunkData,
OBJECT_TYPE,
OBJECT_CHUNK_TYPE,
FuseChunk,
FuseChunkData,
FUSE_CHUNK_TYPE,
OutputType,
register_output_types,
get_output_types,
Expand Down
5 changes: 2 additions & 3 deletions mars/core/entity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .chunks import Chunk, ChunkData, CHUNK_TYPE
from .chunks import Chunk
from .core import Entity, EntityData, ENTITY_TYPE
from .executable import ExecutableTuple, _ExecuteAndFetchMixin
from .fuse import FuseChunk, FuseChunkData, FUSE_CHUNK_TYPE
from .fuse import FuseChunk
from .objects import (
ObjectChunk,
ObjectChunkData,
Object,
ObjectData,
OBJECT_CHUNK_TYPE,
Expand Down
14 changes: 2 additions & 12 deletions mars/core/entity/chunks.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@

from ...serialization.serializables import BoolField, FieldTypes, TupleField
from ...utils import tokenize
from .core import EntityData, Entity
from .core import EntityData


class ChunkData(EntityData):
class Chunk(EntityData):
__slots__ = ()

is_broadcaster = BoolField("is_broadcaster", default=False)
Expand Down Expand Up @@ -56,13 +56,3 @@ def _update_key(self):
*(getattr(self, k, None) for k in self._keys_ if k != "_index"),
),
)


class Chunk(Entity):
_allow_data_type_ = (ChunkData,)

def __repr__(self):
return f"{type(self).__name__}({self._data.__repr__()})"


CHUNK_TYPE = (Chunk, ChunkData)
16 changes: 3 additions & 13 deletions mars/core/entity/fuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,13 @@
import numpy as np

from ...serialization.serializables import ReferenceField
from .chunks import ChunkData, Chunk, CHUNK_TYPE
from .chunks import Chunk


class FuseChunkData(ChunkData):
class FuseChunk(Chunk):
__slots__ = ("_inited",)

_chunk = ReferenceField(
"chunk", CHUNK_TYPE, on_serialize=lambda x: x.data if hasattr(x, "data") else x
)
_chunk = ReferenceField("chunk", Chunk)

def __init__(self, *args, **kwargs):
self._inited = False
Expand Down Expand Up @@ -63,11 +61,3 @@ def __setattr__(self, attr, value):
@property
def nbytes(self):
return np.prod(self.shape) * self.dtype.itemsize


class FuseChunk(Chunk):
__slots__ = ()
_allow_data_type_ = (FuseChunkData,)


FUSE_CHUNK_TYPE = (FuseChunkData, FuseChunk)
19 changes: 4 additions & 15 deletions mars/core/entity/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
from typing import Any, Dict

from ...serialization.serializables import FieldTypes, ListField
from .chunks import ChunkData, Chunk
from .chunks import Chunk
from .core import Entity
from .executable import _ToObjectMixin
from .tileables import TileableData


class ObjectChunkData(ChunkData):
class ObjectChunk(Chunk):
# chunk whose data could be any serializable
__slots__ = ()
type_name = "Object"
Expand All @@ -48,23 +48,12 @@ def get_params_from_data(cls, data: Any) -> Dict[str, Any]:
return dict()


class ObjectChunk(Chunk):
__slots__ = ()
_allow_data_type_ = (ObjectChunkData,)
type_name = "Object"


class ObjectData(TileableData, _ToObjectMixin):
__slots__ = ()
type_name = "Object"

# optional fields
_chunks = ListField(
"chunks",
FieldTypes.reference(ObjectChunkData),
on_serialize=lambda x: [it.data for it in x] if x is not None else x,
on_deserialize=lambda x: [ObjectChunk(it) for it in x] if x is not None else x,
)
_chunks = ListField("chunks", FieldTypes.reference(ObjectChunk))

def __init__(self, op=None, nsplits=None, chunks=None, **kw):
super().__init__(_op=op, _nsplits=nsplits, _chunks=chunks, **kw)
Expand Down Expand Up @@ -96,4 +85,4 @@ class Object(Entity, _ToObjectMixin):


OBJECT_TYPE = (Object, ObjectData)
OBJECT_CHUNK_TYPE = (ObjectChunk, ObjectChunkData)
OBJECT_CHUNK_TYPE = (ObjectChunk,)
4 changes: 2 additions & 2 deletions mars/core/entity/output_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import functools
from enum import Enum

from .fuse import FUSE_CHUNK_TYPE
from .fuse import FuseChunk
from .objects import OBJECT_TYPE, OBJECT_CHUNK_TYPE


Expand Down Expand Up @@ -83,7 +83,7 @@ def get_output_types(*objs, unknown_as=None):
for obj in objs:
if obj is None:
continue
elif isinstance(obj, FUSE_CHUNK_TYPE):
elif isinstance(obj, FuseChunk):
obj = obj.chunk

try:
Expand Down
10 changes: 5 additions & 5 deletions mars/core/graph/builder/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
Union,
)

from ....core import FUSE_CHUNK_TYPE, CHUNK_TYPE, TILEABLE_TYPE
from ....core import FuseChunk, TILEABLE_TYPE, Chunk
from ....typing import EntityType, TileableType, ChunkType
from ....utils import copy_tileables, build_fetch
from ...entity.tileables import handler
Expand Down Expand Up @@ -223,7 +223,7 @@ def _tile(
chunks = []
if need_process is not None:
for t in need_process:
if isinstance(t, CHUNK_TYPE):
if isinstance(t, Chunk):
chunks.append(self._get_data(t))
elif isinstance(t, TILEABLE_TYPE):
to_update_tileables.append(self._get_data(t))
Expand Down Expand Up @@ -304,7 +304,7 @@ def _iter(self):
# so that fetch chunk can be generated.
# Use chunk key as the key to make sure the copied chunk can be build to a fetch.
processed_chunks = (
c.chunk.key if isinstance(c, FUSE_CHUNK_TYPE) else c.key
c.chunk.key if isinstance(c, FuseChunk) else c.key
for c in chunk_graph.result_chunks
)
self._processed_chunks.update(processed_chunks)
Expand Down Expand Up @@ -406,7 +406,7 @@ def _process_node(self, entity: EntityType):
if entity.key in self._processed_chunks:
if entity not in self._chunk_to_fetch:
# gen fetch
fetch_chunk = build_fetch(entity).data
fetch_chunk = build_fetch(entity)
self._chunk_to_fetch[entity] = fetch_chunk
return self._chunk_to_fetch[entity]
return entity
Expand All @@ -417,7 +417,7 @@ def _select_inputs(self, inputs: List[ChunkType]):
if inp.key in self._processed_chunks:
# gen fetch
if inp not in self._chunk_to_fetch:
fetch_chunk = build_fetch(inp).data
fetch_chunk = build_fetch(inp)
self._chunk_to_fetch[inp] = fetch_chunk
new_inputs.append(self._chunk_to_fetch[inp])
else:
Expand Down
12 changes: 6 additions & 6 deletions mars/core/operand/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@
from ..context import Context
from ..mode import is_eager_mode
from ..entity import (
OutputType,
Chunk,
ExecutableTuple,
OutputType,
get_chunk_types,
get_tileable_types,
get_output_types,
Expand Down Expand Up @@ -78,14 +79,14 @@ def _check_if_gpu(cls, inputs: List[TileableType]):
def _tokenize_output(self, output_idx: int, **kw):
return f"{self._key}_{output_idx}"

def _create_chunk(self, output_idx: int, index: Tuple[int], **kw) -> ChunkType:
def _create_chunk(self, output_idx: int, index: Tuple[int], **kw) -> Chunk:
output_type = kw.pop("output_type", None) or self._get_output_type(output_idx)
if not output_type:
raise ValueError("output_type should be specified")

if isinstance(output_type, (list, tuple)):
output_type = output_type[output_idx]
chunk_type, chunk_data_type = get_chunk_types(output_type)
(chunk_data_type,) = get_chunk_types(output_type)
kw["_i"] = output_idx
kw["op"] = self
kw["index"] = index
Expand All @@ -97,8 +98,7 @@ def _create_chunk(self, output_idx: int, index: Tuple[int], **kw) -> ChunkType:
if "_key" not in kw:
kw["_key"] = self._tokenize_output(output_idx, **kw)

data = chunk_data_type(**kw)
return chunk_type(data)
return chunk_data_type(**kw)

def _new_chunks(
self, inputs: List[ChunkType], kws: List[Dict] = None, **kw
Expand Down Expand Up @@ -130,7 +130,7 @@ def _new_chunks(
# for each output chunk, hold the reference to the other outputs
# so that either no one or everyone are gc collected
for j, t in enumerate(chunks):
t.data._siblings = [c.data for c in chunks[:j] + chunks[j + 1 :]]
t._siblings = [c for c in chunks[:j] + chunks[j + 1 :]]
return chunks

def new_chunks(
Expand Down
5 changes: 2 additions & 3 deletions mars/core/operand/fuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from ... import opcodes
from ...serialization.serializables import ReferenceField
from ..entity import FuseChunk, FuseChunkData, NotSupportTile
from ..entity import FuseChunk, NotSupportTile
from ..graph import ChunkGraph
from .base import Operand

Expand All @@ -30,8 +30,7 @@ class FuseChunkMixin:
__slots__ = ()

def _create_chunk(self, output_idx, index, **kw):
data = FuseChunkData(_index=index, _op=self, **kw)
return FuseChunk(data)
return FuseChunk(_index=index, _op=self, **kw)

@classmethod
def tile(cls, op):
Expand Down
8 changes: 4 additions & 4 deletions mars/dataframe/arithmetic/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
import numpy as np
import pandas as pd

from ...core import ENTITY_TYPE, CHUNK_TYPE, recursive_tile
from ...core import ENTITY_TYPE, recursive_tile
from ...serialization.serializables import AnyField
from ...tensor.core import TENSOR_TYPE, TENSOR_CHUNK_TYPE, ChunkData, Chunk
from ...tensor.core import TENSOR_TYPE, TENSOR_CHUNK_TYPE, Chunk
from ...utils import classproperty, get_dtype
from ..align import (
align_series_series,
Expand Down Expand Up @@ -421,7 +421,7 @@ def _operator(self):

@classmethod
def _calc_properties(cls, x1, x2=None, axis="columns"):
is_chunk = isinstance(x1, CHUNK_TYPE)
is_chunk = isinstance(x1, Chunk)

if isinstance(x1, (DATAFRAME_TYPE, DATAFRAME_CHUNK_TYPE)) and (
x2 is None
Expand Down Expand Up @@ -625,7 +625,7 @@ def _new_chunks(self, inputs, kws=None, **kw):
property_inputs = reversed(property_inputs)
properties = self._calc_properties(*property_inputs)

inputs = [inp for inp in inputs if isinstance(inp, (Chunk, ChunkData))]
inputs = [inp for inp in inputs if isinstance(inp, Chunk)]

shape = properties.pop("shape")
if "shape" in kw:
Expand Down
6 changes: 3 additions & 3 deletions mars/dataframe/arithmetic/tests/test_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1190,9 +1190,9 @@ def test_both_one_chunk(func_name, func_opts):
assert isinstance(c.op, func_opts.op)
assert len(c.inputs) == 2
# test the left side
assert c.inputs[0] is df1.chunks[0].data
assert c.inputs[0] is df1.chunks[0]
# test the right side
assert c.inputs[1] is df2.chunks[0].data
assert c.inputs[1] is df2.chunks[0]


@pytest.mark.parametrize("func_name, func_opts", binary_functions.items())
Expand Down Expand Up @@ -1522,7 +1522,7 @@ def test_arithmetic_lazy_chunk_meta():
df2 = df + 1
df2 = tile(df2)

chunk = df2.chunks[0].data
chunk = df2.chunks[0]
assert chunk._FIELDS["_dtypes"].get(chunk) is None
pd.testing.assert_series_equal(chunk.dtypes, df.dtypes)
assert chunk._FIELDS["_dtypes"].get(chunk) is not None
Expand Down
4 changes: 2 additions & 2 deletions mars/dataframe/base/drop.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import numpy as np

from ... import opcodes
from ...core import Entity, Chunk, CHUNK_TYPE, OutputType, recursive_tile
from ...core import Entity, Chunk, OutputType, recursive_tile
from ...serialization.serializables import AnyField, StringField
from ..core import IndexValue, DATAFRAME_TYPE, SERIES_TYPE, INDEX_CHUNK_TYPE
from ..operands import DataFrameOperand, DataFrameOperandMixin
Expand Down Expand Up @@ -172,7 +172,7 @@ def tile(cls, op: "DataFrameDrop"):
@classmethod
def execute(cls, ctx, op: "DataFrameDrop"):
inp = op.inputs[0]
if isinstance(op.index, CHUNK_TYPE):
if isinstance(op.index, Chunk):
index_val = ctx[op.index.key]
else:
index_val = op.index
Expand Down
Loading

0 comments on commit 67c7717

Please sign in to comment.