14
14
* See the License for the specific language governing permissions and
15
15
* limitations under the License.
16
16
*/
17
- package org .aika ;
17
+ package network .aika ;
18
18
19
- import org .aika .lattice .AndNode ;
20
- import org .aika .lattice .InputNode ;
21
- import org .aika .lattice .Node ;
22
- import org .aika .lattice . OrNode ;
23
- import org .aika .neuron .INeuron ;
24
- import org .aika .neuron .Synapse ;
19
+ import network .aika .lattice .AndNode ;
20
+ import network .aika .lattice .Node ;
21
+ import network .aika .lattice .OrNode ;
22
+ import network .aika .neuron . INeuron ;
23
+ import network .aika .neuron .relation . Relation ;
24
+ import network .aika .neuron .Synapse ;
25
25
26
26
import java .util .*;
27
27
28
+ import static network .aika .neuron .activation .Range .Mapping .NONE ;
29
+
28
30
/**
29
31
* Converts the synapse weights of a neuron into a boolean logic representation of this neuron.
30
32
*
31
33
* @author Lukas Molzberger
32
34
*/
33
35
public class Converter {
34
36
35
- public static int MAX_AND_NODE_SIZE = 4 ;
37
+ public static int MAX_AND_NODE_SIZE = 6 ;
36
38
37
39
38
40
public static Comparator <Synapse > SYNAPSE_COMP = (s1 , s2 ) -> {
39
- int r = Double .compare (s2 .weight , s1 .weight );
41
+ int r = Boolean .compare (
42
+ s2 .key .rangeOutput .begin != NONE || s2 .key .rangeOutput .end != NONE || s2 .key .identity ,
43
+ s1 .key .rangeOutput .begin != NONE || s1 .key .rangeOutput .end != NONE || s1 .key .identity
44
+ );
45
+ if (r != 0 ) return r ;
46
+ r = Double .compare (s2 .weight , s1 .weight );
40
47
if (r != 0 ) return r ;
41
- return Synapse . INPUT_SYNAPSE_COMP . compare (s1 , s2 );
48
+ return Integer . compare (s1 . id , s2 . id );
42
49
};
43
50
44
- private Model model ;
45
51
private int threadId ;
46
52
private INeuron neuron ;
47
53
private Document doc ;
48
54
private OrNode outputNode ;
49
55
private Collection <Synapse > modifiedSynapses ;
50
56
51
57
52
- public static boolean convert (Model m , int threadId , Document doc , INeuron neuron , Collection <Synapse > modifiedSynapses ) {
53
- return new Converter (m , threadId , doc , neuron , modifiedSynapses ).convert ();
58
+ public static boolean convert (int threadId , Document doc , INeuron neuron , Collection <Synapse > modifiedSynapses ) {
59
+ return new Converter (threadId , doc , neuron , modifiedSynapses ).convert ();
54
60
}
55
61
56
62
57
- private Converter (Model model , int threadId , Document doc , INeuron neuron , Collection <Synapse > modifiedSynapses ) {
58
- this .model = model ;
63
+ private Converter (int threadId , Document doc , INeuron neuron , Collection <Synapse > modifiedSynapses ) {
59
64
this .doc = doc ;
60
65
this .neuron = neuron ;
61
66
this .threadId = threadId ;
@@ -70,19 +75,13 @@ private boolean convert() {
70
75
71
76
if (neuron .biasSum + neuron .posDirSum + neuron .posRecSum <= 0.0 ) {
72
77
neuron .requiredSum = neuron .posDirSum + neuron .posRecSum ;
73
- outputNode .removeParents (threadId , false );
78
+ outputNode .removeParents (threadId );
74
79
return false ;
75
80
}
76
81
77
- TreeSet <Synapse > tmp = new TreeSet <>(SYNAPSE_COMP );
78
- for (Synapse s : neuron .inputSynapses .values ()) {
79
- if (!s .isNegative () && !s .key .isRecurrent && !s .inactive ) {
80
- tmp .add (s );
81
- }
82
- }
82
+ List <Synapse > candidates = prepareCandidates ();
83
83
84
- Integer offset = null ;
85
- Node requiredNode = null ;
84
+ NodeContext nodeContext = null ;
86
85
boolean noFurtherRefinement = false ;
87
86
TreeSet <Synapse > reqSyns = new TreeSet <>(Synapse .INPUT_SYNAPSE_COMP );
88
87
double sum = 0.0 ;
@@ -91,7 +90,7 @@ private boolean convert() {
91
90
if (neuron .numDisjunctiveSynapses == 0 ) {
92
91
double remainingSum = neuron .posDirSum ;
93
92
int i = 0 ;
94
- for (Synapse s : tmp ) {
93
+ for (Synapse s : candidates ) {
95
94
final boolean isOptionalInput = sum + remainingSum - s .weight + neuron .posRecSum + neuron .biasSum > 0.0 ;
96
95
final boolean maxAndNodesReached = i >= MAX_AND_NODE_SIZE ;
97
96
if (isOptionalInput || maxAndNodesReached ) {
@@ -102,8 +101,11 @@ private boolean convert() {
102
101
neuron .requiredSum += s .weight ;
103
102
reqSyns .add (s );
104
103
105
- requiredNode = getNextLevelNode (offset , requiredNode , s );
106
- offset = Utils .nullSafeMin (s .key .relativeRid , offset );
104
+ NodeContext nlNodeContext = expandNode (nodeContext , s );
105
+ if (nlNodeContext == null ) {
106
+ break ;
107
+ }
108
+ nodeContext = nlNodeContext ;
107
109
108
110
i ++;
109
111
@@ -114,45 +116,76 @@ private boolean convert() {
114
116
noFurtherRefinement = true ;
115
117
break ;
116
118
}
117
- }
118
119
119
- outputNode .removeParents (threadId , false );
120
- if (requiredNode != outputNode .requiredNode ) {
121
- outputNode .requiredNode = requiredNode ;
122
120
}
123
121
122
+ outputNode .removeParents (threadId );
123
+
124
124
if (noFurtherRefinement || i == MAX_AND_NODE_SIZE ) {
125
- outputNode .addInput (offset , threadId , requiredNode , false );
125
+ outputNode .addInput (nodeContext . getSynapseIds () , threadId , nodeContext . node );
126
126
} else {
127
- for (Synapse s : tmp ) {
127
+ for (Synapse s : candidates ) {
128
128
boolean belowThreshold = sum + s .weight + remainingSum + neuron .posRecSum + neuron .biasSum <= 0.0 ;
129
129
if (belowThreshold ) {
130
130
break ;
131
131
}
132
132
133
133
if (!reqSyns .contains (s )) {
134
- Node nln ;
135
- nln = getNextLevelNode (offset , requiredNode , s );
136
-
137
- Integer nOffset = Utils .nullSafeMin (s .key .relativeRid , offset );
138
- outputNode .addInput (nOffset , threadId , nln , false );
139
- remainingSum -= s .weight ;
134
+ NodeContext nlNodeContext = expandNode (nodeContext , s );
135
+ if (nlNodeContext != null ) {
136
+ outputNode .addInput (nlNodeContext .getSynapseIds (), threadId , nlNodeContext .node );
137
+ remainingSum -= s .weight ;
138
+ }
140
139
}
141
140
}
142
141
}
143
142
} else {
144
143
for (Synapse s : modifiedSynapses ) {
145
144
if (s .weight + neuron .posRecSum + neuron .biasSum > 0.0 ) {
146
- Node nln = s .inputNode .get ();
147
- offset = s .key .relativeRid ;
148
- outputNode .addInput (offset , threadId , nln , false );
145
+ NodeContext nlNodeContext = expandNode (nodeContext , s );
146
+ outputNode .addInput (nlNodeContext .getSynapseIds (), threadId , nlNodeContext .node );
149
147
}
150
148
}
151
149
}
152
150
153
151
return true ;
154
152
}
155
153
154
+ private List <Synapse > prepareCandidates () {
155
+ Synapse syn = getBestSynapse (neuron .inputSynapses .values ());
156
+
157
+ TreeSet <Integer > alreadyCollected = new TreeSet <>();
158
+ ArrayList <Synapse > selectedCandidates = new ArrayList <>();
159
+ TreeMap <Integer , Synapse > relatedSyns = new TreeMap <>();
160
+ while (syn != null && selectedCandidates .size () < MAX_AND_NODE_SIZE ) {
161
+ relatedSyns .remove (syn .id );
162
+ selectedCandidates .add (syn );
163
+ alreadyCollected .add (syn .id );
164
+ for (Integer synId : syn .relations .keySet ()) {
165
+ if (!alreadyCollected .contains (synId )) {
166
+ relatedSyns .put (synId , syn .output .getSynapseById (synId ));
167
+ }
168
+ }
169
+
170
+ syn = getBestSynapse (relatedSyns .values ());
171
+ }
172
+
173
+ return selectedCandidates ;
174
+ }
175
+
176
+
177
+ private Synapse getBestSynapse (Collection <Synapse > synapses ) {
178
+ Synapse maxSyn = null ;
179
+ for (Synapse s : synapses ) {
180
+ if (!s .isNegative () && !s .key .isRecurrent && !s .inactive ) {
181
+ if (maxSyn == null || SYNAPSE_COMP .compare (maxSyn , s ) > 0 ) {
182
+ maxSyn = s ;
183
+ }
184
+ }
185
+ }
186
+ return maxSyn ;
187
+ }
188
+
156
189
157
190
public static final int DIRECT = 0 ;
158
191
public static final int RECURRENT = 1 ;
@@ -171,14 +204,6 @@ private void initInputNodesAndComputeWeightSums() {
171
204
INeuron in = s .input .get ();
172
205
in .lock .acquireWriteLock ();
173
206
try {
174
- if (s .inputNode == null ) {
175
- InputNode iNode = InputNode .add (model , s .key .createInputNodeKey (), s .input .get ());
176
- iNode .setModified ();
177
- iNode .setSynapse (s );
178
- iNode .postCreate (doc );
179
- s .inputNode = iNode .provider ;
180
- }
181
-
182
207
if (!s .inactive ) {
183
208
sumDelta [s .key .isRecurrent ? RECURRENT : DIRECT ][s .isNegative () ? NEGATIVE : POSITIVE ] -= s .weight ;
184
209
sumDelta [s .key .isRecurrent ? RECURRENT : DIRECT ][s .getNewWeight () <= 0.0 ? NEGATIVE : POSITIVE ] += s .getNewWeight ();
@@ -225,13 +250,53 @@ private void initInputNodesAndComputeWeightSums() {
225
250
}
226
251
227
252
228
- private Node getNextLevelNode (Integer offset , Node requiredNode , Synapse s ) {
229
- Node nln ;
230
- if (requiredNode == null ) {
231
- nln = s .inputNode .get ();
253
+ private NodeContext expandNode (NodeContext nc , Synapse s ) {
254
+ if (nc == null ) {
255
+ NodeContext nln = new NodeContext ();
256
+ nln .node = s .input .get ().outputNode .get ();
257
+ nln .offsets = new Synapse [] {s };
258
+ return nln ;
232
259
} else {
233
- nln = AndNode .createNextLevelNode (model , threadId , doc , requiredNode , new AndNode .Refinement (s .key .relativeRid , offset , s .inputNode ), null );
260
+ Relation [] relations = new Relation [nc .offsets .length ];
261
+ for (int i = 0 ; i < nc .offsets .length ; i ++) {
262
+ Synapse linkedSynapse = nc .offsets [i ];
263
+ relations [i ] = s .relations .get (linkedSynapse .id );
264
+ }
265
+
266
+ NodeContext nln = new NodeContext ();
267
+ nln .offsets = new Synapse [nc .offsets .length + 1 ];
268
+ AndNode .Refinement ref = new AndNode .Refinement (new AndNode .RelationsMap (relations ), s .input .get ().outputNode );
269
+ AndNode .RefValue rv = nc .node .extend (threadId , doc , ref );
270
+ if (rv == null ) {
271
+ return null ;
272
+ }
273
+
274
+ nln .node = rv .child .get (doc );
275
+
276
+ for (int i = 0 ; i < nc .offsets .length ; i ++) {
277
+ nln .offsets [rv .offsets [i ]] = nc .offsets [i ];
278
+ }
279
+ for (int i = 0 ; i < nln .offsets .length ; i ++) {
280
+ if (nln .offsets [i ] == null ) {
281
+ nln .offsets [i ] = s ;
282
+ }
283
+ }
284
+ return nln ;
285
+ }
286
+ }
287
+
288
+
289
+ private class NodeContext {
290
+ Node node ;
291
+
292
+ Synapse [] offsets ;
293
+
294
+ int [] getSynapseIds () {
295
+ int [] result = new int [offsets .length ];
296
+ for (int i = 0 ; i < result .length ; i ++) {
297
+ result [i ] = offsets [i ].id ;
298
+ }
299
+ return result ;
234
300
}
235
- return nln ;
236
301
}
237
302
}
0 commit comments