aboutsummaryrefslogtreecommitdiff
path: root/asm/asm.py
diff options
context:
space:
mode:
authorJack O'Connor <[email protected]>2022-04-01 15:43:04 -0400
committerJack O'Connor <[email protected]>2022-04-09 13:31:19 -0700
commite17743e8fdf2845be6dc85ad339bf45feeefc564 (patch)
tree705b7204f9522a666476ad38f49393c31e11a2de /asm/asm.py
parent35ad4ededdbf259c507c49b2e7ac529b43b61671 (diff)
kernel_3d_16 and xof functionskernel
Diffstat (limited to 'asm/asm.py')
-rwxr-xr-xasm/asm.py637
1 files changed, 604 insertions, 33 deletions
diff --git a/asm/asm.py b/asm/asm.py
index 75c7012..014c420 100755
--- a/asm/asm.py
+++ b/asm/asm.py
@@ -1,6 +1,11 @@
#! /usr/bin/env python3
# Generate asm!
+#
+# TODOs:
+# - vzeroupper
+# - CET
+# - prefetches
from dataclasses import dataclass, replace
@@ -11,6 +16,16 @@ SSE41 = "sse41"
SSE2 = "sse2"
LINUX = "linux"
+MESSAGE_SCHEDULE = [
+ [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
+ [2, 6, 3, 10, 7, 0, 4, 13, 1, 11, 12, 5, 9, 14, 15, 8],
+ [3, 4, 10, 12, 13, 2, 7, 14, 6, 5, 9, 0, 11, 15, 8, 1],
+ [10, 7, 12, 9, 14, 3, 13, 15, 4, 0, 11, 2, 5, 8, 1, 6],
+ [12, 13, 9, 11, 15, 10, 14, 8, 7, 2, 5, 3, 0, 1, 6, 4],
+ [9, 14, 11, 5, 8, 12, 15, 1, 13, 3, 0, 10, 2, 6, 4, 7],
+ [11, 15, 5, 0, 1, 9, 8, 6, 14, 10, 2, 12, 3, 4, 7, 13],
+]
+
@dataclass
class Target:
@@ -47,9 +62,9 @@ def add_row(target, output, degree, dest, src):
if target.extension == AVX512:
if degree == 1:
output.append(f"vpaddd xmm{dest}, xmm{dest}, xmm{src}")
- elif degree == 2:
+ elif degree in (2, 8):
output.append(f"vpaddd ymm{dest}, ymm{dest}, ymm{src}")
- elif degree == 4:
+ elif degree in (4, 16):
output.append(f"vpaddd zmm{dest}, zmm{dest}, zmm{src}")
else:
raise NotImplementedError
@@ -68,9 +83,9 @@ def xor_row(target, output, degree, dest, src):
if target.extension == AVX512:
if degree == 1:
output.append(f"vpxord xmm{dest}, xmm{dest}, xmm{src}")
- elif degree == 2:
+ elif degree in (2, 8):
output.append(f"vpxord ymm{dest}, ymm{dest}, ymm{src}")
- elif degree == 4:
+ elif degree in (4, 16):
output.append(f"vpxord zmm{dest}, zmm{dest}, zmm{src}")
else:
raise NotImplementedError
@@ -90,9 +105,9 @@ def bitrotate_row(target, output, degree, reg, bits):
if target.extension == AVX512:
if degree == 1:
output.append(f"vprord xmm{reg}, xmm{reg}, {bits}")
- elif degree == 2:
+ elif degree in (2, 8):
output.append(f"vprord ymm{reg}, ymm{reg}, {bits}")
- elif degree == 4:
+ elif degree in (4, 16):
output.append(f"vprord zmm{reg}, zmm{reg}, {bits}")
else:
raise NotImplementedError
@@ -138,7 +153,7 @@ def bitrotate_row(target, output, degree, reg, bits):
raise NotImplementedError
-# See the comments above kernel2d().
+# See the comments above kernel_2d().
def diagonalize_state_rows(target, output, degree):
if target.extension == AVX512:
if degree == 1:
@@ -169,7 +184,7 @@ def diagonalize_state_rows(target, output, degree):
raise NotImplementedError
-# See the comments above kernel2d().
+# See the comments above kernel_2d().
def undiagonalize_state_rows(target, output, degree):
if target.extension == AVX512:
if degree == 1:
@@ -200,7 +215,7 @@ def undiagonalize_state_rows(target, output, degree):
raise NotImplementedError
-# See the comments above kernel2d().
+# See the comments above kernel_2d().
def permute_message_rows(target, output, degree):
if target.extension == AVX512:
if degree == 1:
@@ -309,8 +324,8 @@ def permute_message_rows(target, output, degree):
raise NotImplementedError
-def kernel2d_name(target, degree):
- return f"blake3_{target.extension}_kernel2d_{degree}"
+def kernel_2d_name(target, degree):
+ return f"blake3_{target.extension}_kernel_2d_{degree}"
# The two-dimensional kernel packs one or more *rows* of the state into a
@@ -353,22 +368,29 @@ def kernel2d_name(target, degree):
# ymm5: a1, a3, a5, a7, b1, b3, b5, b7
# ymm6: a14, a8, a10, a12, b14, b8, b10, b12
# ymm7: a15, a9, a11, a13, b15, b9, b11, b13
-def kernel2d(target, output, degree):
- label = kernel2d_name(target, degree)
+def kernel_2d(target, output, degree):
+ label = kernel_2d_name(target, degree)
output.append(f"{label}:")
# vpshufb indexes
- if target.extension == SSE41:
+ if target.extension == SSE2:
+ assert degree == 1
+ elif target.extension == SSE41:
+ assert degree == 1
output.append(f"movaps xmm14, xmmword ptr [ROT8+rip]")
output.append(f"movaps xmm15, xmmword ptr [ROT16+rip]")
- if target.extension == AVX2:
+ elif target.extension == AVX2:
+ assert degree == 2
output.append(f"vbroadcasti128 ymm14, xmmword ptr [ROT16+rip]")
output.append(f"vbroadcasti128 ymm15, xmmword ptr [ROT8+rip]")
- if target.extension == AVX512:
+ elif target.extension == AVX512:
+ assert degree in (1, 2, 4)
if degree == 4:
output.append(f"mov {target.scratch32(0)}, 43690")
output.append(f"kmovw k3, {target.scratch32(0)}")
output.append(f"mov {target.scratch32(0)}, 34952")
output.append(f"kmovw k4, {target.scratch32(0)}")
+ else:
+ raise NotImplementedError
for round_number in range(7):
if round_number > 0:
# Un-diagonalize and permute before each round except the first.
@@ -482,12 +504,12 @@ def compress(target, output):
output.append(f".global {label}")
output.append(f"{label}:")
compress_setup(target, output)
- output.append(f"call {kernel2d_name(target, 1)}")
+ output.append(f"call {kernel_2d_name(target, 1)}")
compress_finish(target, output)
output.append(target.ret())
-def xof_setup2d(target, output, degree):
+def xof_setup_2d(target, output, degree):
if target.extension == AVX512:
if degree == 1:
# state words
@@ -641,7 +663,7 @@ def xof_setup2d(target, output, degree):
raise NotImplementedError
-def xof_stream_finish2d(target, output, degree):
+def xof_stream_finish_2d(target, output, degree):
if target.extension == AVX512:
if degree == 1:
output.append(f"vpxor xmm2, xmm2, [{target.arg64(0)}]")
@@ -706,7 +728,7 @@ def xof_stream_finish2d(target, output, degree):
raise NotImplementedError
-def xof_xor_finish2d(target, output, degree):
+def xof_xor_finish_2d(target, output, degree):
if target.extension == AVX512:
if degree == 1:
output.append(f"vpxor xmm2, xmm2, [{target.arg64(0)}]")
@@ -842,28 +864,558 @@ def xof_xor_finish2d(target, output, degree):
raise NotImplementedError
+def g_function_3d(target, output, degree, columns, msg_words1, msg_words2):
+ if target.extension == SSE41:
+ assert degree == 4
+ elif target.extension == AVX2:
+ assert degree == 8
+ elif target.extension == AVX512:
+ assert degree in (8, 16)
+ else:
+ raise NotImplementedError
+ for (column, m1) in zip(columns, msg_words1):
+ add_row(target, output, degree, dest=column[0], src=m1)
+ for column in columns:
+ add_row(target, output, degree, dest=column[0], src=column[1])
+ for column in columns:
+ xor_row(target, output, degree, dest=column[3], src=column[0])
+ for column in columns:
+ bitrotate_row(target, output, degree, reg=column[3], bits=16)
+ for column in columns:
+ add_row(target, output, degree, dest=column[2], src=column[3])
+ for column in columns:
+ xor_row(target, output, degree, dest=column[1], src=column[2])
+ for column in columns:
+ bitrotate_row(target, output, degree, reg=column[1], bits=12)
+ for (column, m2) in zip(columns, msg_words2):
+ add_row(target, output, degree, dest=column[0], src=m2)
+ for column in columns:
+ add_row(target, output, degree, dest=column[0], src=column[1])
+ for column in columns:
+ xor_row(target, output, degree, dest=column[3], src=column[0])
+ for column in columns:
+ bitrotate_row(target, output, degree, reg=column[3], bits=8)
+ for column in columns:
+ add_row(target, output, degree, dest=column[2], src=column[3])
+ for column in columns:
+ xor_row(target, output, degree, dest=column[1], src=column[2])
+ for column in columns:
+ bitrotate_row(target, output, degree, reg=column[1], bits=7)
+
+
+def kernel_3d_name(target, degree):
+ return f"blake3_{target.extension}_kernel_3d_{degree}"
+
+
+def kernel_3d(target, output, degree):
+ label = kernel_3d_name(target, degree)
+ output.append(f"{label}:")
+ if target.extension == SSE41:
+ assert degree == 4
+ elif target.extension == AVX2:
+ assert degree == 8
+ elif target.extension == AVX512:
+ assert degree in (8, 16)
+ else:
+ raise NotImplementedError
+ for round_number in range(7):
+ straight_columns = [
+ [0, 4, 8, 12],
+ [1, 5, 9, 13],
+ [2, 6, 10, 14],
+ [3, 7, 11, 15],
+ ]
+ msg_words1 = [16 + MESSAGE_SCHEDULE[round_number][i] for i in [0, 2, 4, 6]]
+ msg_words2 = [16 + MESSAGE_SCHEDULE[round_number][i] for i in [1, 3, 5, 7]]
+ g_function_3d(target, output, degree, straight_columns, msg_words1, msg_words2)
+ diagonal_columns = [
+ [0, 5, 10, 15],
+ [1, 6, 11, 12],
+ [2, 7, 8, 13],
+ [3, 4, 9, 14],
+ ]
+ msg_words1 = [16 + MESSAGE_SCHEDULE[round_number][i] for i in [8, 10, 12, 14]]
+ msg_words2 = [16 + MESSAGE_SCHEDULE[round_number][i] for i in [9, 11, 13, 15]]
+ g_function_3d(target, output, degree, diagonal_columns, msg_words1, msg_words2)
+ # Xor the last two rows into the first two, but don't do the feed forward
+ # here. That's only done in the XOF case.
+ for dest in range(8):
+ xor_row(target, output, degree, dest=dest, src=dest + 8)
+ output.append(target.ret())
+
+
+def xof_setup_3d(target, output, degree):
+ if target.extension == AVX512:
+ if degree == 16:
+ # Load vpermi2d indexes into the counter registers.
+ output.append(f"vmovdqa32 zmm12, zmmword ptr [rip + EVEN_INDEXES]")
+ output.append(f"vmovdqa32 zmm13, zmmword ptr [rip + ODD_INDEXES]")
+ # Load the state words.
+ for i in range(8):
+ output.append(
+ f"vpbroadcastd zmm{i}, dword ptr [{target.arg64(0)}+{4*i}]"
+ )
+ # Load the message words.
+ for i in range(16):
+ output.append(
+ f"vpbroadcastd zmm{i+16}, dword ptr [{target.arg64(1)}+{4*i}]"
+ )
+ # Load the 64-bit counter increments into a temporary register.
+ output.append(f"vmovdqa64 zmm8, zmmword ptr [INCREMENT_3D+rip]")
+ # Broadcast the counter and add it to the increments. This gives
+ # the first 8 counter values.
+ output.append(f"vpbroadcastq zmm9, {target.arg64(2)}")
+ output.append(f"vpaddq zmm9, zmm9, zmm8")
+ # Increment the counter and repeat that for the last 8 counter values.
+ output.append(f"add {target.arg64(2)}, 8")
+ output.append(f"vpbroadcastq zmm10, {target.arg64(2)}")
+ output.append(f"vpaddq zmm10, zmm10, zmm8")
+ # Extract the lower and upper halves of the counter words, using
+ # the permutation tables loaded above.
+ output.append(f"vpermi2d zmm12, zmm9, zmm10")
+ output.append(f"vpermi2d zmm13, zmm9, zmm10")
+ # Load the block length.
+ output.append(f"vpbroadcastd zmm14, {target.arg32(3)}")
+ # Load the domain flags.
+ output.append(f"vpbroadcastd zmm15, {target.arg32(4)}")
+ # Load the IV constants.
+ for i in range(4):
+ output.append(f"vpbroadcastd zmm{i+8}, dword ptr [BLAKE3_IV+rip+{4*i}]")
+ else:
+ raise NotImplementedError
+ else:
+ raise NotImplementedError
+
+
+def xof_stream_finish_3d(target, output, degree):
+ if target.extension == AVX512:
+ if degree == 16:
+ # Re-broadcast the input CV and feed it forward into the second half of the state.
+ output.append(f"vpbroadcastd zmm16, dword ptr [{target.arg64(0)} + 0 * 4]")
+ output.append(f"vpxord zmm8, zmm8, zmm16")
+ output.append(f"vpbroadcastd zmm17, dword ptr [{target.arg64(0)} + 1 * 4]")
+ output.append(f"vpxord zmm9, zmm9, zmm17")
+ output.append(f"vpbroadcastd zmm18, dword ptr [{target.arg64(0)} + 2 * 4]")
+ output.append(f"vpxord zmm10, zmm10, zmm18")
+ output.append(f"vpbroadcastd zmm19, dword ptr [{target.arg64(0)} + 3 * 4]")
+ output.append(f"vpxord zmm11, zmm11, zmm19")
+ output.append(f"vpbroadcastd zmm20, dword ptr [{target.arg64(0)} + 4 * 4]")
+ output.append(f"vpxord zmm12, zmm12, zmm20")
+ output.append(f"vpbroadcastd zmm21, dword ptr [{target.arg64(0)} + 5 * 4]")
+ output.append(f"vpxord zmm13, zmm13, zmm21")
+ output.append(f"vpbroadcastd zmm22, dword ptr [{target.arg64(0)} + 6 * 4]")
+ output.append(f"vpxord zmm14, zmm14, zmm22")
+ output.append(f"vpbroadcastd zmm23, dword ptr [{target.arg64(0)} + 7 * 4]")
+ output.append(f"vpxord zmm15, zmm15, zmm23")
+ # zmm0-zmm15 now contain the final extended state vectors, transposed. We need to un-transpose
+ # them before we write them out. As with blake3_avx512_blocks_16, we prefer to avoid expensive
+ # operations across 128-bit lanes, so we do a couple of interleaving passes and then write out
+ # 128 bits at a time.
+ #
+ # First, interleave 32-bit words. Use zmm16-zmm31 to hold the intermediate results. This
+ # takes the input vectors like:
+ #
+ # a0, b0, c0, d0, e0, f0, g0, h0, i0, j0, k0, l0, m0, n0, o0, p0
+ #
+ # And produces vectors like:
+ #
+ # a0, a1, b0, b1, e0, e1, g0, g1, i0, i1, k0, k1, m0, m1, o0, o1
+ #
+ # Then interleave 64-bit words back into zmm0-zmm15, producing vectors like:
+ #
+ # a0, a1, a2, a3, e0, e1, e2, e3, i0, i1, i2, i3, m0, m1, m2, m3
+ #
+ # Finally, write out each 128-bit group, unaligned.
+ output.append(f"vpunpckldq zmm16, zmm0, zmm1")
+ output.append(f"vpunpckhdq zmm17, zmm0, zmm1")
+ output.append(f"vpunpckldq zmm18, zmm2, zmm3")
+ output.append(f"vpunpckhdq zmm19, zmm2, zmm3")
+ output.append(f"vpunpcklqdq zmm0, zmm16, zmm18")
+ output.append(f"vmovdqu32 xmmword ptr [{target.arg64(5)} + 0 * 16], xmm0")
+ output.append(
+ f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 16 * 16], zmm0, 1"
+ )
+ output.append(
+ f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 32 * 16], zmm0, 2"
+ )
+ output.append(
+ f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 48 * 16], zmm0, 3"
+ )
+ output.append(f"vpunpckhqdq zmm1, zmm16, zmm18")
+ output.append(f"vmovdqu32 xmmword ptr [{target.arg64(5)} + 4 * 16], xmm1")
+ output.append(
+ f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 20 * 16], zmm1, 1"
+ )
+ output.append(
+ f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 36 * 16], zmm1, 2"
+ )
+ output.append(
+ f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 52 * 16], zmm1, 3"
+ )
+ output.append(f"vpunpcklqdq zmm2, zmm17, zmm19")
+ output.append(f"vmovdqu32 xmmword ptr [{target.arg64(5)} + 8 * 16], xmm2")
+ output.append(
+ f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 24 * 16], zmm2, 1"
+ )
+ output.append(
+ f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 40 * 16], zmm2, 2"
+ )
+ output.append(
+ f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 56 * 16], zmm2, 3"
+ )
+ output.append(f"vpunpckhqdq zmm3, zmm17, zmm19")
+ output.append(f"vmovdqu32 xmmword ptr [{target.arg64(5)} + 12 * 16], xmm3")
+ output.append(
+ f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 28 * 16], zmm3, 1"
+ )
+ output.append(
+ f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 44 * 16], zmm3, 2"
+ )
+ output.append(
+ f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 60 * 16], zmm3, 3"
+ )
+ output.append(f"vpunpckldq zmm20, zmm4, zmm5")
+ output.append(f"vpunpckhdq zmm21, zmm4, zmm5")
+ output.append(f"vpunpckldq zmm22, zmm6, zmm7")
+ output.append(f"vpunpckhdq zmm23, zmm6, zmm7")
+ output.append(f"vpunpcklqdq zmm4, zmm20, zmm22")
+ output.append(f"vmovdqu32 xmmword ptr [{target.arg64(5)} + 1 * 16], xmm4")
+ output.append(
+ f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 17 * 16], zmm4, 1"
+ )
+ output.append(
+ f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 33 * 16], zmm4, 2"
+ )
+ output.append(
+ f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 49 * 16], zmm4, 3"
+ )
+ output.append(f"vpunpckhqdq zmm5, zmm20, zmm22")
+ output.append(f"vmovdqu32 xmmword ptr [{target.arg64(5)} + 5 * 16], xmm5")
+ output.append(
+ f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 21 * 16], zmm5, 1"
+ )
+ output.append(
+ f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 37 * 16], zmm5, 2"
+ )
+ output.append(
+ f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 53 * 16], zmm5, 3"
+ )
+ output.append(f"vpunpcklqdq zmm6, zmm21, zmm23")
+ output.append(f"vmovdqu32 xmmword ptr [{target.arg64(5)} + 9 * 16], xmm6")
+ output.append(
+ f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 25 * 16], zmm6, 1"
+ )
+ output.append(
+ f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 41 * 16], zmm6, 2"
+ )
+ output.append(
+ f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 57 * 16], zmm6, 3"
+ )
+ output.append(f"vpunpckhqdq zmm7, zmm21, zmm23")
+ output.append(f"vmovdqu32 xmmword ptr [{target.arg64(5)} + 13 * 16], xmm7")
+ output.append(
+ f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 29 * 16], zmm7, 1"
+ )
+ output.append(
+ f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 45 * 16], zmm7, 2"
+ )
+ output.append(
+ f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 61 * 16], zmm7, 3"
+ )
+ output.append(f"vpunpckldq zmm24, zmm8, zmm9")
+ output.append(f"vpunpckhdq zmm25, zmm8, zmm9")
+ output.append(f"vpunpckldq zmm26, zmm10, zmm11")
+ output.append(f"vpunpckhdq zmm27, zmm10, zmm11")
+ output.append(f"vpunpcklqdq zmm8, zmm24, zmm26")
+ output.append(f"vmovdqu32 xmmword ptr [{target.arg64(5)} + 2 * 16], xmm8")
+ output.append(
+ f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 18 * 16], zmm8, 1"
+ )
+ output.append(
+ f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 34 * 16], zmm8, 2"
+ )
+ output.append(
+ f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 50 * 16], zmm8, 3"
+ )
+ output.append(f"vpunpckhqdq zmm9, zmm24, zmm26")
+ output.append(f"vmovdqu32 xmmword ptr [{target.arg64(5)} + 6 * 16], xmm9")
+ output.append(
+ f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 22 * 16], zmm9, 1"
+ )
+ output.append(
+ f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 38 * 16], zmm9, 2"
+ )
+ output.append(
+ f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 54 * 16], zmm9, 3"
+ )
+ output.append(f"vpunpcklqdq zmm10, zmm25, zmm27")
+ output.append(f"vmovdqu32 xmmword ptr [{target.arg64(5)} + 10 * 16], xmm10")
+ output.append(
+ f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 26 * 16], zmm10, 1"
+ )
+ output.append(
+ f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 42 * 16], zmm10, 2"
+ )
+ output.append(
+ f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 58 * 16], zmm10, 3"
+ )
+ output.append(f"vpunpckhqdq zmm11, zmm25, zmm27")
+ output.append(f"vmovdqu32 xmmword ptr [{target.arg64(5)} + 14 * 16], xmm11")
+ output.append(
+ f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 30 * 16], zmm11, 1"
+ )
+ output.append(
+ f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 46 * 16], zmm11, 2"
+ )
+ output.append(
+ f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 62 * 16], zmm11, 3"
+ )
+ output.append(f"vpunpckldq zmm28, zmm12, zmm13")
+ output.append(f"vpunpckhdq zmm29, zmm12, zmm13")
+ output.append(f"vpunpckldq zmm30, zmm14, zmm15")
+ output.append(f"vpunpckhdq zmm31, zmm14, zmm15")
+ output.append(f"vpunpcklqdq zmm12, zmm28, zmm30")
+ output.append(f"vmovdqu32 xmmword ptr [{target.arg64(5)} + 3 * 16], xmm12")
+ output.append(
+ f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 19 * 16], zmm12, 1"
+ )
+ output.append(
+ f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 35 * 16], zmm12, 2"
+ )
+ output.append(
+ f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 51 * 16], zmm12, 3"
+ )
+ output.append(f"vpunpckhqdq zmm13, zmm28, zmm30")
+ output.append(f"vmovdqu32 xmmword ptr [{target.arg64(5)} + 7 * 16], xmm13")
+ output.append(
+ f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 23 * 16], zmm13, 1"
+ )
+ output.append(
+ f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 39 * 16], zmm13, 2"
+ )
+ output.append(
+ f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 55 * 16], zmm13, 3"
+ )
+ output.append(f"vpunpcklqdq zmm14, zmm29, zmm31")
+ output.append(f"vmovdqu32 xmmword ptr [{target.arg64(5)} + 11 * 16], xmm14")
+ output.append(
+ f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 27 * 16], zmm14, 1"
+ )
+ output.append(
+ f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 43 * 16], zmm14, 2"
+ )
+ output.append(
+ f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 59 * 16], zmm14, 3"
+ )
+ output.append(f"vpunpckhqdq zmm15, zmm29, zmm31")
+ output.append(f"vmovdqu32 xmmword ptr [{target.arg64(5)} + 15 * 16], xmm15")
+ output.append(
+ f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 31 * 16], zmm15, 1"
+ )
+ output.append(
+ f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 47 * 16], zmm15, 2"
+ )
+ output.append(
+ f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 63 * 16], zmm15, 3"
+ )
+ else:
+ raise NotImplementedError
+ else:
+ raise NotImplementedError
+
+
+def xof_xor_finish_3d(target, output, degree):
+ if target.extension == AVX512:
+ if degree == 16:
+ # Re-broadcast the input CV and feed it forward into the second half of the state.
+ output.append(f"vpbroadcastd zmm16, dword ptr [{target.arg64(0)} + 0 * 4]")
+ output.append(f"vpxord zmm8, zmm8, zmm16")
+ output.append(f"vpbroadcastd zmm17, dword ptr [{target.arg64(0)} + 1 * 4]")
+ output.append(f"vpxord zmm9, zmm9, zmm17")
+ output.append(f"vpbroadcastd zmm18, dword ptr [{target.arg64(0)} + 2 * 4]")
+ output.append(f"vpxord zmm10, zmm10, zmm18")
+ output.append(f"vpbroadcastd zmm19, dword ptr [{target.arg64(0)} + 3 * 4]")
+ output.append(f"vpxord zmm11, zmm11, zmm19")
+ output.append(f"vpbroadcastd zmm20, dword ptr [{target.arg64(0)} + 4 * 4]")
+ output.append(f"vpxord zmm12, zmm12, zmm20")
+ output.append(f"vpbroadcastd zmm21, dword ptr [{target.arg64(0)} + 5 * 4]")
+ output.append(f"vpxord zmm13, zmm13, zmm21")
+ output.append(f"vpbroadcastd zmm22, dword ptr [{target.arg64(0)} + 6 * 4]")
+ output.append(f"vpxord zmm14, zmm14, zmm22")
+ output.append(f"vpbroadcastd zmm23, dword ptr [{target.arg64(0)} + 7 * 4]")
+ output.append(f"vpxord zmm15, zmm15, zmm23")
+ # zmm0-zmm15 now contain the final extended state vectors, transposed. We need to un-transpose
+ # them before we write them out. Unlike blake3_avx512_xof_stream_16, we do a complete
+ # un-transpose here, to make the xor step easier.
+ #
+ # First interleave 32-bit words. This takes vectors like:
+ #
+ # a0, b0, c0, d0, e0, f0, g0, h0, i0, j0, k0, l0, m0, n0, o0, p0
+ #
+ # And produces vectors like:
+ #
+ # a0, a1, b0, b1, e0, e1, g0, g1, i0, i1, k0, k1, m0, m1, o0, o1
+ output.append(f"vpunpckldq zmm16, zmm0, zmm1")
+ output.append(f"vpunpckhdq zmm17, zmm0, zmm1")
+ output.append(f"vpunpckldq zmm18, zmm2, zmm3")
+ output.append(f"vpunpckhdq zmm19, zmm2, zmm3")
+ output.append(f"vpunpckldq zmm20, zmm4, zmm5")
+ output.append(f"vpunpckhdq zmm21, zmm4, zmm5")
+ output.append(f"vpunpckldq zmm22, zmm6, zmm7")
+ output.append(f"vpunpckhdq zmm23, zmm6, zmm7")
+ output.append(f"vpunpckldq zmm24, zmm8, zmm9")
+ output.append(f"vpunpckhdq zmm25, zmm8, zmm9")
+ output.append(f"vpunpckldq zmm26, zmm10, zmm11")
+ output.append(f"vpunpckhdq zmm27, zmm10, zmm11")
+ output.append(f"vpunpckldq zmm28, zmm12, zmm13")
+ output.append(f"vpunpckhdq zmm29, zmm12, zmm13")
+ output.append(f"vpunpckldq zmm30, zmm14, zmm15")
+ output.append(f"vpunpckhdq zmm31, zmm14, zmm15")
+ # Then interleave 64-bit words, producing vectors like:
+ #
+ # a0, a1, a2, a3, e0, e1, e2, e3, i0, i1, i2, i3, m0, m1, m2, m3
+ output.append(f"vpunpcklqdq zmm0, zmm16, zmm18")
+ output.append(f"vpunpckhqdq zmm1, zmm16, zmm18")
+ output.append(f"vpunpcklqdq zmm2, zmm17, zmm19")
+ output.append(f"vpunpckhqdq zmm3, zmm17, zmm19")
+ output.append(f"vpunpcklqdq zmm4, zmm20, zmm22")
+ output.append(f"vpunpckhqdq zmm5, zmm20, zmm22")
+ output.append(f"vpunpcklqdq zmm6, zmm21, zmm23")
+ output.append(f"vpunpckhqdq zmm7, zmm21, zmm23")
+ output.append(f"vpunpcklqdq zmm8, zmm24, zmm26")
+ output.append(f"vpunpckhqdq zmm9, zmm24, zmm26")
+ output.append(f"vpunpcklqdq zmm10, zmm25, zmm27")
+ output.append(f"vpunpckhqdq zmm11, zmm25, zmm27")
+ output.append(f"vpunpcklqdq zmm12, zmm28, zmm30")
+ output.append(f"vpunpckhqdq zmm13, zmm28, zmm30")
+ output.append(f"vpunpcklqdq zmm14, zmm29, zmm31")
+ output.append(f"vpunpckhqdq zmm15, zmm29, zmm31")
+ # Then interleave 128-bit lanes, producing vectors like:
+ #
+ # a0, a1, a2, a3, i0, i1, i2, i3, a4, a5, a6, a7, i4, i5, i6, i7
+ output.append(
+ "vshufi32x4 zmm16, zmm0, zmm4, 0x88"
+ ) # lo lanes: 0x88 = 0b10001000 = (0, 2, 0, 2)
+ output.append(f"vshufi32x4 zmm17, zmm1, zmm5, 0x88")
+ output.append(f"vshufi32x4 zmm18, zmm2, zmm6, 0x88")
+ output.append(f"vshufi32x4 zmm19, zmm3, zmm7, 0x88")
+ output.append(
+ "vshufi32x4 zmm20, zmm0, zmm4, 0xdd"
+ ) # hi lanes: 0xdd = 0b11011101 = (1, 3, 1, 3)
+ output.append(f"vshufi32x4 zmm21, zmm1, zmm5, 0xdd")
+ output.append(f"vshufi32x4 zmm22, zmm2, zmm6, 0xdd")
+ output.append(f"vshufi32x4 zmm23, zmm3, zmm7, 0xdd")
+ output.append(f"vshufi32x4 zmm24, zmm8, zmm12, 0x88") # lo lanes
+ output.append(f"vshufi32x4 zmm25, zmm9, zmm13, 0x88")
+ output.append(f"vshufi32x4 zmm26, zmm10, zmm14, 0x88")
+ output.append(f"vshufi32x4 zmm27, zmm11, zmm15, 0x88")
+ output.append(f"vshufi32x4 zmm28, zmm8, zmm12, 0xdd") # hi lanes
+ output.append(f"vshufi32x4 zmm29, zmm9, zmm13, 0xdd")
+ output.append(f"vshufi32x4 zmm30, zmm10, zmm14, 0xdd")
+ output.append(f"vshufi32x4 zmm31, zmm11, zmm15, 0xdd")
+ # Finally interleave 128-bit lanes again (the same permutation as the previous pass, but
+ # different inputs), producing vectors like:
+ #
+ # a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15
+ output.append(f"vshufi32x4 zmm0, zmm16, zmm24, 0x88") # lo lanes
+ output.append(f"vshufi32x4 zmm1, zmm17, zmm25, 0x88")
+ output.append(f"vshufi32x4 zmm2, zmm18, zmm26, 0x88")
+ output.append(f"vshufi32x4 zmm3, zmm19, zmm27, 0x88")
+ output.append(f"vshufi32x4 zmm4, zmm20, zmm28, 0x88")
+ output.append(f"vshufi32x4 zmm5, zmm21, zmm29, 0x88")
+ output.append(f"vshufi32x4 zmm6, zmm22, zmm30, 0x88")
+ output.append(f"vshufi32x4 zmm7, zmm23, zmm31, 0x88")
+ output.append(f"vshufi32x4 zmm8, zmm16, zmm24, 0xdd") # hi lanes
+ output.append(f"vshufi32x4 zmm9, zmm17, zmm25, 0xdd")
+ output.append(f"vshufi32x4 zmm10, zmm18, zmm26, 0xdd")
+ output.append(f"vshufi32x4 zmm11, zmm19, zmm27, 0xdd")
+ output.append(f"vshufi32x4 zmm12, zmm20, zmm28, 0xdd")
+ output.append(f"vshufi32x4 zmm13, zmm21, zmm29, 0xdd")
+ output.append(f"vshufi32x4 zmm14, zmm22, zmm30, 0xdd")
+ output.append(f"vshufi32x4 zmm15, zmm23, zmm31, 0xdd")
+ # zmm0-zmm15 now contain the fully un-transposed state words. Load each 64 block on input
+ # (unaligned), perform the xor, and write out the result (again unaligned).
+ output.append(f"vmovdqu32 zmm16, zmmword ptr [{target.arg64(5)} + 0 * 64]")
+ output.append(f"vpxord zmm0, zmm0, zmm16")
+ output.append(f"vmovdqu32 zmmword ptr [{target.arg64(5)} + 0 * 64], zmm0")
+ output.append(f"vmovdqu32 zmm17, zmmword ptr [{target.arg64(5)} + 1 * 64]")
+ output.append(f"vpxord zmm1, zmm1, zmm17")
+ output.append(f"vmovdqu32 zmmword ptr [{target.arg64(5)} + 1 * 64], zmm1")
+ output.append(f"vmovdqu32 zmm18, zmmword ptr [{target.arg64(5)} + 2 * 64]")
+ output.append(f"vpxord zmm2, zmm2, zmm18")
+ output.append(f"vmovdqu32 zmmword ptr [{target.arg64(5)} + 2 * 64], zmm2")
+ output.append(f"vmovdqu32 zmm19, zmmword ptr [{target.arg64(5)} + 3 * 64]")
+ output.append(f"vpxord zmm3, zmm3, zmm19")
+ output.append(f"vmovdqu32 zmmword ptr [{target.arg64(5)} + 3 * 64], zmm3")
+ output.append(f"vmovdqu32 zmm20, zmmword ptr [{target.arg64(5)} + 4 * 64]")
+ output.append(f"vpxord zmm4, zmm4, zmm20")
+ output.append(f"vmovdqu32 zmmword ptr [{target.arg64(5)} + 4 * 64], zmm4")
+ output.append(f"vmovdqu32 zmm21, zmmword ptr [{target.arg64(5)} + 5 * 64]")
+ output.append(f"vpxord zmm5, zmm5, zmm21")
+ output.append(f"vmovdqu32 zmmword ptr [{target.arg64(5)} + 5 * 64], zmm5")
+ output.append(f"vmovdqu32 zmm22, zmmword ptr [{target.arg64(5)} + 6 * 64]")
+ output.append(f"vpxord zmm6, zmm6, zmm22")
+ output.append(f"vmovdqu32 zmmword ptr [{target.arg64(5)} + 6 * 64], zmm6")
+ output.append(f"vmovdqu32 zmm23, zmmword ptr [{target.arg64(5)} + 7 * 64]")
+ output.append(f"vpxord zmm7, zmm7, zmm23")
+ output.append(f"vmovdqu32 zmmword ptr [{target.arg64(5)} + 7 * 64], zmm7")
+ output.append(f"vmovdqu32 zmm24, zmmword ptr [{target.arg64(5)} + 8 * 64]")
+ output.append(f"vpxord zmm8, zmm8, zmm24")
+ output.append(f"vmovdqu32 zmmword ptr [{target.arg64(5)} + 8 * 64], zmm8")
+ output.append(f"vmovdqu32 zmm25, zmmword ptr [{target.arg64(5)} + 9 * 64]")
+ output.append(f"vpxord zmm9, zmm9, zmm25")
+ output.append(f"vmovdqu32 zmmword ptr [{target.arg64(5)} + 9 * 64], zmm9")
+ output.append(f"vmovdqu32 zmm26, zmmword ptr [{target.arg64(5)} + 10 * 64]")
+ output.append(f"vpxord zmm10, zmm10, zmm26")
+ output.append(f"vmovdqu32 zmmword ptr [{target.arg64(5)} + 10 * 64], zmm10")
+ output.append(f"vmovdqu32 zmm27, zmmword ptr [{target.arg64(5)} + 11 * 64]")
+ output.append(f"vpxord zmm11, zmm11, zmm27")
+ output.append(f"vmovdqu32 zmmword ptr [{target.arg64(5)} + 11 * 64], zmm11")
+ output.append(f"vmovdqu32 zmm28, zmmword ptr [{target.arg64(5)} + 12 * 64]")
+ output.append(f"vpxord zmm12, zmm12, zmm28")
+ output.append(f"vmovdqu32 zmmword ptr [{target.arg64(5)} + 12 * 64], zmm12")
+ output.append(f"vmovdqu32 zmm29, zmmword ptr [{target.arg64(5)} + 13 * 64]")
+ output.append(f"vpxord zmm13, zmm13, zmm29")
+ output.append(f"vmovdqu32 zmmword ptr [{target.arg64(5)} + 13 * 64], zmm13")
+ output.append(f"vmovdqu32 zmm30, zmmword ptr [{target.arg64(5)} + 14 * 64]")
+ output.append(f"vpxord zmm14, zmm14, zmm30")
+ output.append(f"vmovdqu32 zmmword ptr [{target.arg64(5)} + 14 * 64], zmm14")
+ output.append(f"vmovdqu32 zmm31, zmmword ptr [{target.arg64(5)} + 15 * 64]")
+ output.append(f"vpxord zmm15, zmm15, zmm31")
+ output.append(f"vmovdqu32 zmmword ptr [{target.arg64(5)} + 15 * 64], zmm15")
+ else:
+ raise NotImplementedError
+ else:
+ raise NotImplementedError
+
+
def xof_fn(target, output, degree, xor):
variant = "xor" if xor else "stream"
- finish_fn_2d = xof_xor_finish2d if xor else xof_stream_finish2d
+ finish_fn_2d = xof_xor_finish_2d if xor else xof_stream_finish_2d
+ finish_fn_3d = xof_xor_finish_3d if xor else xof_stream_finish_3d
label = f"blake3_{target.extension}_xof_{variant}_{degree}"
output.append(f".global {label}")
output.append(f"{label}:")
if target.extension == AVX512:
if degree in (1, 2, 4):
- xof_setup2d(target, output, degree)
- output.append(f"call {kernel2d_name(target, degree)}")
+ xof_setup_2d(target, output, degree)
+ output.append(f"call {kernel_2d_name(target, degree)}")
finish_fn_2d(target, output, degree)
+ elif degree in (8, 16):
+ xof_setup_3d(target, output, degree)
+ output.append(f"call {kernel_3d_name(target, degree)}")
+ finish_fn_3d(target, output, degree)
else:
raise NotImplementedError
elif target.extension == AVX2:
assert degree == 2
- xof_setup2d(target, output, degree)
- output.append(f"call {kernel2d_name(target, degree)}")
+ xof_setup_2d(target, output, degree)
+ output.append(f"call {kernel_2d_name(target, degree)}")
finish_fn_2d(target, output, degree)
elif target.extension in (SSE41, SSE2):
assert degree == 1
- xof_setup2d(target, output, degree)
- output.append(f"call {kernel2d_name(target, degree)}")
+ xof_setup_2d(target, output, degree)
+ output.append(f"call {kernel_2d_name(target, degree)}")
finish_fn_2d(target, output, degree)
else:
raise NotImplementedError
@@ -877,7 +1429,7 @@ def emit_prelude(target, output):
def emit_sse2(target, output):
target = replace(target, extension=SSE2)
- kernel2d(target, output, 1)
+ kernel_2d(target, output, 1)
compress(target, output)
xof_fn(target, output, 1, xor=False)
xof_fn(target, output, 1, xor=True)
@@ -894,7 +1446,7 @@ def emit_sse2(target, output):
def emit_sse41(target, output):
target = replace(target, extension=SSE41)
- kernel2d(target, output, 1)
+ kernel_2d(target, output, 1)
compress(target, output)
xof_fn(target, output, 1, xor=False)
xof_fn(target, output, 1, xor=True)
@@ -902,24 +1454,35 @@ def emit_sse41(target, output):
def emit_avx2(target, output):
target = replace(target, extension=AVX2)
- kernel2d(target, output, 2)
+ kernel_2d(target, output, 2)
xof_fn(target, output, 2, xor=False)
xof_fn(target, output, 2, xor=True)
def emit_avx512(target, output):
target = replace(target, extension=AVX512)
- kernel2d(target, output, 1)
- kernel2d(target, output, 2)
- kernel2d(target, output, 4)
+
+ # degree 1
+ kernel_2d(target, output, 1)
compress(target, output)
xof_fn(target, output, 1, xor=False)
xof_fn(target, output, 1, xor=True)
+
+ # degree 2
+ kernel_2d(target, output, 2)
xof_fn(target, output, 2, xor=False)
xof_fn(target, output, 2, xor=True)
+
+ # degree 4
+ kernel_2d(target, output, 4)
xof_fn(target, output, 4, xor=False)
xof_fn(target, output, 4, xor=True)
+ # degree 16
+ kernel_3d(target, output, 16)
+ xof_fn(target, output, 16, xor=False)
+ xof_fn(target, output, 16, xor=True)
+
def emit_footer(target, output):
output.append(".balign 16")
@@ -942,6 +1505,14 @@ def emit_footer(target, output):
output.append(".balign 64")
output.append("INCREMENT_2D:")
output.append(".quad 0, 0, 1, 0, 2, 0, 3, 0")
+ output.append("INCREMENT_3D:")
+ output.append(".quad 0, 1, 2, 3, 4, 5, 6, 7")
+
+ output.append(".balign 64")
+ output.append("EVEN_INDEXES:")
+ output.append(".long 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30")
+ output.append("ODD_INDEXES:")
+ output.append(".long 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31")
def format(output):