From e8aaaa4d685c5e438bef3df021ba60ae2f85c63c Mon Sep 17 00:00:00 2001 From: Lutz Roeder Date: Thu, 4 Aug 2022 22:42:49 -0700 Subject: [PATCH] Add PyTorch test file (#720) --- source/python.js | 62 ++++++++++++++++++++++++++++++++++++++++++++++- source/pytorch.js | 10 +++++++- source/view.js | 6 ++--- test/models.json | 7 ++++++ 4 files changed, 80 insertions(+), 5 deletions(-) diff --git a/source/python.js b/source/python.js index f424a5bac3..3ee54a560d 100644 --- a/source/python.js +++ b/source/python.js @@ -1947,10 +1947,67 @@ python.Execution = class { this.registerType('numpy.inexact', class {}); this.registerType('numpy.number', class extends numpy.generic {}); this.registerType('numpy.integer', class extends numpy.number {}); - this.registerType('numpy.signedinteger', class extends numpy.integer {}); this.registerType('numpy.floating', class extends numpy.inexact {}); + this.registerType('numpy.float32', class extends numpy.floating {}); this.registerType('numpy.float64', class extends numpy.floating {}); + this.registerType('numpy.signedinteger', class extends numpy.integer {}); + this.registerType('numpy.int8', class extends numpy.signedinteger {}); + this.registerType('numpy.int16', class extends numpy.signedinteger {}); + this.registerType('numpy.int32', class extends numpy.signedinteger {}); this.registerType('numpy.int64', class extends numpy.signedinteger {}); + this.registerType('numpy.unsignedinteger', class extends numpy.integer {}); + this.registerType('numpy.uint8', class extends numpy.unsignedinteger {}); + this.registerType('numpy.uint16', class extends numpy.unsignedinteger {}); + this.registerType('numpy.uint32', class extends numpy.unsignedinteger {}); + this.registerType('numpy.uint64', class extends numpy.unsignedinteger {}); + this.registerType('fastai.callback.core.TrainEvalCallback', class {}); + this.registerType('fastai.callback.progress.ProgressCallback', class {}); + this.registerType('fastai.data.core.DataLoaders', class {}); + this.registerType('fastai.data.core.Datasets', class {}); + this.registerType('fastai.data.core.TfmdDL', class {}); + this.registerType('fastai.data.core.TfmdLists', class {}); + this.registerType('fastai.data.load._FakeLoader', class {}); + this.registerType('fastai.data.load._wif', class {}); + this.registerType('fastai.data.transforms.Categorize', class {}); + this.registerType('fastai.data.transforms.CategoryMap', class {}); + this.registerType('fastai.data.transforms.IntToFloatTensor', class {}); + this.registerType('fastai.data.transforms.Normalize', class {}); + this.registerType('fastai.data.transforms.parent_label', class {}); + this.registerType('fastai.data.transforms.ToTensor', class {}); + this.registerType('fastai.imports.noop', class {}); + this.registerType('fastai.layers.AdaptiveConcatPool2d', class {}); + this.registerType('fastai.layers.Flatten', class {}); + this.registerType('fastai.learner.AvgLoss', class {}); + this.registerType('fastai.learner.AvgMetric', class {}); + this.registerType('fastai.learner.AvgSmoothLoss', class {}); + this.registerType('fastai.learner.Learner', class {}); + this.registerType('fastai.learner.Recorder', class {}); + this.registerType('fastai.losses.CrossEntropyLossFlat', class {}); + this.registerType('fastai.metrics.error_rate', class {}); + this.registerType('fastai.optimizer.Adam', class {}); + this.registerType('fastai.torch_core._fa_rebuild_tensor', class {}); + this.registerType('fastai.torch_core.TensorBase', class {}); + this.registerType('fastai.torch_core.TensorCategory', class {}); + this.registerType('fastai.torch_core.TensorImage', class {}); + this.registerType('fastai.vision.augment._BrightnessLogit', class {}); + this.registerType('fastai.vision.augment._ContrastLogit', class {}); + this.registerType('fastai.vision.augment._WarpCoord', class {}); + this.registerType('fastai.vision.augment.Brightness', class {}); + this.registerType('fastai.vision.augment.Flip', class {}); + this.registerType('fastai.vision.augment.flip_mat', class {}); + this.registerType('fastai.vision.augment.RandomResizedCropGPU', class {}); + this.registerType('fastai.vision.augment.Resize', class {}); + this.registerType('fastai.vision.augment.rotate_mat', class {}); + this.registerType('fastai.vision.augment.zoom_mat', class {}); + this.registerType('fastai.vision.core.PILImage', class {}); + this.registerType('fastai.vision.learner._resnet_split', class {}); + this.registerType('fastcore.basics.fastuple', class {}); + this.registerType('fastcore.dispatch._TypeDict', class {}); + this.registerType('fastcore.dispatch.TypeDispatch', class {}); + this.registerType('fastcore.foundation.L', class {}); + this.registerType('fastcore.transform.Pipeline', class {}); + this.registerType('fastcore.transform.Transform', class {}); + this.registerType('functools.partial', class {}); this.registerType('gensim.models.doc2vec.Doctag', class {}); this.registerType('gensim.models.doc2vec.Doc2Vec', class {}); this.registerType('gensim.models.doc2vec.Doc2VecTrainables', class {}); @@ -2433,6 +2490,7 @@ python.Execution = class { this.registerType('sklearn.preprocessing._function_transformer.FunctionTransformer', class {}); this.registerType('sklearn.preprocessing._label.LabelBinarizer', class {}); this.registerType('sklearn.preprocessing._label.LabelEncoder', class {}); + this.registerType('sklearn.preprocessing._label.MultiLabelBinarizer', class {}); this.registerType('sklearn.preprocessing._polynomial.PolynomialFeatures', class {}); this.registerType('sklearn.preprocessing.data.Binarizer', class {}); this.registerType('sklearn.preprocessing.data.MaxAbsScaler', class {}); @@ -2841,6 +2899,7 @@ python.Execution = class { return this._reader.stream(size); } }); + this.registerType('random.Random', class {}); this.registerType('re.Pattern', class { constructor(pattern, flags) { this.pattern = pattern; @@ -3776,6 +3835,7 @@ python.Execution = class { this.registerType('torchvision.ops.feature_pyramid_network.FeaturePyramidNetwork', class {}); this.registerType('torchvision.ops.feature_pyramid_network.LastLevelMaxPool', class {}); this.registerType('torchvision.ops.feature_pyramid_network.LastLevelP6P7', class {}); + this.registerType('torchvision.ops.misc.Conv2dNormActivation', class {}); this.registerType('torchvision.ops.misc.ConvNormActivation', class {}); this.registerType('torchvision.ops.misc.ConvTranspose2d', class {}); this.registerType('torchvision.ops.misc.FrozenBatchNorm2d', class {}); diff --git a/source/pytorch.js b/source/pytorch.js index d95d7507ab..7ca07e499e 100644 --- a/source/pytorch.js +++ b/source/pytorch.js @@ -1770,7 +1770,15 @@ pytorch.Container.Zip.Pickle = class extends pytorch.Container.Zip { if (!this._graphs) { const execution = new pytorch.Container.Zip.Execution(null, this._exceptionCallback, this._metadata); const graph = new pytorch.Container.Zip.Pickle.Script(this._entries, execution); - this._graphs = graph.data.forward ? [ graph ] : pytorch.Utility.find(graph.data); + if (graph.data && graph.data.forward) { + this._graphs = [ graph ]; + } + else if (graph.data && graph.data.__class__ && graph.data.__class__.__module__ == 'fastai.learner' && graph.data.__class__.__name__ == 'Learner') { + this._graphs = pytorch.Utility.find(graph.data.model); + } + else { + this._graphs = pytorch.Utility.find(graph.data); + } } return this._graphs; } diff --git a/source/view.js b/source/view.js index 851c408a15..2802d9a05e 100644 --- a/source/view.js +++ b/source/view.js @@ -951,8 +951,8 @@ view.Graph = class extends grapher.Graph { } } if (groupName) { - createCluster(groupName); - this.setParent(viewNode.name, groupName); + createCluster(groupName + '\ngroup'); + this.setParent(viewNode.name, groupName + '\ngroup'); } } } @@ -1100,7 +1100,7 @@ view.Node = class extends grapher.Node { this._add(node.inner); } if (node.nodes) { - this.canvas = this.canvas(); + // this.canvas = this.canvas(); } } diff --git a/test/models.json b/test/models.json index dc5312e657..c476512fc2 100644 --- a/test/models.json +++ b/test/models.json @@ -4198,6 +4198,13 @@ "error": "Unsupported torch.add expression type in 'fasterrcnn_resnet50_fpn.pt'.", "link": "https://github.com/lutzroeder/netron/issues/689" }, + { + "type": "pytorch", + "target": "fruit_veg_model.pkl", + "source": "https://github.com/lutzroeder/netron/files/9265633/fruit_veg_model.pkl.zip[fruit_veg_model.pkl]", + "format": "PyTorch v1.6", + "link": "https://github.com/lutzroeder/netron/issues/720" + }, { "type": "pytorch", "target": "gcn2_tiny_320x240.pb",