-
Notifications
You must be signed in to change notification settings - Fork 1
/
setup.py
64 lines (54 loc) · 1.94 KB
/
setup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from pathlib import Path, PurePath
import torch
import architecture as arch
import utils
def setup(config):
# INPUT IMAGE INFORMATION
image_info = utils.ImageInfo(**config.dataset.parameters)
# ARCH HEAD INFORMATION
heads_info = arch.HeadsInfo(
heads_info=config.architecture.heads.info,
input_size=config.dataset.parameters.input_size,
subhead_count=config.architecture.heads.subhead_count,
)
# OUTPUT_FILES
output_root = PurePath(config.output.root) / str(config.dataset.id)
output_root = PurePath(Path(output_root).resolve())
if not (Path(output_root).is_dir() and Path(output_root).exists()):
Path(output_root).mkdir(parents=True, exist_ok=True)
output_files = utils.OutputFiles(
root_path=output_root,
label_colors=config.output.label_colors,
image_info=image_info,
)
# STATE_FOLDER
state_folder = output_files.get_sub_root(output_files.STATE)
# RENDERING PATHS
# TODO into output_files
dataset = PurePath(config.dataset.root)
if "partitions" in config.dataset:
partitions = config.dataset.partitions
image_folder = dataset / partitions.image
label_folder = dataset / partitions.label
else:
image_folder = dataset
label_folder = None
# NETWORK ARCHITECTURE
structure = arch.Structure(
input_channels=image_info.channel_count,
structure=config.architecture.trunk.structure,
)
trunk = arch.VGGTrunk(structure=structure, **config.architecture.trunk.parameters)
net = arch.SegmentationNet10aTwoHead(
trunk=trunk, heads=heads_info.build_heads(trunk.feature_count)
)
net.to(torch.device("cuda:0"))
return {
"image_info": image_info,
"heads_info": heads_info,
"output_files": output_files,
"state_folder": state_folder,
"image_folder": image_folder,
"label_folder": label_folder,
"net": net,
}