aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJack O'Connor <[email protected]>2022-03-25 12:19:50 -0400
committerJack O'Connor <[email protected]>2022-03-26 11:18:39 -0400
commit35ad4ededdbf259c507c49b2e7ac529b43b61671 (patch)
tree930f3cb141e14c138319fa67777b0a5dbc01a53d
parentea94b544fcdb9cf4120eaf9308e8df937871ad3e (diff)
xor_xof variants for the 2d kernel
-rwxr-xr-xasm/asm.py182
-rw-r--r--asm/out.S254
-rw-r--r--src/kernel.rs96
3 files changed, 495 insertions, 37 deletions
diff --git a/asm/asm.py b/asm/asm.py
index de5b89f..75c7012 100755
--- a/asm/asm.py
+++ b/asm/asm.py
@@ -659,18 +659,10 @@ def xof_stream_finish2d(target, output, degree):
output.append(f"vmovdqu xmmword ptr [{target.arg64(5)} + 1 * 16], xmm1")
output.append(f"vmovdqu xmmword ptr [{target.arg64(5)} + 2 * 16], xmm2")
output.append(f"vmovdqu xmmword ptr [{target.arg64(5)} + 3 * 16], xmm3")
- output.append(
- f"vextracti128 xmmword ptr [{target.arg64(5)} + 4 * 16], ymm0, 1"
- )
- output.append(
- f"vextracti128 xmmword ptr [{target.arg64(5)} + 5 * 16], ymm1, 1"
- )
- output.append(
- f"vextracti128 xmmword ptr [{target.arg64(5)} + 6 * 16], ymm2, 1"
- )
- output.append(
- f"vextracti128 xmmword ptr [{target.arg64(5)} + 7 * 16], ymm3, 1"
- )
+ output.append(f"vextracti128 xmmword ptr [{target.arg64(5)}+4*16], ymm0, 1")
+ output.append(f"vextracti128 xmmword ptr [{target.arg64(5)}+5*16], ymm1, 1")
+ output.append(f"vextracti128 xmmword ptr [{target.arg64(5)}+6*16], ymm2, 1")
+ output.append(f"vextracti128 xmmword ptr [{target.arg64(5)}+7*16], ymm3, 1")
elif degree == 4:
output.append(f"vbroadcasti32x4 zmm4, xmmword ptr [{target.arg64(0)}]")
output.append(f"vpxord zmm2, zmm2, zmm4")
@@ -714,27 +706,165 @@ def xof_stream_finish2d(target, output, degree):
raise NotImplementedError
-def xof_stream(target, output, degree):
- label = f"blake3_{target.extension}_xof_stream_{degree}"
+def xof_xor_finish2d(target, output, degree):
+ if target.extension == AVX512:
+ if degree == 1:
+ output.append(f"vpxor xmm2, xmm2, [{target.arg64(0)}]")
+ output.append(f"vpxor xmm3, xmm3, [{target.arg64(0)}+0x10]")
+ output.append(f"vpxor xmm0, xmm0, [{target.arg64(5)}]")
+ output.append(f"vpxor xmm1, xmm1, [{target.arg64(5)}+0x10]")
+ output.append(f"vpxor xmm2, xmm2, [{target.arg64(5)}+0x20]")
+ output.append(f"vpxor xmm3, xmm3, [{target.arg64(5)}+0x30]")
+ output.append(f"vmovdqu xmmword ptr [{target.arg64(5)}], xmm0")
+ output.append(f"vmovdqu xmmword ptr [{target.arg64(5)}+0x10], xmm1")
+ output.append(f"vmovdqu xmmword ptr [{target.arg64(5)}+0x20], xmm2")
+ output.append(f"vmovdqu xmmword ptr [{target.arg64(5)}+0x30], xmm3")
+ elif degree == 2:
+ output.append(f"vbroadcasti128 ymm4, xmmword ptr [{target.arg64(0)}]")
+ output.append(f"vpxor ymm2, ymm2, ymm4")
+ output.append(f"vbroadcasti128 ymm5, xmmword ptr [{target.arg64(0)} + 16]")
+ output.append(f"vpxor ymm3, ymm3, ymm5")
+ # Each vector now holds rows from two different states:
+ # ymm0: a0, a1, a2, a3, b0, b1, b2, b3
+ # ymm1: a4, a5, a6, a7, b4, b5, b6, b7
+ # ymm2: a8, a9, a10, a11, b8, b9, b10, b11
+ # ymm3: a12, a13, a14, a15, b12, b13, b14, b15
+ # We want to rearrange the 128-bit lanes like this, so we can load
+ # destination bytes and XOR them in directly.
+ # ymm4: a0, a1, a2, a3, a4, a5, a6, a7
+ # ymm5: a8, a9, a10, a11, a12, a13, a14, a15
+ # ymm6: b0, b1, b2, b3, b4, b5, b6, b7
+ # ymm7: b8, b9, b10, b11, b12, b13, b14, b15
+ output.append(f"vperm2f128 ymm4, ymm0, ymm1, {0b0010_0000}") # lower 128
+ output.append(f"vperm2f128 ymm5, ymm2, ymm3, {0b0010_0000}")
+ output.append(f"vperm2f128 ymm6, ymm0, ymm1, {0b0011_0001}") # upper 128
+ output.append(f"vperm2f128 ymm7, ymm2, ymm3, {0b0011_0001}")
+ # XOR in the bytes that are already in the destination.
+ output.append(f"vpxor ymm4, ymm4, ymmword ptr [{target.arg64(5)} + 0 * 32]")
+ output.append(f"vpxor ymm5, ymm5, ymmword ptr [{target.arg64(5)} + 1 * 32]")
+ output.append(f"vpxor ymm6, ymm6, ymmword ptr [{target.arg64(5)} + 2 * 32]")
+ output.append(f"vpxor ymm7, ymm7, ymmword ptr [{target.arg64(5)} + 3 * 32]")
+ # Write out the XOR results.
+ output.append(f"vmovdqu ymmword ptr [{target.arg64(5)} + 0 * 32], ymm4")
+ output.append(f"vmovdqu ymmword ptr [{target.arg64(5)} + 1 * 32], ymm5")
+ output.append(f"vmovdqu ymmword ptr [{target.arg64(5)} + 2 * 32], ymm6")
+ output.append(f"vmovdqu ymmword ptr [{target.arg64(5)} + 3 * 32], ymm7")
+ elif degree == 4:
+ output.append(f"vbroadcasti32x4 zmm4, xmmword ptr [{target.arg64(0)}]")
+ output.append(f"vpxord zmm2, zmm2, zmm4")
+ output.append(f"vbroadcasti32x4 zmm5, xmmword ptr [{target.arg64(0)} + 16]")
+ output.append(f"vpxord zmm3, zmm3, zmm5")
+ # Each vector now holds rows from four different states:
+ # zmm0: a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3, d0, d1, d2, d3
+ # zmm1: a4, a5, a6, a7, b4, b5, b6, b7, c4, c5, c6, c7, d4, d5, d6, d7
+ # zmm2: a8, a9, a10, a11, b8, b9, b10, b11, c8, c9, c10, c11, d8, d9, d10, d11
+ # zmm3: a12, a13, a14, a15, b12, b13, b14, b15, c12, c13, c14, c15, d12, d13, d14, d15
+ # We want to rearrange the 128-bit lanes like this, so we can load
+ # destination bytes and XOR them in directly.
+ # zmm0: a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15
+ # zmm1: b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15
+ # zmm2: c0, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15
+ # zmm3: d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10, d11, d12, d13, d14, d15
+ #
+ # This first interleaving of 256-bit lanes produces vectors like:
+ # zmm4: a0, a1, a2, a3, b0, b1, b2, b3, a4, a5, a6, a7, b4, b5, b6, b7
+ output.append(f"vshufi32x4 zmm4, zmm0, zmm1, {0b0100_0100}") # low 256
+ output.append(f"vshufi32x4 zmm5, zmm0, zmm1, {0b1110_1110}") # high 256
+ output.append(f"vshufi32x4 zmm6, zmm2, zmm3, {0b0100_0100}")
+ output.append(f"vshufi32x4 zmm7, zmm2, zmm3, {0b1110_1110}")
+ # And this second interleaving of 128-bit lanes within each 256-bit
+ # lane produces the vectors we want.
+ output.append(f"vshufi32x4 zmm0, zmm4, zmm6, {0b1000_1000}") # low 128
+ output.append(f"vshufi32x4 zmm1, zmm4, zmm6, {0b1101_1101}") # high 128
+ output.append(f"vshufi32x4 zmm2, zmm5, zmm7, {0b1000_1000}")
+ output.append(f"vshufi32x4 zmm3, zmm5, zmm7, {0b1101_1101}")
+ # XOR in the bytes that are already in the destination.
+ output.append(f"vpxord zmm0, zmm0, zmmword ptr [{target.arg64(5)} + 0*64]")
+ output.append(f"vpxord zmm1, zmm1, zmmword ptr [{target.arg64(5)} + 1*64]")
+ output.append(f"vpxord zmm2, zmm2, zmmword ptr [{target.arg64(5)} + 2*64]")
+ output.append(f"vpxord zmm3, zmm3, zmmword ptr [{target.arg64(5)} + 3*64]")
+ # Write out the XOR results.
+ output.append(f"vmovdqu32 zmmword ptr [{target.arg64(5)} + 0*64], zmm0")
+ output.append(f"vmovdqu32 zmmword ptr [{target.arg64(5)} + 1*64], zmm1")
+ output.append(f"vmovdqu32 zmmword ptr [{target.arg64(5)} + 2*64], zmm2")
+ output.append(f"vmovdqu32 zmmword ptr [{target.arg64(5)} + 3*64], zmm3")
+ else:
+ raise NotImplementedError
+ elif target.extension == AVX2:
+ output.append(f"vbroadcasti128 ymm4, xmmword ptr [{target.arg64(0)}]")
+ output.append(f"vpxor ymm2, ymm2, ymm4")
+ output.append(f"vbroadcasti128 ymm5, xmmword ptr [{target.arg64(0)} + 16]")
+ output.append(f"vpxor ymm3, ymm3, ymm5")
+ # Each vector now holds rows from two different states:
+ # ymm0: a0, a1, a2, a3, b0, b1, b2, b3
+ # ymm1: a4, a5, a6, a7, b4, b5, b6, b7
+ # ymm2: a8, a9, a10, a11, b8, b9, b10, b11
+ # ymm3: a12, a13, a14, a15, b12, b13, b14, b15
+ # We want to rearrange the 128-bit lanes like this, so we can load
+ # destination bytes and XOR them in directly.
+ # ymm4: a0, a1, a2, a3, a4, a5, a6, a7
+ # ymm5: a8, a9, a10, a11, a12, a13, a14, a15
+ # ymm6: b0, b1, b2, b3, b4, b5, b6, b7
+ # ymm7: b8, b9, b10, b11, b12, b13, b14, b15
+ output.append(f"vperm2f128 ymm4, ymm0, ymm1, {0b0010_0000}") # lower 128
+ output.append(f"vperm2f128 ymm5, ymm2, ymm3, {0b0010_0000}")
+ output.append(f"vperm2f128 ymm6, ymm0, ymm1, {0b0011_0001}") # upper 128
+ output.append(f"vperm2f128 ymm7, ymm2, ymm3, {0b0011_0001}")
+ # XOR in the bytes that are already in the destination.
+ output.append(f"vpxor ymm4, ymm4, ymmword ptr [{target.arg64(5)} + 0 * 32]")
+ output.append(f"vpxor ymm5, ymm5, ymmword ptr [{target.arg64(5)} + 1 * 32]")
+ output.append(f"vpxor ymm6, ymm6, ymmword ptr [{target.arg64(5)} + 2 * 32]")
+ output.append(f"vpxor ymm7, ymm7, ymmword ptr [{target.arg64(5)} + 3 * 32]")
+ # Write out the XOR results.
+ output.append(f"vmovdqu ymmword ptr [{target.arg64(5)} + 0 * 32], ymm4")
+ output.append(f"vmovdqu ymmword ptr [{target.arg64(5)} + 1 * 32], ymm5")
+ output.append(f"vmovdqu ymmword ptr [{target.arg64(5)} + 2 * 32], ymm6")
+ output.append(f"vmovdqu ymmword ptr [{target.arg64(5)} + 3 * 32], ymm7")
+ elif target.extension in (SSE41, SSE2):
+ assert degree == 1
+ output.append(f"movdqu xmm4, xmmword ptr [{target.arg64(0)}]")
+ output.append(f"movdqu xmm5, xmmword ptr [{target.arg64(0)}+0x10]")
+ output.append(f"pxor xmm2, xmm4")
+ output.append(f"pxor xmm3, xmm5")
+ output.append(f"movdqu xmm4, [{target.arg64(5)}]")
+ output.append(f"movdqu xmm5, [{target.arg64(5)}+0x10]")
+ output.append(f"movdqu xmm6, [{target.arg64(5)}+0x20]")
+ output.append(f"movdqu xmm7, [{target.arg64(5)}+0x30]")
+ output.append(f"pxor xmm0, xmm4")
+ output.append(f"pxor xmm1, xmm5")
+ output.append(f"pxor xmm2, xmm6")
+ output.append(f"pxor xmm3, xmm7")
+ output.append(f"movups xmmword ptr [{target.arg64(5)}], xmm0")
+ output.append(f"movups xmmword ptr [{target.arg64(5)}+0x10], xmm1")
+ output.append(f"movups xmmword ptr [{target.arg64(5)}+0x20], xmm2")
+ output.append(f"movups xmmword ptr [{target.arg64(5)}+0x30], xmm3")
+ else:
+ raise NotImplementedError
+
+
+def xof_fn(target, output, degree, xor):
+ variant = "xor" if xor else "stream"
+ finish_fn_2d = xof_xor_finish2d if xor else xof_stream_finish2d
+ label = f"blake3_{target.extension}_xof_{variant}_{degree}"
output.append(f".global {label}")
output.append(f"{label}:")
if target.extension == AVX512:
if degree in (1, 2, 4):
xof_setup2d(target, output, degree)
output.append(f"call {kernel2d_name(target, degree)}")
- xof_stream_finish2d(target, output, degree)
+ finish_fn_2d(target, output, degree)
else:
raise NotImplementedError
elif target.extension == AVX2:
assert degree == 2
xof_setup2d(target, output, degree)
output.append(f"call {kernel2d_name(target, degree)}")
- xof_stream_finish2d(target, output, degree)
+ finish_fn_2d(target, output, degree)
elif target.extension in (SSE41, SSE2):
assert degree == 1
xof_setup2d(target, output, degree)
output.append(f"call {kernel2d_name(target, degree)}")
- xof_stream_finish2d(target, output, degree)
+ finish_fn_2d(target, output, degree)
else:
raise NotImplementedError
output.append(target.ret())
@@ -749,7 +879,8 @@ def emit_sse2(target, output):
target = replace(target, extension=SSE2)
kernel2d(target, output, 1)
compress(target, output)
- xof_stream(target, output, 1)
+ xof_fn(target, output, 1, xor=False)
+ xof_fn(target, output, 1, xor=True)
output.append(".balign 16")
output.append("PBLENDW_0x33_MASK:")
output.append(".long 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000")
@@ -765,13 +896,15 @@ def emit_sse41(target, output):
target = replace(target, extension=SSE41)
kernel2d(target, output, 1)
compress(target, output)
- xof_stream(target, output, 1)
+ xof_fn(target, output, 1, xor=False)
+ xof_fn(target, output, 1, xor=True)
def emit_avx2(target, output):
target = replace(target, extension=AVX2)
kernel2d(target, output, 2)
- xof_stream(target, output, 2)
+ xof_fn(target, output, 2, xor=False)
+ xof_fn(target, output, 2, xor=True)
def emit_avx512(target, output):
@@ -780,9 +913,12 @@ def emit_avx512(target, output):
kernel2d(target, output, 2)
kernel2d(target, output, 4)
compress(target, output)
- xof_stream(target, output, 1)
- xof_stream(target, output, 2)
- xof_stream(target, output, 4)
+ xof_fn(target, output, 1, xor=False)
+ xof_fn(target, output, 1, xor=True)
+ xof_fn(target, output, 2, xor=False)
+ xof_fn(target, output, 2, xor=True)
+ xof_fn(target, output, 4, xor=False)
+ xof_fn(target, output, 4, xor=True)
def emit_footer(target, output):
diff --git a/asm/out.S b/asm/out.S
index 0d548b1..9d5f77f 100644
--- a/asm/out.S
+++ b/asm/out.S
@@ -569,6 +569,48 @@ blake3_sse2_xof_stream_1:
movups xmmword ptr [r9+0x20], xmm2
movups xmmword ptr [r9+0x30], xmm3
ret
+.global blake3_sse2_xof_xor_1
+blake3_sse2_xof_xor_1:
+ movups xmm0, xmmword ptr [rdi]
+ movups xmm1, xmmword ptr [rdi+0x10]
+ movaps xmm2, xmmword ptr [BLAKE3_IV+rip]
+ shl r8, 32
+ mov ecx, ecx
+ or rcx, r8
+ vmovq xmm3, rdx
+ vmovq xmm4, rcx
+ punpcklqdq xmm3, xmm4
+ movups xmm4, xmmword ptr [rsi]
+ movups xmm5, xmmword ptr [rsi+0x10]
+ movaps xmm8, xmm4
+ shufps xmm4, xmm5, 136
+ shufps xmm8, xmm5, 221
+ movaps xmm5, xmm8
+ movups xmm6, xmmword ptr [rsi+0x20]
+ movups xmm7, xmmword ptr [rsi+0x30]
+ movaps xmm8, xmm6
+ shufps xmm6, xmm7, 136
+ pshufd xmm6, xmm6, 0x93
+ shufps xmm8, xmm7, 221
+ pshufd xmm7, xmm8, 0x93
+ call blake3_sse2_kernel2d_1
+ movdqu xmm4, xmmword ptr [rdi]
+ movdqu xmm5, xmmword ptr [rdi+0x10]
+ pxor xmm2, xmm4
+ pxor xmm3, xmm5
+ movdqu xmm4, [r9]
+ movdqu xmm5, [r9+0x10]
+ movdqu xmm6, [r9+0x20]
+ movdqu xmm7, [r9+0x30]
+ pxor xmm0, xmm4
+ pxor xmm1, xmm5
+ pxor xmm2, xmm6
+ pxor xmm3, xmm7
+ movups xmmword ptr [r9], xmm0
+ movups xmmword ptr [r9+0x10], xmm1
+ movups xmmword ptr [r9+0x20], xmm2
+ movups xmmword ptr [r9+0x30], xmm3
+ ret
.balign 16
PBLENDW_0x33_MASK:
.long 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000
@@ -1064,6 +1106,48 @@ blake3_sse41_xof_stream_1:
movups xmmword ptr [r9+0x20], xmm2
movups xmmword ptr [r9+0x30], xmm3
ret
+.global blake3_sse41_xof_xor_1
+blake3_sse41_xof_xor_1:
+ movups xmm0, xmmword ptr [rdi]
+ movups xmm1, xmmword ptr [rdi+0x10]
+ movaps xmm2, xmmword ptr [BLAKE3_IV+rip]
+ shl r8, 32
+ mov ecx, ecx
+ or rcx, r8
+ vmovq xmm3, rdx
+ vmovq xmm4, rcx
+ punpcklqdq xmm3, xmm4
+ movups xmm4, xmmword ptr [rsi]
+ movups xmm5, xmmword ptr [rsi+0x10]
+ movaps xmm8, xmm4
+ shufps xmm4, xmm5, 136
+ shufps xmm8, xmm5, 221
+ movaps xmm5, xmm8
+ movups xmm6, xmmword ptr [rsi+0x20]
+ movups xmm7, xmmword ptr [rsi+0x30]
+ movaps xmm8, xmm6
+ shufps xmm6, xmm7, 136
+ pshufd xmm6, xmm6, 0x93
+ shufps xmm8, xmm7, 221
+ pshufd xmm7, xmm8, 0x93
+ call blake3_sse41_kernel2d_1
+ movdqu xmm4, xmmword ptr [rdi]
+ movdqu xmm5, xmmword ptr [rdi+0x10]
+ pxor xmm2, xmm4
+ pxor xmm3, xmm5
+ movdqu xmm4, [r9]
+ movdqu xmm5, [r9+0x10]
+ movdqu xmm6, [r9+0x20]
+ movdqu xmm7, [r9+0x30]
+ pxor xmm0, xmm4
+ pxor xmm1, xmm5
+ pxor xmm2, xmm6
+ pxor xmm3, xmm7
+ movups xmmword ptr [r9], xmm0
+ movups xmmword ptr [r9+0x10], xmm1
+ movups xmmword ptr [r9+0x20], xmm2
+ movups xmmword ptr [r9+0x30], xmm3
+ ret
blake3_avx2_kernel2d_2:
vbroadcasti128 ymm14, xmmword ptr [ROT16+rip]
vbroadcasti128 ymm15, xmmword ptr [ROT8+rip]
@@ -1479,6 +1563,47 @@ blake3_avx2_xof_stream_2:
vextracti128 xmmword ptr [r9 + 6 * 16], ymm2, 1
vextracti128 xmmword ptr [r9 + 7 * 16], ymm3, 1
ret
+.global blake3_avx2_xof_xor_2
+blake3_avx2_xof_xor_2:
+ vbroadcasti128 ymm0, xmmword ptr [rdi]
+ vbroadcasti128 ymm1, xmmword ptr [rdi+0x10]
+ vmovdqa ymm4, ymmword ptr [INCREMENT_2D+rip]
+ vbroadcasti128 ymm2, xmmword ptr [BLAKE3_IV+rip]
+ vpbroadcastq ymm5, rdx
+ vpaddq ymm6, ymm4, ymm5
+ shl r8, 32
+ mov ecx, ecx
+ or rcx, r8
+ vpbroadcastq ymm7, rcx
+ vpblendd ymm3, ymm6, ymm7, 0xCC
+ vbroadcasti128 ymm8, xmmword ptr [rsi]
+ vbroadcasti128 ymm9, xmmword ptr [rsi+0x10]
+ vshufps ymm4, ymm8, ymm9, 136
+ vshufps ymm5, ymm8, ymm9, 221
+ vbroadcasti128 ymm8, xmmword ptr [rsi+0x20]
+ vbroadcasti128 ymm9, xmmword ptr [rsi+0x30]
+ vshufps ymm6, ymm8, ymm9, 136
+ vshufps ymm7, ymm8, ymm9, 221
+ vpshufd ymm6, ymm6, 0x93
+ vpshufd ymm7, ymm7, 0x93
+ call blake3_avx2_kernel2d_2
+ vbroadcasti128 ymm4, xmmword ptr [rdi]
+ vpxor ymm2, ymm2, ymm4
+ vbroadcasti128 ymm5, xmmword ptr [rdi + 16]
+ vpxor ymm3, ymm3, ymm5
+ vperm2f128 ymm4, ymm0, ymm1, 32
+ vperm2f128 ymm5, ymm2, ymm3, 32
+ vperm2f128 ymm6, ymm0, ymm1, 49
+ vperm2f128 ymm7, ymm2, ymm3, 49
+ vpxor ymm4, ymm4, ymmword ptr [r9 + 0 * 32]
+ vpxor ymm5, ymm5, ymmword ptr [r9 + 1 * 32]
+ vpxor ymm6, ymm6, ymmword ptr [r9 + 2 * 32]
+ vpxor ymm7, ymm7, ymmword ptr [r9 + 3 * 32]
+ vmovdqu ymmword ptr [r9 + 0 * 32], ymm4
+ vmovdqu ymmword ptr [r9 + 1 * 32], ymm5
+ vmovdqu ymmword ptr [r9 + 2 * 32], ymm6
+ vmovdqu ymmword ptr [r9 + 3 * 32], ymm7
+ ret
blake3_avx512_kernel2d_1:
vpaddd xmm0, xmm0, xmm4
vpaddd xmm0, xmm0, xmm1
@@ -2497,6 +2622,39 @@ blake3_avx512_xof_stream_1:
vmovdqu xmmword ptr [r9+0x20], xmm2
vmovdqu xmmword ptr [r9+0x30], xmm3
ret
+.global blake3_avx512_xof_xor_1
+blake3_avx512_xof_xor_1:
+ vmovdqu xmm0, xmmword ptr [rdi]
+ vmovdqu xmm1, xmmword ptr [rdi+0x10]
+ shl r8, 32
+ mov ecx, ecx
+ or rcx, r8
+ vmovq xmm3, rdx
+ vmovq xmm4, rcx
+ vpunpcklqdq xmm3, xmm3, xmm4
+ vmovaps xmm2, xmmword ptr [BLAKE3_IV+rip]
+ vmovups xmm8, xmmword ptr [rsi]
+ vmovups xmm9, xmmword ptr [rsi+0x10]
+ vshufps xmm4, xmm8, xmm9, 136
+ vshufps xmm5, xmm8, xmm9, 221
+ vmovups xmm8, xmmword ptr [rsi+0x20]
+ vmovups xmm9, xmmword ptr [rsi+0x30]
+ vshufps xmm6, xmm8, xmm9, 136
+ vshufps xmm7, xmm8, xmm9, 221
+ vpshufd xmm6, xmm6, 0x93
+ vpshufd xmm7, xmm7, 0x93
+ call blake3_avx512_kernel2d_1
+ vpxor xmm2, xmm2, [rdi]
+ vpxor xmm3, xmm3, [rdi+0x10]
+ vpxor xmm0, xmm0, [r9]
+ vpxor xmm1, xmm1, [r9+0x10]
+ vpxor xmm2, xmm2, [r9+0x20]
+ vpxor xmm3, xmm3, [r9+0x30]
+ vmovdqu xmmword ptr [r9], xmm0
+ vmovdqu xmmword ptr [r9+0x10], xmm1
+ vmovdqu xmmword ptr [r9+0x20], xmm2
+ vmovdqu xmmword ptr [r9+0x30], xmm3
+ ret
.global blake3_avx512_xof_stream_2
blake3_avx512_xof_stream_2:
vbroadcasti128 ymm0, xmmword ptr [rdi]
@@ -2529,10 +2687,51 @@ blake3_avx512_xof_stream_2:
vmovdqu xmmword ptr [r9 + 1 * 16], xmm1
vmovdqu xmmword ptr [r9 + 2 * 16], xmm2
vmovdqu xmmword ptr [r9 + 3 * 16], xmm3
- vextracti128 xmmword ptr [r9 + 4 * 16], ymm0, 1
- vextracti128 xmmword ptr [r9 + 5 * 16], ymm1, 1
- vextracti128 xmmword ptr [r9 + 6 * 16], ymm2, 1
- vextracti128 xmmword ptr [r9 + 7 * 16], ymm3, 1
+ vextracti128 xmmword ptr [r9+4*16], ymm0, 1
+ vextracti128 xmmword ptr [r9+5*16], ymm1, 1
+ vextracti128 xmmword ptr [r9+6*16], ymm2, 1
+ vextracti128 xmmword ptr [r9+7*16], ymm3, 1
+ ret
+.global blake3_avx512_xof_xor_2
+blake3_avx512_xof_xor_2:
+ vbroadcasti128 ymm0, xmmword ptr [rdi]
+ vbroadcasti128 ymm1, xmmword ptr [rdi+0x10]
+ vmovdqa ymm4, ymmword ptr [INCREMENT_2D+rip]
+ vbroadcasti128 ymm2, xmmword ptr [BLAKE3_IV+rip]
+ vpbroadcastq ymm5, rdx
+ vpaddq ymm6, ymm4, ymm5
+ shl r8, 32
+ mov ecx, ecx
+ or rcx, r8
+ vpbroadcastq ymm7, rcx
+ vpblendd ymm3, ymm6, ymm7, 0xCC
+ vbroadcasti128 ymm8, xmmword ptr [rsi]
+ vbroadcasti128 ymm9, xmmword ptr [rsi+0x10]
+ vshufps ymm4, ymm8, ymm9, 136
+ vshufps ymm5, ymm8, ymm9, 221
+ vbroadcasti128 ymm8, xmmword ptr [rsi+0x20]
+ vbroadcasti128 ymm9, xmmword ptr [rsi+0x30]
+ vshufps ymm6, ymm8, ymm9, 136
+ vshufps ymm7, ymm8, ymm9, 221
+ vpshufd ymm6, ymm6, 0x93
+ vpshufd ymm7, ymm7, 0x93
+ call blake3_avx512_kernel2d_2
+ vbroadcasti128 ymm4, xmmword ptr [rdi]
+ vpxor ymm2, ymm2, ymm4
+ vbroadcasti128 ymm5, xmmword ptr [rdi + 16]
+ vpxor ymm3, ymm3, ymm5
+ vperm2f128 ymm4, ymm0, ymm1, 32
+ vperm2f128 ymm5, ymm2, ymm3, 32
+ vperm2f128 ymm6, ymm0, ymm1, 49
+ vperm2f128 ymm7, ymm2, ymm3, 49
+ vpxor ymm4, ymm4, ymmword ptr [r9 + 0 * 32]
+ vpxor ymm5, ymm5, ymmword ptr [r9 + 1 * 32]
+ vpxor ymm6, ymm6, ymmword ptr [r9 + 2 * 32]
+ vpxor ymm7, ymm7, ymmword ptr [r9 + 3 * 32]
+ vmovdqu ymmword ptr [r9 + 0 * 32], ymm4
+ vmovdqu ymmword ptr [r9 + 1 * 32], ymm5
+ vmovdqu ymmword ptr [r9 + 2 * 32], ymm6
+ vmovdqu ymmword ptr [r9 + 3 * 32], ymm7
ret
.global blake3_avx512_xof_stream_4
blake3_avx512_xof_stream_4:
@@ -2581,6 +2780,53 @@ blake3_avx512_xof_stream_4:
vextracti32x4 xmmword ptr [r9 + 14 * 16], zmm2, 3
vextracti32x4 xmmword ptr [r9 + 15 * 16], zmm3, 3
ret
+.global blake3_avx512_xof_xor_4
+blake3_avx512_xof_xor_4:
+ vbroadcasti32x4 zmm0, xmmword ptr [rdi]
+ vbroadcasti32x4 zmm1, xmmword ptr [rdi+0x10]
+ vmovdqa32 zmm4, zmmword ptr [INCREMENT_2D+rip]
+ vbroadcasti32x4 zmm2, xmmword ptr [BLAKE3_IV+rip]
+ vpbroadcastq zmm5, rdx
+ vpaddq zmm6, zmm4, zmm5
+ shl r8, 32
+ mov ecx, ecx
+ or rcx, r8
+ vpbroadcastq zmm7, rcx
+ mov eax, 0xAA
+ kmovw k2, eax
+ vpblendmq zmm3 {k2}, zmm6, zmm7
+ vbroadcasti32x4 zmm8, xmmword ptr [rsi]
+ vbroadcasti32x4 zmm9, xmmword ptr [rsi+0x10]
+ vshufps zmm4, zmm8, zmm9, 136
+ vshufps zmm5, zmm8, zmm9, 221
+ vbroadcasti32x4 zmm8, xmmword ptr [rsi+0x20]
+ vbroadcasti32x4 zmm9, xmmword ptr [rsi+0x30]
+ vshufps zmm6, zmm8, zmm9, 136
+ vshufps zmm7, zmm8, zmm9, 221
+ vpshufd zmm6, zmm6, 0x93
+ vpshufd zmm7, zmm7, 0x93
+ call blake3_avx512_kernel2d_4
+ vbroadcasti32x4 zmm4, xmmword ptr [rdi]
+ vpxord zmm2, zmm2, zmm4
+ vbroadcasti32x4 zmm5, xmmword ptr [rdi + 16]
+ vpxord zmm3, zmm3, zmm5
+ vshufi32x4 zmm4, zmm0, zmm1, 68
+ vshufi32x4 zmm5, zmm0, zmm1, 238
+ vshufi32x4 zmm6, zmm2, zmm3, 68
+ vshufi32x4 zmm7, zmm2, zmm3, 238
+ vshufi32x4 zmm0, zmm4, zmm6, 136
+ vshufi32x4 zmm1, zmm4, zmm6, 221
+ vshufi32x4 zmm2, zmm5, zmm7, 136
+ vshufi32x4 zmm3, zmm5, zmm7, 221
+ vpxord zmm0, zmm0, zmmword ptr [r9 + 0*64]
+ vpxord zmm1, zmm1, zmmword ptr [r9 + 1*64]
+ vpxord zmm2, zmm2, zmmword ptr [r9 + 2*64]
+ vpxord zmm3, zmm3, zmmword ptr [r9 + 3*64]
+ vmovdqu32 zmmword ptr [r9 + 0*64], zmm0
+ vmovdqu32 zmmword ptr [r9 + 1*64], zmm1
+ vmovdqu32 zmmword ptr [r9 + 2*64], zmm2
+ vmovdqu32 zmmword ptr [r9 + 3*64], zmm3
+ ret
.balign 16
BLAKE3_IV:
BLAKE3_IV_0:
diff --git a/src/kernel.rs b/src/kernel.rs
index f17db21..cdcb25a 100644
--- a/src/kernel.rs
+++ b/src/kernel.rs
@@ -73,6 +73,54 @@ extern "C" {
flags: u32,
out: *mut [u8; 64 * 4],
);
+ pub fn blake3_sse2_xof_xor_1(
+ cv: &[u32; 8],
+ block: &[u8; 64],
+ counter: u64,
+ block_len: u32,
+ flags: u32,
+ out: &mut [u8; 64],
+ );
+ pub fn blake3_sse41_xof_xor_1(
+ cv: &[u32; 8],
+ block: &[u8; 64],
+ counter: u64,
+ block_len: u32,
+ flags: u32,
+ out: &mut [u8; 64],
+ );
+ pub fn blake3_avx512_xof_xor_1(
+ cv: &[u32; 8],
+ block: &[u8; 64],
+ counter: u64,
+ block_len: u32,
+ flags: u32,
+ out: &mut [u8; 64],
+ );
+ pub fn blake3_avx2_xof_xor_2(
+ cv: &[u32; 8],
+ block: &[u8; 64],
+ counter: u64,
+ block_len: u32,
+ flags: u32,
+ out: &mut [u8; 64 * 2],
+ );
+ pub fn blake3_avx512_xof_xor_2(
+ cv: &[u32; 8],
+ block: &[u8; 64],
+ counter: u64,
+ block_len: u32,
+ flags: u32,
+ out: &mut [u8; 64 * 2],
+ );
+ pub fn blake3_avx512_xof_xor_4(
+ cv: &[u32; 8],
+ block: &[u8; 64],
+ counter: u64,
+ block_len: u32,
+ flags: u32,
+ out: &mut [u8; 64 * 4],
+ );
}
pub type CompressionFn =
@@ -87,6 +135,15 @@ pub type XofStreamFn<const N: usize> = unsafe extern "C" fn(
out: *mut [u8; N],
);
+pub type XofXorFn<const N: usize> = unsafe extern "C" fn(
+ cv: &[u32; 8],
+ block: &[u8; 64],
+ counter: u64,
+ block_len: u32,
+ flags: u32,
+ out: &mut [u8; N],
+);
+
#[cfg(test)]
mod test {
use super::*;
@@ -142,13 +199,14 @@ mod test {
test_compression_function(blake3_avx512_compress);
}
- fn test_xof_function<const N: usize>(f: XofStreamFn<N>) {
+ fn test_xof_functions<const N: usize>(stream_fn: XofStreamFn<N>, xor_fn: XofXorFn<N>) {
let mut block = [0; 64];
let block_len = 53;
crate::test::paint_test_input(&mut block[..block_len]);
let counter = u32::MAX as u64;
let flags = crate::CHUNK_START | crate::CHUNK_END | crate::ROOT;
+ // First compute the expected stream.
let mut expected = [0; N];
let mut incrementing_counter = counter;
let mut i = 0;
@@ -167,9 +225,11 @@ mod test {
}
assert_eq!(incrementing_counter, counter + N as u64 / 64);
- let mut found = [0; N];
+ // And compare that to the stream under test. The 0x42 bytes are there to make sure we
+ // overwrite them.
+ let mut found = [0x42; N];
unsafe {
- f(
+ stream_fn(
crate::IV,
&block,
counter,
@@ -178,8 +238,24 @@ mod test {
&mut found,
);
}
-
assert_eq!(expected, found);
+
+ // XOR 0x99 bytes into the found stream. Then run the xof_xor variant on that stream again.
+ // This should cancel out the original stream, leaving only the 0x99 bytes.
+ for b in &mut found {
+ *b ^= 0x99;
+ }
+ unsafe {
+ xor_fn(
+ crate::IV,
+ &block,
+ counter,
+ block_len as u32,
+ flags as u32,
+ &mut found,
+ );
+ }
+ assert_eq!([0x99; N], found);
}
#[test]
@@ -188,7 +264,7 @@ mod test {
if !is_x86_feature_detected!("sse2") {
return;
}
- test_xof_function(blake3_sse2_xof_stream_1);
+ test_xof_functions(blake3_sse2_xof_stream_1, blake3_sse2_xof_xor_1);
}
#[test]
@@ -197,7 +273,7 @@ mod test {
if !is_x86_feature_detected!("sse4.1") {
return;
}
- test_xof_function(blake3_sse41_xof_stream_1);
+ test_xof_functions(blake3_sse41_xof_stream_1, blake3_sse41_xof_xor_1);
}
#[test]
@@ -206,7 +282,7 @@ mod test {
if !is_x86_feature_detected!("avx512f") || !is_x86_feature_detected!("avx512vl") {
return;
}
- test_xof_function(blake3_avx512_xof_stream_1);
+ test_xof_functions(blake3_avx512_xof_stream_1, blake3_avx512_xof_xor_1);
}
#[test]
@@ -215,7 +291,7 @@ mod test {
if !is_x86_feature_detected!("avx2") {
return;
}
- test_xof_function(blake3_avx2_xof_stream_2);
+ test_xof_functions(blake3_avx2_xof_stream_2, blake3_avx2_xof_xor_2);
}
#[test]
@@ -224,7 +300,7 @@ mod test {
if !is_x86_feature_detected!("avx512f") || !is_x86_feature_detected!("avx512vl") {
return;
}
- test_xof_function(blake3_avx512_xof_stream_2);
+ test_xof_functions(blake3_avx512_xof_stream_2, blake3_avx512_xof_xor_2);
}
#[test]
@@ -233,7 +309,7 @@ mod test {
if !is_x86_feature_detected!("avx512f") || !is_x86_feature_detected!("avx512vl") {
return;
}
- test_xof_function(blake3_avx512_xof_stream_4);
+ test_xof_functions(blake3_avx512_xof_stream_4, blake3_avx512_xof_xor_4);
}
}