aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJack O'Connor <[email protected]>2022-03-08 14:26:31 -0500
committerJack O'Connor <[email protected]>2022-03-08 22:23:09 -0500
commit9fdea0db7cee98343c920c8f28d8e88dc6a3a500 (patch)
treeee24226017b5bb0a37e1a1fca5a7805b4295b8e2
parentbcbbcc8d2c61e2653d3434b049e3c7d857984bd3 (diff)
describe the transposition in comments
-rw-r--r--src/kernel.rs95
1 files changed, 61 insertions, 34 deletions
diff --git a/src/kernel.rs b/src/kernel.rs
index 6d06186..115ec76 100644
--- a/src/kernel.rs
+++ b/src/kernel.rs
@@ -828,6 +828,31 @@ global_asm!(
// and invokes blake3_avx512_kernel_16.
// --------------------------------------------------------------------------------------------
"blake3_avx512_blocks_16:",
+ // Load and transpose the message words. Because operations that cross 128-bit lanes are
+ // relatively expensive, we split each 512-bit load into four 128-bit loads. This results in
+ // vectors like:
+ //
+ // a0, a1, a2, a3, e0, e1, e2, e3, i0, i1, i2, i3, m0, m1, m2, m3
+ //
+ // Here a, b, c and so on are the 1024-byte-strided blocks provided by the caller,
+ // and *0, *1, *2, and so on represent the consecutive 32-bit words of each block. Our goal in
+ // transposition is to produce the vectors (a0, b0, c0, ...), (a1, b1, c1, ...), and so on.
+ //
+ // After the loads, we need to do two interleaving passes. First we interleave 32-bit words.
+ // This produces vectors like:
+ //
+ // a0, b0, a1, b1, e0, f0, e1, f1, i0, j0, i1, j1, m0, n0, m1, n1
+ //
+ // Finally we interleave 64-bit words. This gives us our goal, which is vectors like:
+ //
+ // a0, b0, c0, d0, e0, f0, g0, h0, i0, j0, k0, l0, m0, n0, o0, p0
+ //
+ // The interleavings can be done mostly in place, but the first interleaving requires a single
+ // scratch vector, and the second interleaving requires two scratch vectors, for a total of
+ // three scratch vectors needed. Thus we load each of the message vectors three register
+ // positions "higher" than its final destination. We want the transposed results to reside in
+ // zmm16-zmm31, so we initially load into zmm19-"zmm34" (except zmm32-zmm34 don't exist, so we
+ // substitute zmm13-zmm15 for this range).
"vmovdqu32 xmm19, xmmword ptr [rdi + 0 * 16 + 0 * 1024]",
"vinserti32x4 zmm19, zmm19, xmmword ptr [rdi + 0 * 16 + 4 * 1024], 1",
"vinserti32x4 zmm19, zmm19, xmmword ptr [rdi + 0 * 16 + 8 * 1024], 2",
@@ -836,8 +861,8 @@ global_asm!(
"vinserti32x4 zmm20, zmm20, xmmword ptr [rdi + 0 * 16 + 5 * 1024], 1",
"vinserti32x4 zmm20, zmm20, xmmword ptr [rdi + 0 * 16 + 9 * 1024], 2",
"vinserti32x4 zmm20, zmm20, xmmword ptr [rdi + 0 * 16 + 13 * 1024], 3",
- "vpunpckldq zmm18, zmm19, zmm20",
- "vpunpckhdq zmm19, zmm19, zmm20",
+ "vpunpckldq zmm18, zmm19, zmm20",
+ "vpunpckhdq zmm19, zmm19, zmm20",
"vmovdqu32 xmm21, xmmword ptr [rdi + 0 * 16 + 2 * 1024]",
"vinserti32x4 zmm21, zmm21, xmmword ptr [rdi + 0 * 16 + 6 * 1024], 1",
"vinserti32x4 zmm21, zmm21, xmmword ptr [rdi + 0 * 16 + 10 * 1024], 2",
@@ -846,12 +871,12 @@ global_asm!(
"vinserti32x4 zmm22, zmm22, xmmword ptr [rdi + 0 * 16 + 7 * 1024], 1",
"vinserti32x4 zmm22, zmm22, xmmword ptr [rdi + 0 * 16 + 11 * 1024], 2",
"vinserti32x4 zmm22, zmm22, xmmword ptr [rdi + 0 * 16 + 15 * 1024], 3",
- "vpunpckldq zmm20, zmm21, zmm22",
- "vpunpckhdq zmm21, zmm21, zmm22",
- "vpunpcklqdq zmm16, zmm18, zmm20",
- "vpunpckhqdq zmm17, zmm18, zmm20",
- "vpunpcklqdq zmm18, zmm19, zmm21",
- "vpunpckhqdq zmm19, zmm19, zmm21",
+ "vpunpckldq zmm20, zmm21, zmm22",
+ "vpunpckhdq zmm21, zmm21, zmm22",
+ "vpunpcklqdq zmm16, zmm18, zmm20",
+ "vpunpckhqdq zmm17, zmm18, zmm20",
+ "vpunpcklqdq zmm18, zmm19, zmm21",
+ "vpunpckhqdq zmm19, zmm19, zmm21",
"vmovdqu32 xmm23, xmmword ptr [rdi + 1 * 16 + 0 * 1024]",
"vinserti32x4 zmm23, zmm23, xmmword ptr [rdi + 1 * 16 + 4 * 1024], 1",
"vinserti32x4 zmm23, zmm23, xmmword ptr [rdi + 1 * 16 + 8 * 1024], 2",
@@ -860,8 +885,8 @@ global_asm!(
"vinserti32x4 zmm24, zmm24, xmmword ptr [rdi + 1 * 16 + 5 * 1024], 1",
"vinserti32x4 zmm24, zmm24, xmmword ptr [rdi + 1 * 16 + 9 * 1024], 2",
"vinserti32x4 zmm24, zmm24, xmmword ptr [rdi + 1 * 16 + 13 * 1024], 3",
- "vpunpckldq zmm22, zmm23, zmm24",
- "vpunpckhdq zmm23, zmm23, zmm24",
+ "vpunpckldq zmm22, zmm23, zmm24",
+ "vpunpckhdq zmm23, zmm23, zmm24",
"vmovdqu32 xmm25, xmmword ptr [rdi + 1 * 16 + 2 * 1024]",
"vinserti32x4 zmm25, zmm25, xmmword ptr [rdi + 1 * 16 + 6 * 1024], 1",
"vinserti32x4 zmm25, zmm25, xmmword ptr [rdi + 1 * 16 + 10 * 1024], 2",
@@ -870,12 +895,12 @@ global_asm!(
"vinserti32x4 zmm26, zmm26, xmmword ptr [rdi + 1 * 16 + 7 * 1024], 1",
"vinserti32x4 zmm26, zmm26, xmmword ptr [rdi + 1 * 16 + 11 * 1024], 2",
"vinserti32x4 zmm26, zmm26, xmmword ptr [rdi + 1 * 16 + 15 * 1024], 3",
- "vpunpckldq zmm24, zmm25, zmm26",
- "vpunpckhdq zmm25, zmm25, zmm26",
- "vpunpcklqdq zmm20, zmm22, zmm24",
- "vpunpckhqdq zmm21, zmm22, zmm24",
- "vpunpcklqdq zmm22, zmm23, zmm25",
- "vpunpckhqdq zmm23, zmm23, zmm25",
+ "vpunpckldq zmm24, zmm25, zmm26",
+ "vpunpckhdq zmm25, zmm25, zmm26",
+ "vpunpcklqdq zmm20, zmm22, zmm24",
+ "vpunpckhqdq zmm21, zmm22, zmm24",
+ "vpunpcklqdq zmm22, zmm23, zmm25",
+ "vpunpckhqdq zmm23, zmm23, zmm25",
"vmovdqu32 xmm27, xmmword ptr [rdi + 2 * 16 + 0 * 1024]",
"vinserti32x4 zmm27, zmm27, xmmword ptr [rdi + 2 * 16 + 4 * 1024], 1",
"vinserti32x4 zmm27, zmm27, xmmword ptr [rdi + 2 * 16 + 8 * 1024], 2",
@@ -884,8 +909,8 @@ global_asm!(
"vinserti32x4 zmm28, zmm28, xmmword ptr [rdi + 2 * 16 + 5 * 1024], 1",
"vinserti32x4 zmm28, zmm28, xmmword ptr [rdi + 2 * 16 + 9 * 1024], 2",
"vinserti32x4 zmm28, zmm28, xmmword ptr [rdi + 2 * 16 + 13 * 1024], 3",
- "vpunpckldq zmm26, zmm27, zmm28",
- "vpunpckhdq zmm27, zmm27, zmm28",
+ "vpunpckldq zmm26, zmm27, zmm28",
+ "vpunpckhdq zmm27, zmm27, zmm28",
"vmovdqu32 xmm29, xmmword ptr [rdi + 2 * 16 + 2 * 1024]",
"vinserti32x4 zmm29, zmm29, xmmword ptr [rdi + 2 * 16 + 6 * 1024], 1",
"vinserti32x4 zmm29, zmm29, xmmword ptr [rdi + 2 * 16 + 10 * 1024], 2",
@@ -894,22 +919,24 @@ global_asm!(
"vinserti32x4 zmm30, zmm30, xmmword ptr [rdi + 2 * 16 + 7 * 1024], 1",
"vinserti32x4 zmm30, zmm30, xmmword ptr [rdi + 2 * 16 + 11 * 1024], 2",
"vinserti32x4 zmm30, zmm30, xmmword ptr [rdi + 2 * 16 + 15 * 1024], 3",
- "vpunpckldq zmm28, zmm29, zmm30",
- "vpunpckhdq zmm29, zmm29, zmm30",
- "vpunpcklqdq zmm24, zmm26, zmm28",
- "vpunpckhqdq zmm25, zmm26, zmm28",
- "vpunpcklqdq zmm26, zmm27, zmm29",
- "vpunpckhqdq zmm27, zmm27, zmm29",
+ "vpunpckldq zmm28, zmm29, zmm30",
+ "vpunpckhdq zmm29, zmm29, zmm30",
+ "vpunpcklqdq zmm24, zmm26, zmm28",
+ "vpunpckhqdq zmm25, zmm26, zmm28",
+ "vpunpcklqdq zmm26, zmm27, zmm29",
+ "vpunpckhqdq zmm27, zmm27, zmm29",
"vmovdqu32 xmm31, xmmword ptr [rdi + 3 * 16 + 0 * 1024]",
"vinserti32x4 zmm31, zmm31, xmmword ptr [rdi + 3 * 16 + 4 * 1024], 1",
"vinserti32x4 zmm31, zmm31, xmmword ptr [rdi + 3 * 16 + 8 * 1024], 2",
"vinserti32x4 zmm31, zmm31, xmmword ptr [rdi + 3 * 16 + 12 * 1024], 3",
+ // There are no registers "above" zmm31, so for the next twenty operations we use zmm13-zmm15
+ // to stand in for zmm32-34, but otherwise the pattern is the same.
"vmovdqu32 xmm13, xmmword ptr [rdi + 3 * 16 + 1 * 1024]",
"vinserti32x4 zmm13, zmm13, xmmword ptr [rdi + 3 * 16 + 5 * 1024], 1",
"vinserti32x4 zmm13, zmm13, xmmword ptr [rdi + 3 * 16 + 9 * 1024], 2",
"vinserti32x4 zmm13, zmm13, xmmword ptr [rdi + 3 * 16 + 13 * 1024], 3",
- "vpunpckldq zmm30, zmm31, zmm13",
- "vpunpckhdq zmm31, zmm31, zmm13",
+ "vpunpckldq zmm30, zmm31, zmm13",
+ "vpunpckhdq zmm31, zmm31, zmm13",
"vmovdqu32 xmm14, xmmword ptr [rdi + 3 * 16 + 2 * 1024]",
"vinserti32x4 zmm14, zmm14, xmmword ptr [rdi + 3 * 16 + 6 * 1024], 1",
"vinserti32x4 zmm14, zmm14, xmmword ptr [rdi + 3 * 16 + 10 * 1024], 2",
@@ -918,14 +945,14 @@ global_asm!(
"vinserti32x4 zmm15, zmm15, xmmword ptr [rdi + 3 * 16 + 7 * 1024], 1",
"vinserti32x4 zmm15, zmm15, xmmword ptr [rdi + 3 * 16 + 11 * 1024], 2",
"vinserti32x4 zmm15, zmm15, xmmword ptr [rdi + 3 * 16 + 15 * 1024], 3",
- "vpunpckldq zmm13, zmm14, zmm15",
- "vpunpckhdq zmm14, zmm14, zmm15",
- "vpunpcklqdq zmm28, zmm30, zmm13",
- "vpunpckhqdq zmm29, zmm30, zmm13",
- "vpunpcklqdq zmm30, zmm31, zmm14",
- "vpunpckhqdq zmm31, zmm31, zmm14",
- // Initialize the third and fourth rows of the state, which we just used as scratch space
- // during transposition.
+ "vpunpckldq zmm13, zmm14, zmm15",
+ "vpunpckhdq zmm14, zmm14, zmm15",
+ "vpunpcklqdq zmm28, zmm30, zmm13",
+ "vpunpckhqdq zmm29, zmm30, zmm13",
+ "vpunpcklqdq zmm30, zmm31, zmm14",
+ "vpunpckhqdq zmm31, zmm31, zmm14",
+ // Initialize the third and fourth rows of the state, part of which we just used as scratch
+ // space during transposition.
"vmovdqa32 zmm8, zmmword ptr [BLAKE3_IV0_16 + rip]", // IV constants
"vmovdqa32 zmm9, zmmword ptr [BLAKE3_IV1_16 + rip]",
"vmovdqa32 zmm10, zmmword ptr [BLAKE3_IV2_16 + rip]",