Skip to content

Commit

Permalink
Add PyTorch test files (#543)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Jul 28, 2024
1 parent e74e942 commit ffa1a70
Show file tree
Hide file tree
Showing 3 changed files with 205 additions and 86 deletions.
147 changes: 73 additions & 74 deletions source/pytorch.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import * as flatbuffers from './flatbuffers.js';
import * as python from './python.js';

const pytorch = {};
const numpy = {};

pytorch.ModelFactory = class {

Expand Down Expand Up @@ -299,9 +300,10 @@ pytorch.Node = class {
if (name instanceof pytorch.nnapi.Graph) {
return name;
}
const value = metadata.type(name);
const key = name.startsWith('__torch__.') ? name.substring(10) : name;
const value = metadata.type(key);
const type = value ? { ...value } : { name };
type.identifier = type.name;
type.identifier = name;
type.name = type.name.indexOf('::') === -1 ? type.name : type.name.split('::').pop().split('.')[0];
return type;
};
Expand Down Expand Up @@ -341,6 +343,11 @@ pytorch.Node = class {
}
return false;
};
const isArray = (obj) => {
return obj && obj.__class__ &&
((obj.__class__.__module__ === 'numpy' && obj.__class__.__name__ === 'ndarray') ||
(obj.__class__.__module__ === 'numpy' && obj.__class__.__name__ === 'matrix'));
};
if (!item.module && !item.node) {
this.type = type(metadata, item.type);
this.inputs = item.inputs || [];
Expand All @@ -354,8 +361,14 @@ pytorch.Node = class {
const entries = [];
const attributes = new Map();
stack = stack || new Set();
if (obj) {
for (const [name, value] of Object.entries(obj)) {
if (obj && pytorch.Utility.isInstance(obj, 'fastai.data.core.DataLoaders')) {
// continue
} else if (obj && item.type === 'builtins.bytearray') {
const argument = new pytorch.Argument('value', Array.from(obj), 'byte[]');
this.inputs.push(argument);
} else if (obj) {
const list = obj instanceof Map ? Array.from(obj) : Object.entries(obj);
for (const [name, value] of list) {
if (name === '__class__' || name === '__hide__') {
continue;
} else if (name === '_parameters' && value instanceof Map) {
Expand Down Expand Up @@ -411,6 +424,10 @@ pytorch.Node = class {
const tensors = value.map((value) => new pytorch.Tensor('', value));
const argument = new pytorch.Argument(name, tensors, 'tensor[]');
this.inputs.push(argument);
} else if (isArray(value)) {
const tensor = new numpy.Tensor(value);
const argument = new pytorch.Argument(name, tensor, 'tensor');
this.inputs.push(argument);
} else if (Array.isArray(value) && value.every((value) => typeof value === 'string')) {
const argument = new pytorch.Argument(name, value, 'string[]');
this.inputs.push(argument);
Expand Down Expand Up @@ -793,7 +810,7 @@ pytorch.Container.Tar = class extends pytorch.Container {
const torch = execution.__import__('torch');
const obj = torch.load(this.entries);
delete this.entries;
this.modules = pytorch.Utility.findWeights(obj);
this.modules = pytorch.Utility.find(obj);
if (!this.modules) {
throw new pytorch.Error('File does not contain root module or state dictionary.');
}
Expand Down Expand Up @@ -881,7 +898,7 @@ pytorch.Container.data_pkl = class extends pytorch.Container {
switch (this._type) {
case 'module': {
if (this._data) {
this.modules = pytorch.Utility.findModule(this._data);
this.modules = new Map([['', this._data]]);
delete this._data;
}
if (!this.modules) {
Expand All @@ -893,7 +910,7 @@ pytorch.Container.data_pkl = class extends pytorch.Container {
case 'tensor[]':
case 'tensor<>': {
if (this._data) {
this.modules = pytorch.Utility.findWeights(this._data);
this.modules = pytorch.Utility.find(this._data);
delete this._data;
}
if (!this.modules) {
Expand Down Expand Up @@ -1043,16 +1060,15 @@ pytorch.Container.Zip = class extends pytorch.Container {
}
const torch = execution.__import__('torch');
const reader = new torch.PyTorchFileReader(this._entries);
const torchscript = reader.has_record('constants.pkl');
const name = torchscript ? 'TorchScript' : 'PyTorch';
let torchscript = reader.has_record('constants.pkl');
const version = reader.version();
this.format = pytorch.Utility.format(name, version);
if (torchscript) {
const module = torch.jit.load(reader);
execution.trace = true;
if (module.data && module.data.forward) {
this.modules = new Map([['', module]]);
} else {
torchscript = false;
this.modules = pytorch.Utility.find(module.data);
}
} else {
Expand All @@ -1061,6 +1077,8 @@ pytorch.Container.Zip = class extends pytorch.Container {
const module = torch.load(entries);
this.modules = pytorch.Utility.find(module);
}
const name = torchscript ? 'TorchScript' : 'PyTorch';
this.format = pytorch.Utility.format(name, version);
delete this._model;
delete this._entries;
}
Expand Down Expand Up @@ -1180,7 +1198,7 @@ pytorch.Container.Index = class extends pytorch.Container {
}
}
}
this.modules = pytorch.Utility.findWeights(entries);
this.modules = pytorch.Utility.find(entries);
delete this.context;
delete this._entries;
}
Expand Down Expand Up @@ -3472,17 +3490,19 @@ pytorch.Utility = class {
}

static isSubclass(value, name) {
if (value.__module__ && value.__name__) {
if (name === `${value.__module__}.${value.__name__}`) {
return true;
}
}
if (value.__bases__) {
for (const base of value.__bases__) {
if (pytorch.Utility.isSubclass(base, name)) {
if (value) {
if (value.__module__ && value.__name__) {
if (name === `${value.__module__}.${value.__name__}`) {
return true;
}
}
if (value.__bases__) {
for (const base of value.__bases__) {
if (pytorch.Utility.isSubclass(base, name)) {
return true;
}
}
}
}
return false;
}
Expand Down Expand Up @@ -3525,56 +3545,7 @@ pytorch.Utility = class {
return `${name} ${versions.get(value)}`;
}

static find(data) {
const root = pytorch.Utility.findModule(data);
if (root) {
return root;
}
const weights = pytorch.Utility.findWeights(data);
if (weights) {
return weights;
}
if (data && Array.isArray(data) && data === Object(data) && Object.entries(data).length === 0) {
return [];
}
throw new pytorch.Error('File does not contain root module or state dictionary.');
}

static findModule(root) {
if (root) {
const keys = ['', 'model', 'net'];
for (const key of keys) {
const obj = key === '' ? root : root[key];
if (obj) {
if (obj instanceof Map && obj.has('engine')) {
// https://github.com/NVIDIA-AI-IOT/torch2trt/blob/master/torch2trt/torch2trt.py
const data = obj.get('engine');
const signatures = [
[0x70, 0x74, 0x72, 0x74], // ptrt
[0x66, 0x74, 0x72, 0x74] // ftrt
];
for (const signature of signatures) {
if (data instanceof Uint8Array && data.length > signature.length && signature.every((value, index) => value === data[index])) {
// const buffer = data.slice(0, 24);
// const content = Array.from(buffer).map((c) => (c < 16 ? '0' : '') + c.toString(16)).join('');
throw new pytorch.Error('Invalid file content. File contains undocumented PyTorch TensorRT engine data.');
}
}
}
if (obj._modules) {
return new Map([['', obj]]);
}
const entries = Object.entries(obj).filter(([name, obj]) => name && obj && obj._modules);
if (entries.length > 1) {
return new Map(entries);
}
}
}
}
return null;
}

static findWeights(obj) {
static find(obj) {
if (obj) {
if (pytorch.Utility.isTensor(obj)) {
const module = {};
Expand Down Expand Up @@ -3610,7 +3581,7 @@ pytorch.Utility = class {
}
}
}
return null;
return new Map([['', obj]]);
}

static _convertObjectList(obj) {
Expand Down Expand Up @@ -3658,6 +3629,24 @@ pytorch.Utility = class {
}
return count > 0;
};
const isLayer = (obj) => {
if (obj instanceof Map === false) {
obj = new Map(Object.entries(obj));
}
for (const [key, value] of Array.from(obj)) {
if (pytorch.Utility.isTensor(value)) {
continue;
}
if (key === '_metadata') {
continue;
}
if (typeof value === 'string' || typeof value === 'number' || typeof value === 'boolean') {
continue;
}
return false;
}
return true;
};
const flatten = (obj) => {
if (!obj || Array.isArray(obj) || ArrayBuffer.isView(obj)) {
return null;
Expand Down Expand Up @@ -3698,6 +3687,8 @@ pytorch.Utility = class {
}
} else if (obj instanceof Map && validate(obj)) {
map.set('', flatten(obj));
} else if ((Object(obj) === obj && Object.entries(obj).every(([, value]) => value && isLayer(value)))) {
return new Map([['', { _modules: new Map(Object.entries(obj)) }]]);
} else if (Object(obj) === obj && Object.entries(obj).every(([, value]) => validate(value))) {
for (const [name, value] of Object.entries(obj)) {
if (Object(value) === value) {
Expand All @@ -3707,7 +3698,7 @@ pytorch.Utility = class {
}
}
} else if (Object(obj) === obj && Object.entries(obj).some(([, value]) => pytorch.Utility.isTensor(value))) {
map.set('', new Map(Object.entries(obj).map(([key, value]) => [key, value])));
map.set('', new Map(Object.entries(obj)));
} else {
const value = flatten(obj);
if (value) {
Expand Down Expand Up @@ -3754,9 +3745,7 @@ pytorch.Utility = class {
} else if (value && Array.isArray(value) && value.every((item) => pytorch.Utility.isTensor(item))) {
layer._parameters = layer._parameters || new Map();
layer._parameters.set(parameter, value);
} else if (value && Array.isArray(value) && value.every((item) => typeof item === 'string' || typeof item === 'number')) {
layer[parameter] = value;
} else if (typeof value === 'string' || typeof value === 'number' || typeof value === 'boolean') {
} else {
layer[parameter] = value;
}
}
Expand Down Expand Up @@ -4226,6 +4215,16 @@ pytorch.Metadata = class {
}
};

numpy.Tensor = class {

constructor(array) {
this.type = new pytorch.TensorType(array.dtype.__name__, new pytorch.TensorShape(array.shape));
this.stride = array.strides.map((stride) => stride / array.itemsize);
this.values = this.type.dataType === 'string' || this.type.dataType === 'object' || this.type.dataType === 'void' ? array.flatten().tolist() : array.tobytes();
this.encoding = this.type.dataType === 'string' || this.type.dataType === 'object' ? '|' : array.dtype.byteorder;
}
};

pytorch.Error = class extends Error {

constructor(message) {
Expand Down
12 changes: 10 additions & 2 deletions source/view.js
Original file line number Diff line number Diff line change
Expand Up @@ -809,7 +809,15 @@ view.View = class {
const layout = {};
layout.nodesep = 20;
layout.ranksep = 20;
const rotate = graph.nodes.every((node) => node.inputs.filter((input) => (input.type && !input.type.endsWith('*')) || input.value.every((value) => !value.initializer)).length === 0 && node.outputs.length === 0);
const rotate = graph.nodes.every((node) => {
if (node.inputs.filter((input) => !input.type || input.type.endsWith('*')).some((input) => input.value.every((value) => !value.initializer))) {
return false;
}
if (node.outputs.length > 0) {
return false;
}
return true;
});
const horizontal = rotate ? options.direction === 'vertical' : options.direction !== 'vertical';
if (horizontal) {
layout.rankdir = 'LR';
Expand Down Expand Up @@ -2039,7 +2047,7 @@ view.Node = class extends grapher.Node {
const type = argument.type;
if (type === 'graph' || type === 'object' || type === 'object[]' || type === 'function' || type === 'function[]') {
objects.push(argument);
} else if (options.weights && argument.visible !== false && Array.isArray(argument.value) && argument.value.length === 1 && argument.value[0].initializer) {
} else if (options.weights && argument.visible !== false && argument.type !== 'attribute' && Array.isArray(argument.value) && argument.value.length === 1 && argument.value[0].initializer) {
const item = this.context.createArgument(argument);
list().add(item);
} else if (options.weights && (argument.visible === false || Array.isArray(argument.value) && argument.value.length > 1) && (!argument.type || argument.type.endsWith('*')) && argument.value.some((value) => value.initializer)) {
Expand Down
Loading

0 comments on commit ffa1a70

Please sign in to comment.