Skip to content

Latest commit

 

History

History
229 lines (204 loc) · 12 KB

extr_h.md

File metadata and controls

229 lines (204 loc) · 12 KB

Quick summary

Instruction General theme Writemask Optional special features
extrh (26=0) x[i] =   z[_][i]  7 bit
extrh (26=1,10=0) x[i] = f(z[_][i]) 9 bit Integer right shift, integer saturation
extrh (26=1,10=1) y[i] = f(z[_][i]) 9 bit Integer right shift, integer saturation

Instruction encoding

Bit Width Meaning Notes
10 22 A64 reserved instruction Must be 0x201000 >> 10
5 5 Instruction Must be 8
0 5 5-bit GPR index See below for the meaning of the 64 bits in the GPR

Operand bitfields when 26=1

Bit Width Meaning Notes
63 1 Lane width mode (hi) See bit 11
(63=1) 62 1 Destination is bf16 (1) or f16 (0) Only applies in mixed lane-width modes, ignored otherwise
(63=1) 54 8 Ignored Only applies in mixed lane-width modes, ignored otherwise
(63=0) 58 5 Right shift amount Only applies in mixed lane-width modes, ignored otherwise
(63=0) 57 1 Z is signed (1) or unsigned (0) Only applies in mixed lane-width modes, ignored otherwise
(63=0) 56 1 Z saturation is signed (1) or unsigned (0) Only applies in mixed lane-width modes, ignored otherwise
(63=0) 55 1 Saturate Z (1) or truncate Z (0) Only applies in mixed lane-width modes, ignored otherwise
(63=0) 54 1 Right shift is rounding (1) or truncating (0) Only applies in mixed lane-width modes, ignored otherwise
41 13 Ignored
(31=0) 38 3 Write enable mode
(31=0) 32 6 Write enable value Meaning dependent upon associated mode
31 1 Perform operation for multiple vectors (1)
or just one vector (0)
M2 only (always reads as 0 on M1)
27 4 Ignored
26 1 Must be 1 for this decode variant
(31=1) 25 1 "Multiple" means four vectors (1)
or two vectors (0)
Top two bits of Z row ignored if operating on four vectors
20 6 Z row When 31=1, top bit or top two bits ignored
15 5 Ignored
11 4 Lane width mode (lo) See bit 63
10 1 Destination is Y (1) or is X (0)
9 1 Ignored
0 9 Destination offset (in bytes) On M4, when 31=1, low four bits ignored

Lane widths:

X (or Y) Z 63 11 Notes
i8 or u8 i8 or u8 0 0
i32 or u32 i32 or u32 0 8
i16 or u16 i32 or u32 (two rows, interleaved pair) 0 9 Shift and saturation supported
i16 or u16 i32 or u32 (four rows, interleaved pair from those) 0 10 Shift and saturation supported
i8 or u8 i32 or u32 (four rows, interleaved quartet) 0 11 Shift and saturation supported
i8 or u8 i16 or u16 (two rows, interleaved pair) 0 13 Shift and saturation supported
i16 or u16 i16 or u16 0 anything else
f64 f64 1 1
f32 f32 1 8
f16 or bf16 f32 (two rows, interleaved pair) 1 9 M2 only. Bit 62 determines X (or Y) format
f16 or bf16 f32 (four rows, interleaved pair from those) 1 10 M2 only. Bit 62 determines X (or Y) format
f16 f16 1 anything else

Write enable modes (with regard to X or Y):

Mode Meaning of value (N)
0 Enable all lanes (0 or 4 or 5), or odd lanes only (1), or even lanes only (2), or enable all lanes but write 0 to them regardless of Z (3), or no lanes enabled (anything else)
1 Only enable lane #N
2 Only enable the first N lanes, or all lanes when N is zero
3 Only enable the last N lanes, or all lanes when N is zero
4 Only enable the first N lanes (no lanes when N is zero)
5 Only enable the last N lanes (no lanes when N is zero)
6 No lanes enabled
7 No lanes enabled

Operand bitfields when 26=0

Bit Width Meaning Notes
48 16 Ignored
46 2 Write enable mode
41 5 Write enable value Meaning dependent upon associated mode
30 11 Ignored
28 2 Lane width mode
27 1 Must be 0 Otherwise decodes as extrx
26 1 Must be 0 for this decode variant
20 6 Z row
19 1 Ignored
10 9 Destination offset (in bytes) Destination is always X for this decode variant
0 10 Ignored

Lane width modes:

X,Z 28
any 64-bit 0
any 32-bit 1
any 16-bit 2
any 16-bit, but with high 8 bits of each lane disabled 3

Write enable modes (with regard to X):

Mode Meaning of value (N)
0 Enable all lanes (0), or odd lanes only (1), or even lanes only (2), or no lanes (anything else)
1 Only enable lane #N
2 Only enable the first N lanes, or all lanes when N is zero
3 Only enable the last N lanes, or all lanes when N is zero

Description

When X/Y/Z all have the same lane width (which is always the case when 26=0), this operation is simple: the field at bit 20 identifies a Z row, and that row is copied to X (or transposed and copied to Y). The lane width only affects the write-enable logic.

When Z is wider than X/Y, this operation is more complex, as it needs to perform narrowing. The four mixed-width modes are 9, 10, 11, 13. For integer operands, all of these modes support right-shift and optional saturation of the Z values, and then take the low bits. For floating-point operands, these modes canonicalise NaNs and perform rounding (round to nearest, ties to even).

Mode 9 (32-bit Z elements, 16-bit X or Y elements), correspondance between X/Y lanes and pair of Z registers:

Z0024681012141618202224262830
Z1135791113151719212325272931

Mode 10 (32-bit Z elements, 16-bit X/Y elements), correspondance between X/Y lanes and quartet of Z registers:

Z0024681012141618202224262830
Z1
Z2135791113151719212325272931
Z3

Mode 11 (32-bit Z elements, 8-bit X/Y elements), correspondance between X/Y lanes and quartet of Z registers:

Z004812162024283236404448525660
Z115913172125293337414549535761
Z2261014182226303438424650545862
Z3371115192327313539434751555963

Mode 13 (16-bit Z elements, 8-bit X/Y elements), correspondance between X/Y lanes and pair of Z registers:

Z002468101214161820222426283032343638404244464850525456586062
Z113579111315171921232527293133353739414345474951535557596163

On M2, when 26=1, the whole operation can optionally be repeated multiple times, by setting bit 31. Bit 25 controls the repetition count; either two times or four times. Consecutive X or Y registers are used as the destination. If repeated twice, the top bit of Z row is ignored, and Z row is incremented by 32 for the 2nd iteration. If repeated four times, the top two bits of Z row are ignored, and Z row is incremented by 16 on each iteration.

Emulation code

See extr.c.

A representative sample is:

void emulate_AMX_EXTRX(amx_state* state, uint64_t operand) {
    void* dst;
    uint64_t dst_offset;
    uint64_t z_row = operand >> 20;
    uint64_t z_step = 64;
    uint64_t store_enable = ~(uint64_t)0;
    uint8_t buffer[64];
    uint32_t stride = 0;
    uint32_t zbytes, xybytes;

    if (operand & EXTR_HV) {
        dst = (operand & EXTR_HV_TO_Y) ? state->y : state->x;
        dst_offset = operand;
        switch (((operand >> 63) << 4) | ((operand >> 11) & 0xF)) {
        case  0: xybytes = 1; zbytes = 1; break;
        case  8: xybytes = 4; zbytes = 4; break;
        case  9: xybytes = 2; zbytes = 4; stride = 1; break;
        case 10: xybytes = 2; zbytes = 4; stride = 2; break;
        case 11: xybytes = 1; zbytes = 4; stride = 1; break;
        case 13: xybytes = 1; zbytes = 2; stride = 1; break;
        case 17: xybytes = 8; zbytes = 8; break;
        case 24: xybytes = 4; zbytes = 4; break;
        case 25: xybytes = 2; if (AMX_VER >= AMX_VER_M2) { zbytes = 4; stride = 1; } else { zbytes = 2; } break;
        case 26: xybytes = 2; if (AMX_VER >= AMX_VER_M2) { zbytes = 4; stride = 2; } else { zbytes = 2; } break;
        default: xybytes = 2; zbytes = 2; break;
        }
        if ((AMX_VER >= AMX_VER_M2) && (operand & (1ull << 31))) {
            operand &=~ (0x1ffull << 32);
            z_step = z_row & 32 ? 16 : 32;
            if (AMX_VER >= AMX_VER_M4) {
                dst_offset &= -64u;
            }
        }
        store_enable &= parse_writemask(operand >> 32, xybytes, 9);
    } else if (operand & EXTR_BETWEEN_XY) {
        ...
    } else {
        dst = state->x;
        dst_offset = operand >> 10;
        xybytes = 8 >> ((operand >> 28) & 3);
        if (xybytes == 1) {
            xybytes = 2;
            store_enable &= 0x5555555555555555ull;
        }
        store_enable &= parse_writemask(operand >> 41, xybytes, 7);
        zbytes = xybytes;
    }

    uint32_t signext = (operand & EXTR_SIGNED_INPUT) ? 64 - zbytes*8 : 0;
    for (z_row &= z_step - 1; z_row <= 63; z_row += z_step) {
        for (uint32_t i = 0; i < 64; i += xybytes) {
            uint64_t zoff = (i & (zbytes - 1)) / xybytes * stride;
            int64_t val = load_int(&state->z[bit_select(z_row, z_row + zoff, zbytes - 1)].u8[i & -zbytes],
                                   zbytes, signext);
            if (stride) val = extr_alu(val, operand, xybytes*8);
            store_int(buffer + i, xybytes, val);
        }
        if ((operand & EXTR_HV) && (((operand >> 32) & 0x1ff) == 3)) {
            memset(buffer, 0, sizeof(buffer));
        }
        store_xy_row(dst, dst_offset & 0x1FF, buffer, store_enable);
        dst_offset += 64;
    }
}

int64_t extr_alu(int64_t val, uint64_t operand, uint32_t outbits) {
    uint32_t shift = (operand >> 58) & 0x1f;
    if (operand & (1ull << 63)) {
        if (shift >= 16) {
            val = bf16_from_f32((uint32_t)val);
        } else {
            __asm("fcvt %h0, %s0" : "=w"(val) : "0"(val));
        }
        return val;
    }
    if (shift && (operand & EXTR_ROUNDING_SHIFT)) {
        val += 1 << (shift - 1);
    }
    val >>= shift;
    if (operand & EXTR_SATURATE) {
        if (operand & EXTR_SIGNED_OUTPUT) outbits -= 1;
        int64_t hi = 1ull << outbits;
        if (operand & EXTR_SIGNED_INPUT) {
            int64_t lo = (operand & EXTR_SIGNED_OUTPUT) ? -hi : 0;
            if (val < lo) val = lo;
            if (val >= hi) val = hi - 1;
        } else {
            if ((uint64_t)val >= (uint64_t)hi) val = hi - 1;
        }
    }
    return val;
}