From 3f732b6f2703bb83df8028707e69745533019b08 Mon Sep 17 00:00:00 2001 From: Kostas Chalikias Date: Wed, 29 Jan 2020 21:52:21 +0000 Subject: [PATCH] refactor cluster selection so it can be replaced, start adding dbscan based --- .gitignore | 5 +- pysparnn/cluster_index.py | 47 ++-------- pysparnn/cluster_selection.py | 156 ++++++++++++++++++++++++++++++++ tests/test_cluster_selection.py | 52 +++++++++++ 4 files changed, 220 insertions(+), 40 deletions(-) create mode 100644 pysparnn/cluster_selection.py create mode 100644 tests/test_cluster_selection.py diff --git a/.gitignore b/.gitignore index 1dbc687..297ac65 100644 --- a/.gitignore +++ b/.gitignore @@ -58,5 +58,8 @@ docs/_build/ # PyBuilder target/ -#Ipython Notebook +# Ipython Notebook .ipynb_checkpoints + +# IDEs +.idea diff --git a/pysparnn/cluster_index.py b/pysparnn/cluster_index.py index f798cc7..8f0b8eb 100644 --- a/pysparnn/cluster_index.py +++ b/pysparnn/cluster_index.py @@ -14,19 +14,7 @@ import numpy as _np import pysparnn.matrix_distance - - -def _k_best(tuple_list, k): - """For a list of tuples [(distance, value), ...] - Get the k-best tuples by - distance. - Args: - tuple_list: List of tuples. (distance, value) - k: Number of tuples to return. - """ - tuple_lst = sorted(tuple_list, key=lambda x: x[0], - reverse=False)[:k] - - return tuple_lst +from pysparnn.cluster_selection import _k_best, DefaultClusterSelector def _filter_unique(tuple_list): @@ -81,6 +69,7 @@ class ClusterIndex(object): def __init__(self, features, records_data, distance_type=pysparnn.matrix_distance.CosineDistance, + cluster_selector_type=DefaultClusterSelector, matrix_size=None, parent=None): """Create a search index composed of recursively defined @@ -96,12 +85,13 @@ def __init__(self, features, records_data, distance_type: Class that defines the distance measure to use. matrix_size: Ideal size for matrix multiplication. This controls the depth of the tree. Defaults to 2 levels (approx). Highly - reccomended that the default value is used. + recommended that the default value is used. """ self.is_terminal = False self.parent = parent self.distance_type = distance_type + self.cluster_selector = cluster_selector_type(distance_type) self.desired_matrix_size = matrix_size features = distance_type.features_to_matrix(features) num_records = features.shape[0] @@ -122,28 +112,7 @@ def __init__(self, features, records_data, self.is_terminal = False records_data = _np.array(records_data) - records_index = list(_np.arange(features.shape[0])) - clusters_size = min(self.matrix_size, num_records) - clusters_selection = _random.sample(records_index, clusters_size) - clusters_selection = features[clusters_selection] - - item_to_clusters = _collections.defaultdict(list) - - root = distance_type(clusters_selection, - list(_np.arange(clusters_selection.shape[0]))) - - root.remove_near_duplicates() - root = distance_type(root.matrix, - list(_np.arange(root.matrix.shape[0]))) - - rng_step = self.matrix_size - for rng in range(0, features.shape[0], rng_step): - max_rng = min(rng + rng_step, features.shape[0]) - records_rng = features[rng:max_rng] - for i, clstrs in enumerate(root.nearest_search(records_rng)): - _random.shuffle(clstrs) - for _, cluster in _k_best(clstrs, k=1): - item_to_clusters[cluster].append(i + rng) + clusters_selection, item_to_clusters = self.cluster_selector.select_clusters(features) clusters = [] cluster_keeps = [] @@ -298,7 +267,7 @@ def search(self, features, k=1, k_clusters=1, """ # search no more than 1k records at once - # helps keap the matrix multiplies small + # helps keep the matrix multiplies small batch_size = 1000 results = [] rng_step = batch_size @@ -361,7 +330,7 @@ class MultiClusterIndex(object): Creating more Indexes (random cluster allocations) increases the chances of finding a good match. - There are three perameters that impact recall. Will discuss them all + There are three parameters that impact recall. Will discuss them all here: 1) MuitiClusterIndex(matrix_size) This impacts the tree structure (see cluster index documentation). @@ -424,7 +393,7 @@ class docstring for a description of the method. self.indexes = [] for _ in range(num_indexes): self.indexes.append((ClusterIndex(features, records_data, - distance_type, matrix_size))) + distance_type=distance_type, matrix_size=matrix_size))) def insert(self, feature, record): """Insert a single record into the index. diff --git a/pysparnn/cluster_selection.py b/pysparnn/cluster_selection.py new file mode 100644 index 0000000..854bf3a --- /dev/null +++ b/pysparnn/cluster_selection.py @@ -0,0 +1,156 @@ +import random as _random +import numpy as _np + +import collections as _collections + +from abc import ABC, abstractmethod + +from sklearn.cluster import DBSCAN + + +def _k_best(tuple_list, k): + """For a list of tuples [(distance, value), ...] - Get the k-best tuples by + distance. + Args: + tuple_list: List of tuples. (distance, value) + k: Number of tuples to return. + """ + tuple_lst = sorted(tuple_list, key=lambda x: x[0], + reverse=False)[:k] + + return tuple_lst + + +class ClusterSelector(ABC): + + @abstractmethod + def select_clusters(self, features): + pass + + +class DefaultClusterSelector(ClusterSelector): + """ + Default cluster selector, picks sqrt(num_records) random points (at most 1000) + and allocates points to their nearest category. This can often end up splitting + similar points into multiple paths of the tree + """ + + def __init__(self, distance_type): + self._distance_type = distance_type + + def select_clusters(self, features): + # number of points to cluster + num_records = features.shape[0] + + matrix_size = max(int(_np.sqrt(num_records)), 1000) + + # set num_clusters = min(max(sqrt(num_records), 1000), num_records)) + clusters_size = min(matrix_size, num_records) + + # make list [0, 1, ..., num_records-1] + records_index = list(_np.arange(features.shape[0])) + # randomly choose num_clusters records as the cluster roots + # this randomizes both selection and order of features in the selection + clusters_selection = _random.sample(records_index, clusters_size) + clusters_selection = features[clusters_selection] + + # create structure to store clusters + item_to_clusters = _collections.defaultdict(list) + + # create a distance_type object containing the cluster roots + # labeling them as 0 to N-1 in their current (random) order + root = self._distance_type(clusters_selection, + list(_np.arange(clusters_selection.shape[0]))) + + # remove duplicate cluster roots + root.remove_near_duplicates() + # initialize distance type object with the remaining cluster roots + root = self._distance_type(root.matrix, + list(_np.arange(root.matrix.shape[0]))) + + rng_step = matrix_size + # walk features in steps of matrix_size = max(sqrt(num_records), 1000) + for rng in range(0, features.shape[0], rng_step): + # don't exceed the array length on the last step + max_rng = min(rng + rng_step, features.shape[0]) + records_rng = features[rng:max_rng] + # find the nearest cluster root for each feature in the step + for i, clstrs in enumerate(root.nearest_search(records_rng)): + _random.shuffle(clstrs) + for _, cluster in _k_best(clstrs, k=1): + # add each feature to its nearest cluster, here the cluster label + # is the label assigned to the root feature after it had been selected at random + item_to_clusters[cluster].append(i + rng) + + # row index in clusters_selection maps to key in item_to_clusters + # but the values in item_to_clusters are row indices of the original features matrix + return clusters_selection, item_to_clusters + + +class DbscanClusterSelector(ClusterSelector): + """ + Dbscan based cluster selector, picks sqrt(num_records) random points (at most 1000) + and then forms groups inside the random selection, before allocating other features + to the groups + """ + + def __init__(self, distance_type): + self._distance_type = distance_type + self._eps = 0.4 + + def select_clusters(self, features): + # number of points to cluster + num_records = features.shape[0] + + matrix_size = max(int(_np.sqrt(num_records)), 1000) + + # set num_clusters = min(max(sqrt(num_records), 1000), num_records)) + clusters_size = min(matrix_size, num_records) + + # make list [0, 1, ..., num_records-1] + records_index = list(_np.arange(features.shape[0])) + # randomly choose num_clusters records as the cluster roots + # this randomizes both selection and order of features in the selection + random_clusters_selection = _random.sample(records_index, clusters_size) + random_clusters_selection = features[random_clusters_selection] + + # now cluster the cluster roots themselves to avoid + # randomly separating neighbours, this probably means fewer clusters per level + # TODO might want to propagate the distance type to the clustering + db_scan_clustering = DBSCAN(eps=self._eps, min_samples=2).fit(random_clusters_selection) + + # get all the individual points from the cluster + unique_indices = _np.where(db_scan_clustering.labels_ == -1)[0] + # and the first item from each cluster + _, cluster_start_indices = _np.unique(db_scan_clustering.labels_, return_index=True) + # merge and uniquefy, the result is sorted + all_indices = _np.concatenate((unique_indices, cluster_start_indices)) + all_indices_unique = _np.unique(all_indices) + + # create a matrix where rows are the first item in each dbscan cluster + # set that as cluster selection and then allocate features to cluster + clusters_selection = random_clusters_selection[all_indices_unique] + + # create structure to store clusters + item_to_clusters = _collections.defaultdict(list) + + # create a distance_type object containing the cluster root + root = self._distance_type(clusters_selection, + list(_np.arange(clusters_selection.shape[0]))) + + rng_step = matrix_size + # walk features in steps of matrix_size = max(sqrt(num_records), 1000) + for rng in range(0, features.shape[0], rng_step): + max_rng = min(rng + rng_step, features.shape[0]) + records_rng = features[rng:max_rng] + # find the nearest cluster root for each feature in the step + for i, clstrs in enumerate(root.nearest_search(records_rng)): + # this is slow, disable until proven useful + # _random.shuffle(clstrs) + for _, cluster in _k_best(clstrs, k=1): + # add each feature to its nearest cluster + item_to_clusters[cluster].append(i + rng) + + # row index in clusters_selection maps to key in item_to_clusters + # but the values in item_to_clusters are row indices of the original features matrix + return clusters_selection, item_to_clusters diff --git a/tests/test_cluster_selection.py b/tests/test_cluster_selection.py new file mode 100644 index 0000000..aaa1b24 --- /dev/null +++ b/tests/test_cluster_selection.py @@ -0,0 +1,52 @@ +import numpy as np +from scipy.sparse import csr_matrix + +from pysparnn.cluster_selection import DefaultClusterSelector, DbscanClusterSelector +from pysparnn.matrix_distance import CosineDistance, SlowEuclideanDistance + + +class TestDefaultClusterSelector(object): + def test_single_item_groups(self): + sel = DefaultClusterSelector(CosineDistance) + + features_dense = np.identity(1000) + features = csr_matrix(features_dense) + cs, i2c = sel.select_clusters(features) + + assert cs.shape[0] == 1000 + assert all(len(cl) == 1 for cl in i2c.values()) + + def test_non_single_item_groups(self): + sel = DefaultClusterSelector(SlowEuclideanDistance) + + features_dense = np.identity(1001) + # features = csr_matrix(features_dense) + cs, i2c = sel.select_clusters(features_dense) + + assert cs.shape[0] == 1000 + assert all(len(cl) >= 1 for cl in i2c.values()) + assert sum(len(cl) for cl in i2c.values()) == 1001 + + non_single_groups = list(cl for cl in i2c.values() if len(cl) == 2) + assert len(non_single_groups) == 1 + + +class TestDbscanClusterSelector(object): + + def test_one(self): + sel = DbscanClusterSelector(CosineDistance) + + features_dense = np.identity(100) + similar = [[0.9] + [0] * 99] + features_dense = np.append(features_dense, similar, axis=0) + assert features_dense.shape[0] == 101 + + features = csr_matrix(features_dense) + cs, i2c = sel.select_clusters(features) + + assert cs.shape[0] == 100 + assert all(len(cl) >= 1 for cl in i2c.values()) + assert sum(len(cl) for cl in i2c.values()) == 101 + + non_single_groups = list(cl for cl in i2c.values() if len(cl) == 2) + assert len(non_single_groups) == 1