Skip to content

Commit 6f13aca

Browse files
Fix orphaned nodes bug.
1 parent 8f79267 commit 6f13aca

File tree

6 files changed

+296
-32
lines changed

6 files changed

+296
-32
lines changed

neat/aggregations.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,23 +20,28 @@ def sum_aggregation(x):
2020

2121

2222
def max_aggregation(x):
23-
return max(x)
23+
# Handle empty input (for orphaned nodes with no incoming connections)
24+
return max(x) if x else 0.0
2425

2526

2627
def min_aggregation(x):
27-
return min(x)
28+
# Handle empty input (for orphaned nodes with no incoming connections)
29+
return min(x) if x else 0.0
2830

2931

3032
def maxabs_aggregation(x):
31-
return max(x, key=abs)
33+
# Handle empty input (for orphaned nodes with no incoming connections)
34+
return max(x, key=abs) if x else 0.0
3235

3336

3437
def median_aggregation(x):
35-
return median2(x)
38+
# Handle empty input (for orphaned nodes with no incoming connections)
39+
return median2(x) if x else 0.0
3640

3741

3842
def mean_aggregation(x):
39-
return mean(x)
43+
# Handle empty input (for orphaned nodes with no incoming connections)
44+
return mean(x) if x else 0.0
4045

4146

4247
class InvalidAggregationFunction(TypeError):

neat/graphs.py

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -38,30 +38,14 @@ def required_for_output(inputs, outputs, connections):
3838
"""
3939
assert not set(inputs).intersection(outputs)
4040

41-
# Create a graph representation of the connections
42-
graph = defaultdict(list)
43-
reverse_graph = defaultdict(list)
44-
for a, b in connections:
45-
graph[a].append(b)
46-
reverse_graph[b].append(a)
47-
48-
# Perform a breadth-first search (BFS) from each input to find all reachable nodes
49-
reachable = set(inputs)
50-
queue = deque(inputs)
51-
52-
while queue:
53-
node = queue.popleft()
54-
for neighbor in graph[node]:
55-
if neighbor not in reachable:
56-
reachable.add(neighbor)
57-
queue.append(neighbor)
58-
59-
# Now, traverse from the outputs and find all nodes that are required to reach the outputs
41+
# Traverse backwards from outputs to find all nodes that feed into outputs.
42+
# This includes orphaned nodes (nodes with no incoming connections) that
43+
# connect to outputs, as they are required to compute the output.
6044
required = set(outputs)
6145
s = set(outputs)
6246
while True:
63-
# Find nodes not in s whose output is consumed by a node in s and is reachable from inputs
64-
t = set(a for (a, b) in connections if b in s and a not in s and a in reachable)
47+
# Find nodes not in s whose output is consumed by a node in s
48+
t = set(a for (a, b) in connections if b in s and a not in s)
6549

6650
if not t:
6751
break
@@ -90,8 +74,23 @@ def feed_forward_layers(inputs, outputs, connections):
9074

9175
required = required_for_output(inputs, outputs, connections)
9276

77+
# Find required nodes that have no incoming connections.
78+
# These are "bias neurons" that output activation(bias) independent of inputs.
79+
nodes_with_inputs = set()
80+
for a, b in connections:
81+
nodes_with_inputs.add(b)
82+
83+
# Bias neurons are required nodes with no incoming connections
84+
bias_neurons = required - nodes_with_inputs
85+
9386
layers = []
94-
potential_input = set(inputs)
87+
# Start with inputs AND bias neurons in the ready set
88+
potential_input = set(inputs) | bias_neurons
89+
90+
# If there are bias neurons, add them as the first layer
91+
if bias_neurons:
92+
layers.append(bias_neurons.copy())
93+
9594
while True:
9695
# Find candidate nodes c for the next layer. These nodes should connect
9796
# a node in s to a node not in s.

tests/test_graphs.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,9 +152,85 @@ def test_fuzz_feed_forward_layers():
152152
feed_forward_layers(inputs, outputs, connections)
153153

154154

155+
def test_orphaned_single_node():
156+
"""Test a single orphaned hidden node that connects to output."""
157+
inputs = [0, 1]
158+
outputs = [3]
159+
# Node 2 has no incoming connections (orphaned) but connects to output 3
160+
connections = [(2, 3)]
161+
layers, required = feed_forward_layers(inputs, outputs, connections)
162+
# Node 2 should appear in the first layer as a bias neuron
163+
# Node 3 should appear in the second layer
164+
assert [{2}, {3}] == layers
165+
assert {2, 3} == required
166+
167+
168+
def test_orphaned_multiple_nodes():
169+
"""Test multiple orphaned nodes in the same network."""
170+
inputs = [0]
171+
outputs = [4]
172+
# Nodes 2 and 3 are both orphaned (no incoming connections)
173+
# Both connect to output node 4
174+
connections = [(2, 4), (3, 4)]
175+
layers, required = feed_forward_layers(inputs, outputs, connections)
176+
# Both orphaned nodes should appear in the first layer
177+
# Output node 4 should appear in the second layer
178+
assert [{2, 3}, {4}] == layers
179+
assert {2, 3, 4} == required
180+
181+
182+
def test_orphaned_mixed():
183+
"""Test a mix of normal nodes and orphaned nodes."""
184+
inputs = [0, 1]
185+
outputs = [5]
186+
# Node 2 gets input from 0 (normal)
187+
# Node 3 is orphaned (no inputs)
188+
# Node 4 gets inputs from both 2 and 3
189+
# Node 5 (output) gets input from 4
190+
connections = [(0, 2), (2, 4), (3, 4), (4, 5)]
191+
layers, required = feed_forward_layers(inputs, outputs, connections)
192+
# First layer: orphaned node 3
193+
# Second layer: node 2 (has input from 0)
194+
# Third layer: node 4 (has inputs from 2 and 3)
195+
# Fourth layer: node 5 (has input from 4)
196+
assert [{3}, {2}, {4}, {5}] == layers
197+
assert {2, 3, 4, 5} == required
198+
199+
200+
def test_orphaned_output_node():
201+
"""Test an output node with no incoming connections."""
202+
inputs = [0, 1]
203+
outputs = [2, 3]
204+
# Output node 2 has a connection from input 0
205+
# Output node 3 has no connections at all (orphaned output)
206+
connections = [(0, 2)]
207+
layers, required = feed_forward_layers(inputs, outputs, connections)
208+
# First layer: orphaned output node 3
209+
# Second layer: normal output node 2
210+
assert [{3}, {2}] == layers
211+
assert {2, 3} == required
212+
213+
214+
def test_orphaned_with_self_loop_prevention():
215+
"""Test orphaned nodes don't interfere with cycle detection."""
216+
inputs = [0]
217+
outputs = [3]
218+
# Node 2 is orphaned and connects to output 3
219+
connections = [(2, 3)]
220+
layers, required = feed_forward_layers(inputs, outputs, connections)
221+
assert [{2}, {3}] == layers
222+
# Verify that adding a self-loop would still be detected as a cycle
223+
assert creates_cycle(connections, (2, 2))
224+
225+
155226
if __name__ == '__main__':
156227
test_creates_cycle()
157228
test_required_for_output()
158229
test_fuzz_required()
159230
test_feed_forward_layers()
160231
test_fuzz_feed_forward_layers()
232+
test_orphaned_single_node()
233+
test_orphaned_multiple_nodes()
234+
test_orphaned_mixed()
235+
test_orphaned_output_node()
236+
test_orphaned_with_self_loop_prevention()

tests/test_graphs_edge_cases.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -202,12 +202,13 @@ def test_partially_connected_outputs(self):
202202
outputs = [5, 6]
203203
connections = [
204204
(0, 2), (2, 5), # Path to output 5
205-
(3, 6) # Output 6 not connected to inputs
205+
(3, 6) # Output 6 connected via orphaned node 3
206206
]
207207

208208
required = required_for_output(inputs, outputs, connections)
209-
# Only output 5 and its dependencies
210-
self.assertEqual(required, {2, 5, 6}) # 6 is in outputs so included
209+
# Node 3 is required even though it's orphaned (not reachable from inputs)
210+
# because it feeds into output 6. It acts as a "bias neuron".
211+
self.assertEqual(required, {2, 3, 5, 6})
211212

212213
def test_recurrent_connections(self):
213214
"""Test with recurrent (cyclic) connections."""
@@ -424,8 +425,8 @@ def test_empty_network(self):
424425
connections = []
425426

426427
layers, required = feed_forward_layers(inputs, outputs, connections)
427-
# No layers if no connections
428-
self.assertEqual(layers, [])
428+
# Output node 2 is orphaned (no incoming connections), so it appears as first layer
429+
self.assertEqual(layers, [{2}])
429430
self.assertEqual(required, {2})
430431

431432
def test_bottleneck_structure(self):

tests/test_issue_188.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
"""
2+
Tests for GitHub issue #188: Class attribute bug in neat.attributes
3+
4+
This tests that _config_items is properly isolated between instances
5+
rather than being shared at the class level.
6+
"""
7+
8+
import unittest
9+
from neat.attributes import StringAttribute, FloatAttribute, BoolAttribute, IntegerAttribute
10+
11+
12+
class TestIssue188AttributeIsolation(unittest.TestCase):
13+
"""Test that attribute instances don't share _config_items."""
14+
15+
def test_string_attribute_isolation(self):
16+
"""Test that StringAttribute instances have separate _config_items."""
17+
attr1 = StringAttribute('activation', options='sigmoid')
18+
attr2 = StringAttribute('aggregation', options='sum')
19+
20+
# Should NOT share the same object
21+
self.assertIsNot(attr1._config_items, attr2._config_items,
22+
"StringAttribute instances should not share _config_items")
23+
24+
# Should have different option values
25+
self.assertEqual(attr1._config_items['options'][1], 'sigmoid')
26+
self.assertEqual(attr2._config_items['options'][1], 'sum')
27+
28+
def test_string_attribute_mutation_isolation(self):
29+
"""Test that modifying one instance doesn't affect another."""
30+
attr1 = StringAttribute('foo', options='option_a')
31+
attr2 = StringAttribute('bar', options='option_b')
32+
33+
# Modify attr1's _config_items
34+
attr1._config_items['options'] = ['MODIFIED', 'TEST']
35+
36+
# attr2 should NOT be affected
37+
self.assertEqual(attr2._config_items['options'][1], 'option_b',
38+
"Modifying attr1 should not affect attr2")
39+
40+
def test_float_attribute_isolation(self):
41+
"""Test that FloatAttribute instances have separate _config_items."""
42+
attr1 = FloatAttribute('weight', init_mean=0.0)
43+
attr2 = FloatAttribute('bias', init_mean=1.0)
44+
45+
self.assertIsNot(attr1._config_items, attr2._config_items)
46+
self.assertEqual(attr1._config_items['init_mean'][1], 0.0)
47+
self.assertEqual(attr2._config_items['init_mean'][1], 1.0)
48+
49+
def test_bool_attribute_isolation(self):
50+
"""Test that BoolAttribute instances have separate _config_items."""
51+
attr1 = BoolAttribute('enabled', default='True')
52+
attr2 = BoolAttribute('flag', default='False')
53+
54+
self.assertIsNot(attr1._config_items, attr2._config_items)
55+
self.assertEqual(attr1._config_items['default'][1], 'True')
56+
self.assertEqual(attr2._config_items['default'][1], 'False')
57+
58+
def test_integer_attribute_isolation(self):
59+
"""Test that IntegerAttribute instances have separate _config_items."""
60+
attr1 = IntegerAttribute('count', min_value=0)
61+
attr2 = IntegerAttribute('index', min_value=1)
62+
63+
self.assertIsNot(attr1._config_items, attr2._config_items)
64+
self.assertEqual(attr1._config_items['min_value'][1], 0)
65+
self.assertEqual(attr2._config_items['min_value'][1], 1)
66+
67+
def test_mixed_attribute_types_isolation(self):
68+
"""Test that different attribute types don't interfere with each other."""
69+
float_attr = FloatAttribute('weight')
70+
string_attr = StringAttribute('activation')
71+
bool_attr = BoolAttribute('enabled')
72+
int_attr = IntegerAttribute('count')
73+
74+
# Each should have its own _config_items
75+
items_list = [
76+
float_attr._config_items,
77+
string_attr._config_items,
78+
bool_attr._config_items,
79+
int_attr._config_items
80+
]
81+
82+
# All should be different objects
83+
for i, items_i in enumerate(items_list):
84+
for j, items_j in enumerate(items_list):
85+
if i != j:
86+
self.assertIsNot(items_i, items_j,
87+
f"Attribute instances {i} and {j} should not share _config_items")
88+
89+
def test_genes_pattern(self):
90+
"""Test the actual usage pattern from DefaultNodeGene."""
91+
# This is how it's used in genes.py
92+
attr_activation = StringAttribute('activation', options='')
93+
attr_aggregation = StringAttribute('aggregation', options='')
94+
95+
# Should have separate _config_items even with same default value
96+
self.assertIsNot(attr_activation._config_items, attr_aggregation._config_items,
97+
"activation and aggregation attributes should have separate _config_items")
98+
99+
# Both should have empty string as options default
100+
self.assertEqual(attr_activation._config_items['options'][1], '')
101+
self.assertEqual(attr_aggregation._config_items['options'][1], '')
102+
103+
# Modifying one should not affect the other
104+
attr_activation._config_items['options'][1] = 'modified'
105+
self.assertEqual(attr_aggregation._config_items['options'][1], '',
106+
"Modifying activation should not affect aggregation")
107+
108+
109+
if __name__ == '__main__':
110+
unittest.main()

0 commit comments

Comments
 (0)