-
Notifications
You must be signed in to change notification settings - Fork 333
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Created class for pyramid topology #142
Changes from 1 commit
7842d9e
f595feb
1585076
ff2d63e
8100888
9d766f6
4fc4938
d69de25
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
""" | ||
A Pyramid Network Topology | ||
|
||
This class implements a star topology where all particles are connected in a | ||
pyramid like fashion. | ||
""" | ||
|
||
# Import from stdlib | ||
import logging | ||
|
||
# Import modules | ||
import numpy as np | ||
from scipy.spatial import Delaunay | ||
|
||
# Import from package | ||
from .. import operators as ops | ||
from .base import Topology | ||
|
||
# Create a logger | ||
logger = logging.getLogger(__name__) | ||
|
||
class Pyramid(Topology): | ||
def __init__(self): | ||
super(Pyramid, self).__init__() | ||
|
||
def compute_gbest(self, swarm): | ||
"""Updates the global best using a pyramid neighborhood approach | ||
|
||
This uses the Delaunay method from :code:`scipy` to triangulate space | ||
with simplices | ||
|
||
Parameters | ||
---------- | ||
swarm : pyswarms.backend.swarms.Swarm | ||
a Swarm instance | ||
|
||
Returns | ||
------- | ||
numpy.ndarray | ||
Best position of shape :code:`(n_dimensions, )` | ||
float | ||
Best cost | ||
""" | ||
try: | ||
# If there are less than 5 particles they are all connected | ||
if swarm.n_particles < 5: | ||
best_pos = swarm.pbest_pos[np.argmin(swarm.pbest_cost)] | ||
best_cost = np.min(swarm.pbest_cost) | ||
else: | ||
pyramid = Delaunay(swarm.position) | ||
indices, indptr = pyramid.vertex_neighbor_vertices | ||
idx = np.array() | ||
# Insert all the neighbors for each particle in the idx array | ||
for i in range(swarm.n_particles): | ||
idx.append(indptr[indices[i]:indices[i+1]]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Awesome! Can you check if this can be golfed via list-comprehension? idx = [indptr[indices[i]:indices[i+1] for i in range(swarm.n_particles)] There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure I guess that is possible! I'm currently struggling with the rebasing. I rebased a wrong commit and I'm trying to resolve the mess I did 😃. Could I just refork the repo and add the files again? Or is does this mess up anything? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, nevermind got it working now! It's online now 👍. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Awesome! Yup just found it right now! |
||
idx_min = swarm.pbest_cost[idx].argmin(axis=1) | ||
best_neighbor = idx[np.arange(len(idx)), idx_min] | ||
|
||
# Obtain best cost and position | ||
best_cost = np.min(swarm.pbest_cost[best_neighbor]) | ||
best_pos = swarm.pbest_pos[ | ||
np.argmin(swarm.pbest_cost[best_neighbor]) | ||
] | ||
except AttributeError: | ||
msg = "Please pass a Swarm class. You passed {}".format( | ||
type(swarm) | ||
) | ||
logger.error(msg) | ||
raise | ||
else: | ||
return (best_pos, best_cost) | ||
|
||
def compute_velocity(self, swarm, clamp=None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hey @whzup , can you check the indentation here? I think you should move this 4 spaces backward 👍 |
||
"""Computes the velocity matrix | ||
|
||
This method updates the velocity matrix using the best and current | ||
positions of the swarm. The velocity matrix is computed using the | ||
cognitive and social terms of the swarm. | ||
|
||
A sample usage can be seen with the following: | ||
|
||
.. code-block :: python | ||
|
||
import pyswarms.backend as P | ||
from pyswarms.swarms.backend import Swarm | ||
from pyswarms.backend.topology import Pyramid | ||
|
||
my_swarm = P.create_swarm(n_particles, dimensions) | ||
my_topology = Pyramid() | ||
|
||
for i in range(iters): | ||
# Inside the for-loop | ||
my_swarm.velocity = my_topology.update_velocity(my_swarm, clamp) | ||
|
||
Parameters | ||
---------- | ||
swarm : pyswarms.backend.swarms.Swarm | ||
a Swarm instance | ||
clamp : tuple of floats (default is :code:`None`) | ||
a tuple of size 2 where the first entry is the minimum velocity | ||
and the second entry is the maximum velocity. It | ||
sets the limits for velocity clamping. | ||
|
||
Returns | ||
------- | ||
numpy.ndarray | ||
Updated velocity matrix | ||
""" | ||
return ops.compute_velocity(swarm, clamp) | ||
|
||
def compute_position(self, swarm, bounds=None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The same issue with the indentation here 👍 |
||
"""Updates the position matrix | ||
|
||
This method updates the position matrix given the current position and | ||
the velocity. If bounded, it waives updating the position. | ||
|
||
Parameters | ||
---------- | ||
swarm : pyswarms.backend.swarms.Swarm | ||
a Swarm instance | ||
bounds : tuple of :code:`np.ndarray` or list (default is :code:`None`) | ||
a tuple of size 2 where the first entry is the minimum bound while | ||
the second entry is the maximum bound. Each array must be of shape | ||
:code:`(dimensions,)`. | ||
|
||
Returns | ||
------- | ||
numpy.ndarray | ||
New position-matrix | ||
""" | ||
return ops.compute_position(swarm, bounds) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
|
||
# Import modules | ||
import pytest | ||
import numpy as np | ||
|
||
# Import from package | ||
from pyswarms.backend.topology import Pyramid | ||
|
||
|
||
def test_compute_gbest_return_values(swarm): | ||
"""Test if compute_gbest() gives the expected return values""" | ||
topology = Pyramid() | ||
expected_cost = 1 | ||
expected_pos = np.array([1, 2, 3]) | ||
pos, cost = topology.compute_gbest(swarm) | ||
assert cost == expected_cost | ||
assert (pos == expected_pos).all() | ||
|
||
|
||
@pytest.mark.parametrize("clamp", [None, (0, 1), (-1, 1)]) | ||
def test_compute_velocity_return_values(swarm, clamp): | ||
"""Test if compute_velocity() gives the expected shape and range""" | ||
topology = Pyramid() | ||
v = topology.compute_velocity(swarm, clamp) | ||
assert v.shape == swarm.position.shape | ||
if clamp is not None: | ||
assert (clamp[0] <= v).all() and (clamp[1] >= v).all() | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"bounds", | ||
[None, ([-5, -5, -5], [5, 5, 5]), ([-10, -10, -10], [10, 10, 10])], | ||
) | ||
def test_compute_position_return_values(swarm, bounds): | ||
"""Test if compute_position() gives the expected shape and range""" | ||
topology = Pyramid() | ||
p = topology.compute_position(swarm, bounds) | ||
assert p.shape == swarm.velocity.shape | ||
if bounds is not None: | ||
assert (bounds[0] <= p).all() and (bounds[1] >= p).all() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder what
indptr
stands for? Index pointer? 😕There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can change this into a more understandable variable? 👍