aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJack O'Connor <[email protected]>2022-03-20 19:04:41 -0400
committerJack O'Connor <[email protected]>2022-03-20 20:17:31 -0400
commitea94b544fcdb9cf4120eaf9308e8df937871ad3e (patch)
tree953deb76ef12cce9bb2b672bf19dbffd86c6695d
parent9139fa40e8e0c17a1afd5e137eaca5a25d865066 (diff)
blake3_avx512_xof_stream_4
-rwxr-xr-xasm/asm.py58
-rw-r--r--asm/out.S51
-rw-r--r--src/kernel.rs19
3 files changed, 125 insertions, 3 deletions
diff --git a/asm/asm.py b/asm/asm.py
index 1526ae8..de5b89f 100755
--- a/asm/asm.py
+++ b/asm/asm.py
@@ -361,8 +361,14 @@ def kernel2d(target, output, degree):
output.append(f"movaps xmm14, xmmword ptr [ROT8+rip]")
output.append(f"movaps xmm15, xmmword ptr [ROT16+rip]")
if target.extension == AVX2:
- output.append("vbroadcasti128 ymm14, xmmword ptr [ROT16+rip]")
- output.append("vbroadcasti128 ymm15, xmmword ptr [ROT8+rip]")
+ output.append(f"vbroadcasti128 ymm14, xmmword ptr [ROT16+rip]")
+ output.append(f"vbroadcasti128 ymm15, xmmword ptr [ROT8+rip]")
+ if target.extension == AVX512:
+ if degree == 4:
+ output.append(f"mov {target.scratch32(0)}, 43690")
+ output.append(f"kmovw k3, {target.scratch32(0)}")
+ output.append(f"mov {target.scratch32(0)}, 34952")
+ output.append(f"kmovw k4, {target.scratch32(0)}")
for round_number in range(7):
if round_number > 0:
# Un-diagonalize and permute before each round except the first.
@@ -541,6 +547,39 @@ def xof_setup2d(target, output, degree):
output.append(f"vshufps ymm7, ymm8, ymm9, 221")
output.append(f"vpshufd ymm6, ymm6, 0x93")
output.append(f"vpshufd ymm7, ymm7, 0x93")
+ elif degree == 4:
+ # Load the state words.
+ output.append(f"vbroadcasti32x4 zmm0, xmmword ptr [{target.arg64(0)}]")
+ output.append(f"vbroadcasti32x4 zmm1, xmmword ptr [{target.arg64(0)}+0x10]")
+ # Load the counter increments.
+ output.append(f"vmovdqa32 zmm4, zmmword ptr [INCREMENT_2D+rip]")
+ # Load the IV constants.
+ output.append(f"vbroadcasti32x4 zmm2, xmmword ptr [BLAKE3_IV+rip]")
+ # Broadcast the counter.
+ output.append(f"vpbroadcastq zmm5, {target.arg64(2)}")
+ # Add the counter increments to the counter.
+ output.append(f"vpaddq zmm6, zmm4, zmm5")
+ # Combine the block length and flags into a 64-bit word.
+ output.append(f"shl {target.arg64(4)}, 32")
+ output.append(f"mov {target.arg32(3)}, {target.arg32(3)}")
+ output.append(f"or {target.arg64(3)}, {target.arg64(4)}")
+ # Broadcast the block length and flags.
+ output.append(f"vpbroadcastq zmm7, {target.arg64(3)}")
+ # Blend the counter, block length, and flags.
+ output.append(f"mov {target.scratch32(0)}, 0xAA")
+ output.append(f"kmovw k2, {target.scratch32(0)}")
+ output.append(f"vpblendmq zmm3 {{k2}}, zmm6, zmm7")
+ # Load and permute the message words.
+ output.append(f"vbroadcasti32x4 zmm8, xmmword ptr [{target.arg64(1)}]")
+ output.append(f"vbroadcasti32x4 zmm9, xmmword ptr [{target.arg64(1)}+0x10]")
+ output.append(f"vshufps zmm4, zmm8, zmm9, 136")
+ output.append(f"vshufps zmm5, zmm8, zmm9, 221")
+ output.append(f"vbroadcasti32x4 zmm8, xmmword ptr [{target.arg64(1)}+0x20]")
+ output.append(f"vbroadcasti32x4 zmm9, xmmword ptr [{target.arg64(1)}+0x30]")
+ output.append(f"vshufps zmm6, zmm8, zmm9, 136")
+ output.append(f"vshufps zmm7, zmm8, zmm9, 221")
+ output.append(f"vpshufd zmm6, zmm6, 0x93")
+ output.append(f"vpshufd zmm7, zmm7, 0x93")
else:
raise NotImplementedError
elif target.extension == AVX2:
@@ -632,6 +671,20 @@ def xof_stream_finish2d(target, output, degree):
output.append(
f"vextracti128 xmmword ptr [{target.arg64(5)} + 7 * 16], ymm3, 1"
)
+ elif degree == 4:
+ output.append(f"vbroadcasti32x4 zmm4, xmmword ptr [{target.arg64(0)}]")
+ output.append(f"vpxord zmm2, zmm2, zmm4")
+ output.append(f"vbroadcasti32x4 zmm5, xmmword ptr [{target.arg64(0)} + 16]")
+ output.append(f"vpxord zmm3, zmm3, zmm5")
+ output.append(f"vmovdqu xmmword ptr [{target.arg64(5)} + 0 * 16], xmm0")
+ output.append(f"vmovdqu xmmword ptr [{target.arg64(5)} + 1 * 16], xmm1")
+ output.append(f"vmovdqu xmmword ptr [{target.arg64(5)} + 2 * 16], xmm2")
+ output.append(f"vmovdqu xmmword ptr [{target.arg64(5)} + 3 * 16], xmm3")
+ for i in range(1, 4):
+ for reg in range(0, 4):
+ output.append(
+ f"vextracti32x4 xmmword ptr [{target.arg64(5)} + {4*i+reg} * 16], zmm{reg}, {i}"
+ )
else:
raise NotImplementedError
elif target.extension == AVX2:
@@ -729,6 +782,7 @@ def emit_avx512(target, output):
compress(target, output)
xof_stream(target, output, 1)
xof_stream(target, output, 2)
+ xof_stream(target, output, 4)
def emit_footer(target, output):
diff --git a/asm/out.S b/asm/out.S
index 7ec3321..0d548b1 100644
--- a/asm/out.S
+++ b/asm/out.S
@@ -2120,6 +2120,10 @@ blake3_avx512_kernel2d_2:
vpxord ymm1, ymm1, ymm3
ret
blake3_avx512_kernel2d_4:
+ mov eax, 43690
+ kmovw k3, eax
+ mov eax, 34952
+ kmovw k4, eax
vpaddd zmm0, zmm0, zmm4
vpaddd zmm0, zmm0, zmm1
vpxord zmm3, zmm3, zmm0
@@ -2530,6 +2534,53 @@ blake3_avx512_xof_stream_2:
vextracti128 xmmword ptr [r9 + 6 * 16], ymm2, 1
vextracti128 xmmword ptr [r9 + 7 * 16], ymm3, 1
ret
+.global blake3_avx512_xof_stream_4
+blake3_avx512_xof_stream_4:
+ vbroadcasti32x4 zmm0, xmmword ptr [rdi]
+ vbroadcasti32x4 zmm1, xmmword ptr [rdi+0x10]
+ vmovdqa32 zmm4, zmmword ptr [INCREMENT_2D+rip]
+ vbroadcasti32x4 zmm2, xmmword ptr [BLAKE3_IV+rip]
+ vpbroadcastq zmm5, rdx
+ vpaddq zmm6, zmm4, zmm5
+ shl r8, 32
+ mov ecx, ecx
+ or rcx, r8
+ vpbroadcastq zmm7, rcx
+ mov eax, 0xAA
+ kmovw k2, eax
+ vpblendmq zmm3 {k2}, zmm6, zmm7
+ vbroadcasti32x4 zmm8, xmmword ptr [rsi]
+ vbroadcasti32x4 zmm9, xmmword ptr [rsi+0x10]
+ vshufps zmm4, zmm8, zmm9, 136
+ vshufps zmm5, zmm8, zmm9, 221
+ vbroadcasti32x4 zmm8, xmmword ptr [rsi+0x20]
+ vbroadcasti32x4 zmm9, xmmword ptr [rsi+0x30]
+ vshufps zmm6, zmm8, zmm9, 136
+ vshufps zmm7, zmm8, zmm9, 221
+ vpshufd zmm6, zmm6, 0x93
+ vpshufd zmm7, zmm7, 0x93
+ call blake3_avx512_kernel2d_4
+ vbroadcasti32x4 zmm4, xmmword ptr [rdi]
+ vpxord zmm2, zmm2, zmm4
+ vbroadcasti32x4 zmm5, xmmword ptr [rdi + 16]
+ vpxord zmm3, zmm3, zmm5
+ vmovdqu xmmword ptr [r9 + 0 * 16], xmm0
+ vmovdqu xmmword ptr [r9 + 1 * 16], xmm1
+ vmovdqu xmmword ptr [r9 + 2 * 16], xmm2
+ vmovdqu xmmword ptr [r9 + 3 * 16], xmm3
+ vextracti32x4 xmmword ptr [r9 + 4 * 16], zmm0, 1
+ vextracti32x4 xmmword ptr [r9 + 5 * 16], zmm1, 1
+ vextracti32x4 xmmword ptr [r9 + 6 * 16], zmm2, 1
+ vextracti32x4 xmmword ptr [r9 + 7 * 16], zmm3, 1
+ vextracti32x4 xmmword ptr [r9 + 8 * 16], zmm0, 2
+ vextracti32x4 xmmword ptr [r9 + 9 * 16], zmm1, 2
+ vextracti32x4 xmmword ptr [r9 + 10 * 16], zmm2, 2
+ vextracti32x4 xmmword ptr [r9 + 11 * 16], zmm3, 2
+ vextracti32x4 xmmword ptr [r9 + 12 * 16], zmm0, 3
+ vextracti32x4 xmmword ptr [r9 + 13 * 16], zmm1, 3
+ vextracti32x4 xmmword ptr [r9 + 14 * 16], zmm2, 3
+ vextracti32x4 xmmword ptr [r9 + 15 * 16], zmm3, 3
+ ret
.balign 16
BLAKE3_IV:
BLAKE3_IV_0:
diff --git a/src/kernel.rs b/src/kernel.rs
index a6cca9f..f17db21 100644
--- a/src/kernel.rs
+++ b/src/kernel.rs
@@ -65,6 +65,14 @@ extern "C" {
flags: u32,
out: *mut [u8; 64 * 2],
);
+ pub fn blake3_avx512_xof_stream_4(
+ cv: &[u32; 8],
+ block: &[u8; 64],
+ counter: u64,
+ block_len: u32,
+ flags: u32,
+ out: *mut [u8; 64 * 4],
+ );
}
pub type CompressionFn =
@@ -138,7 +146,7 @@ mod test {
let mut block = [0; 64];
let block_len = 53;
crate::test::paint_test_input(&mut block[..block_len]);
- let counter = u64::MAX - 42;
+ let counter = u32::MAX as u64;
let flags = crate::CHUNK_START | crate::CHUNK_END | crate::ROOT;
let mut expected = [0; N];
@@ -218,6 +226,15 @@ mod test {
}
test_xof_function(blake3_avx512_xof_stream_2);
}
+
+ #[test]
+ #[cfg(target_arch = "x86_64")]
+ fn test_avx512_xof_4() {
+ if !is_x86_feature_detected!("avx512f") || !is_x86_feature_detected!("avx512vl") {
+ return;
+ }
+ test_xof_function(blake3_avx512_xof_stream_4);
+ }
}
global_asm!(