@@ -30,6 +30,10 @@ class LinearLayer(ThetaLayer):
30
30
if premul_input is not None:
31
31
x = x * premul_input
32
32
matmul(x, weight.T) + bias
33
+
34
+ fake_quant exists to allow export without adding dequant ops.
35
+ when fake_quant is True, the op will in quant dequant fashion.
36
+ When false, it will keep quantized types.
33
37
```
34
38
"""
35
39
@@ -70,21 +74,21 @@ def forward(self, x):
70
74
x = q_input .quantize (x )
71
75
if self .fake_quant :
72
76
x = x .unpack ().dequant ()
77
+ elif qdq_input is not None and self .fake_quant :
78
+ x = qdq_input .quantize (x ).unpack ().dequant ()
73
79
74
80
y = ops .linear (x , weight , bias )
75
81
76
82
# Unconditionally dequantize.
77
- # TODO: Support a q_output specifier that signals the layer to let
78
- # the QuantizedTensor escape.
79
83
if isinstance (y , QuantizedTensor ) and not self .fake_quant :
80
84
y = y .unpack ().dequant ()
81
85
# Note that f8_e4m3fnuz types on AMD GPUs accumulate to fp32.
82
86
# We can truncate to fp16 in iree, so we do a cast here
83
- # to account for this in the IR.
87
+ # to account for this in the IR. This is may not be the right
88
+ # level to do this, but for now its here.
84
89
if not self .fake_quant and y .dtype == torch .float8_e4m3fnuz :
85
90
y = ops .to (y , torch .float16 )
86
91
return y
87
- if qdq_output is not None :
88
- # TODO: same as above.
92
+ if qdq_output is not None and self .fake_quant :
89
93
y = qdq_output .quantize (y ).unpack ().dequant ()
90
94
return y
0 commit comments