diff options
| author | Jack O'Connor <[email protected]> | 2022-03-20 18:26:02 -0400 |
|---|---|---|
| committer | Jack O'Connor <[email protected]> | 2022-03-20 18:26:02 -0400 |
| commit | 39ee6f486858b3ccd54e615a9ded3dc32df46b82 (patch) | |
| tree | 9e801df0f8c4ff8cb76d4fcfd39b6c221a54565d | |
| parent | 08288c73bd15585e769b986fcfd114a019263805 (diff) | |
blake3_avx512_xof_stream_2
| -rwxr-xr-x | asm/asm.py | 59 | ||||
| -rw-r--r-- | asm/out.S | 40 | ||||
| -rw-r--r-- | src/kernel.rs | 21 |
3 files changed, 117 insertions, 3 deletions
@@ -510,6 +510,37 @@ def xof_setup2d(target, output, degree): output.append(f"vpshufd xmm6, xmm6, 0x93") # xmm6 = m14 m8 m10 m12 output.append(f"vpshufd xmm7, xmm7, 0x93") # xmm7 = m15 m9 m11 m13 # fmt: on + elif degree == 2: + # Load the state words. + output.append(f"vbroadcasti128 ymm0, xmmword ptr [{target.arg64(0)}]") + output.append(f"vbroadcasti128 ymm1, xmmword ptr [{target.arg64(0)}+0x10]") + # Load the counter increments. + output.append(f"vmovdqa ymm4, ymmword ptr [INCREMENT_2D+rip]") + # Load the IV constants. + output.append(f"vbroadcasti128 ymm2, xmmword ptr [BLAKE3_IV+rip]") + # Broadcast the counter. + output.append(f"vpbroadcastq ymm5, {target.arg64(2)}") + # Add the counter increments to the counter. + output.append(f"vpaddq ymm6, ymm4, ymm5") + # Combine the block length and flags into a 64-bit word. + output.append(f"shl {target.arg64(4)}, 32") + output.append(f"mov {target.arg32(3)}, {target.arg32(3)}") + output.append(f"or {target.arg64(3)}, {target.arg64(4)}") + # Broadcast the block length and flags. + output.append(f"vpbroadcastq ymm7, {target.arg64(3)}") + # Blend the counter, block length, and flags. + output.append(f"vpblendd ymm3, ymm6, ymm7, 0xCC") + # Load and permute the message words. + output.append(f"vbroadcasti128 ymm8, xmmword ptr [{target.arg64(1)}]") + output.append(f"vbroadcasti128 ymm9, xmmword ptr [{target.arg64(1)}+0x10]") + output.append(f"vshufps ymm4, ymm8, ymm9, 136") + output.append(f"vshufps ymm5, ymm8, ymm9, 221") + output.append(f"vbroadcasti128 ymm8, xmmword ptr [{target.arg64(1)}+0x20]") + output.append(f"vbroadcasti128 ymm9, xmmword ptr [{target.arg64(1)}+0x30]") + output.append(f"vshufps ymm6, ymm8, ymm9, 136") + output.append(f"vshufps ymm7, ymm8, ymm9, 221") + output.append(f"vpshufd ymm6, ymm6, 0x93") + output.append(f"vpshufd ymm7, ymm7, 0x93") else: raise NotImplementedError elif target.extension in (SSE41, SSE2): @@ -549,6 +580,27 @@ def xof_stream_finish2d(target, output, degree): output.append(f"vmovdqu xmmword ptr [{target.arg64(5)}+0x10], xmm1") output.append(f"vmovdqu xmmword ptr [{target.arg64(5)}+0x20], xmm2") output.append(f"vmovdqu xmmword ptr [{target.arg64(5)}+0x30], xmm3") + elif degree == 2: + output.append(f"vbroadcasti128 ymm4, xmmword ptr [{target.arg64(0)}]") + output.append(f"vpxor ymm2, ymm2, ymm4") + output.append(f"vbroadcasti128 ymm5, xmmword ptr [{target.arg64(0)} + 16]") + output.append(f"vpxor ymm3, ymm3, ymm5") + output.append(f"vmovdqu xmmword ptr [{target.arg64(5)} + 0 * 16], xmm0") + output.append(f"vmovdqu xmmword ptr [{target.arg64(5)} + 1 * 16], xmm1") + output.append(f"vmovdqu xmmword ptr [{target.arg64(5)} + 2 * 16], xmm2") + output.append(f"vmovdqu xmmword ptr [{target.arg64(5)} + 3 * 16], xmm3") + output.append( + f"vextracti128 xmmword ptr [{target.arg64(5)} + 4 * 16], ymm0, 1" + ) + output.append( + f"vextracti128 xmmword ptr [{target.arg64(5)} + 5 * 16], ymm1, 1" + ) + output.append( + f"vextracti128 xmmword ptr [{target.arg64(5)} + 6 * 16], ymm2, 1" + ) + output.append( + f"vextracti128 xmmword ptr [{target.arg64(5)} + 7 * 16], ymm3, 1" + ) else: raise NotImplementedError elif target.extension in (SSE41, SSE2): @@ -570,7 +622,7 @@ def xof_stream(target, output, degree): output.append(f".global {label}") output.append(f"{label}:") if target.extension == AVX512: - if degree == 1: + if degree in (1, 2, 4): xof_setup2d(target, output, degree) output.append(f"call {kernel2d_name(target, degree)}") xof_stream_finish2d(target, output, degree) @@ -626,6 +678,7 @@ def emit_avx512(target, output): kernel2d(target, output, 4) compress(target, output) xof_stream(target, output, 1) + xof_stream(target, output, 2) def emit_footer(target, output): @@ -646,6 +699,10 @@ def emit_footer(target, output): output.append("ROT8:") output.append(".byte 1, 2, 3, 0, 5, 6, 7, 4, 9, 10, 11, 8, 13, 14, 15, 12") + output.append(".balign 64") + output.append("INCREMENT_2D:") + output.append(".quad 0, 0, 1, 0, 2, 0, 3, 0") + def format(output): print("# This file is generated by asm.py. Don't edit this file directly.") @@ -2456,6 +2456,43 @@ blake3_avx512_xof_stream_1: vmovdqu xmmword ptr [r9+0x20], xmm2 vmovdqu xmmword ptr [r9+0x30], xmm3 ret +.global blake3_avx512_xof_stream_2 +blake3_avx512_xof_stream_2: + vbroadcasti128 ymm0, xmmword ptr [rdi] + vbroadcasti128 ymm1, xmmword ptr [rdi+0x10] + vmovdqa ymm4, ymmword ptr [INCREMENT_2D+rip] + vbroadcasti128 ymm2, xmmword ptr [BLAKE3_IV+rip] + vpbroadcastq ymm5, rdx + vpaddq ymm6, ymm4, ymm5 + shl r8, 32 + mov ecx, ecx + or rcx, r8 + vpbroadcastq ymm7, rcx + vpblendd ymm3, ymm6, ymm7, 0xCC + vbroadcasti128 ymm8, xmmword ptr [rsi] + vbroadcasti128 ymm9, xmmword ptr [rsi+0x10] + vshufps ymm4, ymm8, ymm9, 136 + vshufps ymm5, ymm8, ymm9, 221 + vbroadcasti128 ymm8, xmmword ptr [rsi+0x20] + vbroadcasti128 ymm9, xmmword ptr [rsi+0x30] + vshufps ymm6, ymm8, ymm9, 136 + vshufps ymm7, ymm8, ymm9, 221 + vpshufd ymm6, ymm6, 0x93 + vpshufd ymm7, ymm7, 0x93 + call blake3_avx512_kernel2d_2 + vbroadcasti128 ymm4, xmmword ptr [rdi] + vpxor ymm2, ymm2, ymm4 + vbroadcasti128 ymm5, xmmword ptr [rdi + 16] + vpxor ymm3, ymm3, ymm5 + vmovdqu xmmword ptr [r9 + 0 * 16], xmm0 + vmovdqu xmmword ptr [r9 + 1 * 16], xmm1 + vmovdqu xmmword ptr [r9 + 2 * 16], xmm2 + vmovdqu xmmword ptr [r9 + 3 * 16], xmm3 + vextracti128 xmmword ptr [r9 + 4 * 16], ymm0, 1 + vextracti128 xmmword ptr [r9 + 5 * 16], ymm1, 1 + vextracti128 xmmword ptr [r9 + 6 * 16], ymm2, 1 + vextracti128 xmmword ptr [r9 + 7 * 16], ymm3, 1 + ret .balign 16 BLAKE3_IV: BLAKE3_IV_0: @@ -2471,3 +2508,6 @@ ROT16: .byte 2, 3, 0, 1, 6, 7, 4, 5, 10, 11, 8, 9, 14, 15, 12, 13 ROT8: .byte 1, 2, 3, 0, 5, 6, 7, 4, 9, 10, 11, 8, 13, 14, 15, 12 +.balign 64 +INCREMENT_2D: +.quad 0, 0, 1, 0, 2, 0, 3, 0 diff --git a/src/kernel.rs b/src/kernel.rs index 3491bab..cd81bb5 100644 --- a/src/kernel.rs +++ b/src/kernel.rs @@ -49,6 +49,14 @@ extern "C" { flags: u32, out: *mut [u8; 64], ); + pub fn blake3_avx512_xof_stream_2( + cv: &[u32; 8], + block: &[u8; 64], + counter: u64, + block_len: u32, + flags: u32, + out: *mut [u8; 64 * 2], + ); } pub type CompressionFn = @@ -170,7 +178,7 @@ mod test { #[test] #[cfg(target_arch = "x86_64")] fn test_sse41_xof_1() { - if !is_x86_feature_detected!("sse2") { + if !is_x86_feature_detected!("sse4.1") { return; } test_xof_function(blake3_sse41_xof_stream_1); @@ -179,11 +187,20 @@ mod test { #[test] #[cfg(target_arch = "x86_64")] fn test_avx512_xof_1() { - if !is_x86_feature_detected!("sse2") { + if !is_x86_feature_detected!("avx512f") || !is_x86_feature_detected!("avx512vl") { return; } test_xof_function(blake3_avx512_xof_stream_1); } + + #[test] + #[cfg(target_arch = "x86_64")] + fn test_avx512_xof_2() { + if !is_x86_feature_detected!("avx512f") || !is_x86_feature_detected!("avx512vl") { + return; + } + test_xof_function(blake3_avx512_xof_stream_2); + } } global_asm!( |
