Skip to content

Commit

Permalink
Matlab fixes (#238)
Browse files Browse the repository at this point in the history
* encode/decode column-major

* parse char tensors

* Fix code formatting

---------

Co-authored-by: Format Bot <[email protected]>
  • Loading branch information
dragazo and Format Bot authored May 9, 2024
1 parent fb0281e commit 2e4fb2b
Show file tree
Hide file tree
Showing 4 changed files with 286 additions and 135 deletions.
10 changes: 10 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,15 @@
"touchpad",
"touchpads",
"upsert"
],
"cSpell.words": [
"colcat",
"feval",
"mwdata",
"mwsize",
"mwtype",
"nargout",
"unflatten",
"vals"
]
}
44 changes: 0 additions & 44 deletions src/procedures/matlab/keep-warm.js

This file was deleted.

166 changes: 110 additions & 56 deletions src/procedures/matlab/matlab.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,14 @@
*
* For more information, check out https://www.mathworks.com/products/matlab.html.
*
* @service
* @alpha
* @service
*/

const logger = require("../utils/logger")("matlab");
const axios = require("axios");

const { MATLAB_KEY, MATLAB_URL = "" } = process.env;
const KeepWarm = require("./keep-warm");
const request = axios.create({
headers: {
"X-NetsBlox-Auth-Token": MATLAB_KEY,
Expand All @@ -21,18 +20,6 @@ const request = axios.create({
const MATLAB = {};
MATLAB.serviceName = "MATLAB";

const warmer = new KeepWarm(async () => {
logger.info("warming is disabled");
return;

const body = [...new Array(10)].map(() => ({
function: "ver",
arguments: [],
nargout: 1,
}));
request.post(`${MATLAB_URL}/feval-fast`, body);
});

async function requestWithRetry(url, body, numRetries = 0) {
try {
return await request.post(url, body, {
Expand All @@ -50,9 +37,11 @@ async function requestWithRetry(url, body, numRetries = 0) {
* Evaluate a MATLAB function with the given arguments and number of return
* values.
*
* @param{String} fn Name of the function to call
* @param{Array<Any>=} args arguments to pass to the function
* @param{BoundedInteger<1>=} numReturnValues Number of return values expected.
* For a list of all MATLAB functions, see the `Reference Manual <https://www.mathworks.com/help/matlab/referencelist.html?type=function>`__.
*
* @param {String} fn Name of the function to call
* @param {Array<Any>=} args arguments to pass to the function
* @param {BoundedInteger<1>=} numReturnValues Number of return values expected.
*/
MATLAB.function = async function (fn, args = [], numReturnValues = 1) {
const body = [{
Expand All @@ -75,7 +64,7 @@ MATLAB.function = async function (fn, args = [], numReturnValues = 1) {
JSON.stringify(resp.data)
}`,
);
warmer.keepWarm();

const results = resp.data.FEvalResponse;
// TODO: add batching queue
return this._parseResult(results[0]);
Expand All @@ -97,8 +86,7 @@ MATLAB._parseArgument = function (arg) {
arg = [arg];
}

const shape = MATLAB._shape(arg);
const flatValues = MATLAB._flatten(arg);
const [flatValues, shape] = MATLAB._flatten(arg);
const mwtype = MATLAB._getMwType(flatValues);
const mwdata = flatValues
.map((v) => {
Expand All @@ -118,7 +106,7 @@ MATLAB._parseArgument = function (arg) {

return {
mwdata,
mwsize: shape,
mwsize: shape.length >= 2 ? shape : [1, ...shape],
mwtype,
};
};
Expand Down Expand Up @@ -149,28 +137,29 @@ MATLAB._parseResult = (result) => {
};

MATLAB._parseResultData = (result) => {
// reshape the data
let data = result.mwdata;
let size = result.mwsize;
if (!Array.isArray(data)) {
data = [data];
}
return MATLAB._squeeze(
MATLAB._reshape(data, result.mwsize),
);
};

MATLAB._take = function* (iter, num) {
let chunk = [];
for (const v of iter) {
chunk.push(v);
if (chunk.length === num) {
yield chunk;
chunk = [];
if (result.mwtype === "char") {
if (
!Array.isArray(result.mwdata) || result.mwdata.length !== 1 ||
typeof (result.mwdata[0]) !== "string"
) {
throw Error("error parsing character string result");
}
function rejoin(x) {
if (x.length !== 0 && !Array.isArray(x[0])) {
return x.join("");
}
return x.map((y) => rejoin(y));
}
return MATLAB._squeeze(rejoin(MATLAB._unflatten([...data[0]], size)));
}
if (chunk.length) {
return chunk;
}

return MATLAB._squeeze(MATLAB._unflatten(data, size));
};

MATLAB._squeeze = (data) => {
Expand All @@ -180,38 +169,103 @@ MATLAB._squeeze = (data) => {
return data;
};

MATLAB._reshape = (data, shape) => {
return [
...shape.reverse().reduce(
(iterable, num) => MATLAB._take(iterable, num),
data,
),
].pop();
MATLAB._product = (vals) => {
let res = 1;
for (const v of vals) {
res *= v;
}
return res;
};

MATLAB._colcat = (cols) => {
if (cols.length === 0) {
return [];
}
if (!Array.isArray(cols[0])) {
return cols.reduce((acc, v) => acc.concat(v), []);
}

const rows = cols[0].length;
const res = [];
for (let i = 0; i < rows; ++i) {
res.push(MATLAB._colcat(cols.map((row) => row[i])));
}
return res;
};

MATLAB._unflatten = (data, shape) => {
if (!Array.isArray(data)) {
throw Error("internal usage error");
}

if (shape.length <= 1) {
return data;
}

const colCount = shape[shape.length - 1];
const colShape = shape.slice(0, shape.length - 1);
const colSize = MATLAB._product(colShape);

const cols = [];
for (let i = 0; i < colCount; ++i) {
cols.push(
MATLAB._unflatten(data.slice(i * colSize, (i + 1) * colSize), colShape),
);
}
return MATLAB._colcat(cols);
};

MATLAB._deepEq = (a, b) => {
if (Array.isArray(a) && Array.isArray(b)) {
return a.length === b.length && a.every((x, i) => MATLAB._deepEq(x, b[i]));
}
return a === b;
};

MATLAB._shape = (data) => {
const shape = [];
let item = data;
while (Array.isArray(item)) {
shape.push(item.length);
item = item[0];
if (!Array.isArray(data)) {
throw Error("internal usage error");
}

while (shape.length < 2) {
shape.unshift(1);
if (data.length === 0 || !Array.isArray(data[0])) {
if (data.some((x) => Array.isArray(x))) {
throw Error("input must be rectangular");
}
return [data.length];
}
if (data.some((x) => !Array.isArray(x))) {
throw Error("input must be rectangular");
}

return shape;
const shapes = data.map((x) => MATLAB._shape(x));
if (shapes.some((x) => !MATLAB._deepEq(x, shapes[0]))) {
throw Error("input must be rectangular");
}
return [data.length, ...shapes[0]];
};

// returns [flattened, shape] so that shape can be reused
MATLAB._flatten = (data) => {
return data.flatMap((item) => {
if (Array.isArray(item)) {
return MATLAB._flatten(item);
const shape = MATLAB._shape(data);
if (shape.some((x) => x === 0)) return [[], shape];

const shapeCumProd = [1, ...shape];
for (let i = 1; i < shapeCumProd.length; ++i) {
shapeCumProd[i] *= shapeCumProd[i - 1];
}

const res = new Array(shapeCumProd[shapeCumProd.length - 1]);
function visit(x, pos, depth) {
if (depth === shape.length) {
res[pos] = x;
} else {
return item;
for (let i = 0; i < x.length; ++i) {
visit(x[i], pos + i * shapeCumProd[depth], depth + 1);
}
}
});
}
visit(data, 0, 0);
return [res, shape];
};

MATLAB.isSupported = () => {
Expand Down
Loading

0 comments on commit 2e4fb2b

Please sign in to comment.