diff --git a/source/onnx.js b/source/onnx.js index f859e47acc..681259f298 100644 --- a/source/onnx.js +++ b/source/onnx.js @@ -703,8 +703,7 @@ onnx.Attribute = class { } if (metadata.type === 'DataType') { this._type = metadata.type; - const value = this._value ? parseInt(this._value.toString(), 10) : this._value; - this._value = Number.isInteger(value) ? context.createDataType(value) : value; + this._value = context.createDataType(this._value); } } } @@ -821,6 +820,9 @@ onnx.Tensor = class { this._location = context.createLocation(tensor.data_location); if (tensor.data_location === onnx.DataLocation.DEFAULT) { switch (tensor.data_type) { + case onnx.DataType.UNDEFINED: { + break; + } case onnx.DataType.FLOAT16: if (tensor.int32_data && tensor.int32_data.length > 0) { const buffer = new Uint8Array(tensor.int32_data.length << 1); @@ -1321,7 +1323,7 @@ onnx.GraphContext = class { constructor(context, nodes) { this._context = context; this._dataTypes = new Map(Object.entries(onnx.DataType).map((entry) => [ entry[1], entry[0].toLowerCase() ])); - this._dataTypes.set(onnx.DataType.UNDEFINED, 'UNDEFINED'); + this._dataTypes.set(onnx.DataType.UNDEFINED, 'undefined'); this._dataTypes.set(onnx.DataType.BOOL, 'boolean'); this._dataTypes.set(onnx.DataType.FLOAT, 'float32'); this._dataTypes.set(onnx.DataType.DOUBLE, 'float64'); @@ -1474,7 +1476,21 @@ onnx.GraphContext = class { } createDataType(value) { - return this._dataTypes.has(value) ? this._dataTypes.get(value) : this._dataTypes.get(onnx.DataType.UNDEFINED); + if (!Number.isInteger(value)) { + if (value && value.toNumber) { + value = value.toNumber(); + } + else if (value && typeof value === 'string' && onnx.DataType[value.toUpperCase()] !== undefined) { + value = onnx.DataType[value.toUpperCase()]; + } + else { + throw new onnx.Error("Unsupported data type '" + JSON.stringify(value) + "'."); + } + } + if (this._dataTypes.has(value)) { + return this._dataTypes.get(value); + } + throw new onnx.Error("Unsupported data type '" + JSON.stringify(value) + "'."); } createLocation(value) { diff --git a/test/models.json b/test/models.json index 0dd0a46bdb..4a006b4b9e 100644 --- a/test/models.json +++ b/test/models.json @@ -3238,14 +3238,6 @@ "format": "ONNX v3", "link": "https://github.com/onnx/models/tree/main/models/face_recognition/ArcFace" }, - { - "type": "onnx", - "target": "bidaf-9.onnx.zip", - "source": "https://github.com/lutzroeder/netron/files/6572387/bidaf-9.onnx.zip", - "format": "ONNX v4", - "producer": "CNTK 2.7", - "link": "https://github.com/lutzroeder/netron/issues/6" - }, { "type": "onnx", "target": "bert-base-uncased.onnx.zip", @@ -3263,10 +3255,10 @@ }, { "type": "onnx", - "target": "mnist_bfloat16.onnx", - "source": "https://github.com/lutzroeder/netron/files/8556399/mnist_bfloat16.onnx.zip[mnist_bfloat16.onnx]", + "target": "bidaf-9.onnx.zip", + "source": "https://github.com/lutzroeder/netron/files/6572387/bidaf-9.onnx.zip", "format": "ONNX v4", - "producer": "pytorch 1.12.0", + "producer": "CNTK 2.7", "link": "https://github.com/lutzroeder/netron/issues/6" }, { @@ -3507,6 +3499,14 @@ "format": "ONNX v3", "link": "https://github.com/Microsoft/Windows-Machine-Learning/tree/master/Samples/MNIST/Tutorial/cs/Assets" }, + { + "type": "onnx", + "target": "mnist_bfloat16.onnx", + "source": "https://github.com/lutzroeder/netron/files/8556399/mnist_bfloat16.onnx.zip[mnist_bfloat16.onnx]", + "format": "ONNX v4", + "producer": "pytorch 1.12.0", + "link": "https://github.com/lutzroeder/netron/issues/6" + }, { "type": "onnx", "target": "mlnet_encoder.onnx",