Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed improvements to RadarDataset #85

Merged
merged 6 commits into from
May 10, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 70 additions & 44 deletions bird_cloud_gnn/radar_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(
ValueError: If `data_folder` is not a valid folder.
"""
if not os.path.isdir(data_folder):
raise ValueError(f"'${data_folder}' is not a folder")
raise ValueError(f"'{data_folder}' is not a folder")
self._name = name
self.data_folder = data_folder
self.features = features
Expand All @@ -81,53 +81,79 @@ def __init__(
),
)

def process(self):
"""Internal function for the DGLDataset. Process the folder to create the graphs."""
def _read_one_file(self, data_file):
"""Reads a file and creates the graphs and labels for it."""

xyz = ["x", "y", "z"]
self.graphs = []
self.labels = []
for data_file in os.listdir(self.data_folder):
split_on_dots = data_file.split(".")
if split_on_dots[-1] != "csv" and ".".join(split_on_dots[-2:]) != "csv.gz":
continue
split_on_dots = data_file.split(".")
if (
split_on_dots[-1] not in ["csv", "parquet"]
and ".".join(split_on_dots[-2:]) != "csv.gz"
):
return
if split_on_dots[-1] == "parquet":
data = pd.read_parquet(os.path.join(self.data_folder, data_file))
else:
data = pd.read_csv(os.path.join(self.data_folder, data_file))
data = data.drop(
data[
np.logical_or(
data.range > 100000,
np.logical_or(data.z > 10000, data.range < 5000),
)
].index
).reset_index(drop=True)
na_index = data[data[self.target].isna()].index
data_notna = data.drop(na_index)
notna_index = data_notna.index
data_notna.reset_index(drop=True, inplace=True)
tree = KDTree(data.loc[:, xyz])
tree_notna = KDTree(data_notna.loc[:, xyz])
distance_matrix = tree_notna.sparse_distance_matrix(tree, self.max_distance)
number_neighbours = (
np.array(np.sum(distance_matrix > 0, axis=1)).reshape(-1) + 1
data = data.drop(
data[
np.logical_or(
data.range > 100000,
np.logical_or(data.z > 10000, data.range < 5000),
)
].index
).reset_index(drop=True)

data_xyz = data[xyz]
data_features = data[self.features]

na_index = data[data[self.target].isna()].index

data_xyz_notna = data_xyz.drop(na_index)
data_features_notna = data_features.drop(na_index)

data_target = data[self.target]
data_target_notna = data_target[data_xyz_notna.index]

data_xyz_notna.reset_index(drop=True, inplace=True)
data_features_notna.reset_index(drop=True, inplace=True)

tree = KDTree(data_xyz)
tree_notna = KDTree(data_xyz_notna)

distance_matrix = tree_notna.sparse_distance_matrix(
tree, self.max_distance, output_type="coo_matrix"
)

number_neighbours = distance_matrix.getnnz(1)
points_of_interest = np.where(number_neighbours >= self.min_neighbours)[0]

_, poi_indexes = tree.query(
data_xyz_notna.loc[points_of_interest], self.min_neighbours
)
self.labels = np.concatenate(
(self.labels, data_target_notna.values[points_of_interest])
)
for _, indexes in enumerate(poi_indexes):
local_tree = KDTree(data_xyz.iloc[indexes]) # slow
distances = local_tree.sparse_distance_matrix(
local_tree, self.max_edge_distance, output_type="coo_matrix"
)
points_of_interest = np.where(number_neighbours >= self.min_neighbours)[0]
graph = dgl.graph((distances.row, distances.col))

for point in points_of_interest:
_, indexes = tree.query(
data.loc[notna_index[point], xyz], self.min_neighbours
)
local_tree = KDTree(data.loc[indexes, xyz])
distances = local_tree.sparse_distance_matrix(
local_tree, self.max_edge_distance, output_type="coo_matrix"
)
graph = dgl.graph((distances.row, distances.col))

# TODO: Better fillna
local_data = data.loc[indexes, self.features].fillna(0)
assert not np.any(np.isnan(local_data))
graph.ndata["x"] = torch.tensor(local_data.values)
graph.edata["a"] = torch.tensor(distances.data)
self.graphs.append(graph)
self.labels.append(data_notna.loc[point, self.target])
# TODO: Better fillna
local_data = data_features.iloc[indexes].fillna(0)
graph.ndata["x"] = torch.tensor(local_data.values)
graph.edata["a"] = torch.tensor(distances.data)
self.graphs.append(graph)

def process(self):
"""Internal function for the DGLDataset. Process the folder to create the graphs."""

self.graphs = []
self.labels = np.array([])
for data_file in os.listdir(self.data_folder):
self._read_one_file(data_file)

if len(self.graphs) == 0:
raise ValueError("No graphs selected under rules passed")
Expand Down