Skip to content

Commit

Permalink
Update tflite.js (#783) (#1238)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Mar 3, 2024
1 parent 4249ecc commit cfea826
Show file tree
Hide file tree
Showing 4 changed files with 269 additions and 313 deletions.
218 changes: 84 additions & 134 deletions source/circle.js
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

import * as flatbuffers from '../source/flatbuffers.js';
import * as flexbuffers from '../source/flexbuffers.js';
import * as zip from '../source/zip.js';
import * as flatbuffers from './flatbuffers.js';
import * as flexbuffers from './flexbuffers.js';
import * as zip from './zip.js';

const circle = {};

Expand Down Expand Up @@ -71,11 +71,11 @@ circle.ModelFactory = class {
circle.Model = class {

constructor(metadata, model) {
this._graphs = [];
this._format = 'Circle';
this._format = `${this._format} v${model.version}`;
this._description = model.description || '';
this._metadata = [];
this.graphs = [];
this.format = 'Circle';
this.format = `${this.format} v${model.version}`;
this.description = model.description || '';
this.metadata = [];
const builtinOperators = new Map();
const upperCase = new Set(['2D', 'LSH', 'SVDF', 'RNN', 'L2', 'LSTM']);
for (const key of Object.keys(circle.schema.BuiltinOperator)) {
Expand All @@ -85,8 +85,8 @@ circle.Model = class {
builtinOperators.set(index, name);
}
const operators = model.operator_codes.map((operator) => {
const code = Math.max(operator.deprecated_builtin_code, operator.builtin_code || 0);
const value = {};
const code = operator.builtin_code || 0;
if (code === circle.schema.BuiltinOperator.CUSTOM) {
value.name = operator.custom_code ? operator.custom_code : 'Custom';
value.version = operator.version;
Expand All @@ -105,7 +105,7 @@ circle.Model = class {
switch (metadata.name) {
case 'min_runtime_version': {
const data = buffer.data || new Uint8Array(0);
this._runtime = new TextDecoder().decode(data);
this.runtime = new TextDecoder().decode(data);
break;
}
case 'TFLITE_METADATA': {
Expand All @@ -114,19 +114,19 @@ circle.Model = class {
if (circle.schema.ModelMetadata.identifier(reader)) {
modelMetadata = circle.schema.ModelMetadata.create(reader);
if (modelMetadata.name) {
this._name = modelMetadata.name;
this.name = modelMetadata.name;
}
if (modelMetadata.version) {
this._version = modelMetadata.version;
this.version = modelMetadata.version;
}
if (modelMetadata.description) {
this._description = this._description ? [this._description, modelMetadata.description].join(' ') : modelMetadata.description;
this.description = this._description ? [this._description, modelMetadata.description].join(' ') : modelMetadata.description;
}
if (modelMetadata.author) {
this._metadata.push(new circle.Argument('author', modelMetadata.author));
this.metadata.push(new circle.Argument('author', modelMetadata.author));
}
if (modelMetadata.license) {
this._metadata.set(new circle.Argument('license', modelMetadata.license));
this.metadata.push(new circle.Argument('license', modelMetadata.license));
}
}
break;
Expand All @@ -143,46 +143,17 @@ circle.Model = class {
const subgraph = subgraphs[i];
const name = subgraphs.length > 1 ? i.toString() : '';
const subgraphMetadata = subgraphsMetadata && i < subgraphsMetadata.length ? subgraphsMetadata[i] : null;
this._graphs.push(new circle.Graph(metadata, subgraph, subgraphMetadata, name, operators, model));
const signatures = model.signature_defs.filter((signature) => signature.subgraph_index === i);
const graph = new circle.Graph(metadata, subgraph, signatures, subgraphMetadata, name, operators, model);
this.graphs.push(graph);
}
}

get format() {
return this._format;
}

get runtime() {
return this._runtime;
}

get name() {
return this._name;
}

get version() {
return this._version;
}

get description() {
return this._description;
}

get metadata() {
return this._metadata;
}

get graphs() {
return this._graphs;
}
};

circle.Graph = class {

constructor(metadata, subgraph, subgraphMetadata, name, operators, model) {
this._nodes = [];
this._inputs = [];
this._outputs = [];
this._name = subgraph.name || name;
constructor(metadata, subgraph, signatures, subgraphMetadata, name, operators, model) {
this.name = subgraph.name || name;
const tensors = new Map();
tensors.map = (index, metadata) => {
if (index === -1) {
Expand Down Expand Up @@ -227,47 +198,49 @@ circle.Graph = class {
}
return tensors.get(index);
};
const inputs = subgraph.inputs;
for (let i = 0; i < inputs.length; i++) {
const input = inputs[i];
const metadata = subgraphMetadata && i < subgraphMetadata.input_tensor_metadata.length ? subgraphMetadata.input_tensor_metadata[i] : null;
const value = tensors.map(input, metadata);
const inputs = Array.from(subgraph.inputs).map((tensor_index, index) => {
const metadata = subgraphMetadata && index < subgraphMetadata.input_tensor_metadata.length ? subgraphMetadata.input_tensor_metadata[index] : null;
const value = tensors.map(tensor_index, metadata);
const name = value ? value.name : '?';
const argument = new circle.Argument(name, value ? [value] : []);
this._inputs.push(argument);
}
const outputs = subgraph.outputs;
for (let i = 0; i < outputs.length; i++) {
const output = outputs[i];
const metadata = subgraphMetadata && i < subgraphMetadata.output_tensor_metadata.length ? subgraphMetadata.output_tensor_metadata[i] : null;
const value = tensors.map(output, metadata);
return { name: name, tensor_index: tensor_index };
});
const outputs = Array.from(subgraph.outputs).map((tensor_index, index) => {
const metadata = subgraphMetadata && index < subgraphMetadata.output_tensor_metadata.length ? subgraphMetadata.output_tensor_metadata[index] : null;
const value = tensors.map(tensor_index, metadata);
const name = value ? value.name : '?';
const argument = new circle.Argument(name, value ? [value] : []);
this._outputs.push(argument);
}
for (let i = 0; i < subgraph.operators.length; i++) {
const operator = subgraph.operators[i];
const index = operator.opcode_index;
const opcode = index < operators.length ? operators[index] : { name: `(${index})` };
const node = new circle.Node(metadata, operator, opcode, i.toString(), tensors);
this._nodes.push(node);
}
}

get name() {
return this._name;
return { name: name, tensor_index: tensor_index };
});
const signature = {
signature_key: '',
inputs: inputs,
outputs: outputs
};
signatures = signatures.length === 0 ? [signature] : signatures;
this.signatures = signatures.map((signature) => {
return new circle.Signature(signature, tensors);
});
this.nodes = Array.from(subgraph.operators).map((operator, index) => {
const opcode_index = operator.opcode_index;
const opcode = opcode_index < operators.length ? operators[opcode_index] : { name: `(${opcode_index})` };
return new circle.Node(metadata, operator, opcode, index.toString(), tensors);
});
}
};

get inputs() {
return this._inputs;
}
circle.Signature = class {

get outputs() {
return this._outputs;
}

get nodes() {
return this._nodes;
constructor(signature, tensors) {
this.name = signature.signature_key;
this.inputs = signature.inputs.map((input) => {
const value = tensors.map(input.tensor_index);
const values = value ? [value] : [];
return new circle.Argument(input.name, values);
});
this.outputs = signature.outputs.map((output) => {
const value = tensors.map(output.tensor_index);
const values = value ? [value] : [];
return new circle.Argument(output.name, values);
});
}
};

Expand All @@ -284,18 +257,18 @@ circle.Node = class {
let outputs = [];
inputs = Array.from(node.inputs || new Int32Array(0));
outputs = Array.from(node.outputs || new Int32Array(0));
for (let i = 0; i < inputs.length; i++) {
for (let i = 0; i < inputs.length;) {
let count = 1;
let name = null;
let visible = true;
const values = [];
if (this._type && this._type.inputs && i < this._type.inputs.length) {
const input = this._type.inputs[i];
name = input.name;
if (input.option === 'variadic') {
if (input.list) {
count = inputs.length - i;
}
if (input && input.visible === false) {
if (input.visible === false) {
visible = false;
}
}
Expand Down Expand Up @@ -351,22 +324,25 @@ circle.Node = class {
}
if (!decoded) {
const schema = metadata.attribute(type.name, 'custom');
this._attributes.push(new circle.Attribute(schema, 'custom', Array.from(node.custom_options)));
const attribute = new circle.Attribute(schema, 'custom', Array.from(node.custom_options));
this._attributes.push(attribute);
}
}
const options = node.builtin_options;
if (options) {
for (const [name, value] of Object.entries(options)) {
if (name === 'fused_activation_function' && value !== 0) {
const activationFunctionMap = { 1: 'Relu', 2: 'ReluN1To1', 3: 'Relu6', 4: 'Tanh', 5: 'SignBit' };
if (!activationFunctionMap[value]) {
throw new circle.Error(`Unsupported activation funtion index '${JSON.stringify(value)}'.`);
if (name === 'fused_activation_function' && value) {
if (value < 1 || value > 5) {
throw new circle.Error(`Unsupported activation funtion index '${value}'.`);
}
const type = activationFunctionMap[value];
this._chain = [new circle.Node(metadata, null, { name: type }, null, [])];
const list = ['Unknown', 'Relu', 'ReluN1To1', 'Relu6', 'Tanh', 'SignBit'];
const type = list[value];
const node = new circle.Node(metadata, null, { name: type }, null, []);
this._chain = [node];
}
const schema = metadata.attribute(type.name, name);
this._attributes.push(new circle.Attribute(schema, name, value));
const attribute = new circle.Attribute(schema, name, value);
this._attributes.push(attribute);
}
}
}
Expand Down Expand Up @@ -448,21 +424,9 @@ circle.Attribute = class {
circle.Argument = class {

constructor(name, value, visible) {
this._name = name;
this._value = value;
this._visible = visible === false ? false : true;
}

get name() {
return this._name;
}

get visible() {
return this._visible;
}

get value() {
return this._value;
this.name = name;
this.value = value;
this.visible = visible === false ? false : true;
}
};

Expand Down Expand Up @@ -492,38 +456,22 @@ circle.Value = class {
circle.Tensor = class {

constructor(index, tensor, buffer, is_variable) {
this._location = index.toString();
this._type = new circle.TensorType(tensor);
this._is_variable = is_variable;
this._name = tensor.name;
this.location = index.toString();
this.name = tensor.name;
this.type = new circle.TensorType(tensor);
this.category = is_variable ? 'Variable' : '';
this._data = buffer.data.slice(0);
}

get category() {
return this._is_variable ? 'Variable' : '';
}

get name() {
return this._name;
}

get location() {
return this._location;
}

get type() {
return this._type;
}

get encoding() {
switch (this._type.dataType) {
switch (this.type.dataType) {
case 'string': return '|';
default: return '<';
}
}

get values() {
switch (this._type.dataType) {
switch (this.type.dataType) {
case 'string': {
let offset = 0;
const data = new DataView(this._data.buffer, this._data.byteOffset, this._data.byteLength);
Expand All @@ -543,7 +491,9 @@ circle.Tensor = class {
}
return stringTable;
}
default: return this._data;
default: {
return this._data;
}
}
}
};
Expand Down
Loading

0 comments on commit cfea826

Please sign in to comment.