Skip to content

Commit 6dfe3de

Browse files
Added new network tests.
1 parent 2ceadcd commit 6dfe3de

File tree

5 files changed

+334
-285
lines changed

5 files changed

+334
-285
lines changed

tests/test_iznn.py

Lines changed: 119 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -1,141 +1,131 @@
11
import neat
2+
import pytest
23

34

4-
def test_basic():
5-
p = neat.iznn.REGULAR_SPIKING_PARAMS
6-
n = neat.iznn.IZNeuron(10, p['a'], p['b'], p['c'], p['d'], [])
7-
spike_train = []
8-
for i in range(1000):
9-
spike_train.append(n.v)
10-
n.advance(0.25)
5+
def _make_neuron(params, bias=0.0, inputs=None):
6+
if inputs is None:
7+
inputs = []
8+
return neat.iznn.IZNeuron(bias, params["a"], params["b"], params["c"], params["d"], inputs)
119

1210

13-
def test_network():
14-
p = neat.iznn.INTRINSICALLY_BURSTING_PARAMS
15-
neurons = {0: neat.iznn.IZNeuron(0, p['a'], p['b'], p['c'], p['d'], []),
16-
1: neat.iznn.IZNeuron(0, p['a'], p['b'], p['c'], p['d'], []),
17-
2: neat.iznn.IZNeuron(0, p['a'], p['b'], p['c'], p['d'], [(0, 0.123), (1, 0.234)]),
18-
3: neat.iznn.IZNeuron(0, p['a'], p['b'], p['c'], p['d'], [])}
19-
inputs = [0, 1]
20-
outputs = [2]
11+
def test_single_neuron_regular_spiking_pulse_response():
12+
"""Basic dynamic test mirroring the demo-iznn pulse protocol."""
13+
params = neat.iznn.REGULAR_SPIKING_PARAMS
14+
n = _make_neuron(params, bias=0.0)
15+
16+
dt = 0.25
17+
spikes = []
18+
19+
# Drive the neuron with a pulse of current between 100 and 800 steps,
20+
# following the example in examples/neuron-demo/demo-iznn.py.
21+
for step in range(1000):
22+
n.current = 0.0 if step < 100 or step > 800 else 10.0
23+
spikes.append(n.fired)
24+
n.advance(dt)
25+
26+
# No spikes without input current.
27+
assert max(spikes[:100]) == 0.0
28+
29+
# The neuron should spike at least once while being driven.
30+
assert max(spikes[100:800]) == 1.0
31+
32+
# After the drive is removed the neuron should eventually become quiet again.
33+
# Allow some transient after the input turns off.
34+
assert max(spikes[900:]) == 0.0
35+
36+
# Membrane potential should relax back near the reset value c.
37+
# Use a loose bound here since the exact value depends on integration details.
38+
assert abs(n.v - params["c"]) < 20.0
39+
40+
41+
def test_izneuron_reset_restores_initial_state():
42+
"""IZNeuron.reset should restore v, u, fired, and current to defaults."""
43+
params = neat.iznn.FAST_SPIKING_PARAMS
44+
n = _make_neuron(params, bias=1.5)
45+
46+
# Perturb state.
47+
n.current = 3.0
48+
n.advance(0.5)
49+
50+
# Also modify outputs explicitly so reset has work to do.
51+
n.fired = 1.0
52+
n.current = 7.0
53+
54+
n.reset()
55+
56+
assert n.v == n.c
57+
assert n.u == n.b * n.v
58+
assert n.fired == 0.0
59+
assert n.current == n.bias
60+
61+
62+
def test_iznn_uses_external_inputs_and_resets_neurons():
63+
"""IZNN should aggregate external inputs correctly and support reset()."""
64+
params = neat.iznn.REGULAR_SPIKING_PARAMS
65+
66+
# Single output neuron (key 0) receiving two external inputs -1 and -2.
67+
neuron = _make_neuron(params, bias=0.0, inputs=[(-1, 1.0), (-2, -1.0)])
68+
neurons = {0: neuron}
69+
inputs = [-1, -2]
70+
outputs = [0]
2171

2272
net = neat.iznn.IZNN(neurons, inputs, outputs)
73+
74+
# With zero inputs the synaptic current should equal the bias.
75+
net.set_inputs([0.0, 0.0])
76+
net.advance(0.25)
77+
assert neuron.current == pytest.approx(neuron.bias)
78+
79+
# A positive value on input -1 and zero on -2 should increase current.
80+
net.reset()
2381
net.set_inputs([1.0, 0.0])
2482
net.advance(0.25)
83+
assert neuron.current == pytest.approx(neuron.bias + 1.0)
84+
85+
# A positive value on input -2 (with negative weight) should decrease current.
86+
net.reset()
87+
net.set_inputs([0.0, 1.0])
2588
net.advance(0.25)
89+
assert neuron.current == pytest.approx(neuron.bias - 1.0)
90+
91+
# Reset should restore neuron state for all neurons in the network.
92+
neuron.v = 0.0
93+
neuron.u = 0.0
94+
neuron.current = 10.0
95+
neuron.fired = 1.0
96+
97+
net.reset()
98+
99+
assert neuron.v == neuron.c
100+
assert neuron.u == neuron.b * neuron.v
101+
assert neuron.current == neuron.bias
102+
assert neuron.fired == 0.0
103+
104+
105+
def test_iznn_set_inputs_length_mismatch_raises():
106+
"""set_inputs should enforce input length and raise RuntimeError on mismatch."""
107+
params = neat.iznn.REGULAR_SPIKING_PARAMS
108+
neuron = _make_neuron(params)
109+
neurons = {0: neuron}
110+
inputs = [-1, -2]
111+
outputs = [0]
112+
net = neat.iznn.IZNN(neurons, inputs, outputs)
113+
114+
# Too few inputs.
115+
with pytest.raises(RuntimeError, match="Number of inputs"):
116+
net.set_inputs([1.0])
117+
118+
# Too many inputs.
119+
with pytest.raises(RuntimeError, match="Number of inputs"):
120+
net.set_inputs([1.0, 2.0, 3.0])
121+
26122

123+
def test_get_time_step_positive():
124+
"""get_time_step_msec should return a positive float."""
125+
params = neat.iznn.REGULAR_SPIKING_PARAMS
126+
neuron = _make_neuron(params)
127+
net = neat.iznn.IZNN({0: neuron}, inputs=[-1], outputs=[0])
27128

28-
# # TODO: Update this test to work with the current implementation.
29-
# # def test_iznn_evolve():
30-
# # """This is a stripped-down copy of the XOR2 spiking example."""
31-
# #
32-
# # # Network inputs and expected outputs.
33-
# # xor_inputs = ((0, 0), (0, 1), (1, 0), (1, 1))
34-
# # xor_outputs = (0, 1, 1, 0)
35-
# #
36-
# # # Maximum amount of simulated time (in milliseconds) to wait for the network to produce an output.
37-
# # max_time = 50.0
38-
# #
39-
# # def compute_output(t0, t1):
40-
# # '''Compute the network's output based on the "time to first spike" of the two output neurons.'''
41-
# # if t0 is None or t1 is None:
42-
# # # If one of the output neurons failed to fire within the allotted time,
43-
# # # give a response which produces a large error.
44-
# # return -1.0
45-
# # else:
46-
# # # If the output neurons fire within 1.0 milliseconds of each other,
47-
# # # the output is 1, and if they fire more than 11 milliseconds apart,
48-
# # # the output is 0, with linear interpolation between 1 and 11 milliseconds.
49-
# # response = 1.1 - 0.1 * abs(t0 - t1)
50-
# # return max(0.0, min(1.0, response))
51-
# #
52-
# # def simulate(genome):
53-
# # # Create a network of Izhikevich neurons based on the given genome.
54-
# # net = iznn.create_phenotype(genome, **iznn.THALAMO_CORTICAL_PARAMS)
55-
# # dt = 0.25
56-
# # sum_square_error = 0.0
57-
# # simulated = []
58-
# # for inputData, outputData in zip(xor_inputs, xor_outputs):
59-
# # neuron_data = {}
60-
# # for i, n in net.neurons.items():
61-
# # neuron_data[i] = []
62-
# #
63-
# # # Reset the network, apply the XOR inputs, and run for the maximum allowed time.
64-
# # net.reset()
65-
# # net.set_inputs(inputData)
66-
# # t0 = None
67-
# # t1 = None
68-
# # v0 = None
69-
# # v1 = None
70-
# # num_steps = int(max_time / dt)
71-
# # for j in range(num_steps):
72-
# # t = dt * j
73-
# # output = net.advance(dt)
74-
# #
75-
# # # Capture the time and neuron membrane potential for later use if desired.
76-
# # for i, n in net.neurons.items():
77-
# # neuron_data[i].append((t, n.v))
78-
# #
79-
# # # Remember time and value of the first output spikes from each neuron.
80-
# # if t0 is None and output[0] > 0:
81-
# # t0, v0 = neuron_data[net.outputs[0]][-2]
82-
# #
83-
# # if t1 is None and output[1] > 0:
84-
# # t1, v1 = neuron_data[net.outputs[1]][-2]
85-
# #
86-
# # response = compute_output(t0, t1)
87-
# # sum_square_error += (response - outputData) ** 2
88-
# #
89-
# # simulated.append(
90-
# # (inputData, outputData, t0, t1, v0, v1, neuron_data))
91-
# #
92-
# # return sum_square_error, simulated
93-
# #
94-
# # def eval_fitness(genomes):
95-
# # for genome in genomes:
96-
# # sum_square_error, simulated = simulate(genome)
97-
# # genome.fitness = 1 - sum_square_error
98-
# #
99-
# # # Load the config file, which is assumed to live in
100-
# # # the same directory as this script.
101-
# # local_dir = os.path.dirname(__file__)
102-
# # config = Config(os.path.join(local_dir, 'test_configuration'))
103-
# #
104-
# # # TODO: This is a little hackish, but will a user ever want to do it?
105-
# # # If so, provide a convenience method on Config for it.
106-
# # for i, tc in enumerate(config.type_config['DefaultStagnation']):
107-
# # if tc[0] == 'species_fitness_func':
108-
# # config.type_config['DefaultStagnation'][i] = (tc[0], 'median')
109-
# #
110-
# # # For this network, we use two output neurons and use the difference between
111-
# # # the "time to first spike" to determine the network response. There are
112-
# # # probably a great many different choices one could make for an output encoding,
113-
# # # and this choice may not be the best for tackling a real problem.
114-
# # config.output_nodes = 2
115-
# #
116-
# # pop = population.Population(config)
117-
# # pop.run(eval_fitness, 10)
118-
# #
119-
# # print('Number of evaluations: {0}'.format(pop.total_evaluations))
120-
# #
121-
# # # Visualize the winner network and plot statistics.
122-
# # winner = pop.statistics.best_genome()
123-
# #
124-
# # # Verify network output against training data.
125-
# # print('\nBest network output:')
126-
# # net = iznn.create_phenotype(winner, **iznn.RESONATOR_PARAMS)
127-
# # sum_square_error, simulated = simulate(winner)
128-
# #
129-
# # repr(winner)
130-
# # str(winner)
131-
# # for g in winner.node_genes:
132-
# # repr(g)
133-
# # str(g)
134-
# # for g in winner.conn_genes:
135-
# # repr(g)
136-
# # str(g)
137-
#
138-
#
139-
if __name__ == '__main__':
140-
test_basic()
141-
test_network()
129+
dt = net.get_time_step_msec()
130+
assert isinstance(dt, float)
131+
assert dt > 0.0

0 commit comments

Comments
 (0)