diff options
| author | Jack O'Connor <[email protected]> | 2022-03-25 12:19:50 -0400 |
|---|---|---|
| committer | Jack O'Connor <[email protected]> | 2022-03-26 11:18:39 -0400 |
| commit | 35ad4ededdbf259c507c49b2e7ac529b43b61671 (patch) | |
| tree | 930f3cb141e14c138319fa67777b0a5dbc01a53d | |
| parent | ea94b544fcdb9cf4120eaf9308e8df937871ad3e (diff) | |
xor_xof variants for the 2d kernel
| -rwxr-xr-x | asm/asm.py | 182 | ||||
| -rw-r--r-- | asm/out.S | 254 | ||||
| -rw-r--r-- | src/kernel.rs | 96 |
3 files changed, 495 insertions, 37 deletions
@@ -659,18 +659,10 @@ def xof_stream_finish2d(target, output, degree): output.append(f"vmovdqu xmmword ptr [{target.arg64(5)} + 1 * 16], xmm1") output.append(f"vmovdqu xmmword ptr [{target.arg64(5)} + 2 * 16], xmm2") output.append(f"vmovdqu xmmword ptr [{target.arg64(5)} + 3 * 16], xmm3") - output.append( - f"vextracti128 xmmword ptr [{target.arg64(5)} + 4 * 16], ymm0, 1" - ) - output.append( - f"vextracti128 xmmword ptr [{target.arg64(5)} + 5 * 16], ymm1, 1" - ) - output.append( - f"vextracti128 xmmword ptr [{target.arg64(5)} + 6 * 16], ymm2, 1" - ) - output.append( - f"vextracti128 xmmword ptr [{target.arg64(5)} + 7 * 16], ymm3, 1" - ) + output.append(f"vextracti128 xmmword ptr [{target.arg64(5)}+4*16], ymm0, 1") + output.append(f"vextracti128 xmmword ptr [{target.arg64(5)}+5*16], ymm1, 1") + output.append(f"vextracti128 xmmword ptr [{target.arg64(5)}+6*16], ymm2, 1") + output.append(f"vextracti128 xmmword ptr [{target.arg64(5)}+7*16], ymm3, 1") elif degree == 4: output.append(f"vbroadcasti32x4 zmm4, xmmword ptr [{target.arg64(0)}]") output.append(f"vpxord zmm2, zmm2, zmm4") @@ -714,27 +706,165 @@ def xof_stream_finish2d(target, output, degree): raise NotImplementedError -def xof_stream(target, output, degree): - label = f"blake3_{target.extension}_xof_stream_{degree}" +def xof_xor_finish2d(target, output, degree): + if target.extension == AVX512: + if degree == 1: + output.append(f"vpxor xmm2, xmm2, [{target.arg64(0)}]") + output.append(f"vpxor xmm3, xmm3, [{target.arg64(0)}+0x10]") + output.append(f"vpxor xmm0, xmm0, [{target.arg64(5)}]") + output.append(f"vpxor xmm1, xmm1, [{target.arg64(5)}+0x10]") + output.append(f"vpxor xmm2, xmm2, [{target.arg64(5)}+0x20]") + output.append(f"vpxor xmm3, xmm3, [{target.arg64(5)}+0x30]") + output.append(f"vmovdqu xmmword ptr [{target.arg64(5)}], xmm0") + output.append(f"vmovdqu xmmword ptr [{target.arg64(5)}+0x10], xmm1") + output.append(f"vmovdqu xmmword ptr [{target.arg64(5)}+0x20], xmm2") + output.append(f"vmovdqu xmmword ptr [{target.arg64(5)}+0x30], xmm3") + elif degree == 2: + output.append(f"vbroadcasti128 ymm4, xmmword ptr [{target.arg64(0)}]") + output.append(f"vpxor ymm2, ymm2, ymm4") + output.append(f"vbroadcasti128 ymm5, xmmword ptr [{target.arg64(0)} + 16]") + output.append(f"vpxor ymm3, ymm3, ymm5") + # Each vector now holds rows from two different states: + # ymm0: a0, a1, a2, a3, b0, b1, b2, b3 + # ymm1: a4, a5, a6, a7, b4, b5, b6, b7 + # ymm2: a8, a9, a10, a11, b8, b9, b10, b11 + # ymm3: a12, a13, a14, a15, b12, b13, b14, b15 + # We want to rearrange the 128-bit lanes like this, so we can load + # destination bytes and XOR them in directly. + # ymm4: a0, a1, a2, a3, a4, a5, a6, a7 + # ymm5: a8, a9, a10, a11, a12, a13, a14, a15 + # ymm6: b0, b1, b2, b3, b4, b5, b6, b7 + # ymm7: b8, b9, b10, b11, b12, b13, b14, b15 + output.append(f"vperm2f128 ymm4, ymm0, ymm1, {0b0010_0000}") # lower 128 + output.append(f"vperm2f128 ymm5, ymm2, ymm3, {0b0010_0000}") + output.append(f"vperm2f128 ymm6, ymm0, ymm1, {0b0011_0001}") # upper 128 + output.append(f"vperm2f128 ymm7, ymm2, ymm3, {0b0011_0001}") + # XOR in the bytes that are already in the destination. + output.append(f"vpxor ymm4, ymm4, ymmword ptr [{target.arg64(5)} + 0 * 32]") + output.append(f"vpxor ymm5, ymm5, ymmword ptr [{target.arg64(5)} + 1 * 32]") + output.append(f"vpxor ymm6, ymm6, ymmword ptr [{target.arg64(5)} + 2 * 32]") + output.append(f"vpxor ymm7, ymm7, ymmword ptr [{target.arg64(5)} + 3 * 32]") + # Write out the XOR results. + output.append(f"vmovdqu ymmword ptr [{target.arg64(5)} + 0 * 32], ymm4") + output.append(f"vmovdqu ymmword ptr [{target.arg64(5)} + 1 * 32], ymm5") + output.append(f"vmovdqu ymmword ptr [{target.arg64(5)} + 2 * 32], ymm6") + output.append(f"vmovdqu ymmword ptr [{target.arg64(5)} + 3 * 32], ymm7") + elif degree == 4: + output.append(f"vbroadcasti32x4 zmm4, xmmword ptr [{target.arg64(0)}]") + output.append(f"vpxord zmm2, zmm2, zmm4") + output.append(f"vbroadcasti32x4 zmm5, xmmword ptr [{target.arg64(0)} + 16]") + output.append(f"vpxord zmm3, zmm3, zmm5") + # Each vector now holds rows from four different states: + # zmm0: a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3, d0, d1, d2, d3 + # zmm1: a4, a5, a6, a7, b4, b5, b6, b7, c4, c5, c6, c7, d4, d5, d6, d7 + # zmm2: a8, a9, a10, a11, b8, b9, b10, b11, c8, c9, c10, c11, d8, d9, d10, d11 + # zmm3: a12, a13, a14, a15, b12, b13, b14, b15, c12, c13, c14, c15, d12, d13, d14, d15 + # We want to rearrange the 128-bit lanes like this, so we can load + # destination bytes and XOR them in directly. + # zmm0: a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15 + # zmm1: b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15 + # zmm2: c0, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15 + # zmm3: d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10, d11, d12, d13, d14, d15 + # + # This first interleaving of 256-bit lanes produces vectors like: + # zmm4: a0, a1, a2, a3, b0, b1, b2, b3, a4, a5, a6, a7, b4, b5, b6, b7 + output.append(f"vshufi32x4 zmm4, zmm0, zmm1, {0b0100_0100}") # low 256 + output.append(f"vshufi32x4 zmm5, zmm0, zmm1, {0b1110_1110}") # high 256 + output.append(f"vshufi32x4 zmm6, zmm2, zmm3, {0b0100_0100}") + output.append(f"vshufi32x4 zmm7, zmm2, zmm3, {0b1110_1110}") + # And this second interleaving of 128-bit lanes within each 256-bit + # lane produces the vectors we want. + output.append(f"vshufi32x4 zmm0, zmm4, zmm6, {0b1000_1000}") # low 128 + output.append(f"vshufi32x4 zmm1, zmm4, zmm6, {0b1101_1101}") # high 128 + output.append(f"vshufi32x4 zmm2, zmm5, zmm7, {0b1000_1000}") + output.append(f"vshufi32x4 zmm3, zmm5, zmm7, {0b1101_1101}") + # XOR in the bytes that are already in the destination. + output.append(f"vpxord zmm0, zmm0, zmmword ptr [{target.arg64(5)} + 0*64]") + output.append(f"vpxord zmm1, zmm1, zmmword ptr [{target.arg64(5)} + 1*64]") + output.append(f"vpxord zmm2, zmm2, zmmword ptr [{target.arg64(5)} + 2*64]") + output.append(f"vpxord zmm3, zmm3, zmmword ptr [{target.arg64(5)} + 3*64]") + # Write out the XOR results. + output.append(f"vmovdqu32 zmmword ptr [{target.arg64(5)} + 0*64], zmm0") + output.append(f"vmovdqu32 zmmword ptr [{target.arg64(5)} + 1*64], zmm1") + output.append(f"vmovdqu32 zmmword ptr [{target.arg64(5)} + 2*64], zmm2") + output.append(f"vmovdqu32 zmmword ptr [{target.arg64(5)} + 3*64], zmm3") + else: + raise NotImplementedError + elif target.extension == AVX2: + output.append(f"vbroadcasti128 ymm4, xmmword ptr [{target.arg64(0)}]") + output.append(f"vpxor ymm2, ymm2, ymm4") + output.append(f"vbroadcasti128 ymm5, xmmword ptr [{target.arg64(0)} + 16]") + output.append(f"vpxor ymm3, ymm3, ymm5") + # Each vector now holds rows from two different states: + # ymm0: a0, a1, a2, a3, b0, b1, b2, b3 + # ymm1: a4, a5, a6, a7, b4, b5, b6, b7 + # ymm2: a8, a9, a10, a11, b8, b9, b10, b11 + # ymm3: a12, a13, a14, a15, b12, b13, b14, b15 + # We want to rearrange the 128-bit lanes like this, so we can load + # destination bytes and XOR them in directly. + # ymm4: a0, a1, a2, a3, a4, a5, a6, a7 + # ymm5: a8, a9, a10, a11, a12, a13, a14, a15 + # ymm6: b0, b1, b2, b3, b4, b5, b6, b7 + # ymm7: b8, b9, b10, b11, b12, b13, b14, b15 + output.append(f"vperm2f128 ymm4, ymm0, ymm1, {0b0010_0000}") # lower 128 + output.append(f"vperm2f128 ymm5, ymm2, ymm3, {0b0010_0000}") + output.append(f"vperm2f128 ymm6, ymm0, ymm1, {0b0011_0001}") # upper 128 + output.append(f"vperm2f128 ymm7, ymm2, ymm3, {0b0011_0001}") + # XOR in the bytes that are already in the destination. + output.append(f"vpxor ymm4, ymm4, ymmword ptr [{target.arg64(5)} + 0 * 32]") + output.append(f"vpxor ymm5, ymm5, ymmword ptr [{target.arg64(5)} + 1 * 32]") + output.append(f"vpxor ymm6, ymm6, ymmword ptr [{target.arg64(5)} + 2 * 32]") + output.append(f"vpxor ymm7, ymm7, ymmword ptr [{target.arg64(5)} + 3 * 32]") + # Write out the XOR results. + output.append(f"vmovdqu ymmword ptr [{target.arg64(5)} + 0 * 32], ymm4") + output.append(f"vmovdqu ymmword ptr [{target.arg64(5)} + 1 * 32], ymm5") + output.append(f"vmovdqu ymmword ptr [{target.arg64(5)} + 2 * 32], ymm6") + output.append(f"vmovdqu ymmword ptr [{target.arg64(5)} + 3 * 32], ymm7") + elif target.extension in (SSE41, SSE2): + assert degree == 1 + output.append(f"movdqu xmm4, xmmword ptr [{target.arg64(0)}]") + output.append(f"movdqu xmm5, xmmword ptr [{target.arg64(0)}+0x10]") + output.append(f"pxor xmm2, xmm4") + output.append(f"pxor xmm3, xmm5") + output.append(f"movdqu xmm4, [{target.arg64(5)}]") + output.append(f"movdqu xmm5, [{target.arg64(5)}+0x10]") + output.append(f"movdqu xmm6, [{target.arg64(5)}+0x20]") + output.append(f"movdqu xmm7, [{target.arg64(5)}+0x30]") + output.append(f"pxor xmm0, xmm4") + output.append(f"pxor xmm1, xmm5") + output.append(f"pxor xmm2, xmm6") + output.append(f"pxor xmm3, xmm7") + output.append(f"movups xmmword ptr [{target.arg64(5)}], xmm0") + output.append(f"movups xmmword ptr [{target.arg64(5)}+0x10], xmm1") + output.append(f"movups xmmword ptr [{target.arg64(5)}+0x20], xmm2") + output.append(f"movups xmmword ptr [{target.arg64(5)}+0x30], xmm3") + else: + raise NotImplementedError + + +def xof_fn(target, output, degree, xor): + variant = "xor" if xor else "stream" + finish_fn_2d = xof_xor_finish2d if xor else xof_stream_finish2d + label = f"blake3_{target.extension}_xof_{variant}_{degree}" output.append(f".global {label}") output.append(f"{label}:") if target.extension == AVX512: if degree in (1, 2, 4): xof_setup2d(target, output, degree) output.append(f"call {kernel2d_name(target, degree)}") - xof_stream_finish2d(target, output, degree) + finish_fn_2d(target, output, degree) else: raise NotImplementedError elif target.extension == AVX2: assert degree == 2 xof_setup2d(target, output, degree) output.append(f"call {kernel2d_name(target, degree)}") - xof_stream_finish2d(target, output, degree) + finish_fn_2d(target, output, degree) elif target.extension in (SSE41, SSE2): assert degree == 1 xof_setup2d(target, output, degree) output.append(f"call {kernel2d_name(target, degree)}") - xof_stream_finish2d(target, output, degree) + finish_fn_2d(target, output, degree) else: raise NotImplementedError output.append(target.ret()) @@ -749,7 +879,8 @@ def emit_sse2(target, output): target = replace(target, extension=SSE2) kernel2d(target, output, 1) compress(target, output) - xof_stream(target, output, 1) + xof_fn(target, output, 1, xor=False) + xof_fn(target, output, 1, xor=True) output.append(".balign 16") output.append("PBLENDW_0x33_MASK:") output.append(".long 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000") @@ -765,13 +896,15 @@ def emit_sse41(target, output): target = replace(target, extension=SSE41) kernel2d(target, output, 1) compress(target, output) - xof_stream(target, output, 1) + xof_fn(target, output, 1, xor=False) + xof_fn(target, output, 1, xor=True) def emit_avx2(target, output): target = replace(target, extension=AVX2) kernel2d(target, output, 2) - xof_stream(target, output, 2) + xof_fn(target, output, 2, xor=False) + xof_fn(target, output, 2, xor=True) def emit_avx512(target, output): @@ -780,9 +913,12 @@ def emit_avx512(target, output): kernel2d(target, output, 2) kernel2d(target, output, 4) compress(target, output) - xof_stream(target, output, 1) - xof_stream(target, output, 2) - xof_stream(target, output, 4) + xof_fn(target, output, 1, xor=False) + xof_fn(target, output, 1, xor=True) + xof_fn(target, output, 2, xor=False) + xof_fn(target, output, 2, xor=True) + xof_fn(target, output, 4, xor=False) + xof_fn(target, output, 4, xor=True) def emit_footer(target, output): @@ -569,6 +569,48 @@ blake3_sse2_xof_stream_1: movups xmmword ptr [r9+0x20], xmm2 movups xmmword ptr [r9+0x30], xmm3 ret +.global blake3_sse2_xof_xor_1 +blake3_sse2_xof_xor_1: + movups xmm0, xmmword ptr [rdi] + movups xmm1, xmmword ptr [rdi+0x10] + movaps xmm2, xmmword ptr [BLAKE3_IV+rip] + shl r8, 32 + mov ecx, ecx + or rcx, r8 + vmovq xmm3, rdx + vmovq xmm4, rcx + punpcklqdq xmm3, xmm4 + movups xmm4, xmmword ptr [rsi] + movups xmm5, xmmword ptr [rsi+0x10] + movaps xmm8, xmm4 + shufps xmm4, xmm5, 136 + shufps xmm8, xmm5, 221 + movaps xmm5, xmm8 + movups xmm6, xmmword ptr [rsi+0x20] + movups xmm7, xmmword ptr [rsi+0x30] + movaps xmm8, xmm6 + shufps xmm6, xmm7, 136 + pshufd xmm6, xmm6, 0x93 + shufps xmm8, xmm7, 221 + pshufd xmm7, xmm8, 0x93 + call blake3_sse2_kernel2d_1 + movdqu xmm4, xmmword ptr [rdi] + movdqu xmm5, xmmword ptr [rdi+0x10] + pxor xmm2, xmm4 + pxor xmm3, xmm5 + movdqu xmm4, [r9] + movdqu xmm5, [r9+0x10] + movdqu xmm6, [r9+0x20] + movdqu xmm7, [r9+0x30] + pxor xmm0, xmm4 + pxor xmm1, xmm5 + pxor xmm2, xmm6 + pxor xmm3, xmm7 + movups xmmword ptr [r9], xmm0 + movups xmmword ptr [r9+0x10], xmm1 + movups xmmword ptr [r9+0x20], xmm2 + movups xmmword ptr [r9+0x30], xmm3 + ret .balign 16 PBLENDW_0x33_MASK: .long 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000 @@ -1064,6 +1106,48 @@ blake3_sse41_xof_stream_1: movups xmmword ptr [r9+0x20], xmm2 movups xmmword ptr [r9+0x30], xmm3 ret +.global blake3_sse41_xof_xor_1 +blake3_sse41_xof_xor_1: + movups xmm0, xmmword ptr [rdi] + movups xmm1, xmmword ptr [rdi+0x10] + movaps xmm2, xmmword ptr [BLAKE3_IV+rip] + shl r8, 32 + mov ecx, ecx + or rcx, r8 + vmovq xmm3, rdx + vmovq xmm4, rcx + punpcklqdq xmm3, xmm4 + movups xmm4, xmmword ptr [rsi] + movups xmm5, xmmword ptr [rsi+0x10] + movaps xmm8, xmm4 + shufps xmm4, xmm5, 136 + shufps xmm8, xmm5, 221 + movaps xmm5, xmm8 + movups xmm6, xmmword ptr [rsi+0x20] + movups xmm7, xmmword ptr [rsi+0x30] + movaps xmm8, xmm6 + shufps xmm6, xmm7, 136 + pshufd xmm6, xmm6, 0x93 + shufps xmm8, xmm7, 221 + pshufd xmm7, xmm8, 0x93 + call blake3_sse41_kernel2d_1 + movdqu xmm4, xmmword ptr [rdi] + movdqu xmm5, xmmword ptr [rdi+0x10] + pxor xmm2, xmm4 + pxor xmm3, xmm5 + movdqu xmm4, [r9] + movdqu xmm5, [r9+0x10] + movdqu xmm6, [r9+0x20] + movdqu xmm7, [r9+0x30] + pxor xmm0, xmm4 + pxor xmm1, xmm5 + pxor xmm2, xmm6 + pxor xmm3, xmm7 + movups xmmword ptr [r9], xmm0 + movups xmmword ptr [r9+0x10], xmm1 + movups xmmword ptr [r9+0x20], xmm2 + movups xmmword ptr [r9+0x30], xmm3 + ret blake3_avx2_kernel2d_2: vbroadcasti128 ymm14, xmmword ptr [ROT16+rip] vbroadcasti128 ymm15, xmmword ptr [ROT8+rip] @@ -1479,6 +1563,47 @@ blake3_avx2_xof_stream_2: vextracti128 xmmword ptr [r9 + 6 * 16], ymm2, 1 vextracti128 xmmword ptr [r9 + 7 * 16], ymm3, 1 ret +.global blake3_avx2_xof_xor_2 +blake3_avx2_xof_xor_2: + vbroadcasti128 ymm0, xmmword ptr [rdi] + vbroadcasti128 ymm1, xmmword ptr [rdi+0x10] + vmovdqa ymm4, ymmword ptr [INCREMENT_2D+rip] + vbroadcasti128 ymm2, xmmword ptr [BLAKE3_IV+rip] + vpbroadcastq ymm5, rdx + vpaddq ymm6, ymm4, ymm5 + shl r8, 32 + mov ecx, ecx + or rcx, r8 + vpbroadcastq ymm7, rcx + vpblendd ymm3, ymm6, ymm7, 0xCC + vbroadcasti128 ymm8, xmmword ptr [rsi] + vbroadcasti128 ymm9, xmmword ptr [rsi+0x10] + vshufps ymm4, ymm8, ymm9, 136 + vshufps ymm5, ymm8, ymm9, 221 + vbroadcasti128 ymm8, xmmword ptr [rsi+0x20] + vbroadcasti128 ymm9, xmmword ptr [rsi+0x30] + vshufps ymm6, ymm8, ymm9, 136 + vshufps ymm7, ymm8, ymm9, 221 + vpshufd ymm6, ymm6, 0x93 + vpshufd ymm7, ymm7, 0x93 + call blake3_avx2_kernel2d_2 + vbroadcasti128 ymm4, xmmword ptr [rdi] + vpxor ymm2, ymm2, ymm4 + vbroadcasti128 ymm5, xmmword ptr [rdi + 16] + vpxor ymm3, ymm3, ymm5 + vperm2f128 ymm4, ymm0, ymm1, 32 + vperm2f128 ymm5, ymm2, ymm3, 32 + vperm2f128 ymm6, ymm0, ymm1, 49 + vperm2f128 ymm7, ymm2, ymm3, 49 + vpxor ymm4, ymm4, ymmword ptr [r9 + 0 * 32] + vpxor ymm5, ymm5, ymmword ptr [r9 + 1 * 32] + vpxor ymm6, ymm6, ymmword ptr [r9 + 2 * 32] + vpxor ymm7, ymm7, ymmword ptr [r9 + 3 * 32] + vmovdqu ymmword ptr [r9 + 0 * 32], ymm4 + vmovdqu ymmword ptr [r9 + 1 * 32], ymm5 + vmovdqu ymmword ptr [r9 + 2 * 32], ymm6 + vmovdqu ymmword ptr [r9 + 3 * 32], ymm7 + ret blake3_avx512_kernel2d_1: vpaddd xmm0, xmm0, xmm4 vpaddd xmm0, xmm0, xmm1 @@ -2497,6 +2622,39 @@ blake3_avx512_xof_stream_1: vmovdqu xmmword ptr [r9+0x20], xmm2 vmovdqu xmmword ptr [r9+0x30], xmm3 ret +.global blake3_avx512_xof_xor_1 +blake3_avx512_xof_xor_1: + vmovdqu xmm0, xmmword ptr [rdi] + vmovdqu xmm1, xmmword ptr [rdi+0x10] + shl r8, 32 + mov ecx, ecx + or rcx, r8 + vmovq xmm3, rdx + vmovq xmm4, rcx + vpunpcklqdq xmm3, xmm3, xmm4 + vmovaps xmm2, xmmword ptr [BLAKE3_IV+rip] + vmovups xmm8, xmmword ptr [rsi] + vmovups xmm9, xmmword ptr [rsi+0x10] + vshufps xmm4, xmm8, xmm9, 136 + vshufps xmm5, xmm8, xmm9, 221 + vmovups xmm8, xmmword ptr [rsi+0x20] + vmovups xmm9, xmmword ptr [rsi+0x30] + vshufps xmm6, xmm8, xmm9, 136 + vshufps xmm7, xmm8, xmm9, 221 + vpshufd xmm6, xmm6, 0x93 + vpshufd xmm7, xmm7, 0x93 + call blake3_avx512_kernel2d_1 + vpxor xmm2, xmm2, [rdi] + vpxor xmm3, xmm3, [rdi+0x10] + vpxor xmm0, xmm0, [r9] + vpxor xmm1, xmm1, [r9+0x10] + vpxor xmm2, xmm2, [r9+0x20] + vpxor xmm3, xmm3, [r9+0x30] + vmovdqu xmmword ptr [r9], xmm0 + vmovdqu xmmword ptr [r9+0x10], xmm1 + vmovdqu xmmword ptr [r9+0x20], xmm2 + vmovdqu xmmword ptr [r9+0x30], xmm3 + ret .global blake3_avx512_xof_stream_2 blake3_avx512_xof_stream_2: vbroadcasti128 ymm0, xmmword ptr [rdi] @@ -2529,10 +2687,51 @@ blake3_avx512_xof_stream_2: vmovdqu xmmword ptr [r9 + 1 * 16], xmm1 vmovdqu xmmword ptr [r9 + 2 * 16], xmm2 vmovdqu xmmword ptr [r9 + 3 * 16], xmm3 - vextracti128 xmmword ptr [r9 + 4 * 16], ymm0, 1 - vextracti128 xmmword ptr [r9 + 5 * 16], ymm1, 1 - vextracti128 xmmword ptr [r9 + 6 * 16], ymm2, 1 - vextracti128 xmmword ptr [r9 + 7 * 16], ymm3, 1 + vextracti128 xmmword ptr [r9+4*16], ymm0, 1 + vextracti128 xmmword ptr [r9+5*16], ymm1, 1 + vextracti128 xmmword ptr [r9+6*16], ymm2, 1 + vextracti128 xmmword ptr [r9+7*16], ymm3, 1 + ret +.global blake3_avx512_xof_xor_2 +blake3_avx512_xof_xor_2: + vbroadcasti128 ymm0, xmmword ptr [rdi] + vbroadcasti128 ymm1, xmmword ptr [rdi+0x10] + vmovdqa ymm4, ymmword ptr [INCREMENT_2D+rip] + vbroadcasti128 ymm2, xmmword ptr [BLAKE3_IV+rip] + vpbroadcastq ymm5, rdx + vpaddq ymm6, ymm4, ymm5 + shl r8, 32 + mov ecx, ecx + or rcx, r8 + vpbroadcastq ymm7, rcx + vpblendd ymm3, ymm6, ymm7, 0xCC + vbroadcasti128 ymm8, xmmword ptr [rsi] + vbroadcasti128 ymm9, xmmword ptr [rsi+0x10] + vshufps ymm4, ymm8, ymm9, 136 + vshufps ymm5, ymm8, ymm9, 221 + vbroadcasti128 ymm8, xmmword ptr [rsi+0x20] + vbroadcasti128 ymm9, xmmword ptr [rsi+0x30] + vshufps ymm6, ymm8, ymm9, 136 + vshufps ymm7, ymm8, ymm9, 221 + vpshufd ymm6, ymm6, 0x93 + vpshufd ymm7, ymm7, 0x93 + call blake3_avx512_kernel2d_2 + vbroadcasti128 ymm4, xmmword ptr [rdi] + vpxor ymm2, ymm2, ymm4 + vbroadcasti128 ymm5, xmmword ptr [rdi + 16] + vpxor ymm3, ymm3, ymm5 + vperm2f128 ymm4, ymm0, ymm1, 32 + vperm2f128 ymm5, ymm2, ymm3, 32 + vperm2f128 ymm6, ymm0, ymm1, 49 + vperm2f128 ymm7, ymm2, ymm3, 49 + vpxor ymm4, ymm4, ymmword ptr [r9 + 0 * 32] + vpxor ymm5, ymm5, ymmword ptr [r9 + 1 * 32] + vpxor ymm6, ymm6, ymmword ptr [r9 + 2 * 32] + vpxor ymm7, ymm7, ymmword ptr [r9 + 3 * 32] + vmovdqu ymmword ptr [r9 + 0 * 32], ymm4 + vmovdqu ymmword ptr [r9 + 1 * 32], ymm5 + vmovdqu ymmword ptr [r9 + 2 * 32], ymm6 + vmovdqu ymmword ptr [r9 + 3 * 32], ymm7 ret .global blake3_avx512_xof_stream_4 blake3_avx512_xof_stream_4: @@ -2581,6 +2780,53 @@ blake3_avx512_xof_stream_4: vextracti32x4 xmmword ptr [r9 + 14 * 16], zmm2, 3 vextracti32x4 xmmword ptr [r9 + 15 * 16], zmm3, 3 ret +.global blake3_avx512_xof_xor_4 +blake3_avx512_xof_xor_4: + vbroadcasti32x4 zmm0, xmmword ptr [rdi] + vbroadcasti32x4 zmm1, xmmword ptr [rdi+0x10] + vmovdqa32 zmm4, zmmword ptr [INCREMENT_2D+rip] + vbroadcasti32x4 zmm2, xmmword ptr [BLAKE3_IV+rip] + vpbroadcastq zmm5, rdx + vpaddq zmm6, zmm4, zmm5 + shl r8, 32 + mov ecx, ecx + or rcx, r8 + vpbroadcastq zmm7, rcx + mov eax, 0xAA + kmovw k2, eax + vpblendmq zmm3 {k2}, zmm6, zmm7 + vbroadcasti32x4 zmm8, xmmword ptr [rsi] + vbroadcasti32x4 zmm9, xmmword ptr [rsi+0x10] + vshufps zmm4, zmm8, zmm9, 136 + vshufps zmm5, zmm8, zmm9, 221 + vbroadcasti32x4 zmm8, xmmword ptr [rsi+0x20] + vbroadcasti32x4 zmm9, xmmword ptr [rsi+0x30] + vshufps zmm6, zmm8, zmm9, 136 + vshufps zmm7, zmm8, zmm9, 221 + vpshufd zmm6, zmm6, 0x93 + vpshufd zmm7, zmm7, 0x93 + call blake3_avx512_kernel2d_4 + vbroadcasti32x4 zmm4, xmmword ptr [rdi] + vpxord zmm2, zmm2, zmm4 + vbroadcasti32x4 zmm5, xmmword ptr [rdi + 16] + vpxord zmm3, zmm3, zmm5 + vshufi32x4 zmm4, zmm0, zmm1, 68 + vshufi32x4 zmm5, zmm0, zmm1, 238 + vshufi32x4 zmm6, zmm2, zmm3, 68 + vshufi32x4 zmm7, zmm2, zmm3, 238 + vshufi32x4 zmm0, zmm4, zmm6, 136 + vshufi32x4 zmm1, zmm4, zmm6, 221 + vshufi32x4 zmm2, zmm5, zmm7, 136 + vshufi32x4 zmm3, zmm5, zmm7, 221 + vpxord zmm0, zmm0, zmmword ptr [r9 + 0*64] + vpxord zmm1, zmm1, zmmword ptr [r9 + 1*64] + vpxord zmm2, zmm2, zmmword ptr [r9 + 2*64] + vpxord zmm3, zmm3, zmmword ptr [r9 + 3*64] + vmovdqu32 zmmword ptr [r9 + 0*64], zmm0 + vmovdqu32 zmmword ptr [r9 + 1*64], zmm1 + vmovdqu32 zmmword ptr [r9 + 2*64], zmm2 + vmovdqu32 zmmword ptr [r9 + 3*64], zmm3 + ret .balign 16 BLAKE3_IV: BLAKE3_IV_0: diff --git a/src/kernel.rs b/src/kernel.rs index f17db21..cdcb25a 100644 --- a/src/kernel.rs +++ b/src/kernel.rs @@ -73,6 +73,54 @@ extern "C" { flags: u32, out: *mut [u8; 64 * 4], ); + pub fn blake3_sse2_xof_xor_1( + cv: &[u32; 8], + block: &[u8; 64], + counter: u64, + block_len: u32, + flags: u32, + out: &mut [u8; 64], + ); + pub fn blake3_sse41_xof_xor_1( + cv: &[u32; 8], + block: &[u8; 64], + counter: u64, + block_len: u32, + flags: u32, + out: &mut [u8; 64], + ); + pub fn blake3_avx512_xof_xor_1( + cv: &[u32; 8], + block: &[u8; 64], + counter: u64, + block_len: u32, + flags: u32, + out: &mut [u8; 64], + ); + pub fn blake3_avx2_xof_xor_2( + cv: &[u32; 8], + block: &[u8; 64], + counter: u64, + block_len: u32, + flags: u32, + out: &mut [u8; 64 * 2], + ); + pub fn blake3_avx512_xof_xor_2( + cv: &[u32; 8], + block: &[u8; 64], + counter: u64, + block_len: u32, + flags: u32, + out: &mut [u8; 64 * 2], + ); + pub fn blake3_avx512_xof_xor_4( + cv: &[u32; 8], + block: &[u8; 64], + counter: u64, + block_len: u32, + flags: u32, + out: &mut [u8; 64 * 4], + ); } pub type CompressionFn = @@ -87,6 +135,15 @@ pub type XofStreamFn<const N: usize> = unsafe extern "C" fn( out: *mut [u8; N], ); +pub type XofXorFn<const N: usize> = unsafe extern "C" fn( + cv: &[u32; 8], + block: &[u8; 64], + counter: u64, + block_len: u32, + flags: u32, + out: &mut [u8; N], +); + #[cfg(test)] mod test { use super::*; @@ -142,13 +199,14 @@ mod test { test_compression_function(blake3_avx512_compress); } - fn test_xof_function<const N: usize>(f: XofStreamFn<N>) { + fn test_xof_functions<const N: usize>(stream_fn: XofStreamFn<N>, xor_fn: XofXorFn<N>) { let mut block = [0; 64]; let block_len = 53; crate::test::paint_test_input(&mut block[..block_len]); let counter = u32::MAX as u64; let flags = crate::CHUNK_START | crate::CHUNK_END | crate::ROOT; + // First compute the expected stream. let mut expected = [0; N]; let mut incrementing_counter = counter; let mut i = 0; @@ -167,9 +225,11 @@ mod test { } assert_eq!(incrementing_counter, counter + N as u64 / 64); - let mut found = [0; N]; + // And compare that to the stream under test. The 0x42 bytes are there to make sure we + // overwrite them. + let mut found = [0x42; N]; unsafe { - f( + stream_fn( crate::IV, &block, counter, @@ -178,8 +238,24 @@ mod test { &mut found, ); } - assert_eq!(expected, found); + + // XOR 0x99 bytes into the found stream. Then run the xof_xor variant on that stream again. + // This should cancel out the original stream, leaving only the 0x99 bytes. + for b in &mut found { + *b ^= 0x99; + } + unsafe { + xor_fn( + crate::IV, + &block, + counter, + block_len as u32, + flags as u32, + &mut found, + ); + } + assert_eq!([0x99; N], found); } #[test] @@ -188,7 +264,7 @@ mod test { if !is_x86_feature_detected!("sse2") { return; } - test_xof_function(blake3_sse2_xof_stream_1); + test_xof_functions(blake3_sse2_xof_stream_1, blake3_sse2_xof_xor_1); } #[test] @@ -197,7 +273,7 @@ mod test { if !is_x86_feature_detected!("sse4.1") { return; } - test_xof_function(blake3_sse41_xof_stream_1); + test_xof_functions(blake3_sse41_xof_stream_1, blake3_sse41_xof_xor_1); } #[test] @@ -206,7 +282,7 @@ mod test { if !is_x86_feature_detected!("avx512f") || !is_x86_feature_detected!("avx512vl") { return; } - test_xof_function(blake3_avx512_xof_stream_1); + test_xof_functions(blake3_avx512_xof_stream_1, blake3_avx512_xof_xor_1); } #[test] @@ -215,7 +291,7 @@ mod test { if !is_x86_feature_detected!("avx2") { return; } - test_xof_function(blake3_avx2_xof_stream_2); + test_xof_functions(blake3_avx2_xof_stream_2, blake3_avx2_xof_xor_2); } #[test] @@ -224,7 +300,7 @@ mod test { if !is_x86_feature_detected!("avx512f") || !is_x86_feature_detected!("avx512vl") { return; } - test_xof_function(blake3_avx512_xof_stream_2); + test_xof_functions(blake3_avx512_xof_stream_2, blake3_avx512_xof_xor_2); } #[test] @@ -233,7 +309,7 @@ mod test { if !is_x86_feature_detected!("avx512f") || !is_x86_feature_detected!("avx512vl") { return; } - test_xof_function(blake3_avx512_xof_stream_4); + test_xof_functions(blake3_avx512_xof_stream_4, blake3_avx512_xof_xor_4); } } |
