File tree 1 file changed +16
-2
lines changed
1 file changed +16
-2
lines changed Original file line number Diff line number Diff line change @@ -79,6 +79,7 @@ public function __construct(
79
79
new ExtensionIsLoaded ('svm ' ),
80
80
new ExtensionMinimumVersion ('svm ' , '0.2.0 ' ),
81
81
])->check ();
82
+
82
83
83
84
if ($ nu < 0.0 or $ nu > 1.0 ) {
84
85
throw new InvalidArgumentException ('Nu must be between '
@@ -182,7 +183,14 @@ public function train(Dataset $dataset) : void
182
183
new SamplesAreCompatibleWithEstimator ($ dataset , $ this ),
183
184
])->check ();
184
185
185
- $ this ->model = $ this ->svm ->train ($ dataset ->samples ());
186
+ $ data = [];
187
+
188
+ foreach ($ dataset ->samples () as $ sample ) {
189
+ array_unshift ($ sample , 1 );
190
+ $ data [] = $ sample ;
191
+ }
192
+
193
+ $ this ->model = $ this ->svm ->train ($ data );
186
194
}
187
195
188
196
/**
@@ -211,7 +219,13 @@ public function predictSample(array $sample) : int
211
219
throw new RuntimeException ('Estimator has not been trained. ' );
212
220
}
213
221
214
- return $ this ->model ->predict ($ sample ) !== 1.0 ? 0 : 1 ;
222
+ $ sampleWithOffset = [];
223
+
224
+ foreach ($ sample as $ key => $ value ) {
225
+ $ sampleWithOffset [$ key + 1 ] = $ value ;
226
+ }
227
+
228
+ return $ this ->model ->predict ($ sampleWithOffset ) == 1 ? 0 : 1 ;
215
229
}
216
230
217
231
/**
You can’t perform that action at this time.
0 commit comments