forked from BUTSpeechFIT/VBx
-
Notifications
You must be signed in to change notification settings - Fork 0
/
setup.py
78 lines (65 loc) · 2.09 KB
/
setup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2018 Phonexia
# Author: Jan Profant <[email protected]>
# All Rights Reserved
from distutils.core import setup
import glob
import os
from setuptools.command.install import install
from setuptools.command.develop import develop
from setuptools import find_packages
import tempfile
import zipfile
MODELS_DIR = 'VBx/models'
MODELS = ['ResNet101_8kHz', 'ResNet101_16kHz']
def install_scripts(directory):
"""Call cmd commands to install extra software/repositories.
Args:
directory (str): path
"""
# unpack multiple zip files into .pth file
for model in MODELS:
temp_zip = tempfile.NamedTemporaryFile(delete=False)
nnet_dir = os.path.join(MODELS_DIR, model, 'nnet')
assert os.path.isdir(nnet_dir), f'{nnet_dir} does not exist.'
for zip_part in sorted(glob.glob(f'{os.path.join(nnet_dir, "*.pth.zip.part*")}')):
with open(zip_part, 'rb') as f:
temp_zip.write(f.read())
with zipfile.ZipFile(temp_zip, 'r') as fzip:
fzip.printdir()
fzip.extractall(path=nnet_dir)
class PostDevelopCommand(develop):
"""Post-installation for development mode."""
def run(self):
develop.run(self)
self.execute(install_scripts, (self.egg_path,), msg='Running post install scripts')
class PostInstallCommand(install):
"""Post-installation for installation mode."""
def run(self):
install.run(self)
self.execute(install_scripts, (self.install_lib,), msg='Running post install scripts')
setup(
name='VBx',
version='1.1',
packages=find_packages(),
url='https://github.com/fnlandini/VBx_dev',
install_requires=[
'numpy',
'scipy',
'sklearn',
'numexpr',
'h5py',
'onnxruntime',
'soundfile',
'soundfile',
'torch==1.6.0',
'kaldi_io',
'tabulate',
'intervaltree'
],
dependency_links=[],
license='Apache License, Version 2.0',
cmdclass={'install': PostInstallCommand, 'develop': PostDevelopCommand}
)