Skip to content

Commit 412662b

Browse files
Add feedforward network tests.
1 parent 29048e5 commit 412662b

File tree

1 file changed

+131
-63
lines changed

1 file changed

+131
-63
lines changed

tests/test_feedforward_network.py

Lines changed: 131 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1+
import os
2+
3+
import neat
14
from neat import activations
5+
from neat.genes import DefaultConnectionGene, DefaultNodeGene
26
from neat.nn import FeedForwardNetwork
37

48

@@ -44,71 +48,135 @@ def test_basic():
4448
assert result[0] == r.values[0]
4549

4650

47-
# TODO: Update this test for the current implementation.
48-
# def test_simple_nohidden():
49-
# config_params = {
50-
# 'num_inputs':2,
51-
# 'num_outputs':1,
52-
# 'num_hidden':0,
53-
# 'feed_forward':True,
54-
# 'compatibility_threshold':3.0,
55-
# 'excess_coefficient':1.0,
56-
# 'disjoint_coefficient':1.0,
57-
# 'compatibility_weight_coefficient':1.0,
58-
# 'conn_add_prob':0.5,
59-
# 'conn_delete_prob':0.05,
60-
# 'node_add_prob':0.1,
61-
# 'node_delete_prob':0.05}
62-
# config = DefaultGenomeConfig(config_params)
63-
# config.genome_config.set_input_output_sizes(2, 1)
64-
# g = DefaultGenome(0, config)
65-
# g.add_node(0, 0.0, 1.0, 'sum', 'tanh')
66-
# g.add_connection(-1, 0, 1.0, True)
67-
# g.add_connection(-2, 0, -1.0, True)
68-
#
69-
# net = nn.create_feed_forward_phenotype(g, config)
70-
#
71-
# v00 = net.serial_activate([0.0, 0.0])
72-
# assert_almost_equal(v00[0], 0.0, 1e-3)
73-
#
74-
# v01 = net.serial_activate([0.0, 1.0])
75-
# assert_almost_equal(v01[0], -0.76159, 1e-3)
76-
#
77-
# v10 = net.serial_activate([1.0, 0.0])
78-
# assert_almost_equal(v10[0], 0.76159, 1e-3)
79-
#
80-
# v11 = net.serial_activate([1.0, 1.0])
81-
# assert_almost_equal(v11[0], 0.0, 1e-3)
82-
83-
84-
# TODO: Update this test for the current implementation.
85-
# def test_simple_hidden():
86-
# config = Config()
87-
# config.genome_config.set_input_output_sizes(2, 1)
88-
# g = DefaultGenome(0, config)
89-
#
90-
# g.add_node(0, 0.0, 1.0, 'sum', 'identity')
91-
# g.add_node(1, -0.5, 5.0, 'sum', 'sigmoid')
92-
# g.add_node(2, -1.5, 5.0, 'sum', 'sigmoid')
93-
# g.add_connection(-1, 1, 1.0, True)
94-
# g.add_connection(-2, 2, 1.0, True)
95-
# g.add_connection(1, 0, 1.0, True)
96-
# g.add_connection(2, 0, -1.0, True)
97-
# net = nn.create_feed_forward_phenotype(g, config)
98-
#
99-
# v00 = net.serial_activate([0.0, 0.0])
100-
# assert_almost_equal(v00[0], 0.195115, 1e-3)
101-
#
102-
# v01 = net.serial_activate([0.0, 1.0])
103-
# assert_almost_equal(v01[0], -0.593147, 1e-3)
104-
#
105-
# v10 = net.serial_activate([1.0, 0.0])
106-
# assert_almost_equal(v10[0], 0.806587, 1e-3)
107-
#
108-
# v11 = net.serial_activate([1.0, 1.0])
109-
# assert_almost_equal(v11[0], 0.018325, 1e-3)
51+
def _create_simple_nohidden_network():
52+
"""Small genome-built feedforward net: 2 inputs -> 1 tanh output, no hidden layer."""
53+
local_dir = os.path.dirname(__file__)
54+
config_path = os.path.join(local_dir, "test_configuration")
55+
config = neat.Config(
56+
neat.DefaultGenome,
57+
neat.DefaultReproduction,
58+
neat.DefaultSpeciesSet,
59+
neat.DefaultStagnation,
60+
config_path,
61+
)
62+
63+
genome = neat.DefaultGenome(0)
64+
65+
# Single output node 0 with tanh activation and sum aggregation.
66+
node0 = DefaultNodeGene(0)
67+
node0.bias = 0.0
68+
node0.response = 1.0
69+
node0.activation = "tanh"
70+
node0.aggregation = "sum"
71+
genome.nodes[0] = node0
72+
73+
# Connections: input -1 -> 0 (weight 1.0), input -2 -> 0 (weight -1.0).
74+
conn1_key = (-1, 0)
75+
conn1 = DefaultConnectionGene(conn1_key, innovation=0)
76+
conn1.weight = 1.0
77+
conn1.enabled = True
78+
79+
conn2_key = (-2, 0)
80+
conn2 = DefaultConnectionGene(conn2_key, innovation=1)
81+
conn2.weight = -1.0
82+
conn2.enabled = True
83+
84+
genome.connections[conn1_key] = conn1
85+
genome.connections[conn2_key] = conn2
86+
87+
return FeedForwardNetwork.create(genome, config)
88+
89+
90+
def test_simple_nohidden_from_genome():
91+
"""FeedForwardNetwork.create builds the expected simple no-hidden network."""
92+
net = _create_simple_nohidden_network()
93+
94+
v00 = net.activate([0.0, 0.0])
95+
assert_almost_equal(v00[0], 0.0, 1e-6)
96+
97+
v01 = net.activate([0.0, 1.0])
98+
assert_almost_equal(v01[0], -0.9866142981514303, 1e-6)
99+
100+
v10 = net.activate([1.0, 0.0])
101+
assert_almost_equal(v10[0], 0.9866142981514303, 1e-6)
102+
103+
v11 = net.activate([1.0, 1.0])
104+
assert_almost_equal(v11[0], 0.0, 1e-6)
105+
106+
107+
def _create_simple_hidden_network():
108+
"""Small genome-built feedforward net: 2 inputs -> 2 sigmoid hidden -> 1 identity output."""
109+
local_dir = os.path.dirname(__file__)
110+
config_path = os.path.join(local_dir, "test_configuration")
111+
config = neat.Config(
112+
neat.DefaultGenome,
113+
neat.DefaultReproduction,
114+
neat.DefaultSpeciesSet,
115+
neat.DefaultStagnation,
116+
config_path,
117+
)
118+
119+
genome = neat.DefaultGenome(0)
120+
121+
# Output node 0 (identity), hidden nodes 1 and 2 (sigmoid).
122+
node0 = DefaultNodeGene(0)
123+
node0.bias = 0.0
124+
node0.response = 1.0
125+
node0.activation = "identity"
126+
node0.aggregation = "sum"
127+
128+
node1 = DefaultNodeGene(1)
129+
node1.bias = -0.5
130+
node1.response = 5.0
131+
node1.activation = "sigmoid"
132+
node1.aggregation = "sum"
133+
134+
node2 = DefaultNodeGene(2)
135+
node2.bias = -1.5
136+
node2.response = 5.0
137+
node2.activation = "sigmoid"
138+
node2.aggregation = "sum"
139+
140+
genome.nodes[0] = node0
141+
genome.nodes[1] = node1
142+
genome.nodes[2] = node2
143+
144+
# Connections: -1 -> 1, -2 -> 2, and hidden 1/2 to output 0 (second with weight -1.0).
145+
connections = [
146+
((-1, 1), 1.0),
147+
((-2, 2), 1.0),
148+
((1, 0), 1.0),
149+
((2, 0), -1.0),
150+
]
151+
152+
for innovation, (key, weight) in enumerate(connections):
153+
cg = DefaultConnectionGene(key, innovation=innovation)
154+
cg.weight = weight
155+
cg.enabled = True
156+
genome.connections[key] = cg
157+
158+
return FeedForwardNetwork.create(genome, config)
159+
160+
161+
def test_simple_hidden_from_genome():
162+
"""FeedForwardNetwork.create builds a simple hidden-layer network with expected behavior."""
163+
net = _create_simple_hidden_network()
164+
165+
v00 = net.activate([0.0, 0.0])
166+
assert_almost_equal(v00[0], 0.07530540138431994, 1e-6)
167+
168+
v01 = net.activate([0.0, 1.0])
169+
assert_almost_equal(v01[0], -0.9241417948687655, 1e-6)
170+
171+
v10 = net.activate([1.0, 0.0])
172+
assert_almost_equal(v10[0], 0.9994472211938866, 1e-6)
173+
174+
v11 = net.activate([1.0, 1.0])
175+
assert_almost_equal(v11[0], 2.4940801202077978e-08, 1e-6)
110176

111177

112178
if __name__ == '__main__':
113179
test_unconnected()
114180
test_basic()
181+
test_simple_nohidden_from_genome()
182+
test_simple_hidden_from_genome()

0 commit comments

Comments
 (0)