Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use multiple cluster partitions #76

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 28 additions & 7 deletions clustering/cluster_masks.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,14 +219,36 @@ def fit_cluster(embeddings, method='Agglomerative', k=1000, l2normalize=True,

PARTITION = finch_partition

if PARTITION == -1:
labels = c
n_clusters = num_clust[0]
if type(PARTITION) == int:
if PARTITION == -1:
labels = c
n_clusters = num_clust[0]

else:
labels = c[:,PARTITION]
n_clusters = num_clust[PARTITION]
print('Taking partition {} from finch'.format(PARTITION))
print("Fitted " + str(n_clusters) + " clusters with " + str(method))

else:
labels = c[:,PARTITION]
n_clusters = num_clust[PARTITION]
print('Taking partition {} from finch'.format(PARTITION))
labels = []
n_clusters = []
for p in PARTITION:
labels.append(c[:,p])
n_cluster = num_clust[p]
n_clusters.append(n_cluster)
print('Taking partition {} from finch'.format(p))
print("Fitted " + str(n_cluster) + " clusters with " + str(method))
labels = np.array(labels)
labels = np.transpose(labels)
# labels_t = []
# for j in range(len(labels[0])):
# cur_labels = []
# for i in range(len(labels)):
# cur_labels.append(labels[i][j])
# labels_t.append(cur_labels)
# labels = labels_t


elif method == 'OPTICS':
trained_cluster_obj = OPTICS(min_samples=3, max_eps=0.20, cluster_method='dbscan', metric='cosine', n_jobs=-1).fit(embeddings)
Expand All @@ -237,7 +259,6 @@ def fit_cluster(embeddings, method='Agglomerative', k=1000, l2normalize=True,
print(labels.shape)
n_clusters = len(set(labels)) - (1 if -1 in labels else 0)

print("Fitted " + str(n_clusters) + " clusters with " + str(method))
return labels


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@

TRAIN:
DATASET: "ucf101"

DATASET:
VID_PATH: '/media/diskstation/datasets/UCF101/jpg'
ANNOTATION_PATH: /media/diskstation/datasets/UCF101/UCF101_Flow/json/ucf101_01.json
CLUSTER_PATH: '/media/diskstation/datasets/UCF101/vid_clusters.txt'

TARGET_TYPE_T: 'cluster_label'
TARGET_TYPE_V: 'label'

SAMPLING_STRATEGY: 'random_semi_hard'
POSITIVE_SAMPLING_P: 0.2

POS_CHANNEL_REPLACE: True
PROB_POS_CHANNEL_REPLACE: 0.35

CHANNEL_EXTENSIONS: 'optical_u'
OPTICAL_U_PATH: /media/diskstation/datasets/UCF101/UCF101_Flow/tvl1_flow/u


MODEL:
ARCH: '3dresnet'

RESNET:
MODEL_DEPTH: 18
N_CLASSES: 2048 #512
# N_INPUT_CHANNELS: 3
SHORTCUT: 'B'
CONV1_T_SIZE: 7
CONV1_T_STRIDE: 1
NO_MAX_POOl: true
WIDEN_FACTOR: 1

ITERCLUSTER:
METHOD: 'finch'
FINCH_PARTITION: [0,3]

DATA:
SAMPLE_SIZE: 128
SAMPLE_DURATION: 16
INPUT_CHANNEL_NUM: 3 #4

LOSS:
MARGIN: 0.2
LOCAL_LOCAL_CONTRAST: True


OPTIM:
LR: 0.1
MOMENTUM: 0.5
2 changes: 1 addition & 1 deletion config/default_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@
_C.ITERCLUSTER.ADAPTIVEP = False
_C.ITERCLUSTER.WARMUP_EPOCHS = 0
_C.ITERCLUSTER.L2_NORMALIZE = True
_C.ITERCLUSTER.FINCH_PARTITION = 0
_C.ITERCLUSTER.FINCH_PARTITION = [0]

# -----------------------------------------------------------------------------
# Misc options
Expand Down
1 change: 1 addition & 0 deletions datasets/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ def build_data_loader(split, cfg, is_master_proc=True, triplets=True,
intra_negative=cfg.LOSS.INTRA_NEGATIVE,
modality=cfg.DATASET.MODALITY,
predict_temporal_ds=cfg.MODEL.PREDICT_TEMPORAL_DS,
multi_partition=len(cfg.ITERCLUSTER.FINCH_PARTITION) > 1,
is_master_proc=is_master_proc)

# ============================ Build DataLoader ============================
Expand Down
3 changes: 2 additions & 1 deletion datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def get_data(split, video_path, annotation_path, dataset_name, input_type,
negative_sampling=False, positive_sampling_p=1.0,
pos_channel_replace=False, prob_pos_channel_replace=None, modality=False, predict_temporal_ds=False,
relative_speed_perception=False,
local_local_contrast=False, intra_negative=False, is_master_proc=True):
local_local_contrast=False, intra_negative=False, multi_partition=False, is_master_proc=True):


'''
Expand Down Expand Up @@ -119,6 +119,7 @@ def get_data(split, video_path, annotation_path, dataset_name, input_type,
prob_pos_channel_replace=prob_pos_channel_replace,
modality=modality,
sample_duration=sample_duration,
multi_partition=multi_partition,
predict_temporal_ds=predict_temporal_ds,
relative_speed_perception=relative_speed_perception,
local_local_contrast=local_local_contrast,
Expand Down
43 changes: 38 additions & 5 deletions datasets/triplets_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(self,
intra_negative=False,
modality=False,
predict_temporal_ds=False,
multi_partition=False,
image_name_formatter=lambda x: f'image_{x:05d}.jpg',
target_type='label'):

Expand All @@ -71,6 +72,7 @@ def __init__(self,
self.intra_negative = intra_negative
self.modality = modality
self.predict_temporal_ds = predict_temporal_ds
self.multi_partition = multi_partition
self.max_sr = 4
self.shuffle = Shuffle()

Expand Down Expand Up @@ -99,12 +101,27 @@ def __init__(self,

self.gt_target_type='label'

self.data_labels = np.array([data[self.target_type] for data in self.data])

if self.target_type == 'label':
self.data_labels = np.array([data[self.target_type] for data in self.data])
self.label_to_indices = {label: np.where(self.data_labels == label)[0] for label in self.class_names.keys()}


else: #target_type == cluster_labels
self.label_to_indices = {label: np.where(self.data_labels == label)[0] for label in self.cluster_labels}
if self.multi_partition:
self.data_labels1 = np.array([data[self.target_type][0] for data in self.data])
self.data_labels2 = np.array([data[self.target_type][1] for data in self.data])
self.cluster_labels1 = set(self.data_labels1)
self.cluster_labels2 = set(self.data_labels2)
self.label_to_indices1 = {label: np.where(self.data_labels1 ==
label)[0] for label in self.cluster_labels1}
self.label_to_indices2 = {label: np.where(self.data_labels2 ==
label)[0] for label in self.cluster_labels2}
else:
self.data_labels = np.array([data[self.target_type] for data in self.data])
self.cluster_labels = set(self.data_labels)
self.label_to_indices = {label: np.where(self.data_labels ==
label)[0] for label in self.cluster_labels}

def __getitem__(self, index):
anchor=self.data[index]
Expand All @@ -117,11 +134,27 @@ def __getitem__(self, index):
positive = anchor.copy()

else: #sample positive from same a_target (of type target_type - 'label' or 'cluster_label')
p_idx = np.random.choice(self.label_to_indices[a_target])
if self.multi_partition and self.target_type != "label":
p_idx = np.random.choice(self.label_to_indices1[a_target[0]])
while p_idx == index and len(self.label_to_indices1[a_target[0]]) > 1:
p_idx = np.random.choice(self.label_to_indices1[a_target[0]])

# Pick different video from anchor if there is more than 1 video with target a_target
while p_idx == index and len(self.label_to_indices[a_target]) > 1:
else:
p_idx = np.random.choice(self.label_to_indices[a_target])

# # Pick different video from anchor if there is more than 1 video with target a_target
# while p_idx == index and len(self.label_to_indices[a_target]) > 1:
# p_idx = np.random.choice(self.label_to_indices[a_target])
# positive = self.data[p_idx]
# p_idx = np.random.choice(clust_choices)

# while p_idx == index and len(clust_choices) > 1:
# p_idx = np.random.choice(clust_choices)
# Pick different video from anchor if there is more than 1 video with target a_target
while p_idx == index and len(self.label_to_indices[a_target]) > 1:
p_idx = np.random.choice(self.label_to_indices[a_target])

positive = self.data[p_idx]

p_target = positive[self.target_type]
Expand Down
8 changes: 7 additions & 1 deletion datasets/ucf101.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,13 @@ def read_cluster_labels(self):
return None
with open(self.cluster_path, 'r') as f:
cluster_labels = f.readlines()
cluster_labels = [int(id.replace('\n', '')) for id in cluster_labels]
import re
cluster_labels = [re.sub(' +', ' ', id.replace('\n', '')) for id in cluster_labels]
cluster_labels = [id.replace(' ', ',') for id in cluster_labels]
# print(cluster_labels)
# print(json.loads(cluster_labels[0]))
cluster_labels = [tuple(json.loads(id)) for id in cluster_labels]
print(len(cluster_labels), cluster_labels[0])
if self.is_master_proc:
print('retrieved {} cluster id from file: {}'.format(len(cluster_labels), self.cluster_path))
return cluster_labels
Expand Down
73 changes: 50 additions & 23 deletions loss/triplet_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch.nn as nn
import torch.nn.functional as F
from hyptorch.pmath import dist_matrix, dist

import numpy as np

class MemTripletLoss(nn.Module):
#outputSize = ndata
Expand Down Expand Up @@ -214,6 +214,8 @@ def forward(self, embeddings, labels, gt_labels, sampling_strategy="random_negat
# Get list of (anchor idx, postitive idx, negative idx) triplets
self.triplet_selector = NegativeTripletSelector(self.margin, sampling_strategy, self.dist_metric)
triplets = self.triplet_selector.get_triplets(embeddings, labels) # list of dim: [3, batch_size]

#get False Positive and False Negative Count
gt_labels = gt_labels.reshape((-1, 1))
gt_a = gt_labels[triplets[0],:]
gt_p = gt_labels[triplets[1],:]
Expand All @@ -222,6 +224,7 @@ def forward(self, embeddings, labels, gt_labels, sampling_strategy="random_negat
false_positive = (gt_a!=gt_p).sum()
false_negatvie = (gt_a==gt_n).sum()



# Compute anchor/positive and anchor/negative distances. ap_dists and
# an_dists are tensors with dim: [batch_size]
Expand Down Expand Up @@ -298,33 +301,57 @@ def get_triplets(self, embeddings, labels, distance_matrix=None):
# tensor with dim: [(batch_size * 2), (batch_size * 2)]
distance_matrix = pdist(embeddings, eps=0, dist_metric=self.dist_metric)

# Get tensor with unique labels (<= (batch_size * 2))
unique_labels, counts = torch.unique(labels, return_counts=True)
triplets_indices = [[] for i in range(3)]

# Assert that there is no -1 (noise) label
assert(-1 not in unique_labels)
if type(labels) is tuple: #TODO: make it a flag
labels1 = labels[0]
labels2 = labels[1]

triplets_indices = [[] for i in range(3)]
for label in unique_labels:
for i in range(len(labels2)//2):
# label1 = labels1[i]
label2 = labels2[i]
ap_indices = np.array([i, i+len(labels2)//2])

# Get embeddings indices with current label
label_mask = labels == label
label_indices = torch.where(label_mask)[0]
if label_indices.shape[0] < 2: # must have at least anchor and positive with same label
continue
# where_not_label1 = np.where(labels1 != label1)
negative_indices = np.where(labels2 != label2)[0]
if negative_indices.shape[0] == 0:
continue

# Get embeddings indices without current label
negative_indices = torch.where(torch.logical_not(label_mask))[0]
if negative_indices.shape[0] == 0: # must have at least one negative
continue
# Sample anchor/positive/negative triplet
triplet_label_pairs = self.get_one_one_triplets(
ap_indices, negative_indices, distance_matrix,
)
triplets_indices[0].extend(triplet_label_pairs[0])
triplets_indices[1].extend(triplet_label_pairs[1])
triplets_indices[2].extend(triplet_label_pairs[2])
else:
# Get tensor with unique labels (<= (batch_size * 2))
unique_labels, counts = torch.unique(labels, return_counts=True)

# Sample anchor/positive/negative triplet
triplet_label_pairs = self.get_one_one_triplets(
label_indices, negative_indices, distance_matrix,
)
triplets_indices[0].extend(triplet_label_pairs[0])
triplets_indices[1].extend(triplet_label_pairs[1])
triplets_indices[2].extend(triplet_label_pairs[2])
# Assert that there is no -1 (noise) label
assert(-1 not in unique_labels)

triplets_indices = [[] for i in range(3)]
for label in unique_labels:

# Get embeddings indices with current label
label_mask = labels == label
label_indices = torch.where(label_mask)[0]
if label_indices.shape[0] < 2: # must have at least anchor and positive with same label
continue

# Get embeddings indices without current label
negative_indices = torch.where(torch.logical_not(label_mask))[0]
if negative_indices.shape[0] == 0: # must have at least one negative
continue

# Sample anchor/positive/negative triplet
triplet_label_pairs = self.get_one_one_triplets(
label_indices, negative_indices, distance_matrix,
)
triplets_indices[0].extend(triplet_label_pairs[0])
triplets_indices[1].extend(triplet_label_pairs[1])
triplets_indices[2].extend(triplet_label_pairs[2])

return triplets_indices

Expand Down
Loading