Skip to content

Commit

Permalink
Simple bind with infer storage type (#32)
Browse files Browse the repository at this point in the history
* Symbol binding for sparse tensor development. (#31)

* Initial checkin

* Add init functions for simple bind in graph_executor

* Add simple_bind c_api

* Add simple bind c-api

* Assign zeros to in_args, arg_grads, and aux_states

* Add simple_bind2 python interface

* Fix python interface bugs

* Interface changes

* Fix

* Fix core dump

* Add bind_ith_exec c_api

* Change simple_bind2

* Fix seg fault

* Finish simple_bind

* Change _bind_ith_exec

* Refactor simple_bind initialization flow for bind

* Consolidate bind and simple_bind graph init flow

* Fix bug

* Clean up

* Add comments

* Clean up

* Clean up

* Minor correction

* Rename APIs in graph executor

* Refactor

* Rebase

* Delete deprecated functions

* Move more front-end work to backend

* Bug fix

* Fix failed tests

* Minor fix

* Fix lint

* Fix lint

* Revert unnecessary changes

* Revert

* Revert

* Clean up

* Fix lint

Conflicts:
	python/mxnet/symbol.py
	src/executor/graph_executor.cc

* Add inferstorage to graph executor

* re-enable tests for sparse embedding with simple_bind

* type switch fix in sparse embedding"
;
  • Loading branch information
eric-haibin-lin authored May 17, 2017
1 parent b9540ae commit 774a362
Show file tree
Hide file tree
Showing 14 changed files with 1,361 additions and 379 deletions.
33 changes: 33 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1156,6 +1156,39 @@ MXNET_DLL int MXExecutorBindEX(SymbolHandle symbol_handle,
NDArrayHandle *aux_states,
ExecutorHandle shared_exec,
ExecutorHandle *out);

MXNET_DLL int MXExecutorSimpleBind(SymbolHandle symbol_handle,
int dev_type,
int dev_id,
const mx_uint num_g2c_keys,
const char** g2c_keys,
const int* g2c_dev_types,
const int* g2c_dev_ids,
const mx_uint provided_grad_req_list_len,
const char** provided_grad_req_names,
const char** provided_grad_req_types,
const mx_uint num_provided_arg_shapes,
const char** provided_arg_shape_names,
const mx_uint* provided_arg_shape_data,
const mx_uint* provided_arg_shape_idx,
const mx_uint num_provided_arg_dtypes,
const char** provided_arg_dtype_names,
const int* provided_arg_dtypes,
const mx_uint num_provided_arg_stypes,
const char** provided_arg_stype_names,
const int* provided_arg_stypes,
const mx_uint num_shared_arg_names,
const char** shared_arg_name_list,
mx_uint* shared_buffer_len,
const char*** shared_buffer_name_list,
NDArrayHandle** shared_buffer_handle_list,
mx_uint* num_in_args,
NDArrayHandle** in_args,
NDArrayHandle** arg_grads,
mx_uint* num_aux_states,
NDArrayHandle** aux_states,
ExecutorHandle shared_exec_handle,
ExecutorHandle* out);
/*!
* \brief set a call back to notify the completion of operation
*/
Expand Down
33 changes: 33 additions & 0 deletions include/mxnet/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,21 @@ class Executor {
* \return array of outputs in the executor.
*/
virtual const std::vector<NDArray> &outputs() const = 0;
/*!
* \brief get input argument map, key is arg name, value is arg's NDArray.
* \return input argument map in the executor.
*/
virtual const std::unordered_map<std::string, NDArray>& in_arg_map() const = 0;
/*!
* \brief get input argument graident map, key is arg name, value is gradient's NDArray.
* \return input argument gradient map in the executor.
*/
virtual const std::unordered_map<std::string, NDArray>& arg_grad_map() const = 0;
/*!
* \brief get aux state map, key is arg name, value is aux state's NDArray.
* \return aux state map in the executor.
*/
virtual const std::unordered_map<std::string, NDArray>& aux_state_map() const = 0;
/*!
* \brief Create an operator by bind symbol with context and arguments.
* If user do not want to compute the gradients of i-th argument, grad_req_type[i] can be kNullOp.
Expand All @@ -91,6 +106,24 @@ class Executor {
const std::vector<OpReqType> &grad_req_type,
const std::vector<NDArray> &aux_states,
Executor* shared_exec = NULL);

static Executor* SimpleBind(nnvm::Symbol symbol,
const Context& default_ctx,
const std::map<std::string, Context>& group2ctx,
const std::vector<Context>& in_arg_ctxes,
const std::vector<Context>& arg_grad_ctxes,
const std::vector<Context>& aux_state_ctxes,
const std::unordered_map<std::string, TShape>& arg_shape_map,
const std::unordered_map<std::string, int>& arg_dtype_map,
const std::unordered_map<std::string, int>& arg_stype_map,
const std::vector<OpReqType>& grad_req_types,
const std::unordered_set<std::string>& param_names,
std::vector<NDArray>* in_args,
std::vector<NDArray>* arg_grads,
std::vector<NDArray>* aux_states,
std::unordered_map<std::string, NDArray>*
shared_data_arrays = nullptr,
Executor* shared_exec = nullptr);
/*!
* \brief the prototype of user-defined monitor callback
*/
Expand Down
8 changes: 4 additions & 4 deletions include/mxnet/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ class NDArray {
} else if (storage_type == kCSRStorage) {
aux_types = {CSR_IND_PTR_TYPE, CSR_IDX_DTYPE};
} else {
LOG(FATAL) << "Unknown storage type";
LOG(FATAL) << "Unknown storage type" << storage_type;
}
}
// Assign default shapes if not given
Expand All @@ -139,7 +139,7 @@ class NDArray {
// aux shapes for indptr and indices
aux_shapes = {TShape({0}), TShape({0})};
} else {
LOG(FATAL) << "Unknown storage type";
LOG(FATAL) << "Unknown storage type" << storage_type;
}
}
if (storage_shape.Size() == 0) {
Expand All @@ -149,7 +149,7 @@ class NDArray {
} else if (storage_type == kCSRStorage) {
storage_shape = aux_shapes[csr::kIdx];
} else {
LOG(FATAL) << "Unknown storage type";
LOG(FATAL) << "Unknown storage type" << storage_type;
}
}
ptr_ = std::make_shared<Chunk>(storage_type, storage_shape, ctx, delay_alloc,
Expand Down Expand Up @@ -735,7 +735,7 @@ class NDArray {
if (skip_free == false) {
Storage::Get()->Free(h);
for (size_t i = 0; i < aux_h.size(); i++) {
Storage::Get()->Free(aux_h[i]);
if (aux_h[i].size > 0) Storage::Get()->Free(aux_h[i]);
}
}
}, shandle.ctx, var);
Expand Down
81 changes: 5 additions & 76 deletions python/mxnet/module/executor_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import logging
from collections import OrderedDict

import numpy as np

from .. import context as ctx
Expand Down Expand Up @@ -559,6 +558,7 @@ def update_metric(self, eval_metric, labels):

def _bind_ith_exec(self, i, data_shapes, label_shapes, shared_group):
"""Internal utility function to bind the i-th executor.
This function utilizes simple_bind python interface.
"""
shared_exec = None if shared_group is None else shared_group.execs[i]
context = self.contexts[i]
Expand All @@ -568,85 +568,14 @@ def _bind_ith_exec(self, i, data_shapes, label_shapes, shared_group):
if label_shapes is not None:
input_shapes.update(dict(label_shapes))

arg_shapes, _, aux_shapes = self.symbol.infer_shape(**input_shapes)
assert arg_shapes is not None, "shape inference failed"

input_types = {x.name: x.dtype for x in data_shapes}
if label_shapes is not None:
input_types.update({x.name: x.dtype for x in label_shapes})
arg_types, _, aux_types = self.symbol.infer_type(**input_types)
assert arg_types is not None, "type inference failed"

arg_arrays = []
grad_arrays = {} if self.for_training else None

def _get_or_reshape(name, shared_data_arrays, arg_shape, arg_type, context, logger):
"""Internal helper to get a memory block or re-use by re-shaping."""
if name in shared_data_arrays:
arg_arr = shared_data_arrays[name]

if np.prod(arg_arr.shape) >= np.prod(arg_shape):
# nice, we can directly re-use this data blob
assert arg_arr.dtype == arg_type
arg_arr = arg_arr.reshape(arg_shape)
else:
logger.warning(('bucketing: data "%s" has a shape %s' % (name, arg_shape)) +
(', which is larger than already allocated ') +
('shape %s' % (arg_arr.shape,)) +
('. Need to re-allocate. Consider putting ') +
('default_bucket_key to') +
(' be the bucket taking the largest input for better ') +
('memory sharing.'))
arg_arr = nd.zeros(arg_shape, context, dtype=arg_type)

# replace existing shared array because the new one is bigger
shared_data_arrays[name] = arg_arr
else:
arg_arr = nd.zeros(arg_shape, context, dtype=arg_type)
shared_data_arrays[name] = arg_arr

return arg_arr

# create or borrow arguments and gradients
for j in range(len(self.arg_names)):
name = self.arg_names[j]
if name in self.param_names: # model parameters
if shared_exec is None:
arg_arr = nd.zeros(arg_shapes[j], context, dtype=arg_types[j])
if self.grad_req[name] != 'null':
grad_arr = nd.zeros(arg_shapes[j], context, dtype=arg_types[j])
grad_arrays[name] = grad_arr
else:
arg_arr = shared_exec.arg_dict[name]
assert arg_arr.shape == arg_shapes[j]
assert arg_arr.dtype == arg_types[j]
if self.grad_req[name] != 'null':
grad_arrays[name] = shared_exec.grad_dict[name]
else: # data, label, or states
arg_arr = _get_or_reshape(name, shared_data_arrays, arg_shapes[j], arg_types[j],
context, self.logger)

# data might also need grad if inputs_need_grad is True
if self.grad_req[name] != 'null':
grad_arrays[name] = _get_or_reshape('grad of ' + name, shared_data_arrays,
arg_shapes[j], arg_types[j], context,
self.logger)

arg_arrays.append(arg_arr)

# create or borrow aux variables
if shared_exec is None:
aux_arrays = [nd.zeros(s, context, dtype=t) for s, t in zip(aux_shapes, aux_types)]
else:
for j, arr in enumerate(shared_exec.aux_arrays):
assert aux_shapes[j] == arr.shape
assert aux_types[j] == arr.dtype
aux_arrays = shared_exec.aux_arrays[:]

executor = self.symbol.bind(ctx=context, args=arg_arrays,
args_grad=grad_arrays, aux_states=aux_arrays,
grad_req=self.grad_req, shared_exec=shared_exec)
# Get the total bytes allocated for this executor
executor = self.symbol.simple_bind(ctx=context, grad_req=self.grad_req,
type_dict=input_types, param_names=self.param_names,
shared_exec=shared_exec,
shared_data_arrays=shared_data_arrays, **input_shapes)
self._total_exec_bytes += int(executor.debug_str().split('\n')[-3].split()[1])
return executor

Expand Down
Loading

0 comments on commit 774a362

Please sign in to comment.