From b97c367165fb305b05282855cf0ef69e7fb9051c Mon Sep 17 00:00:00 2001 From: Winston H <56998716+winstxnhdw@users.noreply.github.com> Date: Fri, 17 May 2024 08:48:18 +0800 Subject: [PATCH] perf: avoid square root computation when possible (#19) Signed-off-by: winstxnhdw --- mapf-viz/src/visibility_visual.rs | 13 +++++++------ mapf/src/motion/conflict.rs | 2 +- mapf/src/motion/waypoint.rs | 2 +- mapf/src/negotiation/mod.rs | 8 +++++--- 4 files changed, 14 insertions(+), 11 deletions(-) diff --git a/mapf-viz/src/visibility_visual.rs b/mapf-viz/src/visibility_visual.rs index b150dea..1e43600 100644 --- a/mapf-viz/src/visibility_visual.rs +++ b/mapf-viz/src/visibility_visual.rs @@ -150,18 +150,19 @@ impl VisibilityVisual { fn find_closest(&self, p: iced::Point) -> Option { let mut closest: Option<(Cell, f64)> = None; let r = self.visibility.agent_radius(); + let r_squared = r * r; let p = Point::new(p.x as f64, p.y as f64); for (cell, _) in self.visibility.iter_points() { let p_cell = cell.center_point(self.grid().cell_size()); - let dist = (p_cell - p).norm(); - if dist <= r { - if let Some((_, old_dist)) = closest { - if dist < old_dist { - closest = Some((*cell, dist)); + let dist_squared = (p_cell - p).norm_squared(); + if dist_squared <= r_squared { + if let Some((_, old_dist_squared)) = closest { + if dist_squared < old_dist_squared { + closest = Some((*cell, dist_squared)); } } else { - closest = Some((*cell, dist)); + closest = Some((*cell, dist_squared)); } } } diff --git a/mapf/src/motion/conflict.rs b/mapf/src/motion/conflict.rs index 1e1413f..609b26a 100644 --- a/mapf/src/motion/conflict.rs +++ b/mapf/src/motion/conflict.rs @@ -148,7 +148,7 @@ where } // The final state should be almost exactly the same as the last move - assert!((to_state.point() - last_p).norm() < 1e-3); + assert!((to_state.point() - last_p).norm_squared() < 1e-6); Arclength { translational, rotational, diff --git a/mapf/src/motion/waypoint.rs b/mapf/src/motion/waypoint.rs index c62f251..69d712b 100644 --- a/mapf/src/motion/waypoint.rs +++ b/mapf/src/motion/waypoint.rs @@ -99,7 +99,7 @@ where } // The final state should be almost exactly the same as the last move - assert!((to_state.point() - last_p).norm() < 1e-3); + assert!((to_state.point() - last_p).norm_squared() < 1e-6); Arclength { translational, rotational, diff --git a/mapf/src/negotiation/mod.rs b/mapf/src/negotiation/mod.rs index a6fccce..49f7f78 100644 --- a/mapf/src/negotiation/mod.rs +++ b/mapf/src/negotiation/mod.rs @@ -68,15 +68,17 @@ pub fn negotiate( let cs = scenario.cell_size; let mut conflicts = HashMap::new(); triangular_for(scenario.agents.iter(), |(n_a, a), (n_b, b)| { + let min_dist = a.radius + b.radius; + let min_dist_squared = min_dist * min_dist; + for (cell_a, cell_b) in [ (a.start_cell(), b.start_cell()), (a.goal_cell(), b.goal_cell()), ] { let pa = cell_a.center_point(cs); let pb = cell_b.center_point(cs); - let dist = (pa - pb).norm(); - let min_dist = a.radius + b.radius; - if dist < min_dist { + let dist_squared = (pa - pb).norm_squared(); + if dist_squared < min_dist_squared { conflicts.insert( (**n_a).clone().min((*n_b).clone()), (**n_a).clone().max((*n_b).clone()),