aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJack O'Connor <[email protected]>2022-03-20 18:26:02 -0400
committerJack O'Connor <[email protected]>2022-03-20 18:26:02 -0400
commit39ee6f486858b3ccd54e615a9ded3dc32df46b82 (patch)
tree9e801df0f8c4ff8cb76d4fcfd39b6c221a54565d
parent08288c73bd15585e769b986fcfd114a019263805 (diff)
blake3_avx512_xof_stream_2
-rwxr-xr-xasm/asm.py59
-rw-r--r--asm/out.S40
-rw-r--r--src/kernel.rs21
3 files changed, 117 insertions, 3 deletions
diff --git a/asm/asm.py b/asm/asm.py
index 755f8e8..da96c54 100755
--- a/asm/asm.py
+++ b/asm/asm.py
@@ -510,6 +510,37 @@ def xof_setup2d(target, output, degree):
output.append(f"vpshufd xmm6, xmm6, 0x93") # xmm6 = m14 m8 m10 m12
output.append(f"vpshufd xmm7, xmm7, 0x93") # xmm7 = m15 m9 m11 m13
# fmt: on
+ elif degree == 2:
+ # Load the state words.
+ output.append(f"vbroadcasti128 ymm0, xmmword ptr [{target.arg64(0)}]")
+ output.append(f"vbroadcasti128 ymm1, xmmword ptr [{target.arg64(0)}+0x10]")
+ # Load the counter increments.
+ output.append(f"vmovdqa ymm4, ymmword ptr [INCREMENT_2D+rip]")
+ # Load the IV constants.
+ output.append(f"vbroadcasti128 ymm2, xmmword ptr [BLAKE3_IV+rip]")
+ # Broadcast the counter.
+ output.append(f"vpbroadcastq ymm5, {target.arg64(2)}")
+ # Add the counter increments to the counter.
+ output.append(f"vpaddq ymm6, ymm4, ymm5")
+ # Combine the block length and flags into a 64-bit word.
+ output.append(f"shl {target.arg64(4)}, 32")
+ output.append(f"mov {target.arg32(3)}, {target.arg32(3)}")
+ output.append(f"or {target.arg64(3)}, {target.arg64(4)}")
+ # Broadcast the block length and flags.
+ output.append(f"vpbroadcastq ymm7, {target.arg64(3)}")
+ # Blend the counter, block length, and flags.
+ output.append(f"vpblendd ymm3, ymm6, ymm7, 0xCC")
+ # Load and permute the message words.
+ output.append(f"vbroadcasti128 ymm8, xmmword ptr [{target.arg64(1)}]")
+ output.append(f"vbroadcasti128 ymm9, xmmword ptr [{target.arg64(1)}+0x10]")
+ output.append(f"vshufps ymm4, ymm8, ymm9, 136")
+ output.append(f"vshufps ymm5, ymm8, ymm9, 221")
+ output.append(f"vbroadcasti128 ymm8, xmmword ptr [{target.arg64(1)}+0x20]")
+ output.append(f"vbroadcasti128 ymm9, xmmword ptr [{target.arg64(1)}+0x30]")
+ output.append(f"vshufps ymm6, ymm8, ymm9, 136")
+ output.append(f"vshufps ymm7, ymm8, ymm9, 221")
+ output.append(f"vpshufd ymm6, ymm6, 0x93")
+ output.append(f"vpshufd ymm7, ymm7, 0x93")
else:
raise NotImplementedError
elif target.extension in (SSE41, SSE2):
@@ -549,6 +580,27 @@ def xof_stream_finish2d(target, output, degree):
output.append(f"vmovdqu xmmword ptr [{target.arg64(5)}+0x10], xmm1")
output.append(f"vmovdqu xmmword ptr [{target.arg64(5)}+0x20], xmm2")
output.append(f"vmovdqu xmmword ptr [{target.arg64(5)}+0x30], xmm3")
+ elif degree == 2:
+ output.append(f"vbroadcasti128 ymm4, xmmword ptr [{target.arg64(0)}]")
+ output.append(f"vpxor ymm2, ymm2, ymm4")
+ output.append(f"vbroadcasti128 ymm5, xmmword ptr [{target.arg64(0)} + 16]")
+ output.append(f"vpxor ymm3, ymm3, ymm5")
+ output.append(f"vmovdqu xmmword ptr [{target.arg64(5)} + 0 * 16], xmm0")
+ output.append(f"vmovdqu xmmword ptr [{target.arg64(5)} + 1 * 16], xmm1")
+ output.append(f"vmovdqu xmmword ptr [{target.arg64(5)} + 2 * 16], xmm2")
+ output.append(f"vmovdqu xmmword ptr [{target.arg64(5)} + 3 * 16], xmm3")
+ output.append(
+ f"vextracti128 xmmword ptr [{target.arg64(5)} + 4 * 16], ymm0, 1"
+ )
+ output.append(
+ f"vextracti128 xmmword ptr [{target.arg64(5)} + 5 * 16], ymm1, 1"
+ )
+ output.append(
+ f"vextracti128 xmmword ptr [{target.arg64(5)} + 6 * 16], ymm2, 1"
+ )
+ output.append(
+ f"vextracti128 xmmword ptr [{target.arg64(5)} + 7 * 16], ymm3, 1"
+ )
else:
raise NotImplementedError
elif target.extension in (SSE41, SSE2):
@@ -570,7 +622,7 @@ def xof_stream(target, output, degree):
output.append(f".global {label}")
output.append(f"{label}:")
if target.extension == AVX512:
- if degree == 1:
+ if degree in (1, 2, 4):
xof_setup2d(target, output, degree)
output.append(f"call {kernel2d_name(target, degree)}")
xof_stream_finish2d(target, output, degree)
@@ -626,6 +678,7 @@ def emit_avx512(target, output):
kernel2d(target, output, 4)
compress(target, output)
xof_stream(target, output, 1)
+ xof_stream(target, output, 2)
def emit_footer(target, output):
@@ -646,6 +699,10 @@ def emit_footer(target, output):
output.append("ROT8:")
output.append(".byte 1, 2, 3, 0, 5, 6, 7, 4, 9, 10, 11, 8, 13, 14, 15, 12")
+ output.append(".balign 64")
+ output.append("INCREMENT_2D:")
+ output.append(".quad 0, 0, 1, 0, 2, 0, 3, 0")
+
def format(output):
print("# This file is generated by asm.py. Don't edit this file directly.")
diff --git a/asm/out.S b/asm/out.S
index d75c4a4..14a5929 100644
--- a/asm/out.S
+++ b/asm/out.S
@@ -2456,6 +2456,43 @@ blake3_avx512_xof_stream_1:
vmovdqu xmmword ptr [r9+0x20], xmm2
vmovdqu xmmword ptr [r9+0x30], xmm3
ret
+.global blake3_avx512_xof_stream_2
+blake3_avx512_xof_stream_2:
+ vbroadcasti128 ymm0, xmmword ptr [rdi]
+ vbroadcasti128 ymm1, xmmword ptr [rdi+0x10]
+ vmovdqa ymm4, ymmword ptr [INCREMENT_2D+rip]
+ vbroadcasti128 ymm2, xmmword ptr [BLAKE3_IV+rip]
+ vpbroadcastq ymm5, rdx
+ vpaddq ymm6, ymm4, ymm5
+ shl r8, 32
+ mov ecx, ecx
+ or rcx, r8
+ vpbroadcastq ymm7, rcx
+ vpblendd ymm3, ymm6, ymm7, 0xCC
+ vbroadcasti128 ymm8, xmmword ptr [rsi]
+ vbroadcasti128 ymm9, xmmword ptr [rsi+0x10]
+ vshufps ymm4, ymm8, ymm9, 136
+ vshufps ymm5, ymm8, ymm9, 221
+ vbroadcasti128 ymm8, xmmword ptr [rsi+0x20]
+ vbroadcasti128 ymm9, xmmword ptr [rsi+0x30]
+ vshufps ymm6, ymm8, ymm9, 136
+ vshufps ymm7, ymm8, ymm9, 221
+ vpshufd ymm6, ymm6, 0x93
+ vpshufd ymm7, ymm7, 0x93
+ call blake3_avx512_kernel2d_2
+ vbroadcasti128 ymm4, xmmword ptr [rdi]
+ vpxor ymm2, ymm2, ymm4
+ vbroadcasti128 ymm5, xmmword ptr [rdi + 16]
+ vpxor ymm3, ymm3, ymm5
+ vmovdqu xmmword ptr [r9 + 0 * 16], xmm0
+ vmovdqu xmmword ptr [r9 + 1 * 16], xmm1
+ vmovdqu xmmword ptr [r9 + 2 * 16], xmm2
+ vmovdqu xmmword ptr [r9 + 3 * 16], xmm3
+ vextracti128 xmmword ptr [r9 + 4 * 16], ymm0, 1
+ vextracti128 xmmword ptr [r9 + 5 * 16], ymm1, 1
+ vextracti128 xmmword ptr [r9 + 6 * 16], ymm2, 1
+ vextracti128 xmmword ptr [r9 + 7 * 16], ymm3, 1
+ ret
.balign 16
BLAKE3_IV:
BLAKE3_IV_0:
@@ -2471,3 +2508,6 @@ ROT16:
.byte 2, 3, 0, 1, 6, 7, 4, 5, 10, 11, 8, 9, 14, 15, 12, 13
ROT8:
.byte 1, 2, 3, 0, 5, 6, 7, 4, 9, 10, 11, 8, 13, 14, 15, 12
+.balign 64
+INCREMENT_2D:
+.quad 0, 0, 1, 0, 2, 0, 3, 0
diff --git a/src/kernel.rs b/src/kernel.rs
index 3491bab..cd81bb5 100644
--- a/src/kernel.rs
+++ b/src/kernel.rs
@@ -49,6 +49,14 @@ extern "C" {
flags: u32,
out: *mut [u8; 64],
);
+ pub fn blake3_avx512_xof_stream_2(
+ cv: &[u32; 8],
+ block: &[u8; 64],
+ counter: u64,
+ block_len: u32,
+ flags: u32,
+ out: *mut [u8; 64 * 2],
+ );
}
pub type CompressionFn =
@@ -170,7 +178,7 @@ mod test {
#[test]
#[cfg(target_arch = "x86_64")]
fn test_sse41_xof_1() {
- if !is_x86_feature_detected!("sse2") {
+ if !is_x86_feature_detected!("sse4.1") {
return;
}
test_xof_function(blake3_sse41_xof_stream_1);
@@ -179,11 +187,20 @@ mod test {
#[test]
#[cfg(target_arch = "x86_64")]
fn test_avx512_xof_1() {
- if !is_x86_feature_detected!("sse2") {
+ if !is_x86_feature_detected!("avx512f") || !is_x86_feature_detected!("avx512vl") {
return;
}
test_xof_function(blake3_avx512_xof_stream_1);
}
+
+ #[test]
+ #[cfg(target_arch = "x86_64")]
+ fn test_avx512_xof_2() {
+ if !is_x86_feature_detected!("avx512f") || !is_x86_feature_detected!("avx512vl") {
+ return;
+ }
+ test_xof_function(blake3_avx512_xof_stream_2);
+ }
}
global_asm!(