|
1 | 1 | import os |
2 | 2 | import random |
3 | 3 | import unittest |
| 4 | +import copy |
4 | 5 |
|
5 | 6 | import neat |
6 | 7 | from neat.reporting import ReporterSet |
7 | 8 | from neat.reproduction import DefaultReproduction |
| 9 | +from neat.stagnation import DefaultStagnation |
| 10 | +from neat.species import DefaultSpeciesSet |
8 | 11 |
|
9 | 12 |
|
10 | 13 | class TestSpawnComputation(unittest.TestCase): |
@@ -97,5 +100,206 @@ def test_reproduce_respects_pop_size(self): |
97 | 100 | species_set.speciate(config, population, generation + 1) |
98 | 101 |
|
99 | 102 |
|
| 103 | +class TestElitism(unittest.TestCase): |
| 104 | + """Tests for per-species elitism behavior in DefaultReproduction.""" |
| 105 | + |
| 106 | + def setUp(self): |
| 107 | + # Load standard test configuration. |
| 108 | + local_dir = os.path.dirname(__file__) |
| 109 | + config_path = os.path.join(local_dir, 'test_configuration') |
| 110 | + self.config = neat.Config(neat.DefaultGenome, neat.DefaultReproduction, |
| 111 | + neat.DefaultSpeciesSet, neat.DefaultStagnation, |
| 112 | + config_path) |
| 113 | + |
| 114 | + self.pop_size = self.config.pop_size |
| 115 | + |
| 116 | + # Set up objects similarly to Population.__init__. |
| 117 | + self.reporters = ReporterSet() |
| 118 | + self.stagnation = DefaultStagnation(self.config.stagnation_config, |
| 119 | + self.reporters) |
| 120 | + self.reproduction = DefaultReproduction(self.config.reproduction_config, |
| 121 | + self.reporters, |
| 122 | + self.stagnation) |
| 123 | + self.species_set = DefaultSpeciesSet(self.config.species_set_config, |
| 124 | + self.reporters) |
| 125 | + |
| 126 | + # Deterministic initial population/speciation. |
| 127 | + random.seed(123) |
| 128 | + self.population = self.reproduction.create_new( |
| 129 | + self.config.genome_type, |
| 130 | + self.config.genome_config, |
| 131 | + self.pop_size, |
| 132 | + ) |
| 133 | + self.generation = 0 |
| 134 | + self.species_set.speciate(self.config, self.population, self.generation) |
| 135 | + |
| 136 | + def _assign_deterministic_fitness_by_species(self): |
| 137 | + """Within each species, assign strictly increasing fitness per genome id.""" |
| 138 | + for sid, s in self.species_set.species.items(): |
| 139 | + members = sorted(s.members.items(), key=lambda kv: kv[0]) |
| 140 | + for rank, (gid, genome) in enumerate(members): |
| 141 | + genome.fitness = float(rank + 1) |
| 142 | + |
| 143 | + @staticmethod |
| 144 | + def _snapshot_genome(genome): |
| 145 | + """Create a primitive snapshot of a genome's structure and parameters. |
| 146 | +
|
| 147 | + This avoids comparing object identity of node/connection genes and focuses |
| 148 | + on their attributes (weights, enabled flags, biases, etc.). |
| 149 | + """ |
| 150 | + node_snapshot = {} |
| 151 | + for nid, ng in genome.nodes.items(): |
| 152 | + node_snapshot[nid] = { |
| 153 | + 'bias': getattr(ng, 'bias', None), |
| 154 | + 'response': getattr(ng, 'response', None), |
| 155 | + 'activation': getattr(ng, 'activation', None), |
| 156 | + 'aggregation': getattr(ng, 'aggregation', None), |
| 157 | + } |
| 158 | + |
| 159 | + conn_snapshot = {} |
| 160 | + for key, cg in genome.connections.items(): |
| 161 | + conn_snapshot[key] = { |
| 162 | + 'weight': getattr(cg, 'weight', None), |
| 163 | + 'enabled': getattr(cg, 'enabled', None), |
| 164 | + 'innovation': getattr(cg, 'innovation', None), |
| 165 | + } |
| 166 | + |
| 167 | + return { |
| 168 | + 'nodes': node_snapshot, |
| 169 | + 'connections': conn_snapshot, |
| 170 | + 'fitness': genome.fitness, |
| 171 | + } |
| 172 | + |
| 173 | + def test_elites_preserved_for_surviving_species(self): |
| 174 | + """Elites of non-extinct species are preserved and unmodified between generations. |
| 175 | +
|
| 176 | + For each species that has at least one descendant in the next generation, |
| 177 | + the top min(elitism, species_size) genomes by fitness must survive with |
| 178 | + identical genome parameters. |
| 179 | + """ |
| 180 | + self._assign_deterministic_fitness_by_species() |
| 181 | + |
| 182 | + elitism = self.config.reproduction_config.elitism |
| 183 | + new_ids = None |
| 184 | + |
| 185 | + # Record original members, elites, and structural snapshots of elite genomes. |
| 186 | + original_members_by_species = {} |
| 187 | + elite_ids_by_species = {} |
| 188 | + elite_snapshot = {} |
| 189 | + |
| 190 | + for sid, s in self.species_set.species.items(): |
| 191 | + members = sorted(s.members.items(), key=lambda kv: kv[1].fitness, |
| 192 | + reverse=True) |
| 193 | + original_members_by_species[sid] = {gid for gid, _ in members} |
| 194 | + |
| 195 | + expected_elites = members[:min(elitism, len(members))] |
| 196 | + elite_ids = [gid for gid, _ in expected_elites] |
| 197 | + elite_ids_by_species[sid] = elite_ids |
| 198 | + |
| 199 | + for gid, genome in expected_elites: |
| 200 | + elite_snapshot[gid] = self._snapshot_genome(genome) |
| 201 | + |
| 202 | + # Reproduce one generation. |
| 203 | + new_population = self.reproduction.reproduce( |
| 204 | + self.config, self.species_set, self.pop_size, self.generation |
| 205 | + ) |
| 206 | + new_ids = set(new_population.keys()) |
| 207 | + |
| 208 | + # For each species, check surviving original genomes. |
| 209 | + for sid, original_ids in original_members_by_species.items(): |
| 210 | + surviving_ids = original_ids & new_ids |
| 211 | + |
| 212 | + if not surviving_ids: |
| 213 | + # Species went extinct; elites may legitimately be gone. |
| 214 | + continue |
| 215 | + |
| 216 | + expected_elites = elite_ids_by_species[sid] |
| 217 | + |
| 218 | + # Expect exactly the expected number of elites to survive, and no |
| 219 | + # other original members. |
| 220 | + self.assertEqual( |
| 221 | + surviving_ids, |
| 222 | + set(expected_elites), |
| 223 | + f"Species {sid} survivors {surviving_ids} do not match elites {expected_elites}", |
| 224 | + ) |
| 225 | + |
| 226 | + # Ensure elite genomes were not mutated (structurally or in parameters). |
| 227 | + for gid in expected_elites: |
| 228 | + self.assertIn(gid, new_population) |
| 229 | + self.assertEqual( |
| 230 | + elite_snapshot[gid], |
| 231 | + self._snapshot_genome(new_population[gid]), |
| 232 | + f"Elite genome {gid} from species {sid} was modified", |
| 233 | + ) |
| 234 | + |
| 235 | + def test_elites_may_disappear_only_when_species_extinct(self): |
| 236 | + """Elites are dropped only when their entire species is removed by stagnation.""" |
| 237 | + # Need at least two species to distinguish "good" vs "bad". |
| 238 | + if len(self.species_set.species) < 2: |
| 239 | + self.skipTest("Need at least two species for this test") |
| 240 | + |
| 241 | + # Choose one species to be marked stagnant (bad) and others to improve (good). |
| 242 | + species_items = sorted(self.species_set.species.items(), key=lambda kv: kv[0]) |
| 243 | + bad_sid, bad_species = species_items[0] |
| 244 | + other_species_items = species_items[1:] |
| 245 | + |
| 246 | + # Assign fitness histories and current fitness so that the bad species |
| 247 | + # is considered stagnant while others are improving. |
| 248 | + max_stag = self.config.stagnation_config.max_stagnation |
| 249 | + generation = max_stag + 1 |
| 250 | + |
| 251 | + elitism = self.config.reproduction_config.elitism |
| 252 | + |
| 253 | + # Record original members and elite ids for all species. |
| 254 | + original_members_by_species = {} |
| 255 | + elite_ids_by_species = {} |
| 256 | + |
| 257 | + for sid, s in self.species_set.species.items(): |
| 258 | + members = sorted(s.members.items(), key=lambda kv: kv[0]) |
| 259 | + original_members_by_species[sid] = {gid for gid, _ in members} |
| 260 | + # Use simple per-species ranking for elites. |
| 261 | + elites = members[:min(elitism, len(members))] |
| 262 | + elite_ids_by_species[sid] = [gid for gid, _ in elites] |
| 263 | + |
| 264 | + # Configure bad species: past improvement and current non-improvement. |
| 265 | + bad_species.fitness_history = [2.0] |
| 266 | + bad_species.last_improved = 0 |
| 267 | + for genome in bad_species.members.values(): |
| 268 | + genome.fitness = 1.0 |
| 269 | + |
| 270 | + # Configure other species: improvement this generation. |
| 271 | + for sid, s in other_species_items: |
| 272 | + s.fitness_history = [1.0] |
| 273 | + s.last_improved = 0 |
| 274 | + for genome in s.members.values(): |
| 275 | + genome.fitness = 10.0 |
| 276 | + |
| 277 | + # Reproduce one generation; stagnation logic inside reproduce will |
| 278 | + # remove the bad species entirely. |
| 279 | + new_population = self.reproduction.reproduce( |
| 280 | + self.config, self.species_set, self.pop_size, generation |
| 281 | + ) |
| 282 | + new_ids = set(new_population.keys()) |
| 283 | + |
| 284 | + # Bad species should have no surviving original members (extinct). |
| 285 | + bad_original_ids = original_members_by_species[bad_sid] |
| 286 | + bad_survivors = bad_original_ids & new_ids |
| 287 | + self.assertEqual( |
| 288 | + bad_survivors, |
| 289 | + set(), |
| 290 | + f"Bad species {bad_sid} should have gone extinct but has survivors {bad_survivors}", |
| 291 | + ) |
| 292 | + |
| 293 | + # At least one other species must have survivors. |
| 294 | + any_other_survivors = False |
| 295 | + for sid, _ in other_species_items: |
| 296 | + orig_ids = original_members_by_species[sid] |
| 297 | + survivors = orig_ids & new_ids |
| 298 | + if survivors: |
| 299 | + any_other_survivors = True |
| 300 | + |
| 301 | + self.assertTrue(any_other_survivors, "No surviving species after reproduction") |
| 302 | + |
| 303 | + |
100 | 304 | if __name__ == '__main__': |
101 | 305 | unittest.main() |
0 commit comments