Skip to content

Commit

Permalink
Add PyTorch test file (#720)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Apr 25, 2022
1 parent f9beec5 commit 798b35d
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 0 deletions.
12 changes: 12 additions & 0 deletions source/pytorch.js
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,7 @@ pytorch.Tensor = class {
case 'float16':
case 'float32':
case 'float64':
case 'bfloat16':
break;
default:
context.state = "Tensor data type '" + this._type.dataType + "' is not supported.";
Expand Down Expand Up @@ -778,6 +779,11 @@ pytorch.Tensor = class {
context.index += 8;
context.count++;
break;
case 'bfloat16':
results.push(context.dataView.getBfloat16(context.index, this._littleEndian));
context.index += 2;
context.count++;
break;
default:
throw new pytorch.Error("Unsupported tensor data type '" + context.dataType + "'.");
}
Expand Down Expand Up @@ -1949,6 +1955,11 @@ pytorch.Execution = class extends python.Execution {
super(size, torch.qint32);
}
});
this.registerType('torch.BFloat16Storage', class extends torch.storage._StorageBase {
constructor(size) {
super(size, torch.bfloat16);
}
});
this.registerType('torch.Size', class extends Array {
constructor(size) {
super(size.length);
Expand Down Expand Up @@ -2043,6 +2054,7 @@ pytorch.Execution = class extends python.Execution {
this.registerType('torch.QInt8Tensor', class extends torch.Tensor {});
this.registerType('torch.QUInt8Tensor', class extends torch.Tensor {});
this.registerType('torch.QInt32Tensor', class extends torch.Tensor {});
this.registerType('torch.BFloat16Tensor', class extends torch.Tensor {});
this.registerType('torch.cuda.FloatTensor', class extends torch.Tensor {});
this.registerType('torch.cuda.DoubleTensor', class extends torch.Tensor {});
torch.uint8 = new torch.dtype(pytorch.ScalarType.uint8);
Expand Down
7 changes: 7 additions & 0 deletions test/models.json
Original file line number Diff line number Diff line change
Expand Up @@ -4600,6 +4600,13 @@
"format": "PyTorch v0.1.10",
"link": "https://github.com/facebookresearch/kill-the-bits/tree/master/src/models/compressed"
},
{
"type": "pytorch",
"target": "mnist_bfloat16.pt",
"source": "https://github.com/lutzroeder/netron/files/8556279/mnist_bfloat16.pt.zip[mnist_bfloat16.pt]",
"format": "PyTorch v1.6",
"link": "https://github.com/lutzroeder/netron/issues/720"
},
{
"type": "pytorch",
"target": "mnist_linear.ckpt",
Expand Down

0 comments on commit 798b35d

Please sign in to comment.