aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJack O'Connor <[email protected]>2022-03-08 11:59:58 -0500
committerJack O'Connor <[email protected]>2022-03-08 22:23:09 -0500
commit87a9318233fe8c5347453afecaa69242e09c929c (patch)
treedada030ceae717345e520c8aa5fdc20053b7dd1b
parentd9b803304c3ffe1fb865eb9cbe9140f7c63c3bf9 (diff)
try it with 4 times as many loads
-rw-r--r--src/kernel.rs172
1 files changed, 82 insertions, 90 deletions
diff --git a/src/kernel.rs b/src/kernel.rs
index 4ac76eb..8336661 100644
--- a/src/kernel.rs
+++ b/src/kernel.rs
@@ -828,50 +828,71 @@ global_asm!(
// and invokes blake3_avx512_kernel_16.
// --------------------------------------------------------------------------------------------
"blake3_avx512_blocks_16:",
- // Load the message blocks first (unaligned). See the comments immediately below for why we
- // choose these registers.
- "vmovdqu32 zmm24, zmmword ptr [rdi + 0 * 1024]",
- "vmovdqu32 zmm25, zmmword ptr [rdi + 1 * 1024]",
- "vmovdqu32 zmm26, zmmword ptr [rdi + 2 * 1024]",
- "vmovdqu32 zmm27, zmmword ptr [rdi + 3 * 1024]",
- "vmovdqu32 zmm28, zmmword ptr [rdi + 4 * 1024]",
- "vmovdqu32 zmm29, zmmword ptr [rdi + 5 * 1024]",
- "vmovdqu32 zmm30, zmmword ptr [rdi + 6 * 1024]",
- "vmovdqu32 zmm31, zmmword ptr [rdi + 7 * 1024]",
- "vmovdqu32 zmm8, zmmword ptr [rdi + 8 * 1024]",
- "vmovdqu32 zmm9, zmmword ptr [rdi + 9 * 1024]",
- "vmovdqu32 zmm10, zmmword ptr [rdi + 10 * 1024]",
- "vmovdqu32 zmm11, zmmword ptr [rdi + 11 * 1024]",
- "vmovdqu32 zmm12, zmmword ptr [rdi + 12 * 1024]",
- "vmovdqu32 zmm13, zmmword ptr [rdi + 13 * 1024]",
- "vmovdqu32 zmm14, zmmword ptr [rdi + 14 * 1024]",
- "vmovdqu32 zmm15, zmmword ptr [rdi + 15 * 1024]",
- // Transpose the message blocks. This requires a few different passes:
- // 1) interleave 32-bit lanes
- // 2) interleave 64-bit lanes
- // 3) interleave 128-bit lanes
- // 4) interleave 256-bit lanes (but there's no such instruction, so actually 128 bits again)
- // The last of these passes is easier to implement if we can make use of 8 scratch registers.
- // zmm0-zmm7 are holding the incoming CV and we don't want to touch those. But zmm8-zmm15
- // aren't holding anything important, and we can use those as long as we reinitialize them
- // before we run the kernel. For consistency, we'll use all 8 scratch registers for each pass
- // (even though the earlier passes would be fine using fewer), and we'll have each pass rotate
- // our 24 message+scratch vectors 8 places "to the left". Thus starting 8 places "to the right"
- // in the rotation lets us end up on target after 4 passes, and that's why we loaded the
- // message vectors in the order we did above.
- //
- // The first pass, interleaving 32-bit lanes. Here's the first vector before:
- // (zmm24) a0, a1, a2, a3, a4, a5, a8, a7, a8, a9, a10, a11, a12, a13, a14, a15
- // And after:
- // (zmm16) a0, b0, a1, b1, a4, b4, a5, b5, a8, b8, a9, b9, a12, b12, a13, b13
- "vpunpckldq zmm16, zmm24, zmm25",
- "vpunpckhdq zmm17, zmm24, zmm25",
- "vpunpckldq zmm18, zmm26, zmm27",
- "vpunpckhdq zmm19, zmm26, zmm27",
- "vpunpckldq zmm20, zmm28, zmm29",
- "vpunpckhdq zmm21, zmm28, zmm29",
- "vpunpckldq zmm22, zmm30, zmm31",
- "vpunpckhdq zmm23, zmm30, zmm31",
+ "vmovdqu32 xmm8, xmmword ptr [rdi + 0 * 16 + 0 * 1024]",
+ "vinserti32x4 zmm8, zmm8, xmmword ptr [rdi + 0 * 16 + 4 * 1024], 1",
+ "vinserti32x4 zmm8, zmm8, xmmword ptr [rdi + 0 * 16 + 8 * 1024], 2",
+ "vinserti32x4 zmm8, zmm8, xmmword ptr [rdi + 0 * 16 + 12 * 1024], 3",
+ "vmovdqu32 xmm9, xmmword ptr [rdi + 0 * 16 + 1 * 1024]",
+ "vinserti32x4 zmm9, zmm9, xmmword ptr [rdi + 0 * 16 + 5 * 1024], 1",
+ "vinserti32x4 zmm9, zmm9, xmmword ptr [rdi + 0 * 16 + 9 * 1024], 2",
+ "vinserti32x4 zmm9, zmm9, xmmword ptr [rdi + 0 * 16 + 13 * 1024], 3",
+ "vmovdqu32 xmm10, xmmword ptr [rdi + 0 * 16 + 2 * 1024]",
+ "vinserti32x4 zmm10, zmm10, xmmword ptr [rdi + 0 * 16 + 6 * 1024], 1",
+ "vinserti32x4 zmm10, zmm10, xmmword ptr [rdi + 0 * 16 + 10 * 1024], 2",
+ "vinserti32x4 zmm10, zmm10, xmmword ptr [rdi + 0 * 16 + 14 * 1024], 3",
+ "vmovdqu32 xmm11, xmmword ptr [rdi + 0 * 16 + 3 * 1024]",
+ "vinserti32x4 zmm11, zmm11, xmmword ptr [rdi + 0 * 16 + 7 * 1024], 1",
+ "vinserti32x4 zmm11, zmm11, xmmword ptr [rdi + 0 * 16 + 11 * 1024], 2",
+ "vinserti32x4 zmm11, zmm11, xmmword ptr [rdi + 0 * 16 + 15 * 1024], 3",
+ "vmovdqu32 xmm12, xmmword ptr [rdi + 1 * 16 + 0 * 1024]",
+ "vinserti32x4 zmm12, zmm12, xmmword ptr [rdi + 1 * 16 + 4 * 1024], 1",
+ "vinserti32x4 zmm12, zmm12, xmmword ptr [rdi + 1 * 16 + 8 * 1024], 2",
+ "vinserti32x4 zmm12, zmm12, xmmword ptr [rdi + 1 * 16 + 12 * 1024], 3",
+ "vmovdqu32 xmm13, xmmword ptr [rdi + 1 * 16 + 1 * 1024]",
+ "vinserti32x4 zmm13, zmm13, xmmword ptr [rdi + 1 * 16 + 5 * 1024], 1",
+ "vinserti32x4 zmm13, zmm13, xmmword ptr [rdi + 1 * 16 + 9 * 1024], 2",
+ "vinserti32x4 zmm13, zmm13, xmmword ptr [rdi + 1 * 16 + 13 * 1024], 3",
+ "vmovdqu32 xmm14, xmmword ptr [rdi + 1 * 16 + 2 * 1024]",
+ "vinserti32x4 zmm14, zmm14, xmmword ptr [rdi + 1 * 16 + 6 * 1024], 1",
+ "vinserti32x4 zmm14, zmm14, xmmword ptr [rdi + 1 * 16 + 10 * 1024], 2",
+ "vinserti32x4 zmm14, zmm14, xmmword ptr [rdi + 1 * 16 + 14 * 1024], 3",
+ "vmovdqu32 xmm15, xmmword ptr [rdi + 1 * 16 + 3 * 1024]",
+ "vinserti32x4 zmm15, zmm15, xmmword ptr [rdi + 1 * 16 + 7 * 1024], 1",
+ "vinserti32x4 zmm15, zmm15, xmmword ptr [rdi + 1 * 16 + 11 * 1024], 2",
+ "vinserti32x4 zmm15, zmm15, xmmword ptr [rdi + 1 * 16 + 15 * 1024], 3",
+ "vmovdqu32 xmm16, xmmword ptr [rdi + 2 * 16 + 0 * 1024]",
+ "vinserti32x4 zmm16, zmm16, xmmword ptr [rdi + 2 * 16 + 4 * 1024], 1",
+ "vinserti32x4 zmm16, zmm16, xmmword ptr [rdi + 2 * 16 + 8 * 1024], 2",
+ "vinserti32x4 zmm16, zmm16, xmmword ptr [rdi + 2 * 16 + 12 * 1024], 3",
+ "vmovdqu32 xmm17, xmmword ptr [rdi + 2 * 16 + 1 * 1024]",
+ "vinserti32x4 zmm17, zmm17, xmmword ptr [rdi + 2 * 16 + 5 * 1024], 1",
+ "vinserti32x4 zmm17, zmm17, xmmword ptr [rdi + 2 * 16 + 9 * 1024], 2",
+ "vinserti32x4 zmm17, zmm17, xmmword ptr [rdi + 2 * 16 + 13 * 1024], 3",
+ "vmovdqu32 xmm18, xmmword ptr [rdi + 2 * 16 + 2 * 1024]",
+ "vinserti32x4 zmm18, zmm18, xmmword ptr [rdi + 2 * 16 + 6 * 1024], 1",
+ "vinserti32x4 zmm18, zmm18, xmmword ptr [rdi + 2 * 16 + 10 * 1024], 2",
+ "vinserti32x4 zmm18, zmm18, xmmword ptr [rdi + 2 * 16 + 14 * 1024], 3",
+ "vmovdqu32 xmm19, xmmword ptr [rdi + 2 * 16 + 3 * 1024]",
+ "vinserti32x4 zmm19, zmm19, xmmword ptr [rdi + 2 * 16 + 7 * 1024], 1",
+ "vinserti32x4 zmm19, zmm19, xmmword ptr [rdi + 2 * 16 + 11 * 1024], 2",
+ "vinserti32x4 zmm19, zmm19, xmmword ptr [rdi + 2 * 16 + 15 * 1024], 3",
+ "vmovdqu32 xmm20, xmmword ptr [rdi + 3 * 16 + 0 * 1024]",
+ "vinserti32x4 zmm20, zmm20, xmmword ptr [rdi + 3 * 16 + 4 * 1024], 1",
+ "vinserti32x4 zmm20, zmm20, xmmword ptr [rdi + 3 * 16 + 8 * 1024], 2",
+ "vinserti32x4 zmm20, zmm20, xmmword ptr [rdi + 3 * 16 + 12 * 1024], 3",
+ "vmovdqu32 xmm21, xmmword ptr [rdi + 3 * 16 + 1 * 1024]",
+ "vinserti32x4 zmm21, zmm21, xmmword ptr [rdi + 3 * 16 + 5 * 1024], 1",
+ "vinserti32x4 zmm21, zmm21, xmmword ptr [rdi + 3 * 16 + 9 * 1024], 2",
+ "vinserti32x4 zmm21, zmm21, xmmword ptr [rdi + 3 * 16 + 13 * 1024], 3",
+ "vmovdqu32 xmm22, xmmword ptr [rdi + 3 * 16 + 2 * 1024]",
+ "vinserti32x4 zmm22, zmm22, xmmword ptr [rdi + 3 * 16 + 6 * 1024], 1",
+ "vinserti32x4 zmm22, zmm22, xmmword ptr [rdi + 3 * 16 + 10 * 1024], 2",
+ "vinserti32x4 zmm22, zmm22, xmmword ptr [rdi + 3 * 16 + 14 * 1024], 3",
+ "vmovdqu32 xmm23, xmmword ptr [rdi + 3 * 16 + 3 * 1024]",
+ "vinserti32x4 zmm23, zmm23, xmmword ptr [rdi + 3 * 16 + 7 * 1024], 1",
+ "vinserti32x4 zmm23, zmm23, xmmword ptr [rdi + 3 * 16 + 11 * 1024], 2",
+ "vinserti32x4 zmm23, zmm23, xmmword ptr [rdi + 3 * 16 + 15 * 1024], 3",
+ // interleave 32 bit words
"vpunpckldq zmm24, zmm8, zmm9",
"vpunpckhdq zmm25, zmm8, zmm9",
"vpunpckldq zmm26, zmm10, zmm11",
@@ -880,16 +901,15 @@ global_asm!(
"vpunpckhdq zmm29, zmm12, zmm13",
"vpunpckldq zmm30, zmm14, zmm15",
"vpunpckhdq zmm31, zmm14, zmm15",
- // The second pass, interleaving 64-bit lanes. After this the first vector will be:
- // (zmm8) a0, b0, c0, d0, a4, b4, c4, d4, a8, b8, c8, d8, a12, b12, c12, d12
- "vpunpcklqdq zmm8, zmm16, zmm18",
- "vpunpckhqdq zmm9, zmm16, zmm18",
- "vpunpcklqdq zmm10, zmm17, zmm19",
- "vpunpckhqdq zmm11, zmm17, zmm19",
- "vpunpcklqdq zmm12, zmm20, zmm22",
- "vpunpckhqdq zmm13, zmm20, zmm22",
- "vpunpcklqdq zmm14, zmm21, zmm23",
- "vpunpckhqdq zmm15, zmm21, zmm23",
+ "vpunpckldq zmm8, zmm16, zmm17",
+ "vpunpckhdq zmm9, zmm16, zmm17",
+ "vpunpckldq zmm10, zmm18, zmm19",
+ "vpunpckhdq zmm11, zmm18, zmm19",
+ "vpunpckldq zmm12, zmm20, zmm21",
+ "vpunpckhdq zmm13, zmm20, zmm21",
+ "vpunpckldq zmm14, zmm22, zmm23",
+ "vpunpckhdq zmm15, zmm22, zmm23",
+ // interleave 64-bit words
"vpunpcklqdq zmm16, zmm24, zmm26",
"vpunpckhqdq zmm17, zmm24, zmm26",
"vpunpcklqdq zmm18, zmm25, zmm27",
@@ -898,42 +918,14 @@ global_asm!(
"vpunpckhqdq zmm21, zmm28, zmm30",
"vpunpcklqdq zmm22, zmm29, zmm31",
"vpunpckhqdq zmm23, zmm29, zmm31",
- // The third pass, interleaving 128-bit lanes. After this the first vector will be:
- // (zmm24) a0, b0, c0, d0, a8, b8, c8, d8, e0, f0, g0, h0, e8, f8, g8, h8
- "vshufi32x4 zmm24, zmm8, zmm12, 0x88", // 0b10001000: lo 128-bit lanes A0/A2/B0/B2
- "vshufi32x4 zmm25, zmm9, zmm13, 0x88",
- "vshufi32x4 zmm26, zmm10, zmm14, 0x88",
- "vshufi32x4 zmm27, zmm11, zmm15, 0x88",
- "vshufi32x4 zmm28, zmm8, zmm12, 0xdd", // 0b11011101: hi 128-bit lanes A1/A3/B1/B3
- "vshufi32x4 zmm29, zmm9, zmm13, 0xdd",
- "vshufi32x4 zmm30, zmm10, zmm14, 0xdd",
- "vshufi32x4 zmm31, zmm11, zmm15, 0xdd",
- "vshufi32x4 zmm8, zmm16, zmm20, 0x88", // lo
- "vshufi32x4 zmm9, zmm17, zmm21, 0x88",
- "vshufi32x4 zmm10, zmm18, zmm22, 0x88",
- "vshufi32x4 zmm11, zmm19, zmm23, 0x88",
- "vshufi32x4 zmm12, zmm16, zmm20, 0xdd", // hi
- "vshufi32x4 zmm13, zmm17, zmm21, 0xdd",
- "vshufi32x4 zmm14, zmm18, zmm22, 0xdd",
- "vshufi32x4 zmm15, zmm19, zmm23, 0xdd",
- // The fourth and final pass, interleaving 128-bit lanes again. The first vector will be:
- // (zmm16) a0, b0, c0, d0, e0, f0, g0, h0, i0, j0, k0, l0, m0, n0, o0, p0
- "vshufi32x4 zmm16, zmm24, zmm8, 0x88", // lo
- "vshufi32x4 zmm17, zmm25, zmm9, 0x88",
- "vshufi32x4 zmm18, zmm26, zmm10, 0x88",
- "vshufi32x4 zmm19, zmm27, zmm11, 0x88",
- "vshufi32x4 zmm20, zmm28, zmm12, 0x88",
- "vshufi32x4 zmm21, zmm29, zmm13, 0x88",
- "vshufi32x4 zmm22, zmm30, zmm14, 0x88",
- "vshufi32x4 zmm23, zmm31, zmm15, 0x88",
- "vshufi32x4 zmm24, zmm24, zmm8, 0xdd", // hi
- "vshufi32x4 zmm25, zmm25, zmm9, 0xdd",
- "vshufi32x4 zmm26, zmm26, zmm10, 0xdd",
- "vshufi32x4 zmm27, zmm27, zmm11, 0xdd",
- "vshufi32x4 zmm28, zmm28, zmm12, 0xdd",
- "vshufi32x4 zmm29, zmm29, zmm13, 0xdd",
- "vshufi32x4 zmm30, zmm30, zmm14, 0xdd",
- "vshufi32x4 zmm31, zmm31, zmm15, 0xdd",
+ "vpunpcklqdq zmm24, zmm8, zmm10",
+ "vpunpckhqdq zmm25, zmm8, zmm10",
+ "vpunpcklqdq zmm26, zmm9, zmm11",
+ "vpunpckhqdq zmm27, zmm9, zmm11",
+ "vpunpcklqdq zmm28, zmm12, zmm14",
+ "vpunpckhqdq zmm29, zmm12, zmm14",
+ "vpunpcklqdq zmm30, zmm13, zmm15",
+ "vpunpckhqdq zmm31, zmm13, zmm15",
// Initialize the third and fourth rows of the state, which we just used as scratch space
// during transposition.
"vmovdqa32 zmm8, zmmword ptr [BLAKE3_IV0_16 + rip]", // IV constants