Skip to content

Commit 18acb0b

Browse files
committed
Fix test and appease Stan
1 parent 715b0e7 commit 18acb0b

File tree

6 files changed

+11
-9
lines changed

6 files changed

+11
-9
lines changed

CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
- 2.5.1
2+
- Fix bug in SVM (SVC and SVR) inferencing
3+
14
- 2.5.0
25
- Added Vantage Point Spatial tree
36
- Blob Generator can now `simulate()` a Dataset object

src/AnomalyDetectors/OneClassSVM.php

-1
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@ public function __construct(
7979
new ExtensionIsLoaded('svm'),
8080
new ExtensionMinimumVersion('svm', '0.2.0'),
8181
])->check();
82-
8382

8483
if ($nu < 0.0 or $nu > 1.0) {
8584
throw new InvalidArgumentException('Nu must be between'

src/Classifiers/AdaBoost.php

+1-1
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,7 @@ public function train(Dataset $dataset) : void
421421
* Make predictions from a dataset.
422422
*
423423
* @param Dataset $dataset
424-
* @return list<string>
424+
* @return list<int|string>
425425
*/
426426
public function predict(Dataset $dataset) : array
427427
{

src/Datasets/Generators/Agglomerate.php

+1-1
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ public function __construct(array $generators = [], ?array $weights = null)
104104
}
105105

106106
$this->generators = $generators;
107-
$this->weights = array_combine(array_keys($generators), $weights) ?: [];
107+
$this->weights = array_combine(array_keys($generators), $weights);
108108
$this->dimensions = $dimensions;
109109
}
110110

tests/AnomalyDetectors/OneClassSVMTest.php

+5-5
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
use Rubix\ML\Estimator;
88
use Rubix\ML\EstimatorType;
99
use Rubix\ML\Datasets\Unlabeled;
10-
use Rubix\ML\Kernels\SVM\Polynomial;
10+
use Rubix\ML\Kernels\SVM\RBF;
1111
use Rubix\ML\Datasets\Generators\Blob;
1212
use Rubix\ML\Datasets\Generators\Circle;
1313
use Rubix\ML\AnomalyDetectors\OneClassSVM;
@@ -43,7 +43,7 @@ class OneClassSVMTest extends TestCase
4343
*
4444
* @var float
4545
*/
46-
protected const MIN_SCORE = 0.5;
46+
protected const MIN_SCORE = 0.7;
4747

4848
/**
4949
* Constant used to see the random number generator.
@@ -77,7 +77,7 @@ protected function setUp() : void
7777
1 => new Circle(0.0, 0.0, 8.0, 1.0),
7878
], [0.9, 0.1]);
7979

80-
$this->estimator = new OneClassSVM(0.01, new Polynomial(4, 1e-3), true, 1e-4);
80+
$this->estimator = new OneClassSVM(0.3, new RBF(), true, 1e-4);
8181

8282
$this->metric = new FBeta();
8383

@@ -125,8 +125,8 @@ public function compatibility() : void
125125
public function params() : void
126126
{
127127
$expected = [
128-
'nu' => 0.01,
129-
'kernel' => new Polynomial(4, 1e-3),
128+
'nu' => 0.3,
129+
'kernel' => new RBF(),
130130
'shrinking' => true,
131131
'tolerance' => 0.0001,
132132
'cache size' => 100.0,

tests/NeuralNet/FeedForwardTest.php

+1-1
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ public function build() : void
8686
*/
8787
public function layers() : void
8888
{
89-
$this->assertCount(7, $this->network->layers());
89+
$this->assertCount(5, iterator_to_array($this->network->layers()));
9090
}
9191

9292
/**

0 commit comments

Comments
 (0)