Skip to content

Commit 1d89f70

Browse files
Merge pull request #12 from aika-algorithm/synapse-relations
Synapse relations
2 parents d4ae6a3 + c51091f commit 1d89f70

File tree

84 files changed

+5492
-5245
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

84 files changed

+5492
-5245
lines changed

src/main/java/org/aika/AbstractNode.java renamed to src/main/java/network/aika/AbstractNode.java

+5-5
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414
* See the License for the specific language governing permissions and
1515
* limitations under the License.
1616
*/
17-
package org.aika;
17+
package network.aika;
1818

19-
import org.aika.lattice.Node;
20-
import org.aika.lattice.NodeActivation;
21-
import org.aika.neuron.INeuron;
22-
import org.aika.neuron.Neuron;
19+
import network.aika.lattice.Node;
20+
import network.aika.lattice.NodeActivation;
21+
import network.aika.neuron.INeuron;
22+
import network.aika.neuron.Neuron;
2323

2424
import java.io.DataInput;
2525
import java.io.IOException;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
package network.aika;
2+
3+
4+
public enum ActivationFunction {
5+
6+
RECTIFIED_SCALED_LOGISTIC_SIGMOID(x -> Math.max(0.0, (2.0 * Utils.sigmoid(x)) - 1.0)),
7+
RECTIFIED_HYPERBOLIC_TANGENT(x -> Math.max(0.0, Math.tanh(x))),
8+
RECTIFIED_LINEAR_UNIT(x -> Math.max(0.0, x));
9+
10+
Function f;
11+
12+
ActivationFunction(Function f) {
13+
this.f = f;
14+
}
15+
16+
public double f(double x) {
17+
return f.f(x);
18+
}
19+
20+
21+
interface Function {
22+
double f(double x);
23+
}
24+
25+
}

src/main/java/org/aika/Converter.java renamed to src/main/java/network/aika/Converter.java

+121-56
Original file line numberDiff line numberDiff line change
@@ -14,48 +14,53 @@
1414
* See the License for the specific language governing permissions and
1515
* limitations under the License.
1616
*/
17-
package org.aika;
17+
package network.aika;
1818

19-
import org.aika.lattice.AndNode;
20-
import org.aika.lattice.InputNode;
21-
import org.aika.lattice.Node;
22-
import org.aika.lattice.OrNode;
23-
import org.aika.neuron.INeuron;
24-
import org.aika.neuron.Synapse;
19+
import network.aika.lattice.AndNode;
20+
import network.aika.lattice.Node;
21+
import network.aika.lattice.OrNode;
22+
import network.aika.neuron.INeuron;
23+
import network.aika.neuron.relation.Relation;
24+
import network.aika.neuron.Synapse;
2525

2626
import java.util.*;
2727

28+
import static network.aika.neuron.activation.Range.Mapping.NONE;
29+
2830
/**
2931
* Converts the synapse weights of a neuron into a boolean logic representation of this neuron.
3032
*
3133
* @author Lukas Molzberger
3234
*/
3335
public class Converter {
3436

35-
public static int MAX_AND_NODE_SIZE = 4;
37+
public static int MAX_AND_NODE_SIZE = 6;
3638

3739

3840
public static Comparator<Synapse> SYNAPSE_COMP = (s1, s2) -> {
39-
int r = Double.compare(s2.weight, s1.weight);
41+
int r = Boolean.compare(
42+
s2.key.rangeOutput.begin != NONE || s2.key.rangeOutput.end != NONE || s2.key.identity,
43+
s1.key.rangeOutput.begin != NONE || s1.key.rangeOutput.end != NONE || s1.key.identity
44+
);
45+
if (r != 0) return r;
46+
r = Double.compare(s2.weight, s1.weight);
4047
if (r != 0) return r;
41-
return Synapse.INPUT_SYNAPSE_COMP.compare(s1, s2);
48+
return Integer.compare(s1.id, s2.id);
4249
};
4350

44-
private Model model;
4551
private int threadId;
4652
private INeuron neuron;
4753
private Document doc;
4854
private OrNode outputNode;
4955
private Collection<Synapse> modifiedSynapses;
5056

5157

52-
public static boolean convert(Model m, int threadId, Document doc, INeuron neuron, Collection<Synapse> modifiedSynapses) {
53-
return new Converter(m, threadId, doc, neuron, modifiedSynapses).convert();
58+
public static boolean convert(int threadId, Document doc, INeuron neuron, Collection<Synapse> modifiedSynapses) {
59+
return new Converter(threadId, doc, neuron, modifiedSynapses).convert();
5460
}
5561

5662

57-
private Converter(Model model, int threadId, Document doc, INeuron neuron, Collection<Synapse> modifiedSynapses) {
58-
this.model = model;
63+
private Converter(int threadId, Document doc, INeuron neuron, Collection<Synapse> modifiedSynapses) {
5964
this.doc = doc;
6065
this.neuron = neuron;
6166
this.threadId = threadId;
@@ -70,19 +75,13 @@ private boolean convert() {
7075

7176
if(neuron.biasSum + neuron.posDirSum + neuron.posRecSum <= 0.0) {
7277
neuron.requiredSum = neuron.posDirSum + neuron.posRecSum;
73-
outputNode.removeParents(threadId, false);
78+
outputNode.removeParents(threadId);
7479
return false;
7580
}
7681

77-
TreeSet<Synapse> tmp = new TreeSet<>(SYNAPSE_COMP);
78-
for(Synapse s: neuron.inputSynapses.values()) {
79-
if(!s.isNegative() && !s.key.isRecurrent && !s.inactive) {
80-
tmp.add(s);
81-
}
82-
}
82+
List<Synapse> candidates = prepareCandidates();
8383

84-
Integer offset = null;
85-
Node requiredNode = null;
84+
NodeContext nodeContext = null;
8685
boolean noFurtherRefinement = false;
8786
TreeSet<Synapse> reqSyns = new TreeSet<>(Synapse.INPUT_SYNAPSE_COMP);
8887
double sum = 0.0;
@@ -91,7 +90,7 @@ private boolean convert() {
9190
if(neuron.numDisjunctiveSynapses == 0) {
9291
double remainingSum = neuron.posDirSum;
9392
int i = 0;
94-
for (Synapse s : tmp) {
93+
for (Synapse s : candidates) {
9594
final boolean isOptionalInput = sum + remainingSum - s.weight + neuron.posRecSum + neuron.biasSum > 0.0;
9695
final boolean maxAndNodesReached = i >= MAX_AND_NODE_SIZE;
9796
if (isOptionalInput || maxAndNodesReached) {
@@ -102,8 +101,11 @@ private boolean convert() {
102101
neuron.requiredSum += s.weight;
103102
reqSyns.add(s);
104103

105-
requiredNode = getNextLevelNode(offset, requiredNode, s);
106-
offset = Utils.nullSafeMin(s.key.relativeRid, offset);
104+
NodeContext nlNodeContext = expandNode(nodeContext, s);
105+
if(nlNodeContext == null) {
106+
break;
107+
}
108+
nodeContext = nlNodeContext;
107109

108110
i++;
109111

@@ -114,45 +116,76 @@ private boolean convert() {
114116
noFurtherRefinement = true;
115117
break;
116118
}
117-
}
118119

119-
outputNode.removeParents(threadId, false);
120-
if (requiredNode != outputNode.requiredNode) {
121-
outputNode.requiredNode = requiredNode;
122120
}
123121

122+
outputNode.removeParents(threadId);
123+
124124
if (noFurtherRefinement || i == MAX_AND_NODE_SIZE) {
125-
outputNode.addInput(offset, threadId, requiredNode, false);
125+
outputNode.addInput(nodeContext.getSynapseIds(), threadId, nodeContext.node);
126126
} else {
127-
for (Synapse s : tmp) {
127+
for (Synapse s : candidates) {
128128
boolean belowThreshold = sum + s.weight + remainingSum + neuron.posRecSum + neuron.biasSum <= 0.0;
129129
if (belowThreshold) {
130130
break;
131131
}
132132

133133
if (!reqSyns.contains(s)) {
134-
Node nln;
135-
nln = getNextLevelNode(offset, requiredNode, s);
136-
137-
Integer nOffset = Utils.nullSafeMin(s.key.relativeRid, offset);
138-
outputNode.addInput(nOffset, threadId, nln, false);
139-
remainingSum -= s.weight;
134+
NodeContext nlNodeContext = expandNode(nodeContext, s);
135+
if(nlNodeContext != null) {
136+
outputNode.addInput(nlNodeContext.getSynapseIds(), threadId, nlNodeContext.node);
137+
remainingSum -= s.weight;
138+
}
140139
}
141140
}
142141
}
143142
} else {
144143
for (Synapse s : modifiedSynapses) {
145144
if (s.weight + neuron.posRecSum + neuron.biasSum > 0.0) {
146-
Node nln = s.inputNode.get();
147-
offset = s.key.relativeRid;
148-
outputNode.addInput(offset, threadId, nln, false);
145+
NodeContext nlNodeContext = expandNode(nodeContext, s);
146+
outputNode.addInput(nlNodeContext.getSynapseIds(), threadId, nlNodeContext.node);
149147
}
150148
}
151149
}
152150

153151
return true;
154152
}
155153

154+
private List<Synapse> prepareCandidates() {
155+
Synapse syn = getBestSynapse(neuron.inputSynapses.values());
156+
157+
TreeSet<Integer> alreadyCollected = new TreeSet<>();
158+
ArrayList<Synapse> selectedCandidates = new ArrayList<>();
159+
TreeMap<Integer, Synapse> relatedSyns = new TreeMap<>();
160+
while(syn != null && selectedCandidates.size() < MAX_AND_NODE_SIZE) {
161+
relatedSyns.remove(syn.id);
162+
selectedCandidates.add(syn);
163+
alreadyCollected.add(syn.id);
164+
for(Integer synId: syn.relations.keySet()) {
165+
if(!alreadyCollected.contains(synId)) {
166+
relatedSyns.put(synId, syn.output.getSynapseById(synId));
167+
}
168+
}
169+
170+
syn = getBestSynapse(relatedSyns.values());
171+
}
172+
173+
return selectedCandidates;
174+
}
175+
176+
177+
private Synapse getBestSynapse(Collection<Synapse> synapses) {
178+
Synapse maxSyn = null;
179+
for(Synapse s: synapses) {
180+
if(!s.isNegative() && !s.key.isRecurrent && !s.inactive) {
181+
if(maxSyn == null || SYNAPSE_COMP.compare(maxSyn, s) > 0) {
182+
maxSyn = s;
183+
}
184+
}
185+
}
186+
return maxSyn;
187+
}
188+
156189

157190
public static final int DIRECT = 0;
158191
public static final int RECURRENT = 1;
@@ -171,14 +204,6 @@ private void initInputNodesAndComputeWeightSums() {
171204
INeuron in = s.input.get();
172205
in.lock.acquireWriteLock();
173206
try {
174-
if (s.inputNode == null) {
175-
InputNode iNode = InputNode.add(model, s.key.createInputNodeKey(), s.input.get());
176-
iNode.setModified();
177-
iNode.setSynapse(s);
178-
iNode.postCreate(doc);
179-
s.inputNode = iNode.provider;
180-
}
181-
182207
if (!s.inactive) {
183208
sumDelta[s.key.isRecurrent ? RECURRENT : DIRECT][s.isNegative() ? NEGATIVE : POSITIVE] -= s.weight;
184209
sumDelta[s.key.isRecurrent ? RECURRENT : DIRECT][s.getNewWeight() <= 0.0 ? NEGATIVE : POSITIVE] += s.getNewWeight();
@@ -225,13 +250,53 @@ private void initInputNodesAndComputeWeightSums() {
225250
}
226251

227252

228-
private Node getNextLevelNode(Integer offset, Node requiredNode, Synapse s) {
229-
Node nln;
230-
if (requiredNode == null) {
231-
nln = s.inputNode.get();
253+
private NodeContext expandNode(NodeContext nc, Synapse s) {
254+
if (nc == null) {
255+
NodeContext nln = new NodeContext();
256+
nln.node = s.input.get().outputNode.get();
257+
nln.offsets = new Synapse[] {s};
258+
return nln;
232259
} else {
233-
nln = AndNode.createNextLevelNode(model, threadId, doc, requiredNode, new AndNode.Refinement(s.key.relativeRid, offset, s.inputNode), null);
260+
Relation[] relations = new Relation[nc.offsets.length];
261+
for(int i = 0; i < nc.offsets.length; i++) {
262+
Synapse linkedSynapse = nc.offsets[i];
263+
relations[i] = s.relations.get(linkedSynapse.id);
264+
}
265+
266+
NodeContext nln = new NodeContext();
267+
nln.offsets = new Synapse[nc.offsets.length + 1];
268+
AndNode.Refinement ref = new AndNode.Refinement(new AndNode.RelationsMap(relations), s.input.get().outputNode);
269+
AndNode.RefValue rv = nc.node.extend(threadId, doc, ref);
270+
if(rv == null) {
271+
return null;
272+
}
273+
274+
nln.node = rv.child.get(doc);
275+
276+
for (int i = 0; i < nc.offsets.length; i++) {
277+
nln.offsets[rv.offsets[i]] = nc.offsets[i];
278+
}
279+
for (int i = 0; i < nln.offsets.length; i++) {
280+
if (nln.offsets[i] == null) {
281+
nln.offsets[i] = s;
282+
}
283+
}
284+
return nln;
285+
}
286+
}
287+
288+
289+
private class NodeContext {
290+
Node node;
291+
292+
Synapse[] offsets;
293+
294+
int[] getSynapseIds() {
295+
int[] result = new int[offsets.length];
296+
for(int i = 0; i < result.length; i++) {
297+
result[i] = offsets[i].id;
298+
}
299+
return result;
234300
}
235-
return nln;
236301
}
237302
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
package network.aika;
2+
3+
4+
import network.aika.neuron.activation.Activation;
5+
6+
public enum DistanceFunction {
7+
8+
NONE((iAct, oAct) -> 1.0),
9+
DEGRADING((iAct, oAct) -> 1.0 / ((double) (1 + Math.abs(oAct.range.begin - iAct.range.begin))));
10+
11+
Function f;
12+
13+
DistanceFunction(Function f) {
14+
this.f = f;
15+
}
16+
17+
public double f(Activation iAct, Activation oAct) {
18+
return f.f(iAct, oAct);
19+
}
20+
21+
22+
interface Function {
23+
double f(Activation iAct, Activation oAct);
24+
}
25+
26+
}

0 commit comments

Comments
 (0)