Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

ImageRec #46

Merged
merged 16 commits into from
Sep 7, 2015
Merged
27 changes: 24 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,19 @@ ifndef DMLC_CORE
endif


ifneq ($(USE_OPENMP_ITER), 1)
export NO_OPENMP = 1
endif

ifneq ($(USE_OPENMP_ITER), 1)
export NO_OPENMP = 1
endif

# use customized config file
include $(config)
include mshadow/make/mshadow.mk
include $(DMLC_CORE)/make/dmlc.mk
unexport NO_OPENMP

# all tge possible warning tread
WARNFLAGS= -Wall
Expand All @@ -39,10 +48,21 @@ endif

# setup opencv
ifeq ($(USE_OPENCV),1)
CFLAGS+= -DCXXNET_USE_OPENCV=1
CFLAGS+= -DMXNET_USE_OPENCV=1
LDFLAGS+= `pkg-config --libs opencv`
else
CFLAGS+= -DCXXNET_USE_OPENCV=0
CFLAGS+= -DMXNET_USE_OPENCV=0
endif

# setup opencv
ifeq ($(USE_OPENCV_DECODER),1)
CFLAGS+= -DMXNET_USE_OPENCV_DECODER=1
else
CFLAGS+= -DMXNET_USE_OPENCV_DECODER=0
endif

ifeq ($(USE_OPENMP_ITER), 1)
CFLAGS += -fopenmp
endif

ifeq ($(USE_CUDNN), 1)
Expand All @@ -62,7 +82,7 @@ endif
ENGINE=naive_engine.o
BIN = tests/test_simple_engine
OBJ = narray_function_cpu.o
OBJCXX11 = narray.o c_api.o operator.o symbol.o storage.o static_graph.o graph_executor.o io.o iter_mnist.o $(ENGINE)
OBJCXX11 = narray.o c_api.o operator.o symbol.o storage.o static_graph.o graph_executor.o io.o iter_mnist.o iter_image_recordio.o $(ENGINE)
CUOBJ = narray_function_gpu.o
SLIB = lib/libmxnet.so
ALIB = lib/libmxnet.a
Expand Down Expand Up @@ -92,6 +112,7 @@ operator.o: src/operator/operator.cc
c_api.o: src/c_api.cc
io.o: src/io/io.cc
iter_mnist.o: src/io/iter_mnist.cc src/io/*.h
iter_image_recordio.o: src/io/iter_image_recordio.cc src/io/*.h

# Rules for operators
OPERATOR_HDR=$(wildcard src/operator/*-inl.h)
Expand Down
2 changes: 1 addition & 1 deletion dmlc-core
101 changes: 101 additions & 0 deletions example/cifar10/cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,109 @@ def RandomInit(narray):
flatten = mx.symbol.Flatten(data=pool, name="flatten1")
fc = mx.symbol.FullyConnected(data=flatten, num_hidden=10, name="fc1")
loss = mx.symbol.Softmax(data=fc, name="softmax")
args_list = loss.list_arguments()

data_shape = (128, 3, 28, 28)
arg_shapes, out_shapes, aux_shapes = loss.infer_shape(data=data_shape)

arg_narrays = [mx.narray.create(shape, ctx=mx.Context("gpu")) for shape in arg_shapes]
grad_narrays = [mx.narray.create(shape, ctx=mx.Context("gpu")) for shape in arg_shapes]

inputs = dict(zip(args_list, arg_narrays))

name2shape = dict(zip(args_list, arg_shapes))
pred = mx.narray.create(out_shapes[0])

np.random.seed(0)
# set random weight
for name, narray in inputs.items():
if "weight" in name:
tmp = mx.narray.create(name2shape[name])
tmp.numpy[:] = np.random.uniform(-0.07, 0.07, name2shape[name])
tmp.copyto(narray)
if "bias" in name:
narray[:] = 0.0

# bind executer
# TODO(bing): think of a better bind interface
executor = loss.bind(mx.Context('gpu'), arg_narrays, grad_narrays)
# update

out_narray = executor.heads()[0]
grad_narray = mx.narray.create(out_narray.shape)

epoch = 9
lr = 0.1
wd = 0.0004

def Update(grad, weight):
weight[:] -= lr * grad / batch_size

block = list(zip(grad_narrays, arg_narrays))

#check data
get_data.GetCifar10()
train_dataiter = mx.io.ImageRecordIter(
path_imgrec="data/cifar/train.rec",
mean_img="data/cifar/cifar_mean.bin",
rand_crop=True,
rand_mirror=True,
input_shape=(3,28,28),
batch_size=128,
nthread=1)
test_dataiter = mx.io.ImageRecordIter(
path_imgrec="data/cifar/test.rec",
mean_img="data/cifar/cifar_mean.bin",
rand_crop=True,
rand_mirror=True,
input_shape=(3,28,28),
batch_size=100,
nthread=1)

tmp_label = mx.narray.create(name2shape["sm_label"])

def test_cifar():
acc_train = 0.
acc_val = 0.
for i in range(epoch):
# train
print("Epoch %d" % i)
train_acc = 0.0
val_acc = 0.0
train_nbatch = 0
val_nbatch = 0
for data, label in train_dataiter:
data = data
tmp_label.numpy[:] = label.numpy.reshape(tmp_label.shape)
data.copyto(inputs["data"])
tmp_label.copyto(inputs["sm_label"])
executor.forward()
out_narray.copyto(pred)
train_acc += CalAcc(pred.numpy, label.numpy.flatten())
train_nbatch += 1
out_narray.copyto(grad_narray)
executor.backward([grad_narray])

for grad, weight in block:
Update(grad, weight)

# evaluate
for data, label in val_dataiter:
data = data
label = label.numpy.flatten()
data.copyto(inputs["data"])
executor.forward()
out_narray.copyto(pred)
val_acc += CalAcc(pred.numpy, label)
val_nbatch += 1
acc_train = train_acc / train_nbatch
acc_val = val_acc / val_nbatch
print("Train Acc: ", train_acc / train_nbatch)
print("Valid Acc: ", val_acc / val_nbatch)
train_dataiter.reset()
val_dataiter.reset()
assert(acc_train > 0.98)
assert(acc_val > 0.97)

if __name__ == "__main__":
test_cifar()
16 changes: 16 additions & 0 deletions include/mxnet/io.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,5 +109,21 @@ struct DataIteratorReg
} \
DMLC_REGISTRY_REGISTER(::mxnet::DataIteratorReg, DataIteratorReg, name) \
.set_body(__create__ ## DataIteratorType ## __)
/*!
* \brief Macro to register chained Iterators
*
* \code
* // example of registering a imagerec iterator
* MXNET_REGISTER_IO_CHAINED_ITERATOR(ImageRec, ImageRecordIter, BatchIter)
* .describe("batched image record data iterator");
*
* \endcode
*/
#define MXNET_REGISTER_IO_CHAINED_ITER(name, ChainedDataIterType, HoldingDataIterType) \
static ::mxnet::IIterator<DataBatch>* __create__ ## ChainedDataIteratorType ## __() { \
return new HoldingDataIterType(new ChainedDataIterType); \
} \
DMLC_REGISTRY_REGISTER(::mxnet::DataIteratorReg, DataIteratorReg, name) \
.set_body(__create__ ## ChainedDataIteratorType ## __)
} // namespace mxnet
#endif // MXNET_IO_H_
4 changes: 2 additions & 2 deletions make/config.mk
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ USE_CUDA_PATH = NONE
# whether use opencv during compilation
# you can disable it, however, you will not able to use
# imbin iterator
USE_OPENCV = 0
USE_OPENCV_DECODER = 0
USE_OPENCV = 1
USE_OPENCV_DECODER = 1
# whether use CUDNN R3 library
USE_CUDNN = 0
# add the path to CUDNN libary to link and compile flag
Expand Down
5 changes: 5 additions & 0 deletions src/common/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ namespace common {
* \brief Random Engine
*/
typedef std::mt19937 RANDOM_ENGINE;
// Get a double float, prnd is the pointer to a Random Engine
#define NextDouble(prnd) std::generate_canonical<float, 10>(*prnd)
// Get a random int in [0, range)
#define NextUInt32(range, prnd) static_cast<uint32_t> \
(floor(std::generate_canonical<float, 10>(*prnd) * range))

/*!
* \brief Helper functions.
Expand Down
Loading