aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJack O'Connor <[email protected]>2022-03-20 18:35:41 -0400
committerJack O'Connor <[email protected]>2022-03-20 18:35:41 -0400
commit9139fa40e8e0c17a1afd5e137eaca5a25d865066 (patch)
tree7681e75ab34d88474205859e2ebdbb99cf7b9723
parent39ee6f486858b3ccd54e615a9ded3dc32df46b82 (diff)
blake3_avx2_xof_stream_2
-rwxr-xr-xasm/asm.py50
-rw-r--r--asm/out.S37
-rw-r--r--src/kernel.rs17
3 files changed, 104 insertions, 0 deletions
diff --git a/asm/asm.py b/asm/asm.py
index da96c54..1526ae8 100755
--- a/asm/asm.py
+++ b/asm/asm.py
@@ -543,6 +543,37 @@ def xof_setup2d(target, output, degree):
output.append(f"vpshufd ymm7, ymm7, 0x93")
else:
raise NotImplementedError
+ elif target.extension == AVX2:
+ # Load the state words.
+ output.append(f"vbroadcasti128 ymm0, xmmword ptr [{target.arg64(0)}]")
+ output.append(f"vbroadcasti128 ymm1, xmmword ptr [{target.arg64(0)}+0x10]")
+ # Load the counter increments.
+ output.append(f"vmovdqa ymm4, ymmword ptr [INCREMENT_2D+rip]")
+ # Load the IV constants.
+ output.append(f"vbroadcasti128 ymm2, xmmword ptr [BLAKE3_IV+rip]")
+ # Broadcast the counter.
+ output.append(f"vpbroadcastq ymm5, {target.arg64(2)}")
+ # Add the counter increments to the counter.
+ output.append(f"vpaddq ymm6, ymm4, ymm5")
+ # Combine the block length and flags into a 64-bit word.
+ output.append(f"shl {target.arg64(4)}, 32")
+ output.append(f"mov {target.arg32(3)}, {target.arg32(3)}")
+ output.append(f"or {target.arg64(3)}, {target.arg64(4)}")
+ # Broadcast the block length and flags.
+ output.append(f"vpbroadcastq ymm7, {target.arg64(3)}")
+ # Blend the counter, block length, and flags.
+ output.append(f"vpblendd ymm3, ymm6, ymm7, 0xCC")
+ # Load and permute the message words.
+ output.append(f"vbroadcasti128 ymm8, xmmword ptr [{target.arg64(1)}]")
+ output.append(f"vbroadcasti128 ymm9, xmmword ptr [{target.arg64(1)}+0x10]")
+ output.append(f"vshufps ymm4, ymm8, ymm9, 136")
+ output.append(f"vshufps ymm5, ymm8, ymm9, 221")
+ output.append(f"vbroadcasti128 ymm8, xmmword ptr [{target.arg64(1)}+0x20]")
+ output.append(f"vbroadcasti128 ymm9, xmmword ptr [{target.arg64(1)}+0x30]")
+ output.append(f"vshufps ymm6, ymm8, ymm9, 136")
+ output.append(f"vshufps ymm7, ymm8, ymm9, 221")
+ output.append(f"vpshufd ymm6, ymm6, 0x93")
+ output.append(f"vpshufd ymm7, ymm7, 0x93")
elif target.extension in (SSE41, SSE2):
assert degree == 1
output.append(f"movups xmm0, xmmword ptr [{target.arg64(0)}]")
@@ -603,6 +634,19 @@ def xof_stream_finish2d(target, output, degree):
)
else:
raise NotImplementedError
+ elif target.extension == AVX2:
+ output.append(f"vbroadcasti128 ymm4, xmmword ptr [{target.arg64(0)}]")
+ output.append(f"vpxor ymm2, ymm2, ymm4")
+ output.append(f"vbroadcasti128 ymm5, xmmword ptr [{target.arg64(0)} + 16]")
+ output.append(f"vpxor ymm3, ymm3, ymm5")
+ output.append(f"vmovdqu xmmword ptr [{target.arg64(5)} + 0 * 16], xmm0")
+ output.append(f"vmovdqu xmmword ptr [{target.arg64(5)} + 1 * 16], xmm1")
+ output.append(f"vmovdqu xmmword ptr [{target.arg64(5)} + 2 * 16], xmm2")
+ output.append(f"vmovdqu xmmword ptr [{target.arg64(5)} + 3 * 16], xmm3")
+ output.append(f"vextracti128 xmmword ptr [{target.arg64(5)} + 4 * 16], ymm0, 1")
+ output.append(f"vextracti128 xmmword ptr [{target.arg64(5)} + 5 * 16], ymm1, 1")
+ output.append(f"vextracti128 xmmword ptr [{target.arg64(5)} + 6 * 16], ymm2, 1")
+ output.append(f"vextracti128 xmmword ptr [{target.arg64(5)} + 7 * 16], ymm3, 1")
elif target.extension in (SSE41, SSE2):
assert degree == 1
output.append(f"movdqu xmm4, xmmword ptr [{target.arg64(0)}]")
@@ -628,6 +672,11 @@ def xof_stream(target, output, degree):
xof_stream_finish2d(target, output, degree)
else:
raise NotImplementedError
+ elif target.extension == AVX2:
+ assert degree == 2
+ xof_setup2d(target, output, degree)
+ output.append(f"call {kernel2d_name(target, degree)}")
+ xof_stream_finish2d(target, output, degree)
elif target.extension in (SSE41, SSE2):
assert degree == 1
xof_setup2d(target, output, degree)
@@ -669,6 +718,7 @@ def emit_sse41(target, output):
def emit_avx2(target, output):
target = replace(target, extension=AVX2)
kernel2d(target, output, 2)
+ xof_stream(target, output, 2)
def emit_avx512(target, output):
diff --git a/asm/out.S b/asm/out.S
index 14a5929..7ec3321 100644
--- a/asm/out.S
+++ b/asm/out.S
@@ -1442,6 +1442,43 @@ blake3_avx2_kernel2d_2:
vpxord ymm0, ymm0, ymm2
vpxord ymm1, ymm1, ymm3
ret
+.global blake3_avx2_xof_stream_2
+blake3_avx2_xof_stream_2:
+ vbroadcasti128 ymm0, xmmword ptr [rdi]
+ vbroadcasti128 ymm1, xmmword ptr [rdi+0x10]
+ vmovdqa ymm4, ymmword ptr [INCREMENT_2D+rip]
+ vbroadcasti128 ymm2, xmmword ptr [BLAKE3_IV+rip]
+ vpbroadcastq ymm5, rdx
+ vpaddq ymm6, ymm4, ymm5
+ shl r8, 32
+ mov ecx, ecx
+ or rcx, r8
+ vpbroadcastq ymm7, rcx
+ vpblendd ymm3, ymm6, ymm7, 0xCC
+ vbroadcasti128 ymm8, xmmword ptr [rsi]
+ vbroadcasti128 ymm9, xmmword ptr [rsi+0x10]
+ vshufps ymm4, ymm8, ymm9, 136
+ vshufps ymm5, ymm8, ymm9, 221
+ vbroadcasti128 ymm8, xmmword ptr [rsi+0x20]
+ vbroadcasti128 ymm9, xmmword ptr [rsi+0x30]
+ vshufps ymm6, ymm8, ymm9, 136
+ vshufps ymm7, ymm8, ymm9, 221
+ vpshufd ymm6, ymm6, 0x93
+ vpshufd ymm7, ymm7, 0x93
+ call blake3_avx2_kernel2d_2
+ vbroadcasti128 ymm4, xmmword ptr [rdi]
+ vpxor ymm2, ymm2, ymm4
+ vbroadcasti128 ymm5, xmmword ptr [rdi + 16]
+ vpxor ymm3, ymm3, ymm5
+ vmovdqu xmmword ptr [r9 + 0 * 16], xmm0
+ vmovdqu xmmword ptr [r9 + 1 * 16], xmm1
+ vmovdqu xmmword ptr [r9 + 2 * 16], xmm2
+ vmovdqu xmmword ptr [r9 + 3 * 16], xmm3
+ vextracti128 xmmword ptr [r9 + 4 * 16], ymm0, 1
+ vextracti128 xmmword ptr [r9 + 5 * 16], ymm1, 1
+ vextracti128 xmmword ptr [r9 + 6 * 16], ymm2, 1
+ vextracti128 xmmword ptr [r9 + 7 * 16], ymm3, 1
+ ret
blake3_avx512_kernel2d_1:
vpaddd xmm0, xmm0, xmm4
vpaddd xmm0, xmm0, xmm1
diff --git a/src/kernel.rs b/src/kernel.rs
index cd81bb5..a6cca9f 100644
--- a/src/kernel.rs
+++ b/src/kernel.rs
@@ -49,6 +49,14 @@ extern "C" {
flags: u32,
out: *mut [u8; 64],
);
+ pub fn blake3_avx2_xof_stream_2(
+ cv: &[u32; 8],
+ block: &[u8; 64],
+ counter: u64,
+ block_len: u32,
+ flags: u32,
+ out: *mut [u8; 64 * 2],
+ );
pub fn blake3_avx512_xof_stream_2(
cv: &[u32; 8],
block: &[u8; 64],
@@ -195,6 +203,15 @@ mod test {
#[test]
#[cfg(target_arch = "x86_64")]
+ fn test_avx2_xof_2() {
+ if !is_x86_feature_detected!("avx2") {
+ return;
+ }
+ test_xof_function(blake3_avx2_xof_stream_2);
+ }
+
+ #[test]
+ #[cfg(target_arch = "x86_64")]
fn test_avx512_xof_2() {
if !is_x86_feature_detected!("avx512f") || !is_x86_feature_detected!("avx512vl") {
return;