-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataloader.py
77 lines (64 loc) · 2.65 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# !/usr/bin/env python
# coding: utf-8
import os
import matplotlib.pyplot as plt
import numpy as np
import torch
def load_data(device, data_f):
"""
Load data from npz file
:param device: Device CPU/GPU
:param data_f: Data file (.npz)
:return: images, poses, init_ds, init_o
"""
data = np.load(data_f)
images = data["images"] / 255
img_size = images.shape[1]
xs = (torch.arange(img_size) - (img_size / 2 - 0.5)).float()
ys = (torch.arange(img_size) - (img_size / 2 - 0.5)).float()
(xs, ys) = torch.meshgrid(xs, -ys, indexing="xy")
focal = float(data["focal"])
pixel_coords = torch.stack([xs, ys, torch.full_like(xs, -focal)], dim=-1)
camera_coords = pixel_coords / focal
init_ds = camera_coords.to(device)
init_o = torch.Tensor(np.array([0, 0, 2.25])).to(device)
return images, data["poses"], init_ds, init_o
def set_up_train_data(device, conf):
train_folder = conf["train_dir"]
train_files = os.listdir(train_folder)
num_of_files = len(train_files)
images, poses, init_ds, init_o = None, None, None, None
for i in range(num_of_files):
image_set, pose_set, init_ds_set, init_o_set = load_data(device, os.path.join(train_folder, train_files[i]))
init_ds_set = init_ds_set.unsqueeze(0)
init_o_set = init_o_set.unsqueeze(0)
if i == 0:
images = torch.tensor(image_set)
poses = torch.tensor(pose_set)
init_ds = torch.tensor(init_ds_set)
init_o = torch.tensor(init_o_set)
else:
images = torch.cat((images, torch.tensor(image_set)), dim=0)
poses = torch.cat((poses, torch.tensor(pose_set)), dim=0)
init_ds = torch.cat((init_ds, torch.tensor(init_ds_set)), dim=0)
init_o = torch.cat((init_o, torch.tensor(init_o_set)), dim=0)
return num_of_files, images, poses, init_ds, init_o
def set_up_test_data(images, device, poses, init_ds, init_o, conf, test_idx=150):
"""
Set up test data
:param images: Images
:param device: Device CPU/GPU
:param poses: Poses
:param init_ds: Ray directions
:param init_o: Ray origins
:param test_idx: Test file index for visualization
:return: test_ds, test_os, test_img, train_idxs
"""
plt.imshow(images[test_idx])
plt.savefig(f'{conf["model"]}/results/test_img.png')
test_img = torch.Tensor(images[test_idx]).to(device)
test_R = torch.Tensor(poses[test_idx, :3, :3]).to(device)
test_ds = torch.einsum("ij,hwj->hwi", test_R, init_ds)
test_os = (test_R @ init_o).expand(test_ds.shape)
train_idxs = np.arange(len(images)) != test_idx
return test_ds, test_os, test_img, train_idxs