Skip to content

Commit 717f236

Browse files
authored
Implement NumberConverter (#377)
1 parent 1e1d794 commit 717f236

File tree

10 files changed

+102
-8
lines changed

10 files changed

+102
-8
lines changed

Diff for: src/FeatureExtraction/TfIdfTransformer.php

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ public function fit(array $samples, ?array $targets = null): void
3030
}
3131
}
3232

33-
public function transform(array &$samples): void
33+
public function transform(array &$samples, ?array &$targets = null): void
3434
{
3535
foreach ($samples as &$sample) {
3636
foreach ($sample as $index => &$feature) {

Diff for: src/FeatureExtraction/TokenCountVectorizer.php

+1-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ public function fit(array $samples, ?array $targets = null): void
4646
$this->buildVocabulary($samples);
4747
}
4848

49-
public function transform(array &$samples): void
49+
public function transform(array &$samples, ?array &$targets = null): void
5050
{
5151
array_walk($samples, function (string &$sample): void {
5252
$this->transformSample($sample);

Diff for: src/FeatureSelection/SelectKBest.php

+1-1
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ public function fit(array $samples, ?array $targets = null): void
5656
$this->keepColumns = array_slice($sorted, 0, $this->k, true);
5757
}
5858

59-
public function transform(array &$samples): void
59+
public function transform(array &$samples, ?array &$targets = null): void
6060
{
6161
if ($this->keepColumns === null) {
6262
return;

Diff for: src/FeatureSelection/VarianceThreshold.php

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ public function fit(array $samples, ?array $targets = null): void
4848
}
4949
}
5050

51-
public function transform(array &$samples): void
51+
public function transform(array &$samples, ?array &$targets = null): void
5252
{
5353
foreach ($samples as &$sample) {
5454
$sample = array_values(array_intersect_key($sample, $this->keepColumns));

Diff for: src/Preprocessing/Imputer.php

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ public function fit(array $samples, ?array $targets = null): void
4949
$this->samples = $samples;
5050
}
5151

52-
public function transform(array &$samples): void
52+
public function transform(array &$samples, ?array &$targets = null): void
5353
{
5454
if ($this->samples === []) {
5555
throw new InvalidOperationException('Missing training samples for Imputer.');

Diff for: src/Preprocessing/LabelEncoder.php

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ public function fit(array $samples, ?array $targets = null): void
2222
}
2323
}
2424

25-
public function transform(array &$samples): void
25+
public function transform(array &$samples, ?array &$targets = null): void
2626
{
2727
foreach ($samples as &$sample) {
2828
$sample = $this->classes[(string) $sample];

Diff for: src/Preprocessing/Normalizer.php

+1-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ public function fit(array $samples, ?array $targets = null): void
6666
$this->fitted = true;
6767
}
6868

69-
public function transform(array &$samples): void
69+
public function transform(array &$samples, ?array &$targets = null): void
7070
{
7171
$methods = [
7272
self::NORM_L1 => 'normalizeL1',

Diff for: src/Preprocessing/NumberConverter.php

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
namespace Phpml\Preprocessing;
6+
7+
final class NumberConverter implements Preprocessor
8+
{
9+
/**
10+
* @var bool
11+
*/
12+
private $transformTargets;
13+
14+
/**
15+
* @var mixed
16+
*/
17+
private $nonNumericPlaceholder;
18+
19+
/**
20+
* @param mixed $nonNumericPlaceholder
21+
*/
22+
public function __construct(bool $transformTargets = false, $nonNumericPlaceholder = null)
23+
{
24+
$this->transformTargets = $transformTargets;
25+
$this->nonNumericPlaceholder = $nonNumericPlaceholder;
26+
}
27+
28+
public function fit(array $samples, ?array $targets = null): void
29+
{
30+
//nothing to do
31+
}
32+
33+
public function transform(array &$samples, ?array &$targets = null): void
34+
{
35+
foreach ($samples as &$sample) {
36+
foreach ($sample as &$feature) {
37+
$feature = is_numeric($feature) ? (float) $feature : $this->nonNumericPlaceholder;
38+
}
39+
}
40+
41+
if ($this->transformTargets && is_array($targets)) {
42+
foreach ($targets as &$target) {
43+
$target = is_numeric($target) ? (float) $target : $this->nonNumericPlaceholder;
44+
}
45+
}
46+
}
47+
}

Diff for: src/Transformer.php

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,5 @@ interface Transformer
1111
*/
1212
public function fit(array $samples, ?array $targets = null): void;
1313

14-
public function transform(array &$samples): void;
14+
public function transform(array &$samples, ?array &$targets = null): void;
1515
}

Diff for: tests/Preprocessing/NumberConverterTest.php

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
namespace Phpml\Tests\Preprocessing;
6+
7+
use Phpml\Preprocessing\NumberConverter;
8+
use PHPUnit\Framework\TestCase;
9+
10+
final class NumberConverterTest extends TestCase
11+
{
12+
public function testConvertSamples(): void
13+
{
14+
$samples = [['1', '-4'], ['2.0', 3.0], ['3', '112.5'], ['5', '0.0004']];
15+
$targets = ['1', '1', '2', '2'];
16+
17+
$converter = new NumberConverter();
18+
$converter->transform($samples, $targets);
19+
20+
self::assertEquals([[1.0, -4.0], [2.0, 3.0], [3.0, 112.5], [5.0, 0.0004]], $samples);
21+
self::assertEquals(['1', '1', '2', '2'], $targets);
22+
}
23+
24+
public function testConvertTargets(): void
25+
{
26+
$samples = [['1', '-4'], ['2.0', 3.0], ['3', '112.5'], ['5', '0.0004']];
27+
$targets = ['1', '1', '2', 'not'];
28+
29+
$converter = new NumberConverter(true);
30+
$converter->transform($samples, $targets);
31+
32+
self::assertEquals([[1.0, -4.0], [2.0, 3.0], [3.0, 112.5], [5.0, 0.0004]], $samples);
33+
self::assertEquals([1.0, 1.0, 2.0, null], $targets);
34+
}
35+
36+
public function testConvertWithPlaceholder(): void
37+
{
38+
$samples = [['invalid'], ['13.5']];
39+
$targets = ['invalid', '2'];
40+
41+
$converter = new NumberConverter(true, 'missing');
42+
$converter->transform($samples, $targets);
43+
44+
self::assertEquals([['missing'], [13.5]], $samples);
45+
self::assertEquals(['missing', 2.0], $targets);
46+
}
47+
}

0 commit comments

Comments
 (0)