Skip to content

Commit 93c9296

Browse files
authored
feat(lib): Allow per-column standardisation (#202)
BREAKING CHANGE: standardise does not allow boolean values anymore, meand and variance are no longer returned
1 parent 34a5794 commit 93c9296

7 files changed

+211
-102
lines changed

README.md

+5-15
Original file line numberDiff line numberDiff line change
@@ -53,33 +53,23 @@ Advanced usage:
5353
```js
5454
import loadCsv from 'tensorflow-load-csv';
5555

56-
const {
57-
features,
58-
labels,
59-
testFeatures,
60-
testLabels,
61-
mean, // tensor holding mean of features, ignores testFeatures
62-
variance, // tensor holding variance of features, ignores testFeatures
63-
} = loadCsv('./data.csv', {
56+
const { features, labels, testFeatures, testLabels } = loadCsv('./data.csv', {
6457
featureColumns: ['lat', 'lng', 'height'],
6558
labelColumns: ['temperature'],
6659
mappings: {
6760
height: (ft) => ft * 0.3048, // feet to meters
6861
temperature: (f) => (f < 50 ? [1, 0] : [0, 1]), // cold or hot classification
6962
}, // Map values based on which column they are in before they are loaded into tensors.
7063
flatten: ['temperature'], // Flattens the array result of a mapping so that each member is a new column.
71-
shuffle: true, // Pass true to shuffle with a fixed seed, or a string to use it as a seed for the shuffling.
72-
splitTest: true, // Splits your data in half. You can also provide a certain row count for the test data, or a percentage string (e.g. 10%).
73-
prependOnes: true, // Prepends a column of 1s to your features and testFeatures tensors, useful for linear regression.
74-
standardise: true, // Calculates mean and variance for each feature column using data only in features, then standardises the values in features and testFeatures. Does not touch labels.
64+
shuffle: true, // Pass true to shuffle with a fixed seed, or a string to use as a seed for the shuffling.
65+
splitTest: true, // Splits your data in half. You can also provide a certain row count for the test data, or a percentage string (e.g. '10%').
66+
standardise: ['height'], // Calculates mean and variance for each feature column using data only in features, then standardises the values in features and testFeatures. Does not touch labels.
67+
prependOnes: true, // Prepends a column of 1s to your features and testFeatures tensors, useful for regression problems.
7568
});
7669

7770
features.print();
7871
labels.print();
7972

8073
testFeatures.print();
8174
testLabels.print();
82-
83-
mean.print();
84-
variance.print();
8575
```

jest.config.js

+4-4
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@ module.exports = {
1212
coveragePathIgnorePatterns: ['/node_modules/', '/tests/'],
1313
coverageThreshold: {
1414
global: {
15-
branches: 90,
16-
functions: 95,
17-
lines: 95,
18-
statements: 95,
15+
branches: 100,
16+
functions: 100,
17+
lines: 100,
18+
statements: 100,
1919
},
2020
},
2121
collectCoverageFrom: ['src/*.{js,ts}'],

src/loadCsv.models.ts

+2-2
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@ export interface CsvReadOptions {
3535
*/
3636
prependOnes?: boolean;
3737
/**
38-
* If true, calculates mean and variance for each feature column using data only in features, then standardises the values in features and testFeatures. Does not touch labels.
38+
* Calculates mean and variance for given columns using data only in features, then standardises the values in features and testFeatures. Does not touch labels.
3939
*/
40-
standardise?: boolean | string[];
40+
standardise?: string[];
4141
/**
4242
* Useful for classification problems, if you have mapped a column's values to an array using `mappings`, you can choose to flatten it here so that each element becomes a new column.
4343
*

src/loadCsv.ts

+16-12
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import filterColumns from './filterColumns';
77
import splitTestData from './splitTestData';
88
import applyMappings from './applyMappings';
99
import shuffle from './shuffle';
10+
import standardise from './standardise';
1011

1112
const defaultShuffleSeed = 'mncv9340ur';
1213

@@ -31,10 +32,10 @@ const loadCsv = (
3132
featureColumns,
3233
labelColumns,
3334
mappings = {},
34-
shuffle: shouldShuffle = false,
35+
shuffle: shouldShuffleOrSeed = false,
3536
splitTest,
3637
prependOnes = false,
37-
standardise = false,
38+
standardise: columnsToStandardise = [],
3839
flatten = [],
3940
}: CsvReadOptions
4041
) => {
@@ -54,11 +55,13 @@ const loadCsv = (
5455
};
5556

5657
tables.labels.shift();
57-
tables.features.shift();
58+
const featureColumnNames = tables.features.shift() as string[];
5859

59-
if (shouldShuffle) {
60+
if (shouldShuffleOrSeed) {
6061
const seed =
61-
typeof shouldShuffle === 'string' ? shouldShuffle : defaultShuffleSeed;
62+
typeof shouldShuffleOrSeed === 'string'
63+
? shouldShuffleOrSeed
64+
: defaultShuffleSeed;
6265
tables.features = shuffle(tables.features, seed);
6366
tables.labels = shuffle(tables.labels, seed);
6467
}
@@ -76,11 +79,14 @@ const loadCsv = (
7679
const labels = tf.tensor(tables.labels);
7780
const testLabels = tf.tensor(tables.testLabels);
7881

79-
const { mean, variance } = tf.moments(features, 0);
80-
81-
if (standardise) {
82-
features = features.sub(mean).div(variance.pow(0.5));
83-
testFeatures = testFeatures.sub(mean).div(variance.pow(0.5));
82+
if (columnsToStandardise.length > 0) {
83+
const result = standardise(
84+
features,
85+
testFeatures,
86+
featureColumnNames.map((c) => columnsToStandardise.includes(c))
87+
);
88+
features = result.features;
89+
testFeatures = result.testFeatures;
8490
}
8591

8692
if (prependOnes) {
@@ -93,8 +99,6 @@ const loadCsv = (
9399
labels,
94100
testFeatures,
95101
testLabels,
96-
mean,
97-
variance,
98102
};
99103
};
100104

src/standardise.ts

+68
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import * as tf from '@tensorflow/tfjs';
2+
3+
const standardise = (
4+
features: tf.Tensor<tf.Rank>,
5+
testFeatures: tf.Tensor<tf.Rank>,
6+
indicesToStandardise: boolean[]
7+
): {
8+
features: tf.Tensor<tf.Rank>;
9+
testFeatures: tf.Tensor<tf.Rank>;
10+
} => {
11+
let newFeatures, newTestFeatures;
12+
13+
if (features.shape.length < 2 || testFeatures.shape.length < 2) {
14+
throw new Error(
15+
'features and testFeatures must have at least two dimensions'
16+
);
17+
}
18+
19+
if (features.shape[1] !== testFeatures.shape[1]) {
20+
throw new Error(
21+
'Length of the second dimension of features and testFeatures must be the same'
22+
);
23+
}
24+
25+
if (features.shape[1] !== indicesToStandardise.length) {
26+
throw new Error(
27+
'Length of indicesToStandardise must match the length of the second dimension of features'
28+
);
29+
}
30+
31+
if (features.shape[1] === 0) {
32+
return { features, testFeatures };
33+
}
34+
35+
for (let i = 0; i < features.shape[1]; i++) {
36+
let featureSlice = features.slice([0, i], [features.shape[0], 1]);
37+
let testFeatureSlice = testFeatures.slice(
38+
[0, i],
39+
[testFeatures.shape[0], 1]
40+
);
41+
if (indicesToStandardise[i]) {
42+
const sliceMoments = tf.moments(featureSlice);
43+
featureSlice = featureSlice
44+
.sub(sliceMoments.mean)
45+
.div(sliceMoments.variance.pow(0.5));
46+
testFeatureSlice = testFeatureSlice
47+
.sub(sliceMoments.mean)
48+
.div(sliceMoments.variance.pow(0.5));
49+
}
50+
if (!newFeatures) {
51+
newFeatures = featureSlice;
52+
} else {
53+
newFeatures = newFeatures.concat(featureSlice, 1);
54+
}
55+
if (!newTestFeatures) {
56+
newTestFeatures = testFeatureSlice;
57+
} else {
58+
newTestFeatures = newTestFeatures.concat(testFeatureSlice, 1);
59+
}
60+
}
61+
62+
return {
63+
features: newFeatures as tf.Tensor<tf.Rank>,
64+
testFeatures: newTestFeatures as tf.Tensor<tf.Rank>,
65+
};
66+
};
67+
68+
export default standardise;

tests/loadCsv.test.ts

+28-69
Original file line numberDiff line numberDiff line change
@@ -46,95 +46,54 @@ test('Loading with only the required options should work', () => {
4646
]);
4747
});
4848

49-
test('Shuffling should work and preserve feature - label pairs', () => {
50-
const { features, labels } = loadCsv(filePath, {
49+
test('Loading with all extra options should work', () => {
50+
const { features, labels, testFeatures, testLabels } = loadCsv(filePath, {
5151
featureColumns: ['lat', 'lng'],
5252
labelColumns: ['country'],
53+
mappings: {
54+
country: (name) => (name as string).toUpperCase(),
55+
lat: (lat) => ((lat as number) > 0 ? [0, 1] : [1, 0]), // South or North classification
56+
},
57+
flatten: ['lat'],
5358
shuffle: true,
59+
splitTest: true,
60+
prependOnes: true,
61+
standardise: ['lng'],
5462
});
5563
// @ts-ignore
5664
expect(features.arraySync()).toBeDeepCloseTo(
5765
[
58-
[5, 40.34],
59-
[0.234, 1.47],
60-
[-93.2, 103.34],
61-
[102, -164],
66+
[1, 0, 1, 1],
67+
[1, 0, 1, -1],
6268
],
6369
3
6470
);
65-
expect(labels.arraySync()).toMatchObject([
66-
['Landistan'],
67-
['SomeCountria'],
68-
['SomeOtherCountria'],
69-
['Landotzka'],
70-
]);
71-
});
72-
73-
test('Shuffling with a custom seed should work', () => {
74-
const { features, labels } = loadCsv(filePath, {
75-
featureColumns: ['lat', 'lng'],
76-
labelColumns: ['country'],
77-
shuffle: 'hello-is-it-me-you-are-looking-for',
78-
});
71+
expect(labels.arraySync()).toMatchObject([['LANDISTAN'], ['SOMECOUNTRIA']]);
7972
// @ts-ignore
80-
expect(features.arraySync()).toBeDeepCloseTo(
73+
expect(testFeatures.arraySync()).toBeDeepCloseTo(
8174
[
82-
[-93.2, 103.34],
83-
[102, -164],
84-
[5, 40.34],
85-
[0.234, 1.47],
75+
[1, 1, 0, 4.241],
76+
[1, 0, 1, -9.514],
8677
],
8778
3
8879
);
89-
expect(labels.arraySync()).toMatchObject([
90-
['SomeOtherCountria'],
91-
['Landotzka'],
92-
['Landistan'],
93-
['SomeCountria'],
80+
expect(testLabels.arraySync()).toMatchObject([
81+
['SOMEOTHERCOUNTRIA'],
82+
['LANDOTZKA'],
9483
]);
9584
});
9685

97-
test('Loading with all extra options other than shuffle as true should work', () => {
98-
const {
99-
features,
100-
labels,
101-
testFeatures,
102-
testLabels,
103-
mean,
104-
variance,
105-
} = loadCsv(filePath, {
86+
test('Loading with custom seed should use the custom seed', () => {
87+
const { features } = loadCsv(filePath, {
10688
featureColumns: ['lat', 'lng'],
10789
labelColumns: ['country'],
108-
mappings: {
109-
country: (name) => (name as string).toUpperCase(),
110-
},
111-
splitTest: true,
112-
prependOnes: true,
113-
standardise: true,
90+
shuffle: true,
91+
});
92+
const { features: featuresCustom } = loadCsv(filePath, {
93+
featureColumns: ['lat', 'lng'],
94+
labelColumns: ['country'],
95+
shuffle: 'sdhjhdf',
11496
});
11597
// @ts-ignore
116-
expect(features.arraySync()).toBeDeepCloseTo(
117-
[
118-
[1, 1, -1],
119-
[1, -1, 1],
120-
],
121-
3
122-
);
123-
expect(labels.arraySync()).toMatchObject([
124-
['SOMECOUNTRIA'],
125-
['SOMEOTHERCOUNTRIA'],
126-
]);
127-
// @ts-ignore
128-
expect(testFeatures.arraySync()).toBeDeepCloseTo(
129-
[
130-
[1, 1.102, -0.236],
131-
[1, 3.178, -4.248],
132-
],
133-
3
134-
);
135-
expect(testLabels.arraySync()).toMatchObject([['LANDISTAN'], ['LANDOTZKA']]);
136-
// @ts-ignore
137-
expect(mean.arraySync()).toBeDeepCloseTo([-46.482, 52.404], 3);
138-
// @ts-ignore
139-
expect(variance.arraySync()).toBeDeepCloseTo([2182.478, 2594.374], 3);
98+
expect(features).not.toBeDeepCloseTo(featuresCustom, 1);
14099
});

0 commit comments

Comments
 (0)