diff options
| author | Jack O'Connor <[email protected]> | 2022-03-20 19:04:41 -0400 |
|---|---|---|
| committer | Jack O'Connor <[email protected]> | 2022-03-20 20:17:31 -0400 |
| commit | ea94b544fcdb9cf4120eaf9308e8df937871ad3e (patch) | |
| tree | 953deb76ef12cce9bb2b672bf19dbffd86c6695d | |
| parent | 9139fa40e8e0c17a1afd5e137eaca5a25d865066 (diff) | |
blake3_avx512_xof_stream_4
| -rwxr-xr-x | asm/asm.py | 58 | ||||
| -rw-r--r-- | asm/out.S | 51 | ||||
| -rw-r--r-- | src/kernel.rs | 19 |
3 files changed, 125 insertions, 3 deletions
@@ -361,8 +361,14 @@ def kernel2d(target, output, degree): output.append(f"movaps xmm14, xmmword ptr [ROT8+rip]") output.append(f"movaps xmm15, xmmword ptr [ROT16+rip]") if target.extension == AVX2: - output.append("vbroadcasti128 ymm14, xmmword ptr [ROT16+rip]") - output.append("vbroadcasti128 ymm15, xmmword ptr [ROT8+rip]") + output.append(f"vbroadcasti128 ymm14, xmmword ptr [ROT16+rip]") + output.append(f"vbroadcasti128 ymm15, xmmword ptr [ROT8+rip]") + if target.extension == AVX512: + if degree == 4: + output.append(f"mov {target.scratch32(0)}, 43690") + output.append(f"kmovw k3, {target.scratch32(0)}") + output.append(f"mov {target.scratch32(0)}, 34952") + output.append(f"kmovw k4, {target.scratch32(0)}") for round_number in range(7): if round_number > 0: # Un-diagonalize and permute before each round except the first. @@ -541,6 +547,39 @@ def xof_setup2d(target, output, degree): output.append(f"vshufps ymm7, ymm8, ymm9, 221") output.append(f"vpshufd ymm6, ymm6, 0x93") output.append(f"vpshufd ymm7, ymm7, 0x93") + elif degree == 4: + # Load the state words. + output.append(f"vbroadcasti32x4 zmm0, xmmword ptr [{target.arg64(0)}]") + output.append(f"vbroadcasti32x4 zmm1, xmmword ptr [{target.arg64(0)}+0x10]") + # Load the counter increments. + output.append(f"vmovdqa32 zmm4, zmmword ptr [INCREMENT_2D+rip]") + # Load the IV constants. + output.append(f"vbroadcasti32x4 zmm2, xmmword ptr [BLAKE3_IV+rip]") + # Broadcast the counter. + output.append(f"vpbroadcastq zmm5, {target.arg64(2)}") + # Add the counter increments to the counter. + output.append(f"vpaddq zmm6, zmm4, zmm5") + # Combine the block length and flags into a 64-bit word. + output.append(f"shl {target.arg64(4)}, 32") + output.append(f"mov {target.arg32(3)}, {target.arg32(3)}") + output.append(f"or {target.arg64(3)}, {target.arg64(4)}") + # Broadcast the block length and flags. + output.append(f"vpbroadcastq zmm7, {target.arg64(3)}") + # Blend the counter, block length, and flags. + output.append(f"mov {target.scratch32(0)}, 0xAA") + output.append(f"kmovw k2, {target.scratch32(0)}") + output.append(f"vpblendmq zmm3 {{k2}}, zmm6, zmm7") + # Load and permute the message words. + output.append(f"vbroadcasti32x4 zmm8, xmmword ptr [{target.arg64(1)}]") + output.append(f"vbroadcasti32x4 zmm9, xmmword ptr [{target.arg64(1)}+0x10]") + output.append(f"vshufps zmm4, zmm8, zmm9, 136") + output.append(f"vshufps zmm5, zmm8, zmm9, 221") + output.append(f"vbroadcasti32x4 zmm8, xmmword ptr [{target.arg64(1)}+0x20]") + output.append(f"vbroadcasti32x4 zmm9, xmmword ptr [{target.arg64(1)}+0x30]") + output.append(f"vshufps zmm6, zmm8, zmm9, 136") + output.append(f"vshufps zmm7, zmm8, zmm9, 221") + output.append(f"vpshufd zmm6, zmm6, 0x93") + output.append(f"vpshufd zmm7, zmm7, 0x93") else: raise NotImplementedError elif target.extension == AVX2: @@ -632,6 +671,20 @@ def xof_stream_finish2d(target, output, degree): output.append( f"vextracti128 xmmword ptr [{target.arg64(5)} + 7 * 16], ymm3, 1" ) + elif degree == 4: + output.append(f"vbroadcasti32x4 zmm4, xmmword ptr [{target.arg64(0)}]") + output.append(f"vpxord zmm2, zmm2, zmm4") + output.append(f"vbroadcasti32x4 zmm5, xmmword ptr [{target.arg64(0)} + 16]") + output.append(f"vpxord zmm3, zmm3, zmm5") + output.append(f"vmovdqu xmmword ptr [{target.arg64(5)} + 0 * 16], xmm0") + output.append(f"vmovdqu xmmword ptr [{target.arg64(5)} + 1 * 16], xmm1") + output.append(f"vmovdqu xmmword ptr [{target.arg64(5)} + 2 * 16], xmm2") + output.append(f"vmovdqu xmmword ptr [{target.arg64(5)} + 3 * 16], xmm3") + for i in range(1, 4): + for reg in range(0, 4): + output.append( + f"vextracti32x4 xmmword ptr [{target.arg64(5)} + {4*i+reg} * 16], zmm{reg}, {i}" + ) else: raise NotImplementedError elif target.extension == AVX2: @@ -729,6 +782,7 @@ def emit_avx512(target, output): compress(target, output) xof_stream(target, output, 1) xof_stream(target, output, 2) + xof_stream(target, output, 4) def emit_footer(target, output): @@ -2120,6 +2120,10 @@ blake3_avx512_kernel2d_2: vpxord ymm1, ymm1, ymm3 ret blake3_avx512_kernel2d_4: + mov eax, 43690 + kmovw k3, eax + mov eax, 34952 + kmovw k4, eax vpaddd zmm0, zmm0, zmm4 vpaddd zmm0, zmm0, zmm1 vpxord zmm3, zmm3, zmm0 @@ -2530,6 +2534,53 @@ blake3_avx512_xof_stream_2: vextracti128 xmmword ptr [r9 + 6 * 16], ymm2, 1 vextracti128 xmmword ptr [r9 + 7 * 16], ymm3, 1 ret +.global blake3_avx512_xof_stream_4 +blake3_avx512_xof_stream_4: + vbroadcasti32x4 zmm0, xmmword ptr [rdi] + vbroadcasti32x4 zmm1, xmmword ptr [rdi+0x10] + vmovdqa32 zmm4, zmmword ptr [INCREMENT_2D+rip] + vbroadcasti32x4 zmm2, xmmword ptr [BLAKE3_IV+rip] + vpbroadcastq zmm5, rdx + vpaddq zmm6, zmm4, zmm5 + shl r8, 32 + mov ecx, ecx + or rcx, r8 + vpbroadcastq zmm7, rcx + mov eax, 0xAA + kmovw k2, eax + vpblendmq zmm3 {k2}, zmm6, zmm7 + vbroadcasti32x4 zmm8, xmmword ptr [rsi] + vbroadcasti32x4 zmm9, xmmword ptr [rsi+0x10] + vshufps zmm4, zmm8, zmm9, 136 + vshufps zmm5, zmm8, zmm9, 221 + vbroadcasti32x4 zmm8, xmmword ptr [rsi+0x20] + vbroadcasti32x4 zmm9, xmmword ptr [rsi+0x30] + vshufps zmm6, zmm8, zmm9, 136 + vshufps zmm7, zmm8, zmm9, 221 + vpshufd zmm6, zmm6, 0x93 + vpshufd zmm7, zmm7, 0x93 + call blake3_avx512_kernel2d_4 + vbroadcasti32x4 zmm4, xmmword ptr [rdi] + vpxord zmm2, zmm2, zmm4 + vbroadcasti32x4 zmm5, xmmword ptr [rdi + 16] + vpxord zmm3, zmm3, zmm5 + vmovdqu xmmword ptr [r9 + 0 * 16], xmm0 + vmovdqu xmmword ptr [r9 + 1 * 16], xmm1 + vmovdqu xmmword ptr [r9 + 2 * 16], xmm2 + vmovdqu xmmword ptr [r9 + 3 * 16], xmm3 + vextracti32x4 xmmword ptr [r9 + 4 * 16], zmm0, 1 + vextracti32x4 xmmword ptr [r9 + 5 * 16], zmm1, 1 + vextracti32x4 xmmword ptr [r9 + 6 * 16], zmm2, 1 + vextracti32x4 xmmword ptr [r9 + 7 * 16], zmm3, 1 + vextracti32x4 xmmword ptr [r9 + 8 * 16], zmm0, 2 + vextracti32x4 xmmword ptr [r9 + 9 * 16], zmm1, 2 + vextracti32x4 xmmword ptr [r9 + 10 * 16], zmm2, 2 + vextracti32x4 xmmword ptr [r9 + 11 * 16], zmm3, 2 + vextracti32x4 xmmword ptr [r9 + 12 * 16], zmm0, 3 + vextracti32x4 xmmword ptr [r9 + 13 * 16], zmm1, 3 + vextracti32x4 xmmword ptr [r9 + 14 * 16], zmm2, 3 + vextracti32x4 xmmword ptr [r9 + 15 * 16], zmm3, 3 + ret .balign 16 BLAKE3_IV: BLAKE3_IV_0: diff --git a/src/kernel.rs b/src/kernel.rs index a6cca9f..f17db21 100644 --- a/src/kernel.rs +++ b/src/kernel.rs @@ -65,6 +65,14 @@ extern "C" { flags: u32, out: *mut [u8; 64 * 2], ); + pub fn blake3_avx512_xof_stream_4( + cv: &[u32; 8], + block: &[u8; 64], + counter: u64, + block_len: u32, + flags: u32, + out: *mut [u8; 64 * 4], + ); } pub type CompressionFn = @@ -138,7 +146,7 @@ mod test { let mut block = [0; 64]; let block_len = 53; crate::test::paint_test_input(&mut block[..block_len]); - let counter = u64::MAX - 42; + let counter = u32::MAX as u64; let flags = crate::CHUNK_START | crate::CHUNK_END | crate::ROOT; let mut expected = [0; N]; @@ -218,6 +226,15 @@ mod test { } test_xof_function(blake3_avx512_xof_stream_2); } + + #[test] + #[cfg(target_arch = "x86_64")] + fn test_avx512_xof_4() { + if !is_x86_feature_detected!("avx512f") || !is_x86_feature_detected!("avx512vl") { + return; + } + test_xof_function(blake3_avx512_xof_stream_4); + } } global_asm!( |
