Skip to content

Commit

Permalink
Map Dask Series to Dask Series (dask#4872)
Browse files Browse the repository at this point in the history
* index-test needed fix

* single-parititon-error

* added code to make it work

* add tests

* delete some comments

* remove seed set

* updated tests

* remove sort_index and add tests
  • Loading branch information
Justin Waugh authored and jcrist committed Jun 17, 2019
1 parent f7d73f8 commit 255cc5b
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 1 deletion.
57 changes: 57 additions & 0 deletions dask/dataframe/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2374,6 +2374,8 @@ def isin(self, values):
@insert_meta_param_description(pad=12)
@derived_from(pd.Series)
def map(self, arg, na_action=None, meta=no_default):
if is_series_like(arg) and is_dask_collection(arg):
return series_map(self, arg)
if not (isinstance(arg, dict) or
callable(arg) or
is_series_like(arg) and not is_dask_collection(arg)):
Expand Down Expand Up @@ -5059,3 +5061,58 @@ def meta_warning(df):
" Before: .apply(func)\n"
" After: .apply(func, meta=%s)\n" % str(meta_str))
return msg


def mapseries(base_chunk, concat_map):
return base_chunk.map(concat_map)


def mapseries_combine(index, concat_result):
final_series = concat_result.sort_index()
final_series.index = index
return final_series


def series_map(base_series, map_series):
npartitions = base_series.npartitions
split_out = map_series.npartitions

dsk = {}

base_token_key = tokenize(base_series, split_out)
base_split_prefix = 'base-split-{}'.format(base_token_key)
base_shard_prefix = 'base-shard-{}'.format(base_token_key)
for i, key in enumerate(base_series.__dask_keys__()):
dsk[(base_split_prefix, i)] = (hash_shard, key, split_out)
for j in range(split_out):
dsk[(base_shard_prefix, 0, i, j)] = (getitem, (base_split_prefix, i), j)

map_token_key = tokenize(map_series)
map_split_prefix = 'map-split-{}'.format(map_token_key)
map_shard_prefix = 'map-shard-{}'.format(map_token_key)
for i, key in enumerate(map_series.__dask_keys__()):
dsk[(map_split_prefix, i)] = (hash_shard, key, split_out, split_out_on_index, None)
for j in range(split_out):
dsk[(map_shard_prefix, 0, i, j)] = (getitem, (map_split_prefix, i), j)

token_key = tokenize(base_series, map_series)
map_prefix = 'map-series-{}'.format(token_key)
for i in range(npartitions):
for j in range(split_out):
dsk[(map_prefix, i, j)] = (mapseries,
(base_shard_prefix, 0, i, j),
(_concat, [(map_shard_prefix, 0, k, j) for k in range(split_out)]))

final_prefix = 'map-series-combine-{}'.format(token_key)
for i, key in enumerate(base_series.index.__dask_keys__()):
dsk[(final_prefix, i)] = (mapseries_combine, key, (_concat, [(map_prefix, i, j) for j in range(split_out)]))

meta = map_series._meta.copy()
meta.index = base_series._meta.index
meta = make_meta(meta)

dependencies = [base_series, map_series, base_series.index]
graph = HighLevelGraph.from_collections(final_prefix, dsk, dependencies=dependencies)
divisions = list(base_series.divisions)

return new_dd_object(graph, final_prefix, meta, divisions)
24 changes: 23 additions & 1 deletion dask/dataframe/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1171,7 +1171,6 @@ def test_map():
assert_eq(ddf.b.map(lk), df.b.map(lk))
assert_eq(ddf.b.map(lk, meta=ddf.b), df.b.map(lk))
assert_eq(ddf.b.map(lk, meta=('b', 'i8')), df.b.map(lk))
pytest.raises(TypeError, lambda: ddf.a.map(d.b))


def test_concat():
Expand Down Expand Up @@ -3739,3 +3738,26 @@ def test_dtype_cast():
assert ddf.B.dtype == np.int64
# fails
assert ddf.A.dtype == np.int32


@pytest.mark.parametrize("base_npart", [1, 4])
@pytest.mark.parametrize("map_npart", [1, 3])
@pytest.mark.parametrize("sorted_index", [False, True])
@pytest.mark.parametrize("sorted_map_index", [False, True])
def test_series_map(base_npart, map_npart, sorted_index, sorted_map_index):
base = pd.Series([''.join(np.random.choice(['a', 'b', 'c'], size=3)) for x in range(100)])
if not sorted_index:
index = np.arange(100)
np.random.shuffle(index)
base.index = index
map_index = [''.join(x) for x in product('abc', repeat=3)]
mapper = pd.Series(np.random.randint(50, size=len(map_index)), index=map_index)
if not sorted_map_index:
map_index = np.array(map_index)
np.random.shuffle(map_index)
mapper.index = map_index
expected = base.map(mapper)
dask_base = dd.from_pandas(base, npartitions=base_npart)
dask_map = dd.from_pandas(mapper, npartitions=map_npart)
result = dask_base.map(dask_map)
dd.utils.assert_eq(expected, result)

0 comments on commit 255cc5b

Please sign in to comment.