Skip to content

Commit 4590d5c

Browse files
authored
Implement OneHotEncoder (#384)
1 parent 3baf152 commit 4590d5c

File tree

4 files changed

+141
-1
lines changed

4 files changed

+141
-1
lines changed

Diff for: CHANGELOG.md

+8-1
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,16 @@ All notable changes to this project will be documented in this file.
44
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
55
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
66

7-
## [Unreleased]
7+
## [0.9.0] - Unreleased
88
### Added
99
- [Preprocessing] Implement LabelEncoder
10+
- [Preprocessing] Implement ColumnFilter
11+
- [Preprocessing] Implement LambdaTransformer
12+
- [Preprocessing] Implement NumberConverter
13+
- [Preprocessing] Implement OneHotEncoder
14+
- [Workflow] Implement FeatureUnion
15+
- [Metric] Add Regression metrics: meanSquaredError, meanSquaredLogarithmicError, meanAbsoluteError, medianAbsoluteError, r2Score, maxError
16+
- [Regression] Implement DecisionTreeRegressor
1017

1118
## [0.8.0] - 2019-03-20
1219
### Added

Diff for: README.md

+1
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ Public datasets are available in a separate repository [php-ai/php-ml-datasets](
107107
* LambdaTransformer
108108
* NumberConverter
109109
* ColumnFilter
110+
* OneHotEncoder
110111
* Feature Extraction
111112
* [Token Count Vectorizer](http://php-ml.readthedocs.io/en/latest/machine-learning/feature-extraction/token-count-vectorizer/)
112113
* NGramTokenizer

Diff for: src/Preprocessing/OneHotEncoder.php

+66
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
namespace Phpml\Preprocessing;
6+
7+
use Phpml\Exception\InvalidArgumentException;
8+
9+
final class OneHotEncoder implements Preprocessor
10+
{
11+
/**
12+
* @var bool
13+
*/
14+
private $ignoreUnknown;
15+
16+
/**
17+
* @var array
18+
*/
19+
private $categories = [];
20+
21+
public function __construct(bool $ignoreUnknown = false)
22+
{
23+
$this->ignoreUnknown = $ignoreUnknown;
24+
}
25+
26+
public function fit(array $samples, ?array $targets = null): void
27+
{
28+
foreach (array_keys(array_values(current($samples))) as $column) {
29+
$this->fitColumn($column, array_values(array_unique(array_column($samples, $column))));
30+
}
31+
}
32+
33+
public function transform(array &$samples, ?array &$targets = null): void
34+
{
35+
foreach ($samples as &$sample) {
36+
$sample = $this->transformSample(array_values($sample));
37+
}
38+
}
39+
40+
private function fitColumn(int $column, array $values): void
41+
{
42+
$count = count($values);
43+
foreach ($values as $index => $value) {
44+
$map = array_fill(0, $count, 0);
45+
$map[$index] = 1;
46+
$this->categories[$column][$value] = $map;
47+
}
48+
}
49+
50+
private function transformSample(array $sample): array
51+
{
52+
$encoded = [];
53+
foreach ($sample as $column => $feature) {
54+
if (!isset($this->categories[$column][$feature]) && !$this->ignoreUnknown) {
55+
throw new InvalidArgumentException(sprintf('Missing category "%s" for column %s in trained encoder', $feature, $column));
56+
}
57+
58+
$encoded = array_merge(
59+
$encoded,
60+
$this->categories[$column][$feature] ?? array_fill(0, count($this->categories[$column]), 0)
61+
);
62+
}
63+
64+
return $encoded;
65+
}
66+
}

Diff for: tests/Preprocessing/OneHotEncoderTest.php

+66
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
namespace Phpml\Tests\Preprocessing;
6+
7+
use Phpml\Exception\InvalidArgumentException;
8+
use Phpml\Preprocessing\OneHotEncoder;
9+
use PHPUnit\Framework\TestCase;
10+
11+
final class OneHotEncoderTest extends TestCase
12+
{
13+
public function testOneHotEncodingWithoutIgnoreUnknown(): void
14+
{
15+
$samples = [
16+
['fish', 'New York', 'regression'],
17+
['dog', 'New York', 'regression'],
18+
['fish', 'Vancouver', 'classification'],
19+
['dog', 'Vancouver', 'regression'],
20+
];
21+
22+
$encoder = new OneHotEncoder();
23+
$encoder->fit($samples);
24+
$encoder->transform($samples);
25+
26+
self::assertEquals([
27+
[1, 0, 1, 0, 1, 0],
28+
[0, 1, 1, 0, 1, 0],
29+
[1, 0, 0, 1, 0, 1],
30+
[0, 1, 0, 1, 1, 0],
31+
], $samples);
32+
}
33+
34+
public function testThrowExceptionWhenUnknownCategory(): void
35+
{
36+
$encoder = new OneHotEncoder();
37+
$encoder->fit([
38+
['fish', 'New York', 'regression'],
39+
['dog', 'New York', 'regression'],
40+
['fish', 'Vancouver', 'classification'],
41+
['dog', 'Vancouver', 'regression'],
42+
]);
43+
$samples = [['fish', 'New York', 'ka boom']];
44+
45+
$this->expectException(InvalidArgumentException::class);
46+
47+
$encoder->transform($samples);
48+
}
49+
50+
public function testIgnoreMissingCategory(): void
51+
{
52+
$encoder = new OneHotEncoder(true);
53+
$encoder->fit([
54+
['fish', 'New York', 'regression'],
55+
['dog', 'New York', 'regression'],
56+
['fish', 'Vancouver', 'classification'],
57+
['dog', 'Vancouver', 'regression'],
58+
]);
59+
$samples = [['ka', 'boom', 'riko']];
60+
$encoder->transform($samples);
61+
62+
self::assertEquals([
63+
[0, 0, 0, 0, 0, 0],
64+
], $samples);
65+
}
66+
}

0 commit comments

Comments
 (0)