Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Merge pull request #47 from antinucleon/master
Browse files Browse the repository at this point in the history
Change  NArray Interface
  • Loading branch information
antinucleon committed Sep 7, 2015
2 parents 726817f + 089cce4 commit e25b129
Show file tree
Hide file tree
Showing 17 changed files with 356 additions and 127 deletions.
2 changes: 1 addition & 1 deletion dmlc-core
Empty file removed doc/user-guide/executor.md
Empty file.
Empty file removed doc/user-guide/symbol.md
Empty file.
41 changes: 38 additions & 3 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,47 @@ MXNET_DLL int MXNArrayListLoad(const char* fname,
mx_uint *out_name_size,
const char*** out_names);
/*!
* \brief wait until all the operation with respect NArray
* to this NArray is finished, always call this before fetching data out
* \brief Perform a synchronize copy from a continugous CPU memory region.
*
* This function will call WaitToWrite before the copy is performed.
* This is useful to copy data from existing memory region that are
* not wrapped by NArray(thus dependency not being tracked).
*
* \param handle the NArray handle
* \param data the data source to copy from.
* \param size the memory size we want to copy from.
*/
MXNET_DLL int MXNArraySyncCopyFromCPU(NArrayHandle handle,
const mx_float *data,
size_t size);
/*!
* \brief Perform a synchronize copyto a continugous CPU memory region.
*
* This function will call WaitToRead before the copy is performed.
* This is useful to copy data from existing memory region that are
* not wrapped by NArray(thus dependency not being tracked).
*
* \param handle the NArray handle
* \param data the data source to copy into.
* \param size the memory size we want to copy into.
*/
MXNET_DLL int MXNArraySyncCopyToCPU(NArrayHandle handle,
mx_float *data,
size_t size);
/*!
* \brief Wait until all the pending writes with respect NArray are finished.
* Always call this before read data out synchronizely.
* \param handle the NArray handle
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXNArrayWaitToRead(NArrayHandle handle);
/*!
* \brief Wait until all the pending read/write with respect NArray are finished.
* Always call this before write data into NArray synchronizely.
* \param handle the NArray handle
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXNArrayWait(NArrayHandle handle);
MXNET_DLL int MXNArrayWaitToWrite(NArrayHandle handle);
/*!
* \brief wait until all delayed operations in
* the system is completed
Expand Down
26 changes: 22 additions & 4 deletions include/mxnet/dag_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,29 @@ class DAGEngine {
*/
virtual void PushDelete(Fn delete_fun, Context exec_ctx, Variable var) = 0;
/*!
* \brief Wait for variable.
* \param var The variable we should wait for, this function returns when all
* the operations related to var has been completed.
* \brief Wait to read a variable.
*
* The caller should read the content immediately in a synchronized way,
* before any subsequent write operations are issued.
* The subsequent write operations to the variable can destroy the content.
*
* \param var The variable we should wait for,
* This function returns when all the write operations to this
* var has been completed.
*/
virtual void WaitToRead(Variable var) = 0;
/*!
* \brief Wait to write a variable.
*
* The caller should rwrite the content immediately in a synchronized way,
* before any subsequent write operations are issued.
* The subsequent write operations to the variable can destroy the content.
*
* \param var The variable we should wait for,
* This function returns when all the read/write operations
* on var has been completed.
*/
virtual void WaitForVar(Variable var) = 0;
virtual void WaitToWrite(Variable var) = 0;
/*!
* \brief Wait until all the activity of dag engine finishes.
*/
Expand Down
39 changes: 36 additions & 3 deletions include/mxnet/narray.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,21 @@ class NArray {
inline bool is_none() const {
return ptr_.get() == nullptr;
}
/*! \brief wait until the result of the NArray is computed */
inline void Wait() const {
/*!
* \brief Block until all the pending write operations with respect
* to current NArray are finished, and read can be performed.
*/
inline void WaitToRead() const {
if (is_none()) return;
DAGEngine::Get()->WaitForVar(ptr_->var);
DAGEngine::Get()->WaitToRead(ptr_->var);
}
/*!
* \brief Block until all the pending read/write operations with respect
* to current NArray are finished, and write can be performed.
*/
inline void WaitToWrite() const {
if (is_none()) return;
DAGEngine::Get()->WaitToWrite(ptr_->var);
}
/*! \return the associated DAG variable of the narray.*/
inline DAGEngine::Variable var() const {
Expand Down Expand Up @@ -166,6 +177,28 @@ class NArray {
* \return the new copy
*/
NArray Copy(Context ctx) const;
/*!
* \brief Do a synchronize copy from a continugous CPU memory region.
*
* This function will call WaitToWrite before the copy is performed.
* This is useful to copy data from existing memory region that are
* not wrapped by NArray(thus dependency not being tracked).
*
* \param data the data source to copy from.
* \param size the memory size we want to copy from.
*/
void SyncCopyFromCPU(const real_t *data, size_t size) const;
/*!
* \brief Do a synchronize copy to a continugous CPU memory region.
*
* This function will call WaitToRead before the copy is performed.
* This is useful to copy data from existing memory region that are
* not wrapped by NArray(thus dependency not being tracked).
*
* \param data the data source to copyinto.
* \param size the memory size we want to copy into.
*/
void SyncCopyToCPU(real_t *data, size_t size) const;
/*!
* \brief Slice a NArray
* \param begin begin index in first dim
Expand Down
105 changes: 87 additions & 18 deletions python/mxnet/narray.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
import ctypes
import warnings
import sys
import numpy as np
from .base import _LIB, string_types, numeric_types
from .base import c_array, py_str, c_str
from .base import mx_uint, mx_float, NArrayHandle, FunctionHandle
from .base import ctypes2numpy_shared, ctypes2buffer
from .base import ctypes2buffer
from .base import check_call
from .context import Context

Expand Down Expand Up @@ -183,7 +184,9 @@ def __setitem__(self, in_slice, value):
if value.handle is not self.handle:
value.copyto(self)
elif isinstance(value, numeric_types):
return NArray._set_value(float(value), out=self)
NArray._set_value(float(value), out=self)
elif isinstance(value, (np.ndarray, np.generic)):
self._sync_copyfrom(value)
else:
raise TypeError('type %s not supported' % str(type(value)))

Expand All @@ -193,9 +196,47 @@ def __getitem__(self, in_slice):
raise Exception("Set NArray should use empty index array[:] += value")
return self

def wait(self):
"""Wait until the data on current NArray is available."""
check_call(_LIB.MXNArrayWait(self.handle))
def _sync_copyfrom(self, source_array):
"""Peform an synchronize copy from the array.
Parameters
----------
source_array : array_like
The data source we should like to copy from.
"""
if not isinstance(source_array, np.ndarray):
try:
source_array = np.array(source_array, dtype=np.float32)
except:
raise TypeError('array must be an array_like data,' +
'type %s is not supported' % str(type(array)))
source_array = np.ascontiguousarray(source_array, dtype=np.float32)

if source_array.shape != self.shape:
raise ValueError('array shape do not match the shape of NArray')

check_call(_LIB.MXNArraySyncCopyFromCPU(
self.handle,
source_array.ctypes.data_as(ctypes.POINTER(mx_float)),
source_array.size))

def wait_to_read(self):
"""Block until all pending writes operations on current NArray are finished.
This function will return when all the pending writes to the current
NArray finishes. There can still be pending read going on when the
function returns.
"""
check_call(_LIB.MXNArrayWaitToRead(self.handle))

def wait_to_write(self):
"""Block until all pending read/write operations on current NArray are finished.
This function will return when all the pending writes to the current
NArray finishes. There can still be pending read going on when the
function returns.
"""
check_call(_LIB.MXNArrayWaitToWrite(self.handle))

@property
def shape(self):
Expand All @@ -217,28 +258,29 @@ def context(self):
Returns
-------
the context of current NArray
context : mxnet.Context
The context of current NArray.
"""
dev_mask = ctypes.c_int()
dev_id = ctypes.c_int()
check_call(_LIB.MXNArrayGetContext(
self.handle, ctypes.byref(dev_mask), ctypes.byref(dev_id)))
return Context(Context.devmask2type[dev_mask.value], dev_id.value)

@property
def numpy(self):
"""Return a numpy representation of current array.
This array have to sit on CPU
def asnumpy(self):
"""Return a copied numpy array of current array.
Returns
-------
a numpy array view
array : numpy.ndarray
A copy of array content.
"""
self.wait()
pdata = ctypes.POINTER(mx_float)()
check_call(_LIB.MXNArrayGetData(self.handle, ctypes.byref(pdata)))
return ctypes2numpy_shared(pdata, self.shape)
data = np.empty(self.shape, dtype=np.float32)
check_call(_LIB.MXNArraySyncCopyToCPU(
self.handle,
data.ctypes.data,
data.size))
return data

def copyto(self, other):
"""Copy the content of current array to other.
Expand Down Expand Up @@ -271,8 +313,8 @@ def copyto(self, other):
# pylint: enable= no-member


def create(shape, ctx=None):
"""Create a new NArray, with specified shape.
def empty(shape, ctx=None):
"""Create an empty uninitialized new NArray, with specified shape.
Parameters
----------
Expand All @@ -292,6 +334,33 @@ def create(shape, ctx=None):
return NArray(handle=_new_alloc_handle(shape, ctx, False))


def array(source_array, ctx=None):
"""Create a new NArray that copies content from source_array.
Parameters
----------
source_array : array_like
Source data to create NArray from.
ctx : Context, optional
The context of the NArray, default to current default context.
Returns
-------
out: Array
The created NArray.
"""

if not isinstance(source_array, np.ndarray):
try:
source_array = np.array(source_array, dtype=np.float32)
except:
raise TypeError('source_array must be array like object')
arr = empty(source_array.shape, ctx)
arr[:] = source_array
return arr


def load(fname):
"""Load narray from binary file.
Expand Down
26 changes: 24 additions & 2 deletions src/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -235,9 +235,31 @@ int MXNArraySaveRawBytes(NArrayHandle handle,
API_END();
}

int MXNArrayWait(NArrayHandle handle) {
int MXNArraySyncCopyFromCPU(NArrayHandle handle,
const mx_float *data,
size_t size) {
API_BEGIN();
static_cast<NArray*>(handle)->Wait();
static_cast<NArray*>(handle)->SyncCopyFromCPU(data, size);
API_END();
}

int MXNArraySyncCopyToCPU(NArrayHandle handle,
mx_float *data,
size_t size) {
API_BEGIN();
static_cast<NArray*>(handle)->SyncCopyToCPU(data, size);
API_END();
}

int MXNArrayWaitToRead(NArrayHandle handle) {
API_BEGIN();
static_cast<NArray*>(handle)->WaitToRead();
API_END();
}

int MXNArrayWaitToWrite(NArrayHandle handle) {
API_BEGIN();
static_cast<NArray*>(handle)->WaitToWrite();
API_END();
}

Expand Down
5 changes: 4 additions & 1 deletion src/dag_engine/naive_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,10 @@ class NaiveEngine : public DAGEngine {
this->Push(delete_fun, exec_ctx, {}, {var});
}

void WaitForVar(Variable var) override {
void WaitToRead(Variable var) override {
}

void WaitToWrite(Variable var) override {
}

void WaitForAll() override {
Expand Down
Loading

0 comments on commit e25b129

Please sign in to comment.