diff options
| author | Jack O'Connor <[email protected]> | 2022-03-08 11:59:58 -0500 |
|---|---|---|
| committer | Jack O'Connor <[email protected]> | 2022-03-08 22:23:09 -0500 |
| commit | 87a9318233fe8c5347453afecaa69242e09c929c (patch) | |
| tree | dada030ceae717345e520c8aa5fdc20053b7dd1b | |
| parent | d9b803304c3ffe1fb865eb9cbe9140f7c63c3bf9 (diff) | |
try it with 4 times as many loads
| -rw-r--r-- | src/kernel.rs | 172 |
1 files changed, 82 insertions, 90 deletions
diff --git a/src/kernel.rs b/src/kernel.rs index 4ac76eb..8336661 100644 --- a/src/kernel.rs +++ b/src/kernel.rs @@ -828,50 +828,71 @@ global_asm!( // and invokes blake3_avx512_kernel_16. // -------------------------------------------------------------------------------------------- "blake3_avx512_blocks_16:", - // Load the message blocks first (unaligned). See the comments immediately below for why we - // choose these registers. - "vmovdqu32 zmm24, zmmword ptr [rdi + 0 * 1024]", - "vmovdqu32 zmm25, zmmword ptr [rdi + 1 * 1024]", - "vmovdqu32 zmm26, zmmword ptr [rdi + 2 * 1024]", - "vmovdqu32 zmm27, zmmword ptr [rdi + 3 * 1024]", - "vmovdqu32 zmm28, zmmword ptr [rdi + 4 * 1024]", - "vmovdqu32 zmm29, zmmword ptr [rdi + 5 * 1024]", - "vmovdqu32 zmm30, zmmword ptr [rdi + 6 * 1024]", - "vmovdqu32 zmm31, zmmword ptr [rdi + 7 * 1024]", - "vmovdqu32 zmm8, zmmword ptr [rdi + 8 * 1024]", - "vmovdqu32 zmm9, zmmword ptr [rdi + 9 * 1024]", - "vmovdqu32 zmm10, zmmword ptr [rdi + 10 * 1024]", - "vmovdqu32 zmm11, zmmword ptr [rdi + 11 * 1024]", - "vmovdqu32 zmm12, zmmword ptr [rdi + 12 * 1024]", - "vmovdqu32 zmm13, zmmword ptr [rdi + 13 * 1024]", - "vmovdqu32 zmm14, zmmword ptr [rdi + 14 * 1024]", - "vmovdqu32 zmm15, zmmword ptr [rdi + 15 * 1024]", - // Transpose the message blocks. This requires a few different passes: - // 1) interleave 32-bit lanes - // 2) interleave 64-bit lanes - // 3) interleave 128-bit lanes - // 4) interleave 256-bit lanes (but there's no such instruction, so actually 128 bits again) - // The last of these passes is easier to implement if we can make use of 8 scratch registers. - // zmm0-zmm7 are holding the incoming CV and we don't want to touch those. But zmm8-zmm15 - // aren't holding anything important, and we can use those as long as we reinitialize them - // before we run the kernel. For consistency, we'll use all 8 scratch registers for each pass - // (even though the earlier passes would be fine using fewer), and we'll have each pass rotate - // our 24 message+scratch vectors 8 places "to the left". Thus starting 8 places "to the right" - // in the rotation lets us end up on target after 4 passes, and that's why we loaded the - // message vectors in the order we did above. - // - // The first pass, interleaving 32-bit lanes. Here's the first vector before: - // (zmm24) a0, a1, a2, a3, a4, a5, a8, a7, a8, a9, a10, a11, a12, a13, a14, a15 - // And after: - // (zmm16) a0, b0, a1, b1, a4, b4, a5, b5, a8, b8, a9, b9, a12, b12, a13, b13 - "vpunpckldq zmm16, zmm24, zmm25", - "vpunpckhdq zmm17, zmm24, zmm25", - "vpunpckldq zmm18, zmm26, zmm27", - "vpunpckhdq zmm19, zmm26, zmm27", - "vpunpckldq zmm20, zmm28, zmm29", - "vpunpckhdq zmm21, zmm28, zmm29", - "vpunpckldq zmm22, zmm30, zmm31", - "vpunpckhdq zmm23, zmm30, zmm31", + "vmovdqu32 xmm8, xmmword ptr [rdi + 0 * 16 + 0 * 1024]", + "vinserti32x4 zmm8, zmm8, xmmword ptr [rdi + 0 * 16 + 4 * 1024], 1", + "vinserti32x4 zmm8, zmm8, xmmword ptr [rdi + 0 * 16 + 8 * 1024], 2", + "vinserti32x4 zmm8, zmm8, xmmword ptr [rdi + 0 * 16 + 12 * 1024], 3", + "vmovdqu32 xmm9, xmmword ptr [rdi + 0 * 16 + 1 * 1024]", + "vinserti32x4 zmm9, zmm9, xmmword ptr [rdi + 0 * 16 + 5 * 1024], 1", + "vinserti32x4 zmm9, zmm9, xmmword ptr [rdi + 0 * 16 + 9 * 1024], 2", + "vinserti32x4 zmm9, zmm9, xmmword ptr [rdi + 0 * 16 + 13 * 1024], 3", + "vmovdqu32 xmm10, xmmword ptr [rdi + 0 * 16 + 2 * 1024]", + "vinserti32x4 zmm10, zmm10, xmmword ptr [rdi + 0 * 16 + 6 * 1024], 1", + "vinserti32x4 zmm10, zmm10, xmmword ptr [rdi + 0 * 16 + 10 * 1024], 2", + "vinserti32x4 zmm10, zmm10, xmmword ptr [rdi + 0 * 16 + 14 * 1024], 3", + "vmovdqu32 xmm11, xmmword ptr [rdi + 0 * 16 + 3 * 1024]", + "vinserti32x4 zmm11, zmm11, xmmword ptr [rdi + 0 * 16 + 7 * 1024], 1", + "vinserti32x4 zmm11, zmm11, xmmword ptr [rdi + 0 * 16 + 11 * 1024], 2", + "vinserti32x4 zmm11, zmm11, xmmword ptr [rdi + 0 * 16 + 15 * 1024], 3", + "vmovdqu32 xmm12, xmmword ptr [rdi + 1 * 16 + 0 * 1024]", + "vinserti32x4 zmm12, zmm12, xmmword ptr [rdi + 1 * 16 + 4 * 1024], 1", + "vinserti32x4 zmm12, zmm12, xmmword ptr [rdi + 1 * 16 + 8 * 1024], 2", + "vinserti32x4 zmm12, zmm12, xmmword ptr [rdi + 1 * 16 + 12 * 1024], 3", + "vmovdqu32 xmm13, xmmword ptr [rdi + 1 * 16 + 1 * 1024]", + "vinserti32x4 zmm13, zmm13, xmmword ptr [rdi + 1 * 16 + 5 * 1024], 1", + "vinserti32x4 zmm13, zmm13, xmmword ptr [rdi + 1 * 16 + 9 * 1024], 2", + "vinserti32x4 zmm13, zmm13, xmmword ptr [rdi + 1 * 16 + 13 * 1024], 3", + "vmovdqu32 xmm14, xmmword ptr [rdi + 1 * 16 + 2 * 1024]", + "vinserti32x4 zmm14, zmm14, xmmword ptr [rdi + 1 * 16 + 6 * 1024], 1", + "vinserti32x4 zmm14, zmm14, xmmword ptr [rdi + 1 * 16 + 10 * 1024], 2", + "vinserti32x4 zmm14, zmm14, xmmword ptr [rdi + 1 * 16 + 14 * 1024], 3", + "vmovdqu32 xmm15, xmmword ptr [rdi + 1 * 16 + 3 * 1024]", + "vinserti32x4 zmm15, zmm15, xmmword ptr [rdi + 1 * 16 + 7 * 1024], 1", + "vinserti32x4 zmm15, zmm15, xmmword ptr [rdi + 1 * 16 + 11 * 1024], 2", + "vinserti32x4 zmm15, zmm15, xmmword ptr [rdi + 1 * 16 + 15 * 1024], 3", + "vmovdqu32 xmm16, xmmword ptr [rdi + 2 * 16 + 0 * 1024]", + "vinserti32x4 zmm16, zmm16, xmmword ptr [rdi + 2 * 16 + 4 * 1024], 1", + "vinserti32x4 zmm16, zmm16, xmmword ptr [rdi + 2 * 16 + 8 * 1024], 2", + "vinserti32x4 zmm16, zmm16, xmmword ptr [rdi + 2 * 16 + 12 * 1024], 3", + "vmovdqu32 xmm17, xmmword ptr [rdi + 2 * 16 + 1 * 1024]", + "vinserti32x4 zmm17, zmm17, xmmword ptr [rdi + 2 * 16 + 5 * 1024], 1", + "vinserti32x4 zmm17, zmm17, xmmword ptr [rdi + 2 * 16 + 9 * 1024], 2", + "vinserti32x4 zmm17, zmm17, xmmword ptr [rdi + 2 * 16 + 13 * 1024], 3", + "vmovdqu32 xmm18, xmmword ptr [rdi + 2 * 16 + 2 * 1024]", + "vinserti32x4 zmm18, zmm18, xmmword ptr [rdi + 2 * 16 + 6 * 1024], 1", + "vinserti32x4 zmm18, zmm18, xmmword ptr [rdi + 2 * 16 + 10 * 1024], 2", + "vinserti32x4 zmm18, zmm18, xmmword ptr [rdi + 2 * 16 + 14 * 1024], 3", + "vmovdqu32 xmm19, xmmword ptr [rdi + 2 * 16 + 3 * 1024]", + "vinserti32x4 zmm19, zmm19, xmmword ptr [rdi + 2 * 16 + 7 * 1024], 1", + "vinserti32x4 zmm19, zmm19, xmmword ptr [rdi + 2 * 16 + 11 * 1024], 2", + "vinserti32x4 zmm19, zmm19, xmmword ptr [rdi + 2 * 16 + 15 * 1024], 3", + "vmovdqu32 xmm20, xmmword ptr [rdi + 3 * 16 + 0 * 1024]", + "vinserti32x4 zmm20, zmm20, xmmword ptr [rdi + 3 * 16 + 4 * 1024], 1", + "vinserti32x4 zmm20, zmm20, xmmword ptr [rdi + 3 * 16 + 8 * 1024], 2", + "vinserti32x4 zmm20, zmm20, xmmword ptr [rdi + 3 * 16 + 12 * 1024], 3", + "vmovdqu32 xmm21, xmmword ptr [rdi + 3 * 16 + 1 * 1024]", + "vinserti32x4 zmm21, zmm21, xmmword ptr [rdi + 3 * 16 + 5 * 1024], 1", + "vinserti32x4 zmm21, zmm21, xmmword ptr [rdi + 3 * 16 + 9 * 1024], 2", + "vinserti32x4 zmm21, zmm21, xmmword ptr [rdi + 3 * 16 + 13 * 1024], 3", + "vmovdqu32 xmm22, xmmword ptr [rdi + 3 * 16 + 2 * 1024]", + "vinserti32x4 zmm22, zmm22, xmmword ptr [rdi + 3 * 16 + 6 * 1024], 1", + "vinserti32x4 zmm22, zmm22, xmmword ptr [rdi + 3 * 16 + 10 * 1024], 2", + "vinserti32x4 zmm22, zmm22, xmmword ptr [rdi + 3 * 16 + 14 * 1024], 3", + "vmovdqu32 xmm23, xmmword ptr [rdi + 3 * 16 + 3 * 1024]", + "vinserti32x4 zmm23, zmm23, xmmword ptr [rdi + 3 * 16 + 7 * 1024], 1", + "vinserti32x4 zmm23, zmm23, xmmword ptr [rdi + 3 * 16 + 11 * 1024], 2", + "vinserti32x4 zmm23, zmm23, xmmword ptr [rdi + 3 * 16 + 15 * 1024], 3", + // interleave 32 bit words "vpunpckldq zmm24, zmm8, zmm9", "vpunpckhdq zmm25, zmm8, zmm9", "vpunpckldq zmm26, zmm10, zmm11", @@ -880,16 +901,15 @@ global_asm!( "vpunpckhdq zmm29, zmm12, zmm13", "vpunpckldq zmm30, zmm14, zmm15", "vpunpckhdq zmm31, zmm14, zmm15", - // The second pass, interleaving 64-bit lanes. After this the first vector will be: - // (zmm8) a0, b0, c0, d0, a4, b4, c4, d4, a8, b8, c8, d8, a12, b12, c12, d12 - "vpunpcklqdq zmm8, zmm16, zmm18", - "vpunpckhqdq zmm9, zmm16, zmm18", - "vpunpcklqdq zmm10, zmm17, zmm19", - "vpunpckhqdq zmm11, zmm17, zmm19", - "vpunpcklqdq zmm12, zmm20, zmm22", - "vpunpckhqdq zmm13, zmm20, zmm22", - "vpunpcklqdq zmm14, zmm21, zmm23", - "vpunpckhqdq zmm15, zmm21, zmm23", + "vpunpckldq zmm8, zmm16, zmm17", + "vpunpckhdq zmm9, zmm16, zmm17", + "vpunpckldq zmm10, zmm18, zmm19", + "vpunpckhdq zmm11, zmm18, zmm19", + "vpunpckldq zmm12, zmm20, zmm21", + "vpunpckhdq zmm13, zmm20, zmm21", + "vpunpckldq zmm14, zmm22, zmm23", + "vpunpckhdq zmm15, zmm22, zmm23", + // interleave 64-bit words "vpunpcklqdq zmm16, zmm24, zmm26", "vpunpckhqdq zmm17, zmm24, zmm26", "vpunpcklqdq zmm18, zmm25, zmm27", @@ -898,42 +918,14 @@ global_asm!( "vpunpckhqdq zmm21, zmm28, zmm30", "vpunpcklqdq zmm22, zmm29, zmm31", "vpunpckhqdq zmm23, zmm29, zmm31", - // The third pass, interleaving 128-bit lanes. After this the first vector will be: - // (zmm24) a0, b0, c0, d0, a8, b8, c8, d8, e0, f0, g0, h0, e8, f8, g8, h8 - "vshufi32x4 zmm24, zmm8, zmm12, 0x88", // 0b10001000: lo 128-bit lanes A0/A2/B0/B2 - "vshufi32x4 zmm25, zmm9, zmm13, 0x88", - "vshufi32x4 zmm26, zmm10, zmm14, 0x88", - "vshufi32x4 zmm27, zmm11, zmm15, 0x88", - "vshufi32x4 zmm28, zmm8, zmm12, 0xdd", // 0b11011101: hi 128-bit lanes A1/A3/B1/B3 - "vshufi32x4 zmm29, zmm9, zmm13, 0xdd", - "vshufi32x4 zmm30, zmm10, zmm14, 0xdd", - "vshufi32x4 zmm31, zmm11, zmm15, 0xdd", - "vshufi32x4 zmm8, zmm16, zmm20, 0x88", // lo - "vshufi32x4 zmm9, zmm17, zmm21, 0x88", - "vshufi32x4 zmm10, zmm18, zmm22, 0x88", - "vshufi32x4 zmm11, zmm19, zmm23, 0x88", - "vshufi32x4 zmm12, zmm16, zmm20, 0xdd", // hi - "vshufi32x4 zmm13, zmm17, zmm21, 0xdd", - "vshufi32x4 zmm14, zmm18, zmm22, 0xdd", - "vshufi32x4 zmm15, zmm19, zmm23, 0xdd", - // The fourth and final pass, interleaving 128-bit lanes again. The first vector will be: - // (zmm16) a0, b0, c0, d0, e0, f0, g0, h0, i0, j0, k0, l0, m0, n0, o0, p0 - "vshufi32x4 zmm16, zmm24, zmm8, 0x88", // lo - "vshufi32x4 zmm17, zmm25, zmm9, 0x88", - "vshufi32x4 zmm18, zmm26, zmm10, 0x88", - "vshufi32x4 zmm19, zmm27, zmm11, 0x88", - "vshufi32x4 zmm20, zmm28, zmm12, 0x88", - "vshufi32x4 zmm21, zmm29, zmm13, 0x88", - "vshufi32x4 zmm22, zmm30, zmm14, 0x88", - "vshufi32x4 zmm23, zmm31, zmm15, 0x88", - "vshufi32x4 zmm24, zmm24, zmm8, 0xdd", // hi - "vshufi32x4 zmm25, zmm25, zmm9, 0xdd", - "vshufi32x4 zmm26, zmm26, zmm10, 0xdd", - "vshufi32x4 zmm27, zmm27, zmm11, 0xdd", - "vshufi32x4 zmm28, zmm28, zmm12, 0xdd", - "vshufi32x4 zmm29, zmm29, zmm13, 0xdd", - "vshufi32x4 zmm30, zmm30, zmm14, 0xdd", - "vshufi32x4 zmm31, zmm31, zmm15, 0xdd", + "vpunpcklqdq zmm24, zmm8, zmm10", + "vpunpckhqdq zmm25, zmm8, zmm10", + "vpunpcklqdq zmm26, zmm9, zmm11", + "vpunpckhqdq zmm27, zmm9, zmm11", + "vpunpcklqdq zmm28, zmm12, zmm14", + "vpunpckhqdq zmm29, zmm12, zmm14", + "vpunpcklqdq zmm30, zmm13, zmm15", + "vpunpckhqdq zmm31, zmm13, zmm15", // Initialize the third and fourth rows of the state, which we just used as scratch space // during transposition. "vmovdqa32 zmm8, zmmword ptr [BLAKE3_IV0_16 + rip]", // IV constants |
