diff options
Diffstat (limited to 'asm/asm.py')
| -rwxr-xr-x | asm/asm.py | 637 |
1 files changed, 604 insertions, 33 deletions
@@ -1,6 +1,11 @@ #! /usr/bin/env python3 # Generate asm! +# +# TODOs: +# - vzeroupper +# - CET +# - prefetches from dataclasses import dataclass, replace @@ -11,6 +16,16 @@ SSE41 = "sse41" SSE2 = "sse2" LINUX = "linux" +MESSAGE_SCHEDULE = [ + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], + [2, 6, 3, 10, 7, 0, 4, 13, 1, 11, 12, 5, 9, 14, 15, 8], + [3, 4, 10, 12, 13, 2, 7, 14, 6, 5, 9, 0, 11, 15, 8, 1], + [10, 7, 12, 9, 14, 3, 13, 15, 4, 0, 11, 2, 5, 8, 1, 6], + [12, 13, 9, 11, 15, 10, 14, 8, 7, 2, 5, 3, 0, 1, 6, 4], + [9, 14, 11, 5, 8, 12, 15, 1, 13, 3, 0, 10, 2, 6, 4, 7], + [11, 15, 5, 0, 1, 9, 8, 6, 14, 10, 2, 12, 3, 4, 7, 13], +] + @dataclass class Target: @@ -47,9 +62,9 @@ def add_row(target, output, degree, dest, src): if target.extension == AVX512: if degree == 1: output.append(f"vpaddd xmm{dest}, xmm{dest}, xmm{src}") - elif degree == 2: + elif degree in (2, 8): output.append(f"vpaddd ymm{dest}, ymm{dest}, ymm{src}") - elif degree == 4: + elif degree in (4, 16): output.append(f"vpaddd zmm{dest}, zmm{dest}, zmm{src}") else: raise NotImplementedError @@ -68,9 +83,9 @@ def xor_row(target, output, degree, dest, src): if target.extension == AVX512: if degree == 1: output.append(f"vpxord xmm{dest}, xmm{dest}, xmm{src}") - elif degree == 2: + elif degree in (2, 8): output.append(f"vpxord ymm{dest}, ymm{dest}, ymm{src}") - elif degree == 4: + elif degree in (4, 16): output.append(f"vpxord zmm{dest}, zmm{dest}, zmm{src}") else: raise NotImplementedError @@ -90,9 +105,9 @@ def bitrotate_row(target, output, degree, reg, bits): if target.extension == AVX512: if degree == 1: output.append(f"vprord xmm{reg}, xmm{reg}, {bits}") - elif degree == 2: + elif degree in (2, 8): output.append(f"vprord ymm{reg}, ymm{reg}, {bits}") - elif degree == 4: + elif degree in (4, 16): output.append(f"vprord zmm{reg}, zmm{reg}, {bits}") else: raise NotImplementedError @@ -138,7 +153,7 @@ def bitrotate_row(target, output, degree, reg, bits): raise NotImplementedError -# See the comments above kernel2d(). +# See the comments above kernel_2d(). def diagonalize_state_rows(target, output, degree): if target.extension == AVX512: if degree == 1: @@ -169,7 +184,7 @@ def diagonalize_state_rows(target, output, degree): raise NotImplementedError -# See the comments above kernel2d(). +# See the comments above kernel_2d(). def undiagonalize_state_rows(target, output, degree): if target.extension == AVX512: if degree == 1: @@ -200,7 +215,7 @@ def undiagonalize_state_rows(target, output, degree): raise NotImplementedError -# See the comments above kernel2d(). +# See the comments above kernel_2d(). def permute_message_rows(target, output, degree): if target.extension == AVX512: if degree == 1: @@ -309,8 +324,8 @@ def permute_message_rows(target, output, degree): raise NotImplementedError -def kernel2d_name(target, degree): - return f"blake3_{target.extension}_kernel2d_{degree}" +def kernel_2d_name(target, degree): + return f"blake3_{target.extension}_kernel_2d_{degree}" # The two-dimensional kernel packs one or more *rows* of the state into a @@ -353,22 +368,29 @@ def kernel2d_name(target, degree): # ymm5: a1, a3, a5, a7, b1, b3, b5, b7 # ymm6: a14, a8, a10, a12, b14, b8, b10, b12 # ymm7: a15, a9, a11, a13, b15, b9, b11, b13 -def kernel2d(target, output, degree): - label = kernel2d_name(target, degree) +def kernel_2d(target, output, degree): + label = kernel_2d_name(target, degree) output.append(f"{label}:") # vpshufb indexes - if target.extension == SSE41: + if target.extension == SSE2: + assert degree == 1 + elif target.extension == SSE41: + assert degree == 1 output.append(f"movaps xmm14, xmmword ptr [ROT8+rip]") output.append(f"movaps xmm15, xmmword ptr [ROT16+rip]") - if target.extension == AVX2: + elif target.extension == AVX2: + assert degree == 2 output.append(f"vbroadcasti128 ymm14, xmmword ptr [ROT16+rip]") output.append(f"vbroadcasti128 ymm15, xmmword ptr [ROT8+rip]") - if target.extension == AVX512: + elif target.extension == AVX512: + assert degree in (1, 2, 4) if degree == 4: output.append(f"mov {target.scratch32(0)}, 43690") output.append(f"kmovw k3, {target.scratch32(0)}") output.append(f"mov {target.scratch32(0)}, 34952") output.append(f"kmovw k4, {target.scratch32(0)}") + else: + raise NotImplementedError for round_number in range(7): if round_number > 0: # Un-diagonalize and permute before each round except the first. @@ -482,12 +504,12 @@ def compress(target, output): output.append(f".global {label}") output.append(f"{label}:") compress_setup(target, output) - output.append(f"call {kernel2d_name(target, 1)}") + output.append(f"call {kernel_2d_name(target, 1)}") compress_finish(target, output) output.append(target.ret()) -def xof_setup2d(target, output, degree): +def xof_setup_2d(target, output, degree): if target.extension == AVX512: if degree == 1: # state words @@ -641,7 +663,7 @@ def xof_setup2d(target, output, degree): raise NotImplementedError -def xof_stream_finish2d(target, output, degree): +def xof_stream_finish_2d(target, output, degree): if target.extension == AVX512: if degree == 1: output.append(f"vpxor xmm2, xmm2, [{target.arg64(0)}]") @@ -706,7 +728,7 @@ def xof_stream_finish2d(target, output, degree): raise NotImplementedError -def xof_xor_finish2d(target, output, degree): +def xof_xor_finish_2d(target, output, degree): if target.extension == AVX512: if degree == 1: output.append(f"vpxor xmm2, xmm2, [{target.arg64(0)}]") @@ -842,28 +864,558 @@ def xof_xor_finish2d(target, output, degree): raise NotImplementedError +def g_function_3d(target, output, degree, columns, msg_words1, msg_words2): + if target.extension == SSE41: + assert degree == 4 + elif target.extension == AVX2: + assert degree == 8 + elif target.extension == AVX512: + assert degree in (8, 16) + else: + raise NotImplementedError + for (column, m1) in zip(columns, msg_words1): + add_row(target, output, degree, dest=column[0], src=m1) + for column in columns: + add_row(target, output, degree, dest=column[0], src=column[1]) + for column in columns: + xor_row(target, output, degree, dest=column[3], src=column[0]) + for column in columns: + bitrotate_row(target, output, degree, reg=column[3], bits=16) + for column in columns: + add_row(target, output, degree, dest=column[2], src=column[3]) + for column in columns: + xor_row(target, output, degree, dest=column[1], src=column[2]) + for column in columns: + bitrotate_row(target, output, degree, reg=column[1], bits=12) + for (column, m2) in zip(columns, msg_words2): + add_row(target, output, degree, dest=column[0], src=m2) + for column in columns: + add_row(target, output, degree, dest=column[0], src=column[1]) + for column in columns: + xor_row(target, output, degree, dest=column[3], src=column[0]) + for column in columns: + bitrotate_row(target, output, degree, reg=column[3], bits=8) + for column in columns: + add_row(target, output, degree, dest=column[2], src=column[3]) + for column in columns: + xor_row(target, output, degree, dest=column[1], src=column[2]) + for column in columns: + bitrotate_row(target, output, degree, reg=column[1], bits=7) + + +def kernel_3d_name(target, degree): + return f"blake3_{target.extension}_kernel_3d_{degree}" + + +def kernel_3d(target, output, degree): + label = kernel_3d_name(target, degree) + output.append(f"{label}:") + if target.extension == SSE41: + assert degree == 4 + elif target.extension == AVX2: + assert degree == 8 + elif target.extension == AVX512: + assert degree in (8, 16) + else: + raise NotImplementedError + for round_number in range(7): + straight_columns = [ + [0, 4, 8, 12], + [1, 5, 9, 13], + [2, 6, 10, 14], + [3, 7, 11, 15], + ] + msg_words1 = [16 + MESSAGE_SCHEDULE[round_number][i] for i in [0, 2, 4, 6]] + msg_words2 = [16 + MESSAGE_SCHEDULE[round_number][i] for i in [1, 3, 5, 7]] + g_function_3d(target, output, degree, straight_columns, msg_words1, msg_words2) + diagonal_columns = [ + [0, 5, 10, 15], + [1, 6, 11, 12], + [2, 7, 8, 13], + [3, 4, 9, 14], + ] + msg_words1 = [16 + MESSAGE_SCHEDULE[round_number][i] for i in [8, 10, 12, 14]] + msg_words2 = [16 + MESSAGE_SCHEDULE[round_number][i] for i in [9, 11, 13, 15]] + g_function_3d(target, output, degree, diagonal_columns, msg_words1, msg_words2) + # Xor the last two rows into the first two, but don't do the feed forward + # here. That's only done in the XOF case. + for dest in range(8): + xor_row(target, output, degree, dest=dest, src=dest + 8) + output.append(target.ret()) + + +def xof_setup_3d(target, output, degree): + if target.extension == AVX512: + if degree == 16: + # Load vpermi2d indexes into the counter registers. + output.append(f"vmovdqa32 zmm12, zmmword ptr [rip + EVEN_INDEXES]") + output.append(f"vmovdqa32 zmm13, zmmword ptr [rip + ODD_INDEXES]") + # Load the state words. + for i in range(8): + output.append( + f"vpbroadcastd zmm{i}, dword ptr [{target.arg64(0)}+{4*i}]" + ) + # Load the message words. + for i in range(16): + output.append( + f"vpbroadcastd zmm{i+16}, dword ptr [{target.arg64(1)}+{4*i}]" + ) + # Load the 64-bit counter increments into a temporary register. + output.append(f"vmovdqa64 zmm8, zmmword ptr [INCREMENT_3D+rip]") + # Broadcast the counter and add it to the increments. This gives + # the first 8 counter values. + output.append(f"vpbroadcastq zmm9, {target.arg64(2)}") + output.append(f"vpaddq zmm9, zmm9, zmm8") + # Increment the counter and repeat that for the last 8 counter values. + output.append(f"add {target.arg64(2)}, 8") + output.append(f"vpbroadcastq zmm10, {target.arg64(2)}") + output.append(f"vpaddq zmm10, zmm10, zmm8") + # Extract the lower and upper halves of the counter words, using + # the permutation tables loaded above. + output.append(f"vpermi2d zmm12, zmm9, zmm10") + output.append(f"vpermi2d zmm13, zmm9, zmm10") + # Load the block length. + output.append(f"vpbroadcastd zmm14, {target.arg32(3)}") + # Load the domain flags. + output.append(f"vpbroadcastd zmm15, {target.arg32(4)}") + # Load the IV constants. + for i in range(4): + output.append(f"vpbroadcastd zmm{i+8}, dword ptr [BLAKE3_IV+rip+{4*i}]") + else: + raise NotImplementedError + else: + raise NotImplementedError + + +def xof_stream_finish_3d(target, output, degree): + if target.extension == AVX512: + if degree == 16: + # Re-broadcast the input CV and feed it forward into the second half of the state. + output.append(f"vpbroadcastd zmm16, dword ptr [{target.arg64(0)} + 0 * 4]") + output.append(f"vpxord zmm8, zmm8, zmm16") + output.append(f"vpbroadcastd zmm17, dword ptr [{target.arg64(0)} + 1 * 4]") + output.append(f"vpxord zmm9, zmm9, zmm17") + output.append(f"vpbroadcastd zmm18, dword ptr [{target.arg64(0)} + 2 * 4]") + output.append(f"vpxord zmm10, zmm10, zmm18") + output.append(f"vpbroadcastd zmm19, dword ptr [{target.arg64(0)} + 3 * 4]") + output.append(f"vpxord zmm11, zmm11, zmm19") + output.append(f"vpbroadcastd zmm20, dword ptr [{target.arg64(0)} + 4 * 4]") + output.append(f"vpxord zmm12, zmm12, zmm20") + output.append(f"vpbroadcastd zmm21, dword ptr [{target.arg64(0)} + 5 * 4]") + output.append(f"vpxord zmm13, zmm13, zmm21") + output.append(f"vpbroadcastd zmm22, dword ptr [{target.arg64(0)} + 6 * 4]") + output.append(f"vpxord zmm14, zmm14, zmm22") + output.append(f"vpbroadcastd zmm23, dword ptr [{target.arg64(0)} + 7 * 4]") + output.append(f"vpxord zmm15, zmm15, zmm23") + # zmm0-zmm15 now contain the final extended state vectors, transposed. We need to un-transpose + # them before we write them out. As with blake3_avx512_blocks_16, we prefer to avoid expensive + # operations across 128-bit lanes, so we do a couple of interleaving passes and then write out + # 128 bits at a time. + # + # First, interleave 32-bit words. Use zmm16-zmm31 to hold the intermediate results. This + # takes the input vectors like: + # + # a0, b0, c0, d0, e0, f0, g0, h0, i0, j0, k0, l0, m0, n0, o0, p0 + # + # And produces vectors like: + # + # a0, a1, b0, b1, e0, e1, g0, g1, i0, i1, k0, k1, m0, m1, o0, o1 + # + # Then interleave 64-bit words back into zmm0-zmm15, producing vectors like: + # + # a0, a1, a2, a3, e0, e1, e2, e3, i0, i1, i2, i3, m0, m1, m2, m3 + # + # Finally, write out each 128-bit group, unaligned. + output.append(f"vpunpckldq zmm16, zmm0, zmm1") + output.append(f"vpunpckhdq zmm17, zmm0, zmm1") + output.append(f"vpunpckldq zmm18, zmm2, zmm3") + output.append(f"vpunpckhdq zmm19, zmm2, zmm3") + output.append(f"vpunpcklqdq zmm0, zmm16, zmm18") + output.append(f"vmovdqu32 xmmword ptr [{target.arg64(5)} + 0 * 16], xmm0") + output.append( + f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 16 * 16], zmm0, 1" + ) + output.append( + f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 32 * 16], zmm0, 2" + ) + output.append( + f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 48 * 16], zmm0, 3" + ) + output.append(f"vpunpckhqdq zmm1, zmm16, zmm18") + output.append(f"vmovdqu32 xmmword ptr [{target.arg64(5)} + 4 * 16], xmm1") + output.append( + f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 20 * 16], zmm1, 1" + ) + output.append( + f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 36 * 16], zmm1, 2" + ) + output.append( + f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 52 * 16], zmm1, 3" + ) + output.append(f"vpunpcklqdq zmm2, zmm17, zmm19") + output.append(f"vmovdqu32 xmmword ptr [{target.arg64(5)} + 8 * 16], xmm2") + output.append( + f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 24 * 16], zmm2, 1" + ) + output.append( + f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 40 * 16], zmm2, 2" + ) + output.append( + f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 56 * 16], zmm2, 3" + ) + output.append(f"vpunpckhqdq zmm3, zmm17, zmm19") + output.append(f"vmovdqu32 xmmword ptr [{target.arg64(5)} + 12 * 16], xmm3") + output.append( + f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 28 * 16], zmm3, 1" + ) + output.append( + f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 44 * 16], zmm3, 2" + ) + output.append( + f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 60 * 16], zmm3, 3" + ) + output.append(f"vpunpckldq zmm20, zmm4, zmm5") + output.append(f"vpunpckhdq zmm21, zmm4, zmm5") + output.append(f"vpunpckldq zmm22, zmm6, zmm7") + output.append(f"vpunpckhdq zmm23, zmm6, zmm7") + output.append(f"vpunpcklqdq zmm4, zmm20, zmm22") + output.append(f"vmovdqu32 xmmword ptr [{target.arg64(5)} + 1 * 16], xmm4") + output.append( + f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 17 * 16], zmm4, 1" + ) + output.append( + f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 33 * 16], zmm4, 2" + ) + output.append( + f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 49 * 16], zmm4, 3" + ) + output.append(f"vpunpckhqdq zmm5, zmm20, zmm22") + output.append(f"vmovdqu32 xmmword ptr [{target.arg64(5)} + 5 * 16], xmm5") + output.append( + f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 21 * 16], zmm5, 1" + ) + output.append( + f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 37 * 16], zmm5, 2" + ) + output.append( + f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 53 * 16], zmm5, 3" + ) + output.append(f"vpunpcklqdq zmm6, zmm21, zmm23") + output.append(f"vmovdqu32 xmmword ptr [{target.arg64(5)} + 9 * 16], xmm6") + output.append( + f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 25 * 16], zmm6, 1" + ) + output.append( + f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 41 * 16], zmm6, 2" + ) + output.append( + f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 57 * 16], zmm6, 3" + ) + output.append(f"vpunpckhqdq zmm7, zmm21, zmm23") + output.append(f"vmovdqu32 xmmword ptr [{target.arg64(5)} + 13 * 16], xmm7") + output.append( + f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 29 * 16], zmm7, 1" + ) + output.append( + f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 45 * 16], zmm7, 2" + ) + output.append( + f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 61 * 16], zmm7, 3" + ) + output.append(f"vpunpckldq zmm24, zmm8, zmm9") + output.append(f"vpunpckhdq zmm25, zmm8, zmm9") + output.append(f"vpunpckldq zmm26, zmm10, zmm11") + output.append(f"vpunpckhdq zmm27, zmm10, zmm11") + output.append(f"vpunpcklqdq zmm8, zmm24, zmm26") + output.append(f"vmovdqu32 xmmword ptr [{target.arg64(5)} + 2 * 16], xmm8") + output.append( + f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 18 * 16], zmm8, 1" + ) + output.append( + f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 34 * 16], zmm8, 2" + ) + output.append( + f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 50 * 16], zmm8, 3" + ) + output.append(f"vpunpckhqdq zmm9, zmm24, zmm26") + output.append(f"vmovdqu32 xmmword ptr [{target.arg64(5)} + 6 * 16], xmm9") + output.append( + f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 22 * 16], zmm9, 1" + ) + output.append( + f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 38 * 16], zmm9, 2" + ) + output.append( + f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 54 * 16], zmm9, 3" + ) + output.append(f"vpunpcklqdq zmm10, zmm25, zmm27") + output.append(f"vmovdqu32 xmmword ptr [{target.arg64(5)} + 10 * 16], xmm10") + output.append( + f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 26 * 16], zmm10, 1" + ) + output.append( + f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 42 * 16], zmm10, 2" + ) + output.append( + f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 58 * 16], zmm10, 3" + ) + output.append(f"vpunpckhqdq zmm11, zmm25, zmm27") + output.append(f"vmovdqu32 xmmword ptr [{target.arg64(5)} + 14 * 16], xmm11") + output.append( + f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 30 * 16], zmm11, 1" + ) + output.append( + f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 46 * 16], zmm11, 2" + ) + output.append( + f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 62 * 16], zmm11, 3" + ) + output.append(f"vpunpckldq zmm28, zmm12, zmm13") + output.append(f"vpunpckhdq zmm29, zmm12, zmm13") + output.append(f"vpunpckldq zmm30, zmm14, zmm15") + output.append(f"vpunpckhdq zmm31, zmm14, zmm15") + output.append(f"vpunpcklqdq zmm12, zmm28, zmm30") + output.append(f"vmovdqu32 xmmword ptr [{target.arg64(5)} + 3 * 16], xmm12") + output.append( + f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 19 * 16], zmm12, 1" + ) + output.append( + f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 35 * 16], zmm12, 2" + ) + output.append( + f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 51 * 16], zmm12, 3" + ) + output.append(f"vpunpckhqdq zmm13, zmm28, zmm30") + output.append(f"vmovdqu32 xmmword ptr [{target.arg64(5)} + 7 * 16], xmm13") + output.append( + f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 23 * 16], zmm13, 1" + ) + output.append( + f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 39 * 16], zmm13, 2" + ) + output.append( + f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 55 * 16], zmm13, 3" + ) + output.append(f"vpunpcklqdq zmm14, zmm29, zmm31") + output.append(f"vmovdqu32 xmmword ptr [{target.arg64(5)} + 11 * 16], xmm14") + output.append( + f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 27 * 16], zmm14, 1" + ) + output.append( + f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 43 * 16], zmm14, 2" + ) + output.append( + f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 59 * 16], zmm14, 3" + ) + output.append(f"vpunpckhqdq zmm15, zmm29, zmm31") + output.append(f"vmovdqu32 xmmword ptr [{target.arg64(5)} + 15 * 16], xmm15") + output.append( + f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 31 * 16], zmm15, 1" + ) + output.append( + f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 47 * 16], zmm15, 2" + ) + output.append( + f"vextracti32x4 xmmword ptr [{target.arg64(5)} + 63 * 16], zmm15, 3" + ) + else: + raise NotImplementedError + else: + raise NotImplementedError + + +def xof_xor_finish_3d(target, output, degree): + if target.extension == AVX512: + if degree == 16: + # Re-broadcast the input CV and feed it forward into the second half of the state. + output.append(f"vpbroadcastd zmm16, dword ptr [{target.arg64(0)} + 0 * 4]") + output.append(f"vpxord zmm8, zmm8, zmm16") + output.append(f"vpbroadcastd zmm17, dword ptr [{target.arg64(0)} + 1 * 4]") + output.append(f"vpxord zmm9, zmm9, zmm17") + output.append(f"vpbroadcastd zmm18, dword ptr [{target.arg64(0)} + 2 * 4]") + output.append(f"vpxord zmm10, zmm10, zmm18") + output.append(f"vpbroadcastd zmm19, dword ptr [{target.arg64(0)} + 3 * 4]") + output.append(f"vpxord zmm11, zmm11, zmm19") + output.append(f"vpbroadcastd zmm20, dword ptr [{target.arg64(0)} + 4 * 4]") + output.append(f"vpxord zmm12, zmm12, zmm20") + output.append(f"vpbroadcastd zmm21, dword ptr [{target.arg64(0)} + 5 * 4]") + output.append(f"vpxord zmm13, zmm13, zmm21") + output.append(f"vpbroadcastd zmm22, dword ptr [{target.arg64(0)} + 6 * 4]") + output.append(f"vpxord zmm14, zmm14, zmm22") + output.append(f"vpbroadcastd zmm23, dword ptr [{target.arg64(0)} + 7 * 4]") + output.append(f"vpxord zmm15, zmm15, zmm23") + # zmm0-zmm15 now contain the final extended state vectors, transposed. We need to un-transpose + # them before we write them out. Unlike blake3_avx512_xof_stream_16, we do a complete + # un-transpose here, to make the xor step easier. + # + # First interleave 32-bit words. This takes vectors like: + # + # a0, b0, c0, d0, e0, f0, g0, h0, i0, j0, k0, l0, m0, n0, o0, p0 + # + # And produces vectors like: + # + # a0, a1, b0, b1, e0, e1, g0, g1, i0, i1, k0, k1, m0, m1, o0, o1 + output.append(f"vpunpckldq zmm16, zmm0, zmm1") + output.append(f"vpunpckhdq zmm17, zmm0, zmm1") + output.append(f"vpunpckldq zmm18, zmm2, zmm3") + output.append(f"vpunpckhdq zmm19, zmm2, zmm3") + output.append(f"vpunpckldq zmm20, zmm4, zmm5") + output.append(f"vpunpckhdq zmm21, zmm4, zmm5") + output.append(f"vpunpckldq zmm22, zmm6, zmm7") + output.append(f"vpunpckhdq zmm23, zmm6, zmm7") + output.append(f"vpunpckldq zmm24, zmm8, zmm9") + output.append(f"vpunpckhdq zmm25, zmm8, zmm9") + output.append(f"vpunpckldq zmm26, zmm10, zmm11") + output.append(f"vpunpckhdq zmm27, zmm10, zmm11") + output.append(f"vpunpckldq zmm28, zmm12, zmm13") + output.append(f"vpunpckhdq zmm29, zmm12, zmm13") + output.append(f"vpunpckldq zmm30, zmm14, zmm15") + output.append(f"vpunpckhdq zmm31, zmm14, zmm15") + # Then interleave 64-bit words, producing vectors like: + # + # a0, a1, a2, a3, e0, e1, e2, e3, i0, i1, i2, i3, m0, m1, m2, m3 + output.append(f"vpunpcklqdq zmm0, zmm16, zmm18") + output.append(f"vpunpckhqdq zmm1, zmm16, zmm18") + output.append(f"vpunpcklqdq zmm2, zmm17, zmm19") + output.append(f"vpunpckhqdq zmm3, zmm17, zmm19") + output.append(f"vpunpcklqdq zmm4, zmm20, zmm22") + output.append(f"vpunpckhqdq zmm5, zmm20, zmm22") + output.append(f"vpunpcklqdq zmm6, zmm21, zmm23") + output.append(f"vpunpckhqdq zmm7, zmm21, zmm23") + output.append(f"vpunpcklqdq zmm8, zmm24, zmm26") + output.append(f"vpunpckhqdq zmm9, zmm24, zmm26") + output.append(f"vpunpcklqdq zmm10, zmm25, zmm27") + output.append(f"vpunpckhqdq zmm11, zmm25, zmm27") + output.append(f"vpunpcklqdq zmm12, zmm28, zmm30") + output.append(f"vpunpckhqdq zmm13, zmm28, zmm30") + output.append(f"vpunpcklqdq zmm14, zmm29, zmm31") + output.append(f"vpunpckhqdq zmm15, zmm29, zmm31") + # Then interleave 128-bit lanes, producing vectors like: + # + # a0, a1, a2, a3, i0, i1, i2, i3, a4, a5, a6, a7, i4, i5, i6, i7 + output.append( + "vshufi32x4 zmm16, zmm0, zmm4, 0x88" + ) # lo lanes: 0x88 = 0b10001000 = (0, 2, 0, 2) + output.append(f"vshufi32x4 zmm17, zmm1, zmm5, 0x88") + output.append(f"vshufi32x4 zmm18, zmm2, zmm6, 0x88") + output.append(f"vshufi32x4 zmm19, zmm3, zmm7, 0x88") + output.append( + "vshufi32x4 zmm20, zmm0, zmm4, 0xdd" + ) # hi lanes: 0xdd = 0b11011101 = (1, 3, 1, 3) + output.append(f"vshufi32x4 zmm21, zmm1, zmm5, 0xdd") + output.append(f"vshufi32x4 zmm22, zmm2, zmm6, 0xdd") + output.append(f"vshufi32x4 zmm23, zmm3, zmm7, 0xdd") + output.append(f"vshufi32x4 zmm24, zmm8, zmm12, 0x88") # lo lanes + output.append(f"vshufi32x4 zmm25, zmm9, zmm13, 0x88") + output.append(f"vshufi32x4 zmm26, zmm10, zmm14, 0x88") + output.append(f"vshufi32x4 zmm27, zmm11, zmm15, 0x88") + output.append(f"vshufi32x4 zmm28, zmm8, zmm12, 0xdd") # hi lanes + output.append(f"vshufi32x4 zmm29, zmm9, zmm13, 0xdd") + output.append(f"vshufi32x4 zmm30, zmm10, zmm14, 0xdd") + output.append(f"vshufi32x4 zmm31, zmm11, zmm15, 0xdd") + # Finally interleave 128-bit lanes again (the same permutation as the previous pass, but + # different inputs), producing vectors like: + # + # a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15 + output.append(f"vshufi32x4 zmm0, zmm16, zmm24, 0x88") # lo lanes + output.append(f"vshufi32x4 zmm1, zmm17, zmm25, 0x88") + output.append(f"vshufi32x4 zmm2, zmm18, zmm26, 0x88") + output.append(f"vshufi32x4 zmm3, zmm19, zmm27, 0x88") + output.append(f"vshufi32x4 zmm4, zmm20, zmm28, 0x88") + output.append(f"vshufi32x4 zmm5, zmm21, zmm29, 0x88") + output.append(f"vshufi32x4 zmm6, zmm22, zmm30, 0x88") + output.append(f"vshufi32x4 zmm7, zmm23, zmm31, 0x88") + output.append(f"vshufi32x4 zmm8, zmm16, zmm24, 0xdd") # hi lanes + output.append(f"vshufi32x4 zmm9, zmm17, zmm25, 0xdd") + output.append(f"vshufi32x4 zmm10, zmm18, zmm26, 0xdd") + output.append(f"vshufi32x4 zmm11, zmm19, zmm27, 0xdd") + output.append(f"vshufi32x4 zmm12, zmm20, zmm28, 0xdd") + output.append(f"vshufi32x4 zmm13, zmm21, zmm29, 0xdd") + output.append(f"vshufi32x4 zmm14, zmm22, zmm30, 0xdd") + output.append(f"vshufi32x4 zmm15, zmm23, zmm31, 0xdd") + # zmm0-zmm15 now contain the fully un-transposed state words. Load each 64 block on input + # (unaligned), perform the xor, and write out the result (again unaligned). + output.append(f"vmovdqu32 zmm16, zmmword ptr [{target.arg64(5)} + 0 * 64]") + output.append(f"vpxord zmm0, zmm0, zmm16") + output.append(f"vmovdqu32 zmmword ptr [{target.arg64(5)} + 0 * 64], zmm0") + output.append(f"vmovdqu32 zmm17, zmmword ptr [{target.arg64(5)} + 1 * 64]") + output.append(f"vpxord zmm1, zmm1, zmm17") + output.append(f"vmovdqu32 zmmword ptr [{target.arg64(5)} + 1 * 64], zmm1") + output.append(f"vmovdqu32 zmm18, zmmword ptr [{target.arg64(5)} + 2 * 64]") + output.append(f"vpxord zmm2, zmm2, zmm18") + output.append(f"vmovdqu32 zmmword ptr [{target.arg64(5)} + 2 * 64], zmm2") + output.append(f"vmovdqu32 zmm19, zmmword ptr [{target.arg64(5)} + 3 * 64]") + output.append(f"vpxord zmm3, zmm3, zmm19") + output.append(f"vmovdqu32 zmmword ptr [{target.arg64(5)} + 3 * 64], zmm3") + output.append(f"vmovdqu32 zmm20, zmmword ptr [{target.arg64(5)} + 4 * 64]") + output.append(f"vpxord zmm4, zmm4, zmm20") + output.append(f"vmovdqu32 zmmword ptr [{target.arg64(5)} + 4 * 64], zmm4") + output.append(f"vmovdqu32 zmm21, zmmword ptr [{target.arg64(5)} + 5 * 64]") + output.append(f"vpxord zmm5, zmm5, zmm21") + output.append(f"vmovdqu32 zmmword ptr [{target.arg64(5)} + 5 * 64], zmm5") + output.append(f"vmovdqu32 zmm22, zmmword ptr [{target.arg64(5)} + 6 * 64]") + output.append(f"vpxord zmm6, zmm6, zmm22") + output.append(f"vmovdqu32 zmmword ptr [{target.arg64(5)} + 6 * 64], zmm6") + output.append(f"vmovdqu32 zmm23, zmmword ptr [{target.arg64(5)} + 7 * 64]") + output.append(f"vpxord zmm7, zmm7, zmm23") + output.append(f"vmovdqu32 zmmword ptr [{target.arg64(5)} + 7 * 64], zmm7") + output.append(f"vmovdqu32 zmm24, zmmword ptr [{target.arg64(5)} + 8 * 64]") + output.append(f"vpxord zmm8, zmm8, zmm24") + output.append(f"vmovdqu32 zmmword ptr [{target.arg64(5)} + 8 * 64], zmm8") + output.append(f"vmovdqu32 zmm25, zmmword ptr [{target.arg64(5)} + 9 * 64]") + output.append(f"vpxord zmm9, zmm9, zmm25") + output.append(f"vmovdqu32 zmmword ptr [{target.arg64(5)} + 9 * 64], zmm9") + output.append(f"vmovdqu32 zmm26, zmmword ptr [{target.arg64(5)} + 10 * 64]") + output.append(f"vpxord zmm10, zmm10, zmm26") + output.append(f"vmovdqu32 zmmword ptr [{target.arg64(5)} + 10 * 64], zmm10") + output.append(f"vmovdqu32 zmm27, zmmword ptr [{target.arg64(5)} + 11 * 64]") + output.append(f"vpxord zmm11, zmm11, zmm27") + output.append(f"vmovdqu32 zmmword ptr [{target.arg64(5)} + 11 * 64], zmm11") + output.append(f"vmovdqu32 zmm28, zmmword ptr [{target.arg64(5)} + 12 * 64]") + output.append(f"vpxord zmm12, zmm12, zmm28") + output.append(f"vmovdqu32 zmmword ptr [{target.arg64(5)} + 12 * 64], zmm12") + output.append(f"vmovdqu32 zmm29, zmmword ptr [{target.arg64(5)} + 13 * 64]") + output.append(f"vpxord zmm13, zmm13, zmm29") + output.append(f"vmovdqu32 zmmword ptr [{target.arg64(5)} + 13 * 64], zmm13") + output.append(f"vmovdqu32 zmm30, zmmword ptr [{target.arg64(5)} + 14 * 64]") + output.append(f"vpxord zmm14, zmm14, zmm30") + output.append(f"vmovdqu32 zmmword ptr [{target.arg64(5)} + 14 * 64], zmm14") + output.append(f"vmovdqu32 zmm31, zmmword ptr [{target.arg64(5)} + 15 * 64]") + output.append(f"vpxord zmm15, zmm15, zmm31") + output.append(f"vmovdqu32 zmmword ptr [{target.arg64(5)} + 15 * 64], zmm15") + else: + raise NotImplementedError + else: + raise NotImplementedError + + def xof_fn(target, output, degree, xor): variant = "xor" if xor else "stream" - finish_fn_2d = xof_xor_finish2d if xor else xof_stream_finish2d + finish_fn_2d = xof_xor_finish_2d if xor else xof_stream_finish_2d + finish_fn_3d = xof_xor_finish_3d if xor else xof_stream_finish_3d label = f"blake3_{target.extension}_xof_{variant}_{degree}" output.append(f".global {label}") output.append(f"{label}:") if target.extension == AVX512: if degree in (1, 2, 4): - xof_setup2d(target, output, degree) - output.append(f"call {kernel2d_name(target, degree)}") + xof_setup_2d(target, output, degree) + output.append(f"call {kernel_2d_name(target, degree)}") finish_fn_2d(target, output, degree) + elif degree in (8, 16): + xof_setup_3d(target, output, degree) + output.append(f"call {kernel_3d_name(target, degree)}") + finish_fn_3d(target, output, degree) else: raise NotImplementedError elif target.extension == AVX2: assert degree == 2 - xof_setup2d(target, output, degree) - output.append(f"call {kernel2d_name(target, degree)}") + xof_setup_2d(target, output, degree) + output.append(f"call {kernel_2d_name(target, degree)}") finish_fn_2d(target, output, degree) elif target.extension in (SSE41, SSE2): assert degree == 1 - xof_setup2d(target, output, degree) - output.append(f"call {kernel2d_name(target, degree)}") + xof_setup_2d(target, output, degree) + output.append(f"call {kernel_2d_name(target, degree)}") finish_fn_2d(target, output, degree) else: raise NotImplementedError @@ -877,7 +1429,7 @@ def emit_prelude(target, output): def emit_sse2(target, output): target = replace(target, extension=SSE2) - kernel2d(target, output, 1) + kernel_2d(target, output, 1) compress(target, output) xof_fn(target, output, 1, xor=False) xof_fn(target, output, 1, xor=True) @@ -894,7 +1446,7 @@ def emit_sse2(target, output): def emit_sse41(target, output): target = replace(target, extension=SSE41) - kernel2d(target, output, 1) + kernel_2d(target, output, 1) compress(target, output) xof_fn(target, output, 1, xor=False) xof_fn(target, output, 1, xor=True) @@ -902,24 +1454,35 @@ def emit_sse41(target, output): def emit_avx2(target, output): target = replace(target, extension=AVX2) - kernel2d(target, output, 2) + kernel_2d(target, output, 2) xof_fn(target, output, 2, xor=False) xof_fn(target, output, 2, xor=True) def emit_avx512(target, output): target = replace(target, extension=AVX512) - kernel2d(target, output, 1) - kernel2d(target, output, 2) - kernel2d(target, output, 4) + + # degree 1 + kernel_2d(target, output, 1) compress(target, output) xof_fn(target, output, 1, xor=False) xof_fn(target, output, 1, xor=True) + + # degree 2 + kernel_2d(target, output, 2) xof_fn(target, output, 2, xor=False) xof_fn(target, output, 2, xor=True) + + # degree 4 + kernel_2d(target, output, 4) xof_fn(target, output, 4, xor=False) xof_fn(target, output, 4, xor=True) + # degree 16 + kernel_3d(target, output, 16) + xof_fn(target, output, 16, xor=False) + xof_fn(target, output, 16, xor=True) + def emit_footer(target, output): output.append(".balign 16") @@ -942,6 +1505,14 @@ def emit_footer(target, output): output.append(".balign 64") output.append("INCREMENT_2D:") output.append(".quad 0, 0, 1, 0, 2, 0, 3, 0") + output.append("INCREMENT_3D:") + output.append(".quad 0, 1, 2, 3, 4, 5, 6, 7") + + output.append(".balign 64") + output.append("EVEN_INDEXES:") + output.append(".long 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30") + output.append("ODD_INDEXES:") + output.append(".long 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31") def format(output): |
