From ffa1a70f563ae813a3a7a9ec94a648ec2958e4f8 Mon Sep 17 00:00:00 2001 From: Lutz Roeder Date: Sat, 27 Jul 2024 18:16:15 -0700 Subject: [PATCH] Add PyTorch test files (#543) --- source/pytorch.js | 147 +++++++++++++++++++++++----------------------- source/view.js | 12 +++- test/models.json | 132 +++++++++++++++++++++++++++++++++++++---- 3 files changed, 205 insertions(+), 86 deletions(-) diff --git a/source/pytorch.js b/source/pytorch.js index f26ee08172..749c7d7d7e 100644 --- a/source/pytorch.js +++ b/source/pytorch.js @@ -6,6 +6,7 @@ import * as flatbuffers from './flatbuffers.js'; import * as python from './python.js'; const pytorch = {}; +const numpy = {}; pytorch.ModelFactory = class { @@ -299,9 +300,10 @@ pytorch.Node = class { if (name instanceof pytorch.nnapi.Graph) { return name; } - const value = metadata.type(name); + const key = name.startsWith('__torch__.') ? name.substring(10) : name; + const value = metadata.type(key); const type = value ? { ...value } : { name }; - type.identifier = type.name; + type.identifier = name; type.name = type.name.indexOf('::') === -1 ? type.name : type.name.split('::').pop().split('.')[0]; return type; }; @@ -341,6 +343,11 @@ pytorch.Node = class { } return false; }; + const isArray = (obj) => { + return obj && obj.__class__ && + ((obj.__class__.__module__ === 'numpy' && obj.__class__.__name__ === 'ndarray') || + (obj.__class__.__module__ === 'numpy' && obj.__class__.__name__ === 'matrix')); + }; if (!item.module && !item.node) { this.type = type(metadata, item.type); this.inputs = item.inputs || []; @@ -354,8 +361,14 @@ pytorch.Node = class { const entries = []; const attributes = new Map(); stack = stack || new Set(); - if (obj) { - for (const [name, value] of Object.entries(obj)) { + if (obj && pytorch.Utility.isInstance(obj, 'fastai.data.core.DataLoaders')) { + // continue + } else if (obj && item.type === 'builtins.bytearray') { + const argument = new pytorch.Argument('value', Array.from(obj), 'byte[]'); + this.inputs.push(argument); + } else if (obj) { + const list = obj instanceof Map ? Array.from(obj) : Object.entries(obj); + for (const [name, value] of list) { if (name === '__class__' || name === '__hide__') { continue; } else if (name === '_parameters' && value instanceof Map) { @@ -411,6 +424,10 @@ pytorch.Node = class { const tensors = value.map((value) => new pytorch.Tensor('', value)); const argument = new pytorch.Argument(name, tensors, 'tensor[]'); this.inputs.push(argument); + } else if (isArray(value)) { + const tensor = new numpy.Tensor(value); + const argument = new pytorch.Argument(name, tensor, 'tensor'); + this.inputs.push(argument); } else if (Array.isArray(value) && value.every((value) => typeof value === 'string')) { const argument = new pytorch.Argument(name, value, 'string[]'); this.inputs.push(argument); @@ -793,7 +810,7 @@ pytorch.Container.Tar = class extends pytorch.Container { const torch = execution.__import__('torch'); const obj = torch.load(this.entries); delete this.entries; - this.modules = pytorch.Utility.findWeights(obj); + this.modules = pytorch.Utility.find(obj); if (!this.modules) { throw new pytorch.Error('File does not contain root module or state dictionary.'); } @@ -881,7 +898,7 @@ pytorch.Container.data_pkl = class extends pytorch.Container { switch (this._type) { case 'module': { if (this._data) { - this.modules = pytorch.Utility.findModule(this._data); + this.modules = new Map([['', this._data]]); delete this._data; } if (!this.modules) { @@ -893,7 +910,7 @@ pytorch.Container.data_pkl = class extends pytorch.Container { case 'tensor[]': case 'tensor<>': { if (this._data) { - this.modules = pytorch.Utility.findWeights(this._data); + this.modules = pytorch.Utility.find(this._data); delete this._data; } if (!this.modules) { @@ -1043,16 +1060,15 @@ pytorch.Container.Zip = class extends pytorch.Container { } const torch = execution.__import__('torch'); const reader = new torch.PyTorchFileReader(this._entries); - const torchscript = reader.has_record('constants.pkl'); - const name = torchscript ? 'TorchScript' : 'PyTorch'; + let torchscript = reader.has_record('constants.pkl'); const version = reader.version(); - this.format = pytorch.Utility.format(name, version); if (torchscript) { const module = torch.jit.load(reader); execution.trace = true; if (module.data && module.data.forward) { this.modules = new Map([['', module]]); } else { + torchscript = false; this.modules = pytorch.Utility.find(module.data); } } else { @@ -1061,6 +1077,8 @@ pytorch.Container.Zip = class extends pytorch.Container { const module = torch.load(entries); this.modules = pytorch.Utility.find(module); } + const name = torchscript ? 'TorchScript' : 'PyTorch'; + this.format = pytorch.Utility.format(name, version); delete this._model; delete this._entries; } @@ -1180,7 +1198,7 @@ pytorch.Container.Index = class extends pytorch.Container { } } } - this.modules = pytorch.Utility.findWeights(entries); + this.modules = pytorch.Utility.find(entries); delete this.context; delete this._entries; } @@ -3472,17 +3490,19 @@ pytorch.Utility = class { } static isSubclass(value, name) { - if (value.__module__ && value.__name__) { - if (name === `${value.__module__}.${value.__name__}`) { - return true; - } - } - if (value.__bases__) { - for (const base of value.__bases__) { - if (pytorch.Utility.isSubclass(base, name)) { + if (value) { + if (value.__module__ && value.__name__) { + if (name === `${value.__module__}.${value.__name__}`) { return true; } } + if (value.__bases__) { + for (const base of value.__bases__) { + if (pytorch.Utility.isSubclass(base, name)) { + return true; + } + } + } } return false; } @@ -3525,56 +3545,7 @@ pytorch.Utility = class { return `${name} ${versions.get(value)}`; } - static find(data) { - const root = pytorch.Utility.findModule(data); - if (root) { - return root; - } - const weights = pytorch.Utility.findWeights(data); - if (weights) { - return weights; - } - if (data && Array.isArray(data) && data === Object(data) && Object.entries(data).length === 0) { - return []; - } - throw new pytorch.Error('File does not contain root module or state dictionary.'); - } - - static findModule(root) { - if (root) { - const keys = ['', 'model', 'net']; - for (const key of keys) { - const obj = key === '' ? root : root[key]; - if (obj) { - if (obj instanceof Map && obj.has('engine')) { - // https://github.com/NVIDIA-AI-IOT/torch2trt/blob/master/torch2trt/torch2trt.py - const data = obj.get('engine'); - const signatures = [ - [0x70, 0x74, 0x72, 0x74], // ptrt - [0x66, 0x74, 0x72, 0x74] // ftrt - ]; - for (const signature of signatures) { - if (data instanceof Uint8Array && data.length > signature.length && signature.every((value, index) => value === data[index])) { - // const buffer = data.slice(0, 24); - // const content = Array.from(buffer).map((c) => (c < 16 ? '0' : '') + c.toString(16)).join(''); - throw new pytorch.Error('Invalid file content. File contains undocumented PyTorch TensorRT engine data.'); - } - } - } - if (obj._modules) { - return new Map([['', obj]]); - } - const entries = Object.entries(obj).filter(([name, obj]) => name && obj && obj._modules); - if (entries.length > 1) { - return new Map(entries); - } - } - } - } - return null; - } - - static findWeights(obj) { + static find(obj) { if (obj) { if (pytorch.Utility.isTensor(obj)) { const module = {}; @@ -3610,7 +3581,7 @@ pytorch.Utility = class { } } } - return null; + return new Map([['', obj]]); } static _convertObjectList(obj) { @@ -3658,6 +3629,24 @@ pytorch.Utility = class { } return count > 0; }; + const isLayer = (obj) => { + if (obj instanceof Map === false) { + obj = new Map(Object.entries(obj)); + } + for (const [key, value] of Array.from(obj)) { + if (pytorch.Utility.isTensor(value)) { + continue; + } + if (key === '_metadata') { + continue; + } + if (typeof value === 'string' || typeof value === 'number' || typeof value === 'boolean') { + continue; + } + return false; + } + return true; + }; const flatten = (obj) => { if (!obj || Array.isArray(obj) || ArrayBuffer.isView(obj)) { return null; @@ -3698,6 +3687,8 @@ pytorch.Utility = class { } } else if (obj instanceof Map && validate(obj)) { map.set('', flatten(obj)); + } else if ((Object(obj) === obj && Object.entries(obj).every(([, value]) => value && isLayer(value)))) { + return new Map([['', { _modules: new Map(Object.entries(obj)) }]]); } else if (Object(obj) === obj && Object.entries(obj).every(([, value]) => validate(value))) { for (const [name, value] of Object.entries(obj)) { if (Object(value) === value) { @@ -3707,7 +3698,7 @@ pytorch.Utility = class { } } } else if (Object(obj) === obj && Object.entries(obj).some(([, value]) => pytorch.Utility.isTensor(value))) { - map.set('', new Map(Object.entries(obj).map(([key, value]) => [key, value]))); + map.set('', new Map(Object.entries(obj))); } else { const value = flatten(obj); if (value) { @@ -3754,9 +3745,7 @@ pytorch.Utility = class { } else if (value && Array.isArray(value) && value.every((item) => pytorch.Utility.isTensor(item))) { layer._parameters = layer._parameters || new Map(); layer._parameters.set(parameter, value); - } else if (value && Array.isArray(value) && value.every((item) => typeof item === 'string' || typeof item === 'number')) { - layer[parameter] = value; - } else if (typeof value === 'string' || typeof value === 'number' || typeof value === 'boolean') { + } else { layer[parameter] = value; } } @@ -4226,6 +4215,16 @@ pytorch.Metadata = class { } }; +numpy.Tensor = class { + + constructor(array) { + this.type = new pytorch.TensorType(array.dtype.__name__, new pytorch.TensorShape(array.shape)); + this.stride = array.strides.map((stride) => stride / array.itemsize); + this.values = this.type.dataType === 'string' || this.type.dataType === 'object' || this.type.dataType === 'void' ? array.flatten().tolist() : array.tobytes(); + this.encoding = this.type.dataType === 'string' || this.type.dataType === 'object' ? '|' : array.dtype.byteorder; + } +}; + pytorch.Error = class extends Error { constructor(message) { diff --git a/source/view.js b/source/view.js index 6e4032a821..cef5195b1a 100644 --- a/source/view.js +++ b/source/view.js @@ -809,7 +809,15 @@ view.View = class { const layout = {}; layout.nodesep = 20; layout.ranksep = 20; - const rotate = graph.nodes.every((node) => node.inputs.filter((input) => (input.type && !input.type.endsWith('*')) || input.value.every((value) => !value.initializer)).length === 0 && node.outputs.length === 0); + const rotate = graph.nodes.every((node) => { + if (node.inputs.filter((input) => !input.type || input.type.endsWith('*')).some((input) => input.value.every((value) => !value.initializer))) { + return false; + } + if (node.outputs.length > 0) { + return false; + } + return true; + }); const horizontal = rotate ? options.direction === 'vertical' : options.direction !== 'vertical'; if (horizontal) { layout.rankdir = 'LR'; @@ -2039,7 +2047,7 @@ view.Node = class extends grapher.Node { const type = argument.type; if (type === 'graph' || type === 'object' || type === 'object[]' || type === 'function' || type === 'function[]') { objects.push(argument); - } else if (options.weights && argument.visible !== false && Array.isArray(argument.value) && argument.value.length === 1 && argument.value[0].initializer) { + } else if (options.weights && argument.visible !== false && argument.type !== 'attribute' && Array.isArray(argument.value) && argument.value.length === 1 && argument.value[0].initializer) { const item = this.context.createArgument(argument); list().add(item); } else if (options.weights && (argument.visible === false || Array.isArray(argument.value) && argument.value.length > 1) && (!argument.type || argument.type.endsWith('*')) && argument.value.some((value) => value.initializer)) { diff --git a/test/models.json b/test/models.json index 59ad71d474..1c92b03a2f 100644 --- a/test/models.json +++ b/test/models.json @@ -4887,12 +4887,18 @@ "format": "PyTorch v1.6", "link": "https://github.com/lutzroeder/netron/issues/133" }, + { + "type": "pytorch", + "target": "bad-base_libri.pt", + "source": "https://github.com/user-attachments/files/16401716/base_libri.pt.zip[base_libri.pt]", + "format": "PyTorch v1.6", + "link": "https://github.com/lutzroeder/netron/issues/543" + }, { "type": "pytorch", "target": "bad-hands-5.pt", "source": "https://github.com/lutzroeder/netron/files/14471657/bad-hands-5.pt.zip[bad-hands-5.pt]", "format": "PyTorch v1.6", - "error": "File does not contain root module or state dictionary.", "link": "https://github.com/lutzroeder/netron/issues/720" }, { @@ -4902,6 +4908,20 @@ "format": "TorchScript v1.6", "link": "https://github.com/lutzroeder/netron/issues/842" }, + { + "type": "pytorch", + "target": "best_mask.pth", + "source": "https://github.com/user-attachments/files/16401712/best_mask.pth.zip[best_mask.pth]", + "format": "PyTorch v1.6", + "link": "https://github.com/lutzroeder/netron/issues/543" + }, + { + "type": "pytorch", + "target": "best.pt", + "source": "https://github.com/user-attachments/files/16401713/best.pt.zip[best.pt]", + "format": "PyTorch v1.6", + "link": "https://github.com/lutzroeder/netron/issues/543" + }, { "type": "pytorch", "target": "blitz_cifar10_tutorial.pt", @@ -4937,6 +4957,13 @@ "format": "PyTorch v0.1.10", "link": "https://github.com/lutzroeder/netron/issues/472" }, + { + "type": "pytorch", + "target": "best.pt", + "source": "https://github.com/user-attachments/files/16401547/best.pt.zip[best.pt]", + "format": "PyTorch v1.6", + "link": "https://github.com/lutzroeder/netron/issues/543" + }, { "type": "pytorch", "target": "coco128-yolov8n-seg_output.torchscript.ptl", @@ -4944,6 +4971,13 @@ "format": "TorchScript v1.6", "link": "https://github.com/lutzroeder/netron/issues/1067" }, + { + "type": "pytorch", + "target": "checkpoint_best.pth", + "source": "https://github.com/user-attachments/files/16401715/checkpoint_best.pth.zip[checkpoint_best.pth]", + "format": "PyTorch v1.6", + "link": "https://github.com/lutzroeder/netron/issues/543" + }, { "type": "pytorch", "target": "ckpt.t7", @@ -4951,6 +4985,20 @@ "format": "PyTorch v0.1.10", "link": "https://github.com/babajide07/Redundant-Feature-Pruning-Pytorch-Implementation" }, + { + "type": "pytorch", + "target": "cpu_jit.pt", + "source": "https://github.com/user-attachments/files/16401711/cpu_jit.pt.zip[cpu_jit.pt]", + "format": "PyTorch v1.6", + "link": "https://github.com/lutzroeder/netron/issues/543" + }, + { + "type": "pytorch", + "target": "complex_tensor.pt", + "source": "https://github.com/lutzroeder/netron/files/9108149/complex_tensor.pt.zip[complex_tensor.pt]", + "format": "PyTorch v1.6", + "link": "https://github.com/lutzroeder/netron/issues/720" + }, { "type": "pytorch", "target": "cruise_cutin_vehicle_model.pt", @@ -4966,20 +5014,12 @@ "format": "TorchScript v1.0", "link": "https://github.com/ApolloAuto/apollo" }, - { - "type": "pytorch", - "target": "complex_tensor.pt", - "source": "https://github.com/lutzroeder/netron/files/9108149/complex_tensor.pt.zip[complex_tensor.pt]", - "format": "PyTorch v1.6", - "link": "https://github.com/lutzroeder/netron/issues/720" - }, { "type": "pytorch", "target": "cup_wild_vit_l_1img.ckpt", "source": "https://github.com/user-attachments/files/15752923/cup_wild_vit_l_1img.ckpt.zip[cup_wild_vit_l_1img.ckpt]", "format": "PyTorch v1.6", "tags": "skip-tensor-value", - "assert": "model.graphs[1].name == 'ema_model'", "link": "https://github.com/lutzroeder/netron/issues/720" }, { @@ -5062,6 +5102,7 @@ "target": "DRNL4x_dual_model.pth", "source": "https://github.com/lutzroeder/netron/files/5505677/DRNL4x_dual_model.pth.zip[DRNL4x_dual_model.pth]", "format": "PyTorch v0.1.10", + "assert": "model.graphs[0].nodes[0].inputs.length == 1", "link": "https://github.com/lutzroeder/netron/issues/543" }, { @@ -5248,8 +5289,16 @@ "target": "mask_r_cnn.pth", "source": "https://raw.githubusercontent.com/facebookresearch/kill-the-bits/master/src/models/compressed/mask_r_cnn.pth", "format": "PyTorch v0.1.10", + "assert": "model.graphs[0].nodes.length == 127", "link": "https://github.com/facebookresearch/kill-the-bits/tree/master/src/models/compressed" }, + { + "type": "pytorch", + "target": "mcunet-5fps.pkl", + "source": "https://github.com/user-attachments/files/16401553/mcunet-5fps.pkl.zip[mcunet-5fps.pkl]", + "format": "PyTorch v1.6", + "link": "https://github.com/lutzroeder/netron/issues/543" + }, { "type": "pytorch", "target": "mnist_bfloat16.pt", @@ -5343,6 +5392,20 @@ "format": "TorchScript v1.6", "link": "https://github.com/lutzroeder/netron/issues/647" }, + { + "type": "pytorch", + "target": "model_10_10_10.pth", + "source": "https://github.com/user-attachments/files/16401752/model_10_10_10.pth.zip[model_10_10_10.pth]", + "format": "PyTorch v1.6", + "link": "https://github.com/lutzroeder/netron/issues/543" + }, + { + "type": "pytorch", + "target": "model_final.ckpt", + "source": "https://github.com/user-attachments/files/16401708/model_final.ckpt.zip[model_final.ckpt]", + "format": "PyTorch v1.6", + "link": "https://github.com/lutzroeder/netron/issues/543" + }, { "type": "pytorch", "target": "model_fnet.pt", @@ -5350,6 +5413,13 @@ "format": "TorchScript v1.6", "link": "https://github.com/lutzroeder/netron/issues/882" }, + { + "type": "pytorch", + "target": "model_trt.pth", + "source": "https://github.com/user-attachments/files/16401618/model_trt.pth.zip[model_trt.pth]", + "format": "PyTorch v1.6", + "link": "https://github.com/lutzroeder/netron/issues/543" + }, { "type": "pytorch", "target": "model-reddit16-f140225004_2.pt1", @@ -5402,6 +5472,20 @@ "format": "PyTorch v0.1.10", "link": "https://github.com/deepware/dface/tree/master/dface" }, + { + "type": "pytorch", + "target": "model_scripted.pt", + "source": "https://github.com/user-attachments/files/16401554/model_scripted.pt.zip[model_scripted.pt]", + "format": "PyTorch v1.6", + "link": "https://github.com/lutzroeder/netron/issues/543" + }, + { + "type": "pytorch", + "target": "muzero_models.pb", + "source": "https://github.com/user-attachments/files/16401557/muzero_models.pb.zip[muzero_models.pb]", + "format": "PyTorch v1.6", + "link": "https://github.com/lutzroeder/netron/issues/543" + }, { "type": "pytorch", "target": "multi_return.pt", @@ -5643,6 +5727,13 @@ "error": "'torch.export' not supported.", "link": "https://github.com/lutzroeder/netron/issues/1211" }, + { + "type": "pytorch", + "target": "resnet18_cifar10_quantized.pt", + "source": "https://github.com/user-attachments/files/16401717/resnet18_cifar10_quantized.pt.zip[resnet18_cifar10_quantized.pt]", + "format": "PyTorch v1.6", + "link": "https://github.com/lutzroeder/netron/issues/543" + }, { "type": "pytorch", "target": "resnet18_fbgemm_16fa66dd.pth", @@ -5713,6 +5804,13 @@ "format": "PyTorch v0.1.10", "link": "https://github.com/larry0123du/Decompose-CNN" }, + { + "type": "pytorch", + "target": "rng_state.pth", + "source": "https://github.com/user-attachments/files/16401709/rng_state.pth.zip[rng_state.pth]", + "format": "PyTorch v1.6", + "link": "https://github.com/lutzroeder/netron/issues/543" + }, { "type": "pytorch", "target": "rpn_model.pt", @@ -5734,11 +5832,18 @@ "format": "PyTorch Mobile v1.11", "link": "https://github.com/lutzroeder/netron/issues/1023" }, + { + "type": "pytorch", + "target": "s2a-q4-small-en+pl.model", + "source": "https://github.com/user-attachments/files/16401714/s2a-q4-small-en%2Bpl.model.zip[s2a-q4-small-en+pl.model]", + "format": "PyTorch v1.6", + "link": "https://github.com/lutzroeder/netron/issues/543" + }, { "type": "pytorch", "target": "segmentor.pt", "source": "https://github.com/lutzroeder/netron/files/7663953/segmentor.pt.zip[segmentor.pt]", - "format": "TorchScript v1.6", + "format": "PyTorch v1.6", "link": "https://github.com/lutzroeder/netron/issues/686" }, { @@ -6145,6 +6250,13 @@ "format": "TorchScript v1.6", "link": "https://github.com/lutzroeder/netron/issues/827" }, + { + "type": "pytorch", + "target": "weigths_0000000.pth", + "source": "https://github.com/user-attachments/files/16401710/weigths_0000000.pth.zip[weigths_0000000.pth]", + "format": "PyTorch v1.6", + "link": "https://github.com/lutzroeder/netron/issues/543" + }, { "type": "pytorch", "target": "yolox_m.torchscript.pt",