From 2e4fb2b1835cd48bd46e80175d60d25d888ae585 Mon Sep 17 00:00:00 2001 From: Devin Jean Date: Thu, 9 May 2024 08:06:41 -0500 Subject: [PATCH] Matlab fixes (#238) * encode/decode column-major * parse char tensors * Fix code formatting --------- Co-authored-by: Format Bot --- .vscode/settings.json | 10 ++ src/procedures/matlab/keep-warm.js | 44 ------ src/procedures/matlab/matlab.js | 166 ++++++++++++++------- test/procedures/matlab/matlab.spec.js | 201 +++++++++++++++++++++----- 4 files changed, 286 insertions(+), 135 deletions(-) delete mode 100644 src/procedures/matlab/keep-warm.js diff --git a/.vscode/settings.json b/.vscode/settings.json index 420813ea..a536473a 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -11,5 +11,15 @@ "touchpad", "touchpads", "upsert" + ], + "cSpell.words": [ + "colcat", + "feval", + "mwdata", + "mwsize", + "mwtype", + "nargout", + "unflatten", + "vals" ] } diff --git a/src/procedures/matlab/keep-warm.js b/src/procedures/matlab/keep-warm.js deleted file mode 100644 index 6fa7c423..00000000 --- a/src/procedures/matlab/keep-warm.js +++ /dev/null @@ -1,44 +0,0 @@ -const { setTimeout, setInterval, clearInterval } = require("../../timers"); -const seconds = 1000; -const minutes = 60 * seconds; - -/** - * This is made to "keep the matlab cluster warm" and prevent worker nodes from being - * released back to the cluster. This class takes an action and, once started, call it - * on a given interval for a specified duration. - * - * Essentially, it is a stateful setInterval w/ an end duration that restarts on each call - * - * Please use responsibly. - */ -class KeepWarm { - constructor(action, interval = 10 * seconds) { - this.action = action; - this.interval = interval; - this.currentInterval = null; - } - - async keepWarm(duration = 15 * minutes) { - this.stop(); - const intervalId = setInterval(this.action, this.interval); - this.currentInterval = intervalId; - setTimeout(() => { - if (this.currentInterval === intervalId) { - this.stop(); - } - }, duration); - } - - stop() { - if (this.currentInterval) { - clearInterval(this.currentInterval); - this.currentInterval = null; - } - } - - isStillWarm() { - return !!this.currentInterval; - } -} - -module.exports = KeepWarm; diff --git a/src/procedures/matlab/matlab.js b/src/procedures/matlab/matlab.js index 785abaa2..03a09fbd 100644 --- a/src/procedures/matlab/matlab.js +++ b/src/procedures/matlab/matlab.js @@ -3,15 +3,14 @@ * * For more information, check out https://www.mathworks.com/products/matlab.html. * - * @service * @alpha + * @service */ const logger = require("../utils/logger")("matlab"); const axios = require("axios"); const { MATLAB_KEY, MATLAB_URL = "" } = process.env; -const KeepWarm = require("./keep-warm"); const request = axios.create({ headers: { "X-NetsBlox-Auth-Token": MATLAB_KEY, @@ -21,18 +20,6 @@ const request = axios.create({ const MATLAB = {}; MATLAB.serviceName = "MATLAB"; -const warmer = new KeepWarm(async () => { - logger.info("warming is disabled"); - return; - - const body = [...new Array(10)].map(() => ({ - function: "ver", - arguments: [], - nargout: 1, - })); - request.post(`${MATLAB_URL}/feval-fast`, body); -}); - async function requestWithRetry(url, body, numRetries = 0) { try { return await request.post(url, body, { @@ -50,9 +37,11 @@ async function requestWithRetry(url, body, numRetries = 0) { * Evaluate a MATLAB function with the given arguments and number of return * values. * - * @param{String} fn Name of the function to call - * @param{Array=} args arguments to pass to the function - * @param{BoundedInteger<1>=} numReturnValues Number of return values expected. + * For a list of all MATLAB functions, see the `Reference Manual `__. + * + * @param {String} fn Name of the function to call + * @param {Array=} args arguments to pass to the function + * @param {BoundedInteger<1>=} numReturnValues Number of return values expected. */ MATLAB.function = async function (fn, args = [], numReturnValues = 1) { const body = [{ @@ -75,7 +64,7 @@ MATLAB.function = async function (fn, args = [], numReturnValues = 1) { JSON.stringify(resp.data) }`, ); - warmer.keepWarm(); + const results = resp.data.FEvalResponse; // TODO: add batching queue return this._parseResult(results[0]); @@ -97,8 +86,7 @@ MATLAB._parseArgument = function (arg) { arg = [arg]; } - const shape = MATLAB._shape(arg); - const flatValues = MATLAB._flatten(arg); + const [flatValues, shape] = MATLAB._flatten(arg); const mwtype = MATLAB._getMwType(flatValues); const mwdata = flatValues .map((v) => { @@ -118,7 +106,7 @@ MATLAB._parseArgument = function (arg) { return { mwdata, - mwsize: shape, + mwsize: shape.length >= 2 ? shape : [1, ...shape], mwtype, }; }; @@ -149,28 +137,29 @@ MATLAB._parseResult = (result) => { }; MATLAB._parseResultData = (result) => { - // reshape the data let data = result.mwdata; + let size = result.mwsize; if (!Array.isArray(data)) { data = [data]; } - return MATLAB._squeeze( - MATLAB._reshape(data, result.mwsize), - ); -}; -MATLAB._take = function* (iter, num) { - let chunk = []; - for (const v of iter) { - chunk.push(v); - if (chunk.length === num) { - yield chunk; - chunk = []; + if (result.mwtype === "char") { + if ( + !Array.isArray(result.mwdata) || result.mwdata.length !== 1 || + typeof (result.mwdata[0]) !== "string" + ) { + throw Error("error parsing character string result"); } + function rejoin(x) { + if (x.length !== 0 && !Array.isArray(x[0])) { + return x.join(""); + } + return x.map((y) => rejoin(y)); + } + return MATLAB._squeeze(rejoin(MATLAB._unflatten([...data[0]], size))); } - if (chunk.length) { - return chunk; - } + + return MATLAB._squeeze(MATLAB._unflatten(data, size)); }; MATLAB._squeeze = (data) => { @@ -180,38 +169,103 @@ MATLAB._squeeze = (data) => { return data; }; -MATLAB._reshape = (data, shape) => { - return [ - ...shape.reverse().reduce( - (iterable, num) => MATLAB._take(iterable, num), - data, - ), - ].pop(); +MATLAB._product = (vals) => { + let res = 1; + for (const v of vals) { + res *= v; + } + return res; +}; + +MATLAB._colcat = (cols) => { + if (cols.length === 0) { + return []; + } + if (!Array.isArray(cols[0])) { + return cols.reduce((acc, v) => acc.concat(v), []); + } + + const rows = cols[0].length; + const res = []; + for (let i = 0; i < rows; ++i) { + res.push(MATLAB._colcat(cols.map((row) => row[i]))); + } + return res; +}; + +MATLAB._unflatten = (data, shape) => { + if (!Array.isArray(data)) { + throw Error("internal usage error"); + } + + if (shape.length <= 1) { + return data; + } + + const colCount = shape[shape.length - 1]; + const colShape = shape.slice(0, shape.length - 1); + const colSize = MATLAB._product(colShape); + + const cols = []; + for (let i = 0; i < colCount; ++i) { + cols.push( + MATLAB._unflatten(data.slice(i * colSize, (i + 1) * colSize), colShape), + ); + } + return MATLAB._colcat(cols); +}; + +MATLAB._deepEq = (a, b) => { + if (Array.isArray(a) && Array.isArray(b)) { + return a.length === b.length && a.every((x, i) => MATLAB._deepEq(x, b[i])); + } + return a === b; }; MATLAB._shape = (data) => { - const shape = []; - let item = data; - while (Array.isArray(item)) { - shape.push(item.length); - item = item[0]; + if (!Array.isArray(data)) { + throw Error("internal usage error"); } - while (shape.length < 2) { - shape.unshift(1); + if (data.length === 0 || !Array.isArray(data[0])) { + if (data.some((x) => Array.isArray(x))) { + throw Error("input must be rectangular"); + } + return [data.length]; + } + if (data.some((x) => !Array.isArray(x))) { + throw Error("input must be rectangular"); } - return shape; + const shapes = data.map((x) => MATLAB._shape(x)); + if (shapes.some((x) => !MATLAB._deepEq(x, shapes[0]))) { + throw Error("input must be rectangular"); + } + return [data.length, ...shapes[0]]; }; +// returns [flattened, shape] so that shape can be reused MATLAB._flatten = (data) => { - return data.flatMap((item) => { - if (Array.isArray(item)) { - return MATLAB._flatten(item); + const shape = MATLAB._shape(data); + if (shape.some((x) => x === 0)) return [[], shape]; + + const shapeCumProd = [1, ...shape]; + for (let i = 1; i < shapeCumProd.length; ++i) { + shapeCumProd[i] *= shapeCumProd[i - 1]; + } + + const res = new Array(shapeCumProd[shapeCumProd.length - 1]); + function visit(x, pos, depth) { + if (depth === shape.length) { + res[pos] = x; } else { - return item; + for (let i = 0; i < x.length; ++i) { + visit(x[i], pos + i * shapeCumProd[depth], depth + 1); + } } - }); + } + visit(data, 0, 0); + return [res, shape]; }; MATLAB.isSupported = () => { diff --git a/test/procedures/matlab/matlab.spec.js b/test/procedures/matlab/matlab.spec.js index 3ff14b2a..4167eb05 100644 --- a/test/procedures/matlab/matlab.spec.js +++ b/test/procedures/matlab/matlab.spec.js @@ -113,6 +113,121 @@ describe(utils.suiteName(__filename), function () { /CLASSNAME argument must be a class/, ); }); + + it("should parse strings - 1", function () { + const result = MATLAB._parseResult( + { + "results": [{ + "mwdata": ["hello world"], + "mwsize": [1, 1], + "mwtype": "string", + }], + "isError": false, + "uuid": "", + "messageFaults": [], + }, + ); + const expected = "hello world"; + assert.deepEqual(result, expected); + }); + it("should parse strings - 2", function () { + const result = MATLAB._parseResult( + { + "results": [{ + "mwdata": ["hellotest", "worldtest"], + "mwsize": [1, 2], + "mwtype": "string", + }], + "isError": false, + "uuid": "", + "messageFaults": [], + }, + ); + const expected = ["hellotest", "worldtest"]; + assert.deepEqual(result, expected); + }); + + it("should parse in column major order - 1", function () { + const result = MATLAB._parseResult( + { + "results": [{ + "mwdata": [1, 1, 1, 1, 1, 1, 1, 1, 1], + "mwsize": [3, 3], + "mwtype": "double", + }], + "isError": false, + "uuid": "", + "messageFaults": [], + }, + ); + const expected = [[1, 1, 1], [1, 1, 1], [1, 1, 1]]; + assert.deepEqual(result, expected); + }); + it("should parse in column major order - 2", function () { + const result = MATLAB._parseResult( + { + "results": [{ + "mwdata": [1, 1, 1, 0, 1, 1, 0, 0, 1], + "mwsize": [3, 3], + "mwtype": "double", + }], + "isError": false, + "uuid": "", + "messageFaults": [], + }, + ); + const expected = [[1, 0, 0], [1, 1, 0], [1, 1, 1]]; + assert.deepEqual(result, expected); + }); + it("should parse in column major order - 3", function () { + const result = MATLAB._parseResult( + { + "results": [{ + "mwdata": [1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0], + "mwsize": [3, 4], + "mwtype": "double", + }], + "isError": false, + "uuid": "", + "messageFaults": [], + }, + ); + const expected = [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0]]; + assert.deepEqual(result, expected); + }); + it("should parse in column major order - 4", function () { + const result = MATLAB._parseResult( + { + "results": [{ + "mwdata": [140, 320, 146, 335], + "mwsize": [2, 2], + "mwtype": "double", + }], + "isError": false, + "uuid": "", + "messageFaults": [], + }, + ); + const expected = [[140, 146], [320, 335]]; + assert.deepEqual(result, expected); + }); + + it("should parse character tensors", function () { + const result = MATLAB._parseResult( + { + "results": [{ + "mwdata": ["00110101"], + "mwsize": [4, 2], + "mwtype": "char", + }], + "isError": false, + "uuid": "", + "messageFaults": [], + }, + ); + const expected = ["00", "01", "10", "11"]; + assert.deepEqual(result, expected); + }); }); describe("_parseArgument", function () { @@ -137,9 +252,9 @@ describe(utils.suiteName(__filename), function () { }); it("should coerce nested lists", function () { - const example = [["5", "6"], "7"]; + const example = [["5", "6"], ["7", "9"]]; const actual = MATLAB._parseArgument(example); - const expected = [5, 6, 7]; + const expected = [5, 7, 6, 9]; assert.deepEqual(actual.mwdata, expected); }); @@ -176,70 +291,72 @@ describe(utils.suiteName(__filename), function () { }); }); - describe("_flatten", function () { - it("should flatten recursively", function () { + describe("_flatten/_shape", function () { + it("should flatten recursively - 1", function () { const tensor = [ - [[1, 2]], - [[3, 4]], - [[5, 6]], - [[7, 8]], + [[1, 5]], + [[2, 6]], + [[3, 7]], + [[4, 8]], ]; - const flat = MATLAB._flatten(tensor); + const [flat, shape] = MATLAB._flatten(tensor); assert.deepEqual(flat, range(8)); + assert.deepEqual(shape, [4, 1, 2]); }); - }); - describe("_shape", function () { - it("should detect shape in [4,1,2] tensor", function () { + it("should flatten recursively - 2", function () { const tensor = [ - [[1, 2]], - [[1, 2]], - [[1, 2]], - [[1, 2]], + [[1, 9], [5, 13]], + [[2, 10], [6, 14]], + [[3, 11], [7, 15]], + [[4, 12], [8, 16]], ]; - const actual = MATLAB._shape(tensor); - assert.deepEqual(actual, [4, 1, 2]); + const [flat, shape] = MATLAB._flatten(tensor); + assert.deepEqual(flat, range(16)); + assert.deepEqual(shape, [4, 2, 2]); }); - it("should detect shape in [1 2] tensor", function () { + it("should flatten a 1x2 tensor", function () { const tensor = [1, 2]; - const actual = MATLAB._shape(tensor); - assert.deepEqual(actual, [1, 2]); + const [flat, shape] = MATLAB._flatten(tensor); + assert.deepEqual(flat, [1, 2]); + assert.deepEqual(shape, [2]); }); - it("should detect shape in [3, 4] tensor", function () { + it("should flatten a 3x4 tensor", function () { const tensor = [ [1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4], ]; - const actual = MATLAB._shape(tensor); - assert.deepEqual(actual, [3, 4]); + const [flat, shape] = MATLAB._flatten(tensor); + assert.deepEqual(flat, [1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4]); + assert.deepEqual(shape, [3, 4]); }); }); - describe("_reshape", function () { + describe("_unflatten", function () { it("should reconstruct a 2x2 matrix", function () { const example = [1, 2, 3, 4]; - const actual = MATLAB._reshape(example, [2, 2]); - const expected = [[1, 2], [3, 4]]; + const actual = MATLAB._unflatten(example, [2, 2]); + const expected = [[1, 3], [2, 4]]; assert.deepEqual(actual, expected); }); it("should reconstruct a 3x2 matrix", function () { const example = range(6); - const actual = MATLAB._reshape(example, [3, 2]); - const expected = [[1, 2], [3, 4], [5, 6]]; + const actual = MATLAB._unflatten(example, [3, 2]); + const expected = [[1, 4], [2, 5], [3, 6]]; assert.deepEqual(actual, expected); }); it("should reconstruct a 3x2x2 tensor", function () { const example = range(12); - const actual = MATLAB._reshape(example, [3, 2, 2]); + const actual = MATLAB._unflatten(example, [3, 2, 2]); const expected = [ - [[1, 2], [3, 4]], - [[5, 6], [7, 8]], - [[9, 10], [11, 12]], + [[1, 7], [4, 10]], + [[2, 8], [5, 11]], + [[3, 9], [6, 12]], ]; assert.deepEqual(actual, expected); }); @@ -249,10 +366,24 @@ describe(utils.suiteName(__filename), function () { [[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]], ]; - const shape = MATLAB._shape(input); - const reconstructed = MATLAB._reshape(MATLAB._flatten(input), shape); + const [flat, shape] = MATLAB._flatten(input); + assert.deepEqual(flat, [1, 7, 4, 10, 2, 8, 5, 11, 3, 9, 6, 12]); + assert.deepEqual(shape, [2, 2, 3]); + const reconstructed = MATLAB._unflatten(flat, shape); assert.deepEqual(input, reconstructed); }); + + it("should reconstruct a character tensor", function () { + const example = ["0", "0", "1", "1", "0", "1", "0", "1"]; + const actual = MATLAB._unflatten(example, [4, 2]); + const expected = [ + ["0", "0"], + ["0", "1"], + ["1", "0"], + ["1", "1"], + ]; + assert.deepEqual(actual, expected); + }); }); function range(end) {