From 9242631ae1444941f6dddc8da52dff59964151e7 Mon Sep 17 00:00:00 2001 From: Wen-Ding Li Date: Tue, 10 Jan 2023 23:07:29 -0500 Subject: [PATCH] construct the hashtables iteratively to save memory usage --- pile/processing/dedup/grouped_dedup.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/pile/processing/dedup/grouped_dedup.py b/pile/processing/dedup/grouped_dedup.py index d46b7b4..db0a6f9 100644 --- a/pile/processing/dedup/grouped_dedup.py +++ b/pile/processing/dedup/grouped_dedup.py @@ -225,7 +225,6 @@ def run( start_time = time.time() B, R = optimal_param(threshold, num_perm) HASH_RANGES = [(i * R, (i + 1) * R) for i in range(B)] - HASH_TABLES = [defaultdict(set) for _ in range(B)] group = [] time_measures["load_dataset"] = time.time() for name in data[group_name][:2]: @@ -279,24 +278,28 @@ def run( with_indices=True, desc="Fingerprinting...", ) + time_measures["minhash"] = time.time() - time_measures["minhash"] time_measures["clustering"] = time.time() batch_size: int = 10000 - for i in tqdm( - range(0, len(embedded), batch_size), dynamic_ncols=True, desc="Iterating MinHashes..." # noqa: E501 - ): - batch = embedded[i : i + batch_size] - for key, Hs in zip(batch["__id__"], batch["__signatures__"]): - for H, hashtable in zip(Hs, HASH_TABLES): - hashtable[H].add(key) - for table in tqdm(HASH_TABLES, dynamic_ncols=True, desc="Clustering..."): - for cluster in table.values(): + + for table_idx in range(B): + new_hash_table = defaultdict(set) + for i in tqdm( + range(0, len(embedded), batch_size), dynamic_ncols=True, desc="Iterating MinHashes..." # noqa: E501 + ): + batch = embedded[i : i + batch_size] + for key, Hs in zip(batch["__id__"], batch["__signatures__"]): + new_hash_table[Hs[table_idx]].add(key) + + for cluster in new_hash_table.values(): if len(cluster) <= 1: continue idx = min(cluster) for x in cluster: uf.union(x, idx) + time_measures["clustering"] = time.time() - time_measures["clustering"] time_measures["filtering"] = time.time()