Skip to content

Commit c1e63c8

Browse files
Replace lunar lander with a working example.
1 parent 5d524ad commit c1e63c8

File tree

7 files changed

+248
-270
lines changed

7 files changed

+248
-270
lines changed
Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
11
#!/usr/bin/env bash
22
rm *.csv *.gv *.svg
3-
rm winner*
43
rm neat-checkpoint-*

examples/openai-lander/config renamed to examples/lunar-lander/config-feedforward

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
[NEAT]
2-
pop_size = 150
3-
# Note: the fitness threshold will never be reached because
4-
# we are controlling the termination ourselves based on simulation performance.
52
fitness_criterion = max
6-
fitness_threshold = 1000.0
3+
# Terminate when we reliably solve the task (LunarLander is considered solved
4+
# around an average reward of 200).
5+
fitness_threshold = 200.0
6+
pop_size = 150
77
reset_on_extinction = 0
88

99
no_fitness_termination = False
1010

11-
[LanderGenome]
11+
[DefaultGenome]
12+
# LunarLander observations: x, y, x_dot, y_dot, angle, angular velocity,
13+
# left leg contact, right leg contact.
1214
num_inputs = 8
1315
num_hidden = 0
1416
num_outputs = 4
@@ -73,4 +75,3 @@ elitism = 2
7375
survival_threshold = 0.2
7476

7577
min_species_size = 2
76-
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
"""\
2+
Feed-forward LunarLander-v3 control example.
3+
4+
This example is structured similarly to examples/xor/evolve-feedforward.py and
5+
produces the same kinds of visual artifacts:
6+
7+
* Fitness curve over generations
8+
* Species size stack plot
9+
* Network diagrams (full and pruned) of the winning genome
10+
"""
11+
12+
import multiprocessing
13+
import os
14+
import pickle
15+
16+
import gymnasium as gym
17+
import neat
18+
import visualize
19+
20+
# Evaluation parameters.
21+
runs_per_net = 5
22+
max_steps = 1000
23+
24+
25+
def eval_genome(genome, config):
26+
"""Evaluate a single genome on the LunarLander-v3 environment."""
27+
net = neat.nn.FeedForwardNetwork.create(genome, config)
28+
fitnesses = []
29+
30+
for _ in range(runs_per_net):
31+
# Create a fresh environment for each run (no rendering during training).
32+
env = gym.make("LunarLander-v3")
33+
observation, info = env.reset()
34+
35+
total_reward = 0.0
36+
for _ in range(max_steps):
37+
# Network outputs four action values; take the argmax as the discrete action.
38+
action_values = net.activate(observation)
39+
action = max(range(len(action_values)), key=lambda i: action_values[i])
40+
41+
observation, reward, terminated, truncated, info = env.step(action)
42+
total_reward += reward
43+
44+
if terminated or truncated:
45+
break
46+
47+
env.close()
48+
fitnesses.append(total_reward)
49+
50+
# Use the average reward over runs as the fitness.
51+
return sum(fitnesses) / len(fitnesses)
52+
53+
54+
def eval_genomes(genomes, config):
55+
for genome_id, genome in genomes:
56+
genome.fitness = eval_genome(genome, config)
57+
58+
59+
def run(config_file):
60+
# Load configuration.
61+
config = neat.Config(
62+
neat.DefaultGenome,
63+
neat.DefaultReproduction,
64+
neat.DefaultSpeciesSet,
65+
neat.DefaultStagnation,
66+
config_file,
67+
)
68+
69+
# Create the population, which is the top-level object for a NEAT run.
70+
p = neat.Population(config)
71+
72+
# Add a stdout reporter to show progress in the terminal.
73+
p.add_reporter(neat.StdOutReporter(True))
74+
stats = neat.StatisticsReporter()
75+
p.add_reporter(stats)
76+
# Periodic checkpoints, similar to other examples.
77+
p.add_reporter(neat.Checkpointer(10))
78+
79+
# Use parallel evaluation across available CPU cores.
80+
pe = neat.ParallelEvaluator(multiprocessing.cpu_count(), eval_genome)
81+
82+
# Run until solution or fitness threshold is reached (see config).
83+
winner = p.run(pe.evaluate, 500)
84+
85+
# Display the winning genome.
86+
print(f"\nBest genome:\n{winner!s}")
87+
88+
# Save the winner for later reuse in test-feedforward.py.
89+
with open("winner-feedforward.pickle", "wb") as f:
90+
pickle.dump(winner, f)
91+
92+
# Visualization artifacts analogous to examples/xor/evolve-feedforward.py.
93+
# Fitness & species plots.
94+
visualize.plot_stats(
95+
stats,
96+
ylog=False,
97+
view=True,
98+
filename="feedforward-fitness.svg",
99+
)
100+
visualize.plot_species(
101+
stats,
102+
view=True,
103+
filename="feedforward-speciation.svg",
104+
)
105+
106+
# Node labels for easier interpretation of the evolved controller.
107+
node_names = {
108+
# Observations
109+
-1: "x",
110+
-2: "y",
111+
-3: "x_dot",
112+
-4: "y_dot",
113+
-5: "angle",
114+
-6: "ang_vel",
115+
-7: "left_leg",
116+
-8: "right_leg",
117+
# Discrete actions
118+
0: "do_nothing",
119+
1: "fire_left",
120+
2: "fire_main",
121+
3: "fire_right",
122+
}
123+
124+
# Full and pruned network diagrams for the winning genome.
125+
visualize.draw_net(
126+
config,
127+
winner,
128+
view=True,
129+
node_names=node_names,
130+
filename="winner-feedforward.gv",
131+
)
132+
visualize.draw_net(
133+
config,
134+
winner,
135+
view=True,
136+
node_names=node_names,
137+
filename="winner-feedforward-pruned.gv",
138+
prune_unused=True,
139+
)
140+
141+
return winner, stats
142+
143+
144+
if __name__ == "__main__":
145+
# Determine path to configuration file. This path manipulation is
146+
# here so that the script will run successfully regardless of the
147+
# current working directory.
148+
local_dir = os.path.dirname(__file__)
149+
config_path = os.path.join(local_dir, "config-feedforward")
150+
run(config_path)
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
"""\
2+
Test and visualize the performance of the best genome produced by
3+
examples/lunar-lander/evolve-feedforward.py on the LunarLander-v3 environment.
4+
"""
5+
6+
import os
7+
import pickle
8+
import sys
9+
10+
import gymnasium as gym
11+
import neat
12+
13+
14+
def run_episodes(net, episodes=3, render=True):
15+
"""Run a few episodes using the provided network and optionally render."""
16+
if render:
17+
env = gym.make("LunarLander-v3", render_mode="human")
18+
else:
19+
env = gym.make("LunarLander-v3")
20+
21+
try:
22+
rewards = []
23+
for episode in range(episodes):
24+
observation, info = env.reset()
25+
total_reward = 0.0
26+
step = 0
27+
28+
while True:
29+
step += 1
30+
action_values = net.activate(observation)
31+
action = max(range(len(action_values)), key=lambda i: action_values[i])
32+
33+
observation, reward, terminated, truncated, info = env.step(action)
34+
total_reward += reward
35+
36+
if terminated or truncated:
37+
break
38+
39+
rewards.append(total_reward)
40+
print(
41+
f"Episode {episode + 1}: steps={step}, total_reward={total_reward:.2f}",
42+
)
43+
finally:
44+
env.close()
45+
46+
if rewards:
47+
avg = sum(rewards) / len(rewards)
48+
print(f"\nAverage reward over {len(rewards)} episodes: {avg:.2f}")
49+
50+
51+
def load_and_test(genome_path, config_path, episodes=3, render=True):
52+
"""Load a saved genome and test it on LunarLander-v3."""
53+
# Load the config.
54+
config = neat.Config(
55+
neat.DefaultGenome,
56+
neat.DefaultReproduction,
57+
neat.DefaultSpeciesSet,
58+
neat.DefaultStagnation,
59+
config_path,
60+
)
61+
62+
# Load the genome.
63+
with open(genome_path, "rb") as f:
64+
genome = pickle.load(f)
65+
66+
print("Loaded genome:")
67+
print(genome)
68+
69+
# Create the network and run episodes.
70+
net = neat.nn.FeedForwardNetwork.create(genome, config)
71+
run_episodes(net, episodes=episodes, render=render)
72+
73+
74+
if __name__ == "__main__":
75+
# Determine local paths.
76+
local_dir = os.path.dirname(__file__)
77+
config_path = os.path.join(local_dir, "config-feedforward")
78+
79+
# Optional argument: custom path to winner genome.
80+
if len(sys.argv) > 1:
81+
genome_path = sys.argv[1]
82+
else:
83+
genome_path = os.path.join(local_dir, "winner-feedforward.pickle")
84+
85+
if not os.path.exists(genome_path):
86+
print(f"Error: Genome file not found at {genome_path}")
87+
print("Please train a network first by running evolve-feedforward.py")
88+
sys.exit(1)
89+
90+
print(f"Testing genome from: {genome_path}\n")
91+
load_and_test(genome_path, config_path, episodes=3, render=True)
File renamed without changes.

examples/openai-lander/clean.bat

Lines changed: 0 additions & 3 deletions
This file was deleted.

0 commit comments

Comments
 (0)