Instruction | General theme | Writemask | Optional special features |
---|---|---|---|
extrh (26=0) |
x[i] = z[_][i] |
7 bit | |
extrh (26=1,10=0) |
x[i] = f(z[_][i]) |
9 bit | Integer right shift, integer saturation |
extrh (26=1,10=1) |
y[i] = f(z[_][i]) |
9 bit | Integer right shift, integer saturation |
Bit | Width | Meaning | Notes |
---|---|---|---|
10 | 22 | A64 reserved instruction | Must be 0x201000 >> 10 |
5 | 5 | Instruction | Must be 8 |
0 | 5 | 5-bit GPR index | See below for the meaning of the 64 bits in the GPR |
Bit | Width | Meaning | Notes |
---|---|---|---|
63 | 1 | Lane width mode (hi) | See bit 11 |
(63=1) 62 | 1 | Destination is bf16 (1 ) or f16 (0 ) |
Only applies in mixed lane-width modes, ignored otherwise |
(63=1) 54 | 8 | Ignored | Only applies in mixed lane-width modes, ignored otherwise |
(63=0) 58 | 5 | Right shift amount | Only applies in mixed lane-width modes, ignored otherwise |
(63=0) 57 | 1 | Z is signed (1 ) or unsigned (0 ) |
Only applies in mixed lane-width modes, ignored otherwise |
(63=0) 56 | 1 | Z saturation is signed (1 ) or unsigned (0 ) |
Only applies in mixed lane-width modes, ignored otherwise |
(63=0) 55 | 1 | Saturate Z (1 ) or truncate Z (0 ) |
Only applies in mixed lane-width modes, ignored otherwise |
(63=0) 54 | 1 | Right shift is rounding (1 ) or truncating (0 ) |
Only applies in mixed lane-width modes, ignored otherwise |
41 | 13 | Ignored | |
(31=0) 38 | 3 | Write enable mode | |
(31=0) 32 | 6 | Write enable value | Meaning dependent upon associated mode |
31 | 1 | Perform operation for multiple vectors (1 )or just one vector ( 0 ) |
M2 only (always reads as 0 on M1) |
27 | 4 | Ignored | |
26 | 1 | Must be 1 for this decode variant |
|
(31=1) 25 | 1 | "Multiple" means four vectors (1 )or two vectors ( 0 ) |
Top two bits of Z row ignored if operating on four vectors |
20 | 6 | Z row | When 31=1, top bit or top two bits ignored |
15 | 5 | Ignored | |
11 | 4 | Lane width mode (lo) | See bit 63 |
10 | 1 | Destination is Y (1 ) or is X (0 ) |
|
9 | 1 | Ignored | |
0 | 9 | Destination offset (in bytes) | On M4, when 31=1, low four bits ignored |
Lane widths:
X (or Y) | Z | 63 | 11 | Notes |
---|---|---|---|---|
i8 or u8 | i8 or u8 | 0 |
0 |
|
i32 or u32 | i32 or u32 | 0 |
8 |
|
i16 or u16 | i32 or u32 (two rows, interleaved pair) | 0 |
9 |
Shift and saturation supported |
i16 or u16 | i32 or u32 (four rows, interleaved pair from those) | 0 |
10 |
Shift and saturation supported |
i8 or u8 | i32 or u32 (four rows, interleaved quartet) | 0 |
11 |
Shift and saturation supported |
i8 or u8 | i16 or u16 (two rows, interleaved pair) | 0 |
13 |
Shift and saturation supported |
i16 or u16 | i16 or u16 | 0 |
anything else | |
f64 | f64 | 1 |
1 |
|
f32 | f32 | 1 |
8 |
|
f16 or bf16 | f32 (two rows, interleaved pair) | 1 |
9 |
M2 only. Bit 62 determines X (or Y) format |
f16 or bf16 | f32 (four rows, interleaved pair from those) | 1 |
10 |
M2 only. Bit 62 determines X (or Y) format |
f16 | f16 | 1 |
anything else |
Write enable modes (with regard to X or Y):
Mode | Meaning of value (N) |
---|---|
0 |
Enable all lanes (0 or 4 or 5 ), or odd lanes only (1 ), or even lanes only (2 ), or enable all lanes but write 0 to them regardless of Z (3 ), or no lanes enabled (anything else) |
1 |
Only enable lane #N |
2 |
Only enable the first N lanes, or all lanes when N is zero |
3 |
Only enable the last N lanes, or all lanes when N is zero |
4 |
Only enable the first N lanes (no lanes when N is zero) |
5 |
Only enable the last N lanes (no lanes when N is zero) |
6 |
No lanes enabled |
7 |
No lanes enabled |
Bit | Width | Meaning | Notes |
---|---|---|---|
48 | 16 | Ignored | |
46 | 2 | Write enable mode | |
41 | 5 | Write enable value | Meaning dependent upon associated mode |
30 | 11 | Ignored | |
28 | 2 | Lane width mode | |
27 | 1 | Must be 0 |
Otherwise decodes as extrx |
26 | 1 | Must be 0 for this decode variant |
|
20 | 6 | Z row | |
19 | 1 | Ignored | |
10 | 9 | Destination offset (in bytes) | Destination is always X for this decode variant |
0 | 10 | Ignored |
Lane width modes:
X,Z | 28 |
---|---|
any 64-bit | 0 |
any 32-bit | 1 |
any 16-bit | 2 |
any 16-bit, but with high 8 bits of each lane disabled | 3 |
Write enable modes (with regard to X):
Mode | Meaning of value (N) |
---|---|
0 |
Enable all lanes (0 ), or odd lanes only (1 ), or even lanes only (2 ), or no lanes (anything else) |
1 |
Only enable lane #N |
2 |
Only enable the first N lanes, or all lanes when N is zero |
3 |
Only enable the last N lanes, or all lanes when N is zero |
When X/Y/Z all have the same lane width (which is always the case when 26=0), this operation is simple: the field at bit 20 identifies a Z row, and that row is copied to X (or transposed and copied to Y). The lane width only affects the write-enable logic.
When Z is wider than X/Y, this operation is more complex, as it needs to perform narrowing. The four mixed-width modes are 9
, 10
, 11
, 13
. For integer operands, all of these modes support right-shift and optional saturation of the Z values, and then take the low bits. For floating-point operands, these modes canonicalise NaNs and perform rounding (round to nearest, ties to even).
Mode 9 (32-bit Z elements, 16-bit X or Y elements), correspondance between X/Y lanes and pair of Z registers:
Z0 | 0 | 2 | 4 | 6 | 8 | 10 | 12 | 14 | 16 | 18 | 20 | 22 | 24 | 26 | 28 | 30 |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
Z1 | 1 | 3 | 5 | 7 | 9 | 11 | 13 | 15 | 17 | 19 | 21 | 23 | 25 | 27 | 29 | 31 |
Mode 10 (32-bit Z elements, 16-bit X/Y elements), correspondance between X/Y lanes and quartet of Z registers:
Z0 | 0 | 2 | 4 | 6 | 8 | 10 | 12 | 14 | 16 | 18 | 20 | 22 | 24 | 26 | 28 | 30 |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
Z1 | ||||||||||||||||
Z2 | 1 | 3 | 5 | 7 | 9 | 11 | 13 | 15 | 17 | 19 | 21 | 23 | 25 | 27 | 29 | 31 |
Z3 |
Mode 11 (32-bit Z elements, 8-bit X/Y elements), correspondance between X/Y lanes and quartet of Z registers:
Z0 | 0 | 4 | 8 | 12 | 16 | 20 | 24 | 28 | 32 | 36 | 40 | 44 | 48 | 52 | 56 | 60 |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
Z1 | 1 | 5 | 9 | 13 | 17 | 21 | 25 | 29 | 33 | 37 | 41 | 45 | 49 | 53 | 57 | 61 |
Z2 | 2 | 6 | 10 | 14 | 18 | 22 | 26 | 30 | 34 | 38 | 42 | 46 | 50 | 54 | 58 | 62 |
Z3 | 3 | 7 | 11 | 15 | 19 | 23 | 27 | 31 | 35 | 39 | 43 | 47 | 51 | 55 | 59 | 63 |
Mode 13 (16-bit Z elements, 8-bit X/Y elements), correspondance between X/Y lanes and pair of Z registers:
Z0 | 0 | 2 | 4 | 6 | 8 | 10 | 12 | 14 | 16 | 18 | 20 | 22 | 24 | 26 | 28 | 30 | 32 | 34 | 36 | 38 | 40 | 42 | 44 | 46 | 48 | 50 | 52 | 54 | 56 | 58 | 60 | 62 |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
Z1 | 1 | 3 | 5 | 7 | 9 | 11 | 13 | 15 | 17 | 19 | 21 | 23 | 25 | 27 | 29 | 31 | 33 | 35 | 37 | 39 | 41 | 43 | 45 | 47 | 49 | 51 | 53 | 55 | 57 | 59 | 61 | 63 |
On M2, when 26=1, the whole operation can optionally be repeated multiple times, by setting bit 31. Bit 25 controls the repetition count; either two times or four times. Consecutive X or Y registers are used as the destination. If repeated twice, the top bit of Z row is ignored, and Z row is incremented by 32 for the 2nd iteration. If repeated four times, the top two bits of Z row are ignored, and Z row is incremented by 16 on each iteration.
See extr.c.
A representative sample is:
void emulate_AMX_EXTRX(amx_state* state, uint64_t operand) {
void* dst;
uint64_t dst_offset;
uint64_t z_row = operand >> 20;
uint64_t z_step = 64;
uint64_t store_enable = ~(uint64_t)0;
uint8_t buffer[64];
uint32_t stride = 0;
uint32_t zbytes, xybytes;
if (operand & EXTR_HV) {
dst = (operand & EXTR_HV_TO_Y) ? state->y : state->x;
dst_offset = operand;
switch (((operand >> 63) << 4) | ((operand >> 11) & 0xF)) {
case 0: xybytes = 1; zbytes = 1; break;
case 8: xybytes = 4; zbytes = 4; break;
case 9: xybytes = 2; zbytes = 4; stride = 1; break;
case 10: xybytes = 2; zbytes = 4; stride = 2; break;
case 11: xybytes = 1; zbytes = 4; stride = 1; break;
case 13: xybytes = 1; zbytes = 2; stride = 1; break;
case 17: xybytes = 8; zbytes = 8; break;
case 24: xybytes = 4; zbytes = 4; break;
case 25: xybytes = 2; if (AMX_VER >= AMX_VER_M2) { zbytes = 4; stride = 1; } else { zbytes = 2; } break;
case 26: xybytes = 2; if (AMX_VER >= AMX_VER_M2) { zbytes = 4; stride = 2; } else { zbytes = 2; } break;
default: xybytes = 2; zbytes = 2; break;
}
if ((AMX_VER >= AMX_VER_M2) && (operand & (1ull << 31))) {
operand &=~ (0x1ffull << 32);
z_step = z_row & 32 ? 16 : 32;
if (AMX_VER >= AMX_VER_M4) {
dst_offset &= -64u;
}
}
store_enable &= parse_writemask(operand >> 32, xybytes, 9);
} else if (operand & EXTR_BETWEEN_XY) {
...
} else {
dst = state->x;
dst_offset = operand >> 10;
xybytes = 8 >> ((operand >> 28) & 3);
if (xybytes == 1) {
xybytes = 2;
store_enable &= 0x5555555555555555ull;
}
store_enable &= parse_writemask(operand >> 41, xybytes, 7);
zbytes = xybytes;
}
uint32_t signext = (operand & EXTR_SIGNED_INPUT) ? 64 - zbytes*8 : 0;
for (z_row &= z_step - 1; z_row <= 63; z_row += z_step) {
for (uint32_t i = 0; i < 64; i += xybytes) {
uint64_t zoff = (i & (zbytes - 1)) / xybytes * stride;
int64_t val = load_int(&state->z[bit_select(z_row, z_row + zoff, zbytes - 1)].u8[i & -zbytes],
zbytes, signext);
if (stride) val = extr_alu(val, operand, xybytes*8);
store_int(buffer + i, xybytes, val);
}
if ((operand & EXTR_HV) && (((operand >> 32) & 0x1ff) == 3)) {
memset(buffer, 0, sizeof(buffer));
}
store_xy_row(dst, dst_offset & 0x1FF, buffer, store_enable);
dst_offset += 64;
}
}
int64_t extr_alu(int64_t val, uint64_t operand, uint32_t outbits) {
uint32_t shift = (operand >> 58) & 0x1f;
if (operand & (1ull << 63)) {
if (shift >= 16) {
val = bf16_from_f32((uint32_t)val);
} else {
__asm("fcvt %h0, %s0" : "=w"(val) : "0"(val));
}
return val;
}
if (shift && (operand & EXTR_ROUNDING_SHIFT)) {
val += 1 << (shift - 1);
}
val >>= shift;
if (operand & EXTR_SATURATE) {
if (operand & EXTR_SIGNED_OUTPUT) outbits -= 1;
int64_t hi = 1ull << outbits;
if (operand & EXTR_SIGNED_INPUT) {
int64_t lo = (operand & EXTR_SIGNED_OUTPUT) ? -hi : 0;
if (val < lo) val = lo;
if (val >= hi) val = hi - 1;
} else {
if ((uint64_t)val >= (uint64_t)hi) val = hi - 1;
}
}
return val;
}