Skip to content

Commit

Permalink
fix the data fix length in global search
Browse files Browse the repository at this point in the history
  • Loading branch information
Hanqer authored Nov 12, 2021
1 parent 2808810 commit 824df08
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions batch_gen_resize.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,15 @@ def __init__(self, num_classes, actions_dict, gt_path, features_path, sample_rat
file_ptr = open(vid, 'r')
self.list_of_examples = file_ptr.read().split('\n')[:-1]
file_ptr.close()

self.mask = torch.ones(self.num_classes, 2000, dtype=torch.float)

if '50salads' in gt_path:
self.fix_size = 5000
elif 'breakfast' in gt_path:
self.fix_size = 2000
elif 'gtea' in gt_path:
self.fix_size = 1000

self.mask = torch.ones(self.num_classes, self.fix_size, dtype=torch.float)


def __getitem__(self, idx):
Expand All @@ -45,8 +52,8 @@ def __getitem__(self, idx):
batch_input = torch.from_numpy(batch_input)
batch_target = torch.from_numpy(batch_target)

batch_input = torch.nn.functional.interpolate(batch_input.unsqueeze(0), size=5000, mode='nearest').squeeze()
batch_target = torch.nn.functional.interpolate(batch_target.unsqueeze(0).unsqueeze(0), size=5000, mode='nearest').squeeze().long()
batch_input = torch.nn.functional.interpolate(batch_input.unsqueeze(0), size=self.fix_size, mode='nearest').squeeze()
batch_target = torch.nn.functional.interpolate(batch_target.unsqueeze(0).unsqueeze(0), size=self.fix_size, mode='nearest').squeeze().long()

np.save(self.features_path + batch.split('.')[0] + '_fix', batch_input.numpy())
np.save(self.features_path + batch.split('.')[0] + '_fix_label', batch_target.numpy())
Expand Down

0 comments on commit 824df08

Please sign in to comment.