Skip to content

Commit 0399060

Browse files
authored
fix(lib): Lower tensor memory usage (#228)
1 parent 80ef53c commit 0399060

File tree

1 file changed

+29
-25
lines changed

1 file changed

+29
-25
lines changed

src/loadCsv.ts

+29-25
Original file line numberDiff line numberDiff line change
@@ -76,33 +76,37 @@ const loadCsv = (
7676
);
7777
}
7878

79-
let features = tf.tensor(tables.features);
80-
let testFeatures = tf.tensor(tables.testFeatures);
81-
82-
const labels = tf.tensor(tables.labels);
83-
const testLabels = tf.tensor(tables.testLabels);
84-
85-
if (columnsToStandardise.length > 0) {
86-
const result = standardise(
79+
return tf.tidy(() => {
80+
let features = tf.tensor(tables.features);
81+
let testFeatures = tf.tensor(tables.testFeatures);
82+
83+
const labels = tf.tensor(tables.labels);
84+
const testLabels = tf.tensor(tables.testLabels);
85+
86+
if (columnsToStandardise.length > 0) {
87+
const result = standardise(
88+
features,
89+
testFeatures,
90+
featureColumnNames.map((c) => columnsToStandardise.includes(c))
91+
);
92+
features = result.features;
93+
testFeatures = result.testFeatures;
94+
}
95+
96+
if (prependOnes) {
97+
features = tf.ones([features.shape[0], 1]).concat(features, 1);
98+
testFeatures = tf
99+
.ones([testFeatures.shape[0], 1])
100+
.concat(testFeatures, 1);
101+
}
102+
103+
return {
87104
features,
105+
labels,
88106
testFeatures,
89-
featureColumnNames.map((c) => columnsToStandardise.includes(c))
90-
);
91-
features = result.features;
92-
testFeatures = result.testFeatures;
93-
}
94-
95-
if (prependOnes) {
96-
features = tf.ones([features.shape[0], 1]).concat(features, 1);
97-
testFeatures = tf.ones([testFeatures.shape[0], 1]).concat(testFeatures, 1);
98-
}
99-
100-
return {
101-
features,
102-
labels,
103-
testFeatures,
104-
testLabels,
105-
};
107+
testLabels,
108+
};
109+
});
106110
};
107111

108112
export default loadCsv;

0 commit comments

Comments
 (0)