Skip to content

Commit fd657f1

Browse files
Added tests for elite handling.
1 parent cc76c7b commit fd657f1

File tree

1 file changed

+204
-0
lines changed

1 file changed

+204
-0
lines changed

tests/test_reproduction.py

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
import os
22
import random
33
import unittest
4+
import copy
45

56
import neat
67
from neat.reporting import ReporterSet
78
from neat.reproduction import DefaultReproduction
9+
from neat.stagnation import DefaultStagnation
10+
from neat.species import DefaultSpeciesSet
811

912

1013
class TestSpawnComputation(unittest.TestCase):
@@ -97,5 +100,206 @@ def test_reproduce_respects_pop_size(self):
97100
species_set.speciate(config, population, generation + 1)
98101

99102

103+
class TestElitism(unittest.TestCase):
104+
"""Tests for per-species elitism behavior in DefaultReproduction."""
105+
106+
def setUp(self):
107+
# Load standard test configuration.
108+
local_dir = os.path.dirname(__file__)
109+
config_path = os.path.join(local_dir, 'test_configuration')
110+
self.config = neat.Config(neat.DefaultGenome, neat.DefaultReproduction,
111+
neat.DefaultSpeciesSet, neat.DefaultStagnation,
112+
config_path)
113+
114+
self.pop_size = self.config.pop_size
115+
116+
# Set up objects similarly to Population.__init__.
117+
self.reporters = ReporterSet()
118+
self.stagnation = DefaultStagnation(self.config.stagnation_config,
119+
self.reporters)
120+
self.reproduction = DefaultReproduction(self.config.reproduction_config,
121+
self.reporters,
122+
self.stagnation)
123+
self.species_set = DefaultSpeciesSet(self.config.species_set_config,
124+
self.reporters)
125+
126+
# Deterministic initial population/speciation.
127+
random.seed(123)
128+
self.population = self.reproduction.create_new(
129+
self.config.genome_type,
130+
self.config.genome_config,
131+
self.pop_size,
132+
)
133+
self.generation = 0
134+
self.species_set.speciate(self.config, self.population, self.generation)
135+
136+
def _assign_deterministic_fitness_by_species(self):
137+
"""Within each species, assign strictly increasing fitness per genome id."""
138+
for sid, s in self.species_set.species.items():
139+
members = sorted(s.members.items(), key=lambda kv: kv[0])
140+
for rank, (gid, genome) in enumerate(members):
141+
genome.fitness = float(rank + 1)
142+
143+
@staticmethod
144+
def _snapshot_genome(genome):
145+
"""Create a primitive snapshot of a genome's structure and parameters.
146+
147+
This avoids comparing object identity of node/connection genes and focuses
148+
on their attributes (weights, enabled flags, biases, etc.).
149+
"""
150+
node_snapshot = {}
151+
for nid, ng in genome.nodes.items():
152+
node_snapshot[nid] = {
153+
'bias': getattr(ng, 'bias', None),
154+
'response': getattr(ng, 'response', None),
155+
'activation': getattr(ng, 'activation', None),
156+
'aggregation': getattr(ng, 'aggregation', None),
157+
}
158+
159+
conn_snapshot = {}
160+
for key, cg in genome.connections.items():
161+
conn_snapshot[key] = {
162+
'weight': getattr(cg, 'weight', None),
163+
'enabled': getattr(cg, 'enabled', None),
164+
'innovation': getattr(cg, 'innovation', None),
165+
}
166+
167+
return {
168+
'nodes': node_snapshot,
169+
'connections': conn_snapshot,
170+
'fitness': genome.fitness,
171+
}
172+
173+
def test_elites_preserved_for_surviving_species(self):
174+
"""Elites of non-extinct species are preserved and unmodified between generations.
175+
176+
For each species that has at least one descendant in the next generation,
177+
the top min(elitism, species_size) genomes by fitness must survive with
178+
identical genome parameters.
179+
"""
180+
self._assign_deterministic_fitness_by_species()
181+
182+
elitism = self.config.reproduction_config.elitism
183+
new_ids = None
184+
185+
# Record original members, elites, and structural snapshots of elite genomes.
186+
original_members_by_species = {}
187+
elite_ids_by_species = {}
188+
elite_snapshot = {}
189+
190+
for sid, s in self.species_set.species.items():
191+
members = sorted(s.members.items(), key=lambda kv: kv[1].fitness,
192+
reverse=True)
193+
original_members_by_species[sid] = {gid for gid, _ in members}
194+
195+
expected_elites = members[:min(elitism, len(members))]
196+
elite_ids = [gid for gid, _ in expected_elites]
197+
elite_ids_by_species[sid] = elite_ids
198+
199+
for gid, genome in expected_elites:
200+
elite_snapshot[gid] = self._snapshot_genome(genome)
201+
202+
# Reproduce one generation.
203+
new_population = self.reproduction.reproduce(
204+
self.config, self.species_set, self.pop_size, self.generation
205+
)
206+
new_ids = set(new_population.keys())
207+
208+
# For each species, check surviving original genomes.
209+
for sid, original_ids in original_members_by_species.items():
210+
surviving_ids = original_ids & new_ids
211+
212+
if not surviving_ids:
213+
# Species went extinct; elites may legitimately be gone.
214+
continue
215+
216+
expected_elites = elite_ids_by_species[sid]
217+
218+
# Expect exactly the expected number of elites to survive, and no
219+
# other original members.
220+
self.assertEqual(
221+
surviving_ids,
222+
set(expected_elites),
223+
f"Species {sid} survivors {surviving_ids} do not match elites {expected_elites}",
224+
)
225+
226+
# Ensure elite genomes were not mutated (structurally or in parameters).
227+
for gid in expected_elites:
228+
self.assertIn(gid, new_population)
229+
self.assertEqual(
230+
elite_snapshot[gid],
231+
self._snapshot_genome(new_population[gid]),
232+
f"Elite genome {gid} from species {sid} was modified",
233+
)
234+
235+
def test_elites_may_disappear_only_when_species_extinct(self):
236+
"""Elites are dropped only when their entire species is removed by stagnation."""
237+
# Need at least two species to distinguish "good" vs "bad".
238+
if len(self.species_set.species) < 2:
239+
self.skipTest("Need at least two species for this test")
240+
241+
# Choose one species to be marked stagnant (bad) and others to improve (good).
242+
species_items = sorted(self.species_set.species.items(), key=lambda kv: kv[0])
243+
bad_sid, bad_species = species_items[0]
244+
other_species_items = species_items[1:]
245+
246+
# Assign fitness histories and current fitness so that the bad species
247+
# is considered stagnant while others are improving.
248+
max_stag = self.config.stagnation_config.max_stagnation
249+
generation = max_stag + 1
250+
251+
elitism = self.config.reproduction_config.elitism
252+
253+
# Record original members and elite ids for all species.
254+
original_members_by_species = {}
255+
elite_ids_by_species = {}
256+
257+
for sid, s in self.species_set.species.items():
258+
members = sorted(s.members.items(), key=lambda kv: kv[0])
259+
original_members_by_species[sid] = {gid for gid, _ in members}
260+
# Use simple per-species ranking for elites.
261+
elites = members[:min(elitism, len(members))]
262+
elite_ids_by_species[sid] = [gid for gid, _ in elites]
263+
264+
# Configure bad species: past improvement and current non-improvement.
265+
bad_species.fitness_history = [2.0]
266+
bad_species.last_improved = 0
267+
for genome in bad_species.members.values():
268+
genome.fitness = 1.0
269+
270+
# Configure other species: improvement this generation.
271+
for sid, s in other_species_items:
272+
s.fitness_history = [1.0]
273+
s.last_improved = 0
274+
for genome in s.members.values():
275+
genome.fitness = 10.0
276+
277+
# Reproduce one generation; stagnation logic inside reproduce will
278+
# remove the bad species entirely.
279+
new_population = self.reproduction.reproduce(
280+
self.config, self.species_set, self.pop_size, generation
281+
)
282+
new_ids = set(new_population.keys())
283+
284+
# Bad species should have no surviving original members (extinct).
285+
bad_original_ids = original_members_by_species[bad_sid]
286+
bad_survivors = bad_original_ids & new_ids
287+
self.assertEqual(
288+
bad_survivors,
289+
set(),
290+
f"Bad species {bad_sid} should have gone extinct but has survivors {bad_survivors}",
291+
)
292+
293+
# At least one other species must have survivors.
294+
any_other_survivors = False
295+
for sid, _ in other_species_items:
296+
orig_ids = original_members_by_species[sid]
297+
survivors = orig_ids & new_ids
298+
if survivors:
299+
any_other_survivors = True
300+
301+
self.assertTrue(any_other_survivors, "No surviving species after reproduction")
302+
303+
100304
if __name__ == '__main__':
101305
unittest.main()

0 commit comments

Comments
 (0)