Skip to content

Commit

Permalink
Simplify sorting of individuals in mu_plus_lambda
Browse files Browse the repository at this point in the history
  • Loading branch information
jakobj committed Jul 3, 2020
1 parent 3f6639e commit 2042586
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 20 deletions.
26 changes: 8 additions & 18 deletions cgp/ea/mu_plus_lambda.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import concurrent.futures
import numpy as np

from typing import Callable, List, Tuple
from typing import Callable, List

from ..individual import IndividualBase
from ..population import Population
Expand Down Expand Up @@ -151,33 +151,23 @@ def _compute_fitness(
return combined

def _sort(self, combined: List[IndividualBase]) -> List[IndividualBase]:
# create copy of population
combined_copy = [ind.clone() for ind in combined]
def sort_func(ind: IndividualBase) -> float:
"""Return fitness of an individual, return -infinity for an individual
with fitness equal nan, or raise error if the fitness is
not a float.
# replace all nan by -inf to make sure they end up at the end
# after sorting
for ind in combined_copy:
"""
if np.isnan(ind.fitness):
ind.fitness = -np.inf
return -np.inf

def sort_func(zipped_ind: Tuple[int, IndividualBase]) -> float:
"""Return fitness of an individual or raise error if it is None.
"""
_, ind = zipped_ind
if isinstance(ind.fitness, float):
return ind.fitness
else:
raise ValueError(
f"IndividualBase fitness value is of wrong type {type(ind.fitness)}."
)

# get list of indices that sorts combined_copy ("argsort") in descending order
combined_sorted_indices = [
idx for (idx, _) in sorted(enumerate(combined_copy), key=sort_func, reverse=True)
]

# return original list of individuals sorted in descending order
return [combined[idx] for idx in combined_sorted_indices]
return sorted(combined, key=sort_func, reverse=True)

def _create_new_parent_population(
self, n_parents: int, combined: List[IndividualBase]
Expand Down
6 changes: 4 additions & 2 deletions test/test_ea_mu_plus_lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ def objective_with_label(individual, label):
assert pop.champion.fitness == pytest.approx(-1.0)


def test_fitness_contains_nan(population_params, genome_params):
def test_fitness_contains_and_maintains_nan(population_params, genome_params):
def objective(individual):
if np.random.rand() < 0.5:
if np.random.rand() < 0.95:
individual.fitness = np.nan
else:
individual.fitness = np.random.rand()
Expand All @@ -41,3 +41,5 @@ def objective(individual):
ea = cgp.ea.MuPlusLambda(10, 10, 1)
ea.initialize_fitness_parents(pop, objective)
ea.step(pop, objective)

assert np.nan in [ind.fitness for ind in pop]

0 comments on commit 2042586

Please sign in to comment.