Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[MXNET-769] set MXNET_HOME as base for downloaded models through base…
Browse files Browse the repository at this point in the history
….data_dir() (#11636)

* set MXNET_DATA_DIR as base for downloaded models through base.data_dir()
push joblib to save containers so is not required when running

* MXNET_DATA_DIR -> MXNET_HOME
  • Loading branch information
larroy authored and marcoabreu committed Aug 2, 2018
1 parent a93905d commit 564e01a
Show file tree
Hide file tree
Showing 27 changed files with 201 additions and 89 deletions.
2 changes: 1 addition & 1 deletion ci/docker_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import subprocess
import json
import build as build_util
from joblib import Parallel, delayed



Expand All @@ -43,6 +42,7 @@ def build_save_containers(platforms, registry, load_cache) -> int:
:param load_cache: Load cache before building
:return: 1 if error occurred, 0 otherwise
"""
from joblib import Parallel, delayed
if len(platforms) == 0:
return 0

Expand Down
4 changes: 2 additions & 2 deletions contrib/clojure-package/examples/scripts/get_cifar_data.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@

set -evx

if [ ! -z "$MXNET_DATA_DIR" ]; then
data_path="$MXNET_DATA_DIR"
if [ ! -z "$MXNET_HOME" ]; then
data_path="$MXNET_HOME"
else
data_path="./data"
fi
Expand Down
4 changes: 2 additions & 2 deletions contrib/clojure-package/examples/scripts/get_mnist_data.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@

set -evx

if [ ! -z "$MXNET_DATA_DIR" ]; then
data_path="$MXNET_DATA_DIR"
if [ ! -z "$MXNET_HOME" ]; then
data_path="$MXNET_HOME"
else
data_path="./data"
fi
Expand Down
4 changes: 2 additions & 2 deletions contrib/clojure-package/scripts/get_cifar_data.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@

set -evx

if [ ! -z "$MXNET_DATA_DIR" ]; then
data_path="$MXNET_DATA_DIR"
if [ ! -z "$MXNET_HOME" ]; then
data_path="$MXNET_HOME"
else
data_path="./data"
fi
Expand Down
4 changes: 2 additions & 2 deletions contrib/clojure-package/scripts/get_mnist_data.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@

set -evx

if [ ! -z "$MXNET_DATA_DIR" ]; then
data_path="$MXNET_DATA_DIR"
if [ ! -z "$MXNET_HOME" ]; then
data_path="$MXNET_HOME"
else
data_path="./data"
fi
Expand Down
4 changes: 4 additions & 0 deletions docs/faq/env_var.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,10 @@ When USE_PROFILER is enabled in Makefile or CMake, the following environments ca
- Values: String ```(default='https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/'```
- The repository url to be used for Gluon datasets and pre-trained models.

* MXNET_HOME
- Data directory in the filesystem for storage, for example when downloading gluon models.
- Default in *nix is .mxnet APPDATA/mxnet in windows.

Settings for Minimum Memory Usage
---------------------------------
- Make sure ```min(MXNET_EXEC_NUM_TEMP, MXNET_GPU_WORKER_NTHREADS) = 1```
Expand Down
24 changes: 22 additions & 2 deletions python/mxnet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@

import atexit
import ctypes
import inspect
import os
import sys
import warnings

import inspect
import platform
import numpy as np

from . import libinfo
Expand Down Expand Up @@ -59,6 +59,26 @@
py_str = lambda x: x


def data_dir_default():
"""
:return: default data directory depending on the platform and environment variables
"""
system = platform.system()
if system == 'Windows':
return os.path.join(os.environ.get('APPDATA'), 'mxnet')
else:
return os.path.join(os.path.expanduser("~"), '.mxnet')


def data_dir():
"""
:return: data directory in the filesystem for storage, for example when downloading models
"""
return os.getenv('MXNET_HOME', data_dir_default())


class _NullType(object):
"""Placeholder for arguments"""
def __repr__(self):
Expand Down
9 changes: 5 additions & 4 deletions python/mxnet/contrib/text/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from . import vocab
from ... import ndarray as nd
from ... import registry
from ... import base


def register(embedding_cls):
Expand Down Expand Up @@ -496,7 +497,7 @@ class GloVe(_TokenEmbedding):
----------
pretrained_file_name : str, default 'glove.840B.300d.txt'
The name of the pre-trained token embedding file.
embedding_root : str, default os.path.join('~', '.mxnet', 'embeddings')
embedding_root : str, default $MXNET_HOME/embeddings
The root directory for storing embedding-related files.
init_unknown_vec : callback
The callback used to initialize the embedding vector for the unknown token.
Expand Down Expand Up @@ -541,7 +542,7 @@ def _get_download_file_name(cls, pretrained_file_name):
return archive

def __init__(self, pretrained_file_name='glove.840B.300d.txt',
embedding_root=os.path.join('~', '.mxnet', 'embeddings'),
embedding_root=os.path.join(base.data_dir(), 'embeddings'),
init_unknown_vec=nd.zeros, vocabulary=None, **kwargs):
GloVe._check_pretrained_file_names(pretrained_file_name)

Expand Down Expand Up @@ -600,7 +601,7 @@ class FastText(_TokenEmbedding):
----------
pretrained_file_name : str, default 'wiki.en.vec'
The name of the pre-trained token embedding file.
embedding_root : str, default os.path.join('~', '.mxnet', 'embeddings')
embedding_root : str, default $MXNET_HOME/embeddings
The root directory for storing embedding-related files.
init_unknown_vec : callback
The callback used to initialize the embedding vector for the unknown token.
Expand Down Expand Up @@ -642,7 +643,7 @@ def _get_download_file_name(cls, pretrained_file_name):
return '.'.join(pretrained_file_name.split('.')[:-1])+'.zip'

def __init__(self, pretrained_file_name='wiki.simple.vec',
embedding_root=os.path.join('~', '.mxnet', 'embeddings'),
embedding_root=os.path.join(base.data_dir(), 'embeddings'),
init_unknown_vec=nd.zeros, vocabulary=None, **kwargs):
FastText._check_pretrained_file_names(pretrained_file_name)

Expand Down
11 changes: 5 additions & 6 deletions python/mxnet/gluon/contrib/data/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@
from ...data import dataset
from ...utils import download, check_sha1, _get_repo_file_url
from ....contrib import text
from .... import nd

from .... import nd, base

class _LanguageModelDataset(dataset._DownloadedDataset): # pylint: disable=abstract-method
def __init__(self, root, namespace, vocabulary):
Expand Down Expand Up @@ -116,7 +115,7 @@ class WikiText2(_WikiText):
Parameters
----------
root : str, default '~/.mxnet/datasets/wikitext-2'
root : str, default $MXNET_HOME/datasets/wikitext-2
Path to temp folder for storing data.
segment : str, default 'train'
Dataset segment. Options are 'train', 'validation', 'test'.
Expand All @@ -127,7 +126,7 @@ class WikiText2(_WikiText):
The sequence length of each sample, regardless of the sentence boundary.
"""
def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 'wikitext-2'),
def __init__(self, root=os.path.join(base.data_dir(), 'datasets', 'wikitext-2'),
segment='train', vocab=None, seq_len=35):
self._archive_file = ('wikitext-2-v1.zip', '3c914d17d80b1459be871a5039ac23e752a53cbe')
self._data_file = {'train': ('wiki.train.tokens',
Expand All @@ -154,7 +153,7 @@ class WikiText103(_WikiText):
Parameters
----------
root : str, default '~/.mxnet/datasets/wikitext-103'
root : str, default $MXNET_HOME/datasets/wikitext-103
Path to temp folder for storing data.
segment : str, default 'train'
Dataset segment. Options are 'train', 'validation', 'test'.
Expand All @@ -164,7 +163,7 @@ class WikiText103(_WikiText):
seq_len : int, default 35
The sequence length of each sample, regardless of the sentence boundary.
"""
def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 'wikitext-103'),
def __init__(self, root=os.path.join(base.data_dir(), 'datasets', 'wikitext-103'),
segment='train', vocab=None, seq_len=35):
self._archive_file = ('wikitext-103-v1.zip', '0aec09a7537b58d4bb65362fee27650eeaba625a')
self._data_file = {'train': ('wiki.train.tokens',
Expand Down
18 changes: 9 additions & 9 deletions python/mxnet/gluon/data/vision/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

from .. import dataset
from ...utils import download, check_sha1, _get_repo_file_url
from .... import nd, image, recordio
from .... import nd, image, recordio, base


class MNIST(dataset._DownloadedDataset):
Expand All @@ -40,7 +40,7 @@ class MNIST(dataset._DownloadedDataset):
Parameters
----------
root : str, default '~/.mxnet/datasets/mnist'
root : str, default $MXNET_HOME/datasets/mnist
Path to temp folder for storing data.
train : bool, default True
Whether to load the training or testing set.
Expand All @@ -51,7 +51,7 @@ class MNIST(dataset._DownloadedDataset):
transform=lambda data, label: (data.astype(np.float32)/255, label)
"""
def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 'mnist'),
def __init__(self, root=os.path.join(base.data_dir(), 'datasets', 'mnist'),
train=True, transform=None):
self._train = train
self._train_data = ('train-images-idx3-ubyte.gz',
Expand Down Expand Up @@ -101,7 +101,7 @@ class FashionMNIST(MNIST):
Parameters
----------
root : str, default '~/.mxnet/datasets/fashion-mnist'
root : str, default $MXNET_HOME/datasets/fashion-mnist'
Path to temp folder for storing data.
train : bool, default True
Whether to load the training or testing set.
Expand All @@ -112,7 +112,7 @@ class FashionMNIST(MNIST):
transform=lambda data, label: (data.astype(np.float32)/255, label)
"""
def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 'fashion-mnist'),
def __init__(self, root=os.path.join(base.data_dir(), 'datasets', 'fashion-mnist'),
train=True, transform=None):
self._train = train
self._train_data = ('train-images-idx3-ubyte.gz',
Expand All @@ -134,7 +134,7 @@ class CIFAR10(dataset._DownloadedDataset):
Parameters
----------
root : str, default '~/.mxnet/datasets/cifar10'
root : str, default $MXNET_HOME/datasets/cifar10
Path to temp folder for storing data.
train : bool, default True
Whether to load the training or testing set.
Expand All @@ -145,7 +145,7 @@ class CIFAR10(dataset._DownloadedDataset):
transform=lambda data, label: (data.astype(np.float32)/255, label)
"""
def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 'cifar10'),
def __init__(self, root=os.path.join(base.data_dir(), 'datasets', 'cifar10'),
train=True, transform=None):
self._train = train
self._archive_file = ('cifar-10-binary.tar.gz', 'fab780a1e191a7eda0f345501ccd62d20f7ed891')
Expand Down Expand Up @@ -197,7 +197,7 @@ class CIFAR100(CIFAR10):
Parameters
----------
root : str, default '~/.mxnet/datasets/cifar100'
root : str, default $MXNET_HOME/datasets/cifar100
Path to temp folder for storing data.
fine_label : bool, default False
Whether to load the fine-grained (100 classes) or coarse-grained (20 super-classes) labels.
Expand All @@ -210,7 +210,7 @@ class CIFAR100(CIFAR10):
transform=lambda data, label: (data.astype(np.float32)/255, label)
"""
def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 'cifar100'),
def __init__(self, root=os.path.join(base.data_dir(), 'datasets', 'cifar100'),
fine_label=False, train=True, transform=None):
self._train = train
self._archive_file = ('cifar-100-binary.tar.gz', 'a0bb982c76b83111308126cc779a992fa506b90b')
Expand Down
17 changes: 9 additions & 8 deletions python/mxnet/gluon/model_zoo/model_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@
__all__ = ['get_model_file', 'purge']
import os
import zipfile
import logging

from ..utils import download, check_sha1
from ... import base, util

_model_sha1 = {name: checksum for checksum, name in [
('44335d1f0046b328243b32a26a4fbd62d9057b45', 'alexnet'),
Expand Down Expand Up @@ -68,7 +70,7 @@ def short_hash(name):
raise ValueError('Pretrained model for {name} is not available.'.format(name=name))
return _model_sha1[name][:8]

def get_model_file(name, root=os.path.join('~', '.mxnet', 'models')):
def get_model_file(name, root=os.path.join(base.data_dir(), 'models')):
r"""Return location for the pretrained on local file system.
This function will download from online model zoo when model cannot be found or has mismatch.
Expand All @@ -78,7 +80,7 @@ def get_model_file(name, root=os.path.join('~', '.mxnet', 'models')):
----------
name : str
Name of the model.
root : str, default '~/.mxnet/models'
root : str, default $MXNET_HOME/models
Location for keeping the model parameters.
Returns
Expand All @@ -95,12 +97,11 @@ def get_model_file(name, root=os.path.join('~', '.mxnet', 'models')):
if check_sha1(file_path, sha1_hash):
return file_path
else:
print('Mismatch in the content of model file detected. Downloading again.')
logging.warning('Mismatch in the content of model file detected. Downloading again.')
else:
print('Model file is not found. Downloading.')
logging.info('Model file not found. Downloading to %s.', file_path)

if not os.path.exists(root):
os.makedirs(root)
util.makedirs(root)

zip_file_path = os.path.join(root, file_name+'.zip')
repo_url = os.environ.get('MXNET_GLUON_REPO', apache_repo_url)
Expand All @@ -118,12 +119,12 @@ def get_model_file(name, root=os.path.join('~', '.mxnet', 'models')):
else:
raise ValueError('Downloaded file has different hash. Please try again.')

def purge(root=os.path.join('~', '.mxnet', 'models')):
def purge(root=os.path.join(base.data_dir(), 'models')):
r"""Purge all pretrained model files in local file store.
Parameters
----------
root : str, default '~/.mxnet/models'
root : str, default '$MXNET_HOME/models'
Location for keeping the model parameters.
"""
root = os.path.expanduser(root)
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/gluon/model_zoo/vision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def get_model(name, **kwargs):
Number of classes for the output layer.
ctx : Context, default CPU
The context in which to load the pretrained weights.
root : str, default '~/.mxnet/models'
root : str, default '$MXNET_HOME/models'
Location for keeping the model parameters.
Returns
Expand Down
5 changes: 3 additions & 2 deletions python/mxnet/gluon/model_zoo/vision/alexnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from ....context import cpu
from ...block import HybridBlock
from ... import nn
from .... import base

# Net
class AlexNet(HybridBlock):
Expand Down Expand Up @@ -68,7 +69,7 @@ def hybrid_forward(self, F, x):

# Constructor
def alexnet(pretrained=False, ctx=cpu(),
root=os.path.join('~', '.mxnet', 'models'), **kwargs):
root=os.path.join(base.data_dir(), 'models'), **kwargs):
r"""AlexNet model from the `"One weird trick..." <https://arxiv.org/abs/1404.5997>`_ paper.
Parameters
Expand All @@ -77,7 +78,7 @@ def alexnet(pretrained=False, ctx=cpu(),
Whether to load the pretrained weights for model.
ctx : Context, default CPU
The context in which to load the pretrained weights.
root : str, default '~/.mxnet/models'
root : str, default $MXNET_HOME/models
Location for keeping the model parameters.
"""
net = AlexNet(**kwargs)
Expand Down
Loading

0 comments on commit 564e01a

Please sign in to comment.