aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJack O'Connor <[email protected]>2022-03-20 12:04:44 -0400
committerJack O'Connor <[email protected]>2022-03-20 16:15:12 -0400
commit08288c73bd15585e769b986fcfd114a019263805 (patch)
tree54b4efa16b461b5a9d175d95e87c99f04bd64c02
parent18962919e98cb1d7bb9ed78d2669fe47bf2ed645 (diff)
initial xof_stream functions
-rwxr-xr-xasm/asm.py108
-rw-r--r--asm/out.S97
-rw-r--r--src/kernel.rs100
3 files changed, 305 insertions, 0 deletions
diff --git a/asm/asm.py b/asm/asm.py
index 739bd2b..755f8e8 100755
--- a/asm/asm.py
+++ b/asm/asm.py
@@ -481,6 +481,111 @@ def compress(target, output):
output.append(target.ret())
+def xof_setup2d(target, output, degree):
+ if target.extension == AVX512:
+ if degree == 1:
+ # state words
+ output.append(f"vmovdqu xmm0, xmmword ptr [{target.arg64(0)}]")
+ output.append(f"vmovdqu xmm1, xmmword ptr [{target.arg64(0)}+0x10]")
+ # flags
+ output.append(f"shl {target.arg64(4)}, 32")
+ # block length
+ output.append(f"mov {target.arg32(3)}, {target.arg32(3)}")
+ output.append(f"or {target.arg64(3)}, {target.arg64(4)}")
+ # counter
+ output.append(f"vmovq xmm3, {target.arg64(2)}")
+ output.append(f"vmovq xmm4, {target.arg64(3)}")
+ output.append(f"vpunpcklqdq xmm3, xmm3, xmm4")
+ output.append(f"vmovaps xmm2, xmmword ptr [BLAKE3_IV+rip]")
+ # message words
+ # fmt: off
+ output.append(f"vmovups xmm8, xmmword ptr [{target.arg64(1)}]") # xmm8 = m0 m1 m2 m3
+ output.append(f"vmovups xmm9, xmmword ptr [{target.arg64(1)}+0x10]") # xmm9 = m4 m5 m6 m7
+ output.append(f"vshufps xmm4, xmm8, xmm9, 136") # xmm4 = m0 m2 m4 m6
+ output.append(f"vshufps xmm5, xmm8, xmm9, 221") # xmm5 = m1 m3 m5 m7
+ output.append(f"vmovups xmm8, xmmword ptr [{target.arg64(1)}+0x20]") # xmm8 = m8 m9 m10 m12
+ output.append(f"vmovups xmm9, xmmword ptr [{target.arg64(1)}+0x30]") # xmm9 = m12 m13 m14 m15
+ output.append(f"vshufps xmm6, xmm8, xmm9, 136") # xmm6 = m8 m10 m12 m14
+ output.append(f"vshufps xmm7, xmm8, xmm9, 221") # xmm7 = m9 m11 m13 m15
+ output.append(f"vpshufd xmm6, xmm6, 0x93") # xmm6 = m14 m8 m10 m12
+ output.append(f"vpshufd xmm7, xmm7, 0x93") # xmm7 = m15 m9 m11 m13
+ # fmt: on
+ else:
+ raise NotImplementedError
+ elif target.extension in (SSE41, SSE2):
+ assert degree == 1
+ output.append(f"movups xmm0, xmmword ptr [{target.arg64(0)}]")
+ output.append(f"movups xmm1, xmmword ptr [{target.arg64(0)}+0x10]")
+ output.append(f"movaps xmm2, xmmword ptr [BLAKE3_IV+rip]")
+ output.append(f"shl {target.arg64(4)}, 32")
+ output.append(f"mov {target.arg32(3)}, {target.arg32(3)}")
+ output.append(f"or {target.arg64(3)}, {target.arg64(4)}")
+ output.append(f"vmovq xmm3, {target.arg64(2)}")
+ output.append(f"vmovq xmm4, {target.arg64(3)}")
+ output.append(f"punpcklqdq xmm3, xmm4")
+ output.append(f"movups xmm4, xmmword ptr [{target.arg64(1)}]")
+ output.append(f"movups xmm5, xmmword ptr [{target.arg64(1)}+0x10]")
+ output.append(f"movaps xmm8, xmm4")
+ output.append(f"shufps xmm4, xmm5, 136")
+ output.append(f"shufps xmm8, xmm5, 221")
+ output.append(f"movaps xmm5, xmm8")
+ output.append(f"movups xmm6, xmmword ptr [{target.arg64(1)}+0x20]")
+ output.append(f"movups xmm7, xmmword ptr [{target.arg64(1)}+0x30]")
+ output.append(f"movaps xmm8, xmm6")
+ output.append(f"shufps xmm6, xmm7, 136")
+ output.append(f"pshufd xmm6, xmm6, 0x93")
+ output.append(f"shufps xmm8, xmm7, 221")
+ output.append(f"pshufd xmm7, xmm8, 0x93")
+ else:
+ raise NotImplementedError
+
+
+def xof_stream_finish2d(target, output, degree):
+ if target.extension == AVX512:
+ if degree == 1:
+ output.append(f"vpxor xmm2, xmm2, [{target.arg64(0)}]")
+ output.append(f"vpxor xmm3, xmm3, [{target.arg64(0)}+0x10]")
+ output.append(f"vmovdqu xmmword ptr [{target.arg64(5)}], xmm0")
+ output.append(f"vmovdqu xmmword ptr [{target.arg64(5)}+0x10], xmm1")
+ output.append(f"vmovdqu xmmword ptr [{target.arg64(5)}+0x20], xmm2")
+ output.append(f"vmovdqu xmmword ptr [{target.arg64(5)}+0x30], xmm3")
+ else:
+ raise NotImplementedError
+ elif target.extension in (SSE41, SSE2):
+ assert degree == 1
+ output.append(f"movdqu xmm4, xmmword ptr [{target.arg64(0)}]")
+ output.append(f"movdqu xmm5, xmmword ptr [{target.arg64(0)}+0x10]")
+ output.append(f"pxor xmm2, xmm4")
+ output.append(f"pxor xmm3, xmm5")
+ output.append(f"movups xmmword ptr [{target.arg64(5)}], xmm0")
+ output.append(f"movups xmmword ptr [{target.arg64(5)}+0x10], xmm1")
+ output.append(f"movups xmmword ptr [{target.arg64(5)}+0x20], xmm2")
+ output.append(f"movups xmmword ptr [{target.arg64(5)}+0x30], xmm3")
+ else:
+ raise NotImplementedError
+
+
+def xof_stream(target, output, degree):
+ label = f"blake3_{target.extension}_xof_stream_{degree}"
+ output.append(f".global {label}")
+ output.append(f"{label}:")
+ if target.extension == AVX512:
+ if degree == 1:
+ xof_setup2d(target, output, degree)
+ output.append(f"call {kernel2d_name(target, degree)}")
+ xof_stream_finish2d(target, output, degree)
+ else:
+ raise NotImplementedError
+ elif target.extension in (SSE41, SSE2):
+ assert degree == 1
+ xof_setup2d(target, output, degree)
+ output.append(f"call {kernel2d_name(target, degree)}")
+ xof_stream_finish2d(target, output, degree)
+ else:
+ raise NotImplementedError
+ output.append(target.ret())
+
+
def emit_prelude(target, output):
# output.append(".intel_syntax noprefix")
pass
@@ -490,6 +595,7 @@ def emit_sse2(target, output):
target = replace(target, extension=SSE2)
kernel2d(target, output, 1)
compress(target, output)
+ xof_stream(target, output, 1)
output.append(".balign 16")
output.append("PBLENDW_0x33_MASK:")
output.append(".long 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000")
@@ -505,6 +611,7 @@ def emit_sse41(target, output):
target = replace(target, extension=SSE41)
kernel2d(target, output, 1)
compress(target, output)
+ xof_stream(target, output, 1)
def emit_avx2(target, output):
@@ -518,6 +625,7 @@ def emit_avx512(target, output):
kernel2d(target, output, 2)
kernel2d(target, output, 4)
compress(target, output)
+ xof_stream(target, output, 1)
def emit_footer(target, output):
diff --git a/asm/out.S b/asm/out.S
index 1d7eb82..d75c4a4 100644
--- a/asm/out.S
+++ b/asm/out.S
@@ -535,6 +535,40 @@ blake3_sse2_compress:
movups xmmword ptr [rdi], xmm0
movups xmmword ptr [rdi+0x10], xmm1
ret
+.global blake3_sse2_xof_stream_1
+blake3_sse2_xof_stream_1:
+ movups xmm0, xmmword ptr [rdi]
+ movups xmm1, xmmword ptr [rdi+0x10]
+ movaps xmm2, xmmword ptr [BLAKE3_IV+rip]
+ shl r8, 32
+ mov ecx, ecx
+ or rcx, r8
+ vmovq xmm3, rdx
+ vmovq xmm4, rcx
+ punpcklqdq xmm3, xmm4
+ movups xmm4, xmmword ptr [rsi]
+ movups xmm5, xmmword ptr [rsi+0x10]
+ movaps xmm8, xmm4
+ shufps xmm4, xmm5, 136
+ shufps xmm8, xmm5, 221
+ movaps xmm5, xmm8
+ movups xmm6, xmmword ptr [rsi+0x20]
+ movups xmm7, xmmword ptr [rsi+0x30]
+ movaps xmm8, xmm6
+ shufps xmm6, xmm7, 136
+ pshufd xmm6, xmm6, 0x93
+ shufps xmm8, xmm7, 221
+ pshufd xmm7, xmm8, 0x93
+ call blake3_sse2_kernel2d_1
+ movdqu xmm4, xmmword ptr [rdi]
+ movdqu xmm5, xmmword ptr [rdi+0x10]
+ pxor xmm2, xmm4
+ pxor xmm3, xmm5
+ movups xmmword ptr [r9], xmm0
+ movups xmmword ptr [r9+0x10], xmm1
+ movups xmmword ptr [r9+0x20], xmm2
+ movups xmmword ptr [r9+0x30], xmm3
+ ret
.balign 16
PBLENDW_0x33_MASK:
.long 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000
@@ -996,6 +1030,40 @@ blake3_sse41_compress:
movups xmmword ptr [rdi], xmm0
movups xmmword ptr [rdi+0x10], xmm1
ret
+.global blake3_sse41_xof_stream_1
+blake3_sse41_xof_stream_1:
+ movups xmm0, xmmword ptr [rdi]
+ movups xmm1, xmmword ptr [rdi+0x10]
+ movaps xmm2, xmmword ptr [BLAKE3_IV+rip]
+ shl r8, 32
+ mov ecx, ecx
+ or rcx, r8
+ vmovq xmm3, rdx
+ vmovq xmm4, rcx
+ punpcklqdq xmm3, xmm4
+ movups xmm4, xmmword ptr [rsi]
+ movups xmm5, xmmword ptr [rsi+0x10]
+ movaps xmm8, xmm4
+ shufps xmm4, xmm5, 136
+ shufps xmm8, xmm5, 221
+ movaps xmm5, xmm8
+ movups xmm6, xmmword ptr [rsi+0x20]
+ movups xmm7, xmmword ptr [rsi+0x30]
+ movaps xmm8, xmm6
+ shufps xmm6, xmm7, 136
+ pshufd xmm6, xmm6, 0x93
+ shufps xmm8, xmm7, 221
+ pshufd xmm7, xmm8, 0x93
+ call blake3_sse41_kernel2d_1
+ movdqu xmm4, xmmword ptr [rdi]
+ movdqu xmm5, xmmword ptr [rdi+0x10]
+ pxor xmm2, xmm4
+ pxor xmm3, xmm5
+ movups xmmword ptr [r9], xmm0
+ movups xmmword ptr [r9+0x10], xmm1
+ movups xmmword ptr [r9+0x20], xmm2
+ movups xmmword ptr [r9+0x30], xmm3
+ ret
blake3_avx2_kernel2d_2:
vbroadcasti128 ymm14, xmmword ptr [ROT16+rip]
vbroadcasti128 ymm15, xmmword ptr [ROT8+rip]
@@ -2359,6 +2427,35 @@ blake3_avx512_compress:
vmovdqu xmmword ptr [rdi], xmm0
vmovdqu xmmword ptr [rdi+0x10], xmm1
ret
+.global blake3_avx512_xof_stream_1
+blake3_avx512_xof_stream_1:
+ vmovdqu xmm0, xmmword ptr [rdi]
+ vmovdqu xmm1, xmmword ptr [rdi+0x10]
+ shl r8, 32
+ mov ecx, ecx
+ or rcx, r8
+ vmovq xmm3, rdx
+ vmovq xmm4, rcx
+ vpunpcklqdq xmm3, xmm3, xmm4
+ vmovaps xmm2, xmmword ptr [BLAKE3_IV+rip]
+ vmovups xmm8, xmmword ptr [rsi]
+ vmovups xmm9, xmmword ptr [rsi+0x10]
+ vshufps xmm4, xmm8, xmm9, 136
+ vshufps xmm5, xmm8, xmm9, 221
+ vmovups xmm8, xmmword ptr [rsi+0x20]
+ vmovups xmm9, xmmword ptr [rsi+0x30]
+ vshufps xmm6, xmm8, xmm9, 136
+ vshufps xmm7, xmm8, xmm9, 221
+ vpshufd xmm6, xmm6, 0x93
+ vpshufd xmm7, xmm7, 0x93
+ call blake3_avx512_kernel2d_1
+ vpxor xmm2, xmm2, [rdi]
+ vpxor xmm3, xmm3, [rdi+0x10]
+ vmovdqu xmmword ptr [r9], xmm0
+ vmovdqu xmmword ptr [r9+0x10], xmm1
+ vmovdqu xmmword ptr [r9+0x20], xmm2
+ vmovdqu xmmword ptr [r9+0x30], xmm3
+ ret
.balign 16
BLAKE3_IV:
BLAKE3_IV_0:
diff --git a/src/kernel.rs b/src/kernel.rs
index 5e37bdb..3491bab 100644
--- a/src/kernel.rs
+++ b/src/kernel.rs
@@ -25,11 +25,44 @@ extern "C" {
block_len: u32,
flags: u32,
);
+ pub fn blake3_sse2_xof_stream_1(
+ cv: &[u32; 8],
+ block: &[u8; 64],
+ counter: u64,
+ block_len: u32,
+ flags: u32,
+ out: *mut [u8; 64],
+ );
+ pub fn blake3_sse41_xof_stream_1(
+ cv: &[u32; 8],
+ block: &[u8; 64],
+ counter: u64,
+ block_len: u32,
+ flags: u32,
+ out: *mut [u8; 64],
+ );
+ pub fn blake3_avx512_xof_stream_1(
+ cv: &[u32; 8],
+ block: &[u8; 64],
+ counter: u64,
+ block_len: u32,
+ flags: u32,
+ out: *mut [u8; 64],
+ );
}
pub type CompressionFn =
unsafe extern "C" fn(cv: &[u32; 8], block: &[u8; 64], counter: u64, block_len: u32, flags: u32);
+pub type XofStreamFn<const N: usize> = unsafe extern "C" fn(
+ cv: &[u32; 8],
+ block: &[u8; 64],
+ counter: u64,
+ block_len: u32,
+ flags: u32,
+ out: *mut [u8; N],
+);
+
#[cfg(test)]
mod test {
use super::*;
@@ -84,6 +117,73 @@ mod test {
}
test_compression_function(blake3_avx512_compress);
}
+
+ fn test_xof_function<const N: usize>(f: XofStreamFn<N>) {
+ let mut block = [0; 64];
+ let block_len = 53;
+ crate::test::paint_test_input(&mut block[..block_len]);
+ let counter = u64::MAX - 42;
+ let flags = crate::CHUNK_START | crate::CHUNK_END | crate::ROOT;
+
+ let mut expected = [0; N];
+ let mut incrementing_counter = counter;
+ let mut i = 0;
+ assert_eq!(0, N % 64);
+ while i < N {
+ let out_block: &mut [u8; 64] = (&mut expected[i..][..64]).try_into().unwrap();
+ *out_block = crate::platform::Platform::Portable.compress_xof(
+ crate::IV,
+ &block,
+ block_len as u8,
+ incrementing_counter,
+ flags,
+ );
+ i += 64;
+ incrementing_counter += 1;
+ }
+ assert_eq!(incrementing_counter, counter + N as u64 / 64);
+
+ let mut found = [0; N];
+ unsafe {
+ f(
+ crate::IV,
+ &block,
+ counter,
+ block_len as u32,
+ flags as u32,
+ &mut found,
+ );
+ }
+
+ assert_eq!(expected, found);
+ }
+
+ #[test]
+ #[cfg(target_arch = "x86_64")]
+ fn test_sse2_xof_1() {
+ if !is_x86_feature_detected!("sse2") {
+ return;
+ }
+ test_xof_function(blake3_sse2_xof_stream_1);
+ }
+
+ #[test]
+ #[cfg(target_arch = "x86_64")]
+ fn test_sse41_xof_1() {
+ if !is_x86_feature_detected!("sse2") {
+ return;
+ }
+ test_xof_function(blake3_sse41_xof_stream_1);
+ }
+
+ #[test]
+ #[cfg(target_arch = "x86_64")]
+ fn test_avx512_xof_1() {
+ if !is_x86_feature_detected!("sse2") {
+ return;
+ }
+ test_xof_function(blake3_avx512_xof_stream_1);
+ }
}
global_asm!(