Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flexible snapshot number for data shuffling #570

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 68 additions & 20 deletions mala/datahandling/data_shuffler.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,36 @@ def __shuffle_numpy(
)
)

# if the number of new snapshots is not a divisor of the grid size
# then we have to trim the original snapshots to size
# the indicies to be removed are selected at random
if self.data_points_to_remove is not None:
if self.parameters.shuffling_seed is not None:
np.random.seed(idx * self.parameters.shuffling_seed)
ngrid = descriptor_data[idx].shape[0]
n_descriptor = descriptor_data[idx].shape[-1]
n_target = target_data[idx].shape[-1]

current_target = target_data[idx].reshape(-1, n_target)
current_descriptor = descriptor_data[idx].reshape(
-1, n_descriptor
)

indices = np.random.choice(
ngrid**3,
size=ngrid**3 - self.data_points_to_remove[idx],
)

descriptor_data[idx] = current_descriptor[indices]
target_data[idx] = current_target[indices]

# Do the actual shuffling.
target_name_openpmd = os.path.join(target_save_path,
save_name.replace("*", "%T"))
descriptor_name_openpmd = os.path.join(descriptor_save_path,
save_name.replace("*", "%T"))
target_name_openpmd = os.path.join(
target_save_path, save_name.replace("*", "%T")
)
descriptor_name_openpmd = os.path.join(
descriptor_save_path, save_name.replace("*", "%T")
)
for i in range(0, number_of_new_snapshots):
new_descriptors = np.zeros(
(int(np.prod(shuffle_dimensions)), self.input_dimension),
Expand Down Expand Up @@ -163,16 +188,12 @@ def __shuffle_numpy(
)
new_descriptors[
last_start : current_chunk + last_start
] = descriptor_data[j].reshape(
current_grid_size, self.input_dimension
)[
] = descriptor_data[j].reshape(-1, self.input_dimension)[
i * current_chunk : (i + 1) * current_chunk, :
]
new_targets[
last_start : current_chunk + last_start
] = target_data[j].reshape(
current_grid_size, self.output_dimension
)[
] = target_data[j].reshape(-1, self.output_dimension)[
i * current_chunk : (i + 1) * current_chunk, :
]

Expand Down Expand Up @@ -238,7 +259,6 @@ def __shuffle_numpy(
# It will be executed one after another for both of them.
# Use this class to parameterize which of both should be shuffled.
class __DescriptorOrTarget:

def __init__(
self,
save_path,
Expand All @@ -256,7 +276,6 @@ def __init__(
self.dimension = dimension

class __MockedMPIComm:

def __init__(self):
self.rank = 0
self.size = 1
Expand Down Expand Up @@ -363,9 +382,7 @@ def from_chunk_i(i, n, dset, slice_dimension=0):
import json

# Do the actual shuffling.
name_prefix = os.path.join(
dot.save_path, save_name.replace("*", "%T")
)
name_prefix = os.path.join(dot.save_path, save_name.replace("*", "%T"))
for i in range(my_items_start, my_items_end):
# We check above that in the non-numpy case, OpenPMD will work.
dot.calculator.grid_dimensions = list(shuffle_dimensions)
Expand Down Expand Up @@ -521,6 +538,8 @@ def shuffle_snapshots(
]
number_of_data_points = np.sum(snapshot_size_list)

self.data_points_to_remove = None

if number_of_shuffled_snapshots is None:
# If the user does not tell us how many snapshots to use,
# we have to check if the number of snapshots is straightforward.
Expand Down Expand Up @@ -584,10 +603,40 @@ def shuffle_snapshots(
del specified_number_of_new_snapshots

if number_of_data_points % number_of_new_snapshots != 0:
raise Exception(
"Cannot create this number of snapshots "
"from data provided."
)
if snapshot_type == "numpy":
self.data_points_to_remove = []
for i in range(0, self.nr_snapshots):
gridsize = self.parameters.snapshot_directories_list[
i
].grid_size
shuffled_gridsize = int(
gridsize / number_of_new_snapshots
)
self.data_points_to_remove.append(
gridsize
- shuffled_gridsize * number_of_new_snapshots
)
tot_points_missing = sum(self.data_points_to_remove)

printout(
"Warning: number of requested snapshots is not a divisor of",
"the original grid sizes.\n",
f"{tot_points_missing} / {number_of_data_points} data points",
"will be left out of the shuffled snapshots."
)

shuffle_dimensions = [
int(number_of_data_points / number_of_new_snapshots),
1,
1,
]

elif snapshot_type == "openpmd":
# TODO implement arbitrary grid sizes for openpmd
raise Exception(
"Cannot create this number of snapshots "
"from data provided."
)
else:
shuffle_dimensions = [
int(number_of_data_points / number_of_new_snapshots),
Expand All @@ -606,7 +655,6 @@ def shuffle_snapshots(
permutations = []
seeds = []
for i in range(0, number_of_new_snapshots):

# This makes the shuffling deterministic, if specified by the user.
if self.parameters.shuffling_seed is not None:
np.random.seed(i * self.parameters.shuffling_seed)
Expand Down