diff options
| author | Jack O'Connor <[email protected]> | 2022-12-23 15:37:30 -0800 |
|---|---|---|
| committer | Jack O'Connor <[email protected]> | 2022-12-23 15:37:30 -0800 |
| commit | 93fd593881f2ee6cc0fc2074ec66076c0e67da27 (patch) | |
| tree | 2e4b4cba52313fb4c5701952a9e0e15370d333bf | |
| parent | 0c80427419382e696f91ffcae3ae3157f2bfe768 (diff) | |
try doing 512-bit loadskernel2
| -rw-r--r-- | src/kernel2.rs | 107 |
1 files changed, 60 insertions, 47 deletions
diff --git a/src/kernel2.rs b/src/kernel2.rs index 8527f99..be0726c 100644 --- a/src/kernel2.rs +++ b/src/kernel2.rs @@ -827,55 +827,39 @@ global_asm!( "ret", ); -#[inline] -#[target_feature(enable = "avx512f,avx512vl")] +#[inline(always)] unsafe fn load_transposed_16(input: *const u8) -> [__m512i; 16] { - // We're going to load 16 vectors, each containing 16 words (64 bytes). We assume that these - // vectors are coming from contiguous chunks, so each is offset by CHUNK_LEN (1024 bytes) from - // the last. We'll name the input vectors a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, and p. - // Well denote the words of the input vectors: - // - // a_0, a_1, a_2, a_3, a_4, a_5, a_6, a_7, a_8, a_9, a_a, a_b, a_c, a_d, a_e, a_f - // b_0, b_1, b_2, b_3, b_4, b_5, b_6, b_7, b_8, b_9, b_a, b_b, b_c, b_d, b_e, b_f - // etc. - // - // Our goal is to load and transpose these into output vectors that look like: - // - // a_0, b_0, c_0, d_0, e_0, f_0, g_0, h_0, i_0, j_0, k_0, l_0, m_0, n_0, o_0, p_0 - // a_1, b_1, c_1, d_1, e_1, f_1, g_1, h_1, i_1, j_1, k_1, l_1, m_1, n_1, o_1, p_1 - // etc. + // lane selectors for _mm512_permutex2var_epi64 + let lower_256 = _mm512_setr_epi64(0x0, 0x1, 0x2, 0x3, 0x8, 0x9, 0xa, 0xb); + let upper_256 = _mm512_setr_epi64(0x4, 0x5, 0x6, 0x7, 0xc, 0xd, 0xe, 0xf); + let lower_128 = _mm512_setr_epi64(0x0, 0x1, 0x8, 0x9, 0x4, 0x5, 0xc, 0xd); + let upper_128 = _mm512_setr_epi64(0x2, 0x3, 0xa, 0xb, 0x6, 0x7, 0xe, 0xf); - // Because operations that cross 128-bit lanes are relatively expensive, we split each 512-bit - // load into four 128-bit loads. This results in vectors like: - // a0, a1, a2, a3, e0, e1, e2, e3, i0, i1, i2, i3, m0, m1, m2, m3 - #[inline(always)] - unsafe fn load_4_lanes(input: *const u8) -> __m512i { - let lane0 = _mm_loadu_epi32(input.add(0 * CHUNK_LEN) as *const i32); - let lane1 = _mm_loadu_epi32(input.add(4 * CHUNK_LEN) as *const i32); - let lane2 = _mm_loadu_epi32(input.add(8 * CHUNK_LEN) as *const i32); - let lane3 = _mm_loadu_epi32(input.add(12 * CHUNK_LEN) as *const i32); - let ret = _mm512_castsi128_si512(lane0); - let ret = _mm512_inserti32x4::<1>(ret, lane1); - let ret = _mm512_inserti32x4::<2>(ret, lane2); - let ret = _mm512_inserti32x4::<3>(ret, lane3); - ret - } - let aeim_0123 = load_4_lanes(input.add(0 * CHUNK_LEN + 0 * 16)); - let aeim_4567 = load_4_lanes(input.add(0 * CHUNK_LEN + 1 * 16)); - let aeim_89ab = load_4_lanes(input.add(0 * CHUNK_LEN + 2 * 16)); - let aeim_cdef = load_4_lanes(input.add(0 * CHUNK_LEN + 3 * 16)); - let bfjn_0123 = load_4_lanes(input.add(1 * CHUNK_LEN + 0 * 16)); - let bfjn_4567 = load_4_lanes(input.add(1 * CHUNK_LEN + 1 * 16)); - let bfjn_89ab = load_4_lanes(input.add(1 * CHUNK_LEN + 2 * 16)); - let bfjn_cdef = load_4_lanes(input.add(1 * CHUNK_LEN + 3 * 16)); - let cgko_0123 = load_4_lanes(input.add(2 * CHUNK_LEN + 0 * 16)); - let cgko_4567 = load_4_lanes(input.add(2 * CHUNK_LEN + 1 * 16)); - let cgko_89ab = load_4_lanes(input.add(2 * CHUNK_LEN + 2 * 16)); - let cgko_cdef = load_4_lanes(input.add(2 * CHUNK_LEN + 3 * 16)); - let dhlp_0123 = load_4_lanes(input.add(3 * CHUNK_LEN + 0 * 16)); - let dhlp_4567 = load_4_lanes(input.add(3 * CHUNK_LEN + 1 * 16)); - let dhlp_89ab = load_4_lanes(input.add(3 * CHUNK_LEN + 2 * 16)); - let dhlp_cdef = load_4_lanes(input.add(3 * CHUNK_LEN + 3 * 16)); + let a = _mm512_loadu_si512(input.add(0x0 * CHUNK_LEN) as *const i32); + let i = _mm512_loadu_si512(input.add(0x8 * CHUNK_LEN) as *const i32); + let ai_01234567 = _mm512_permutex2var_epi64(a, lower_256, i); + let ai_89abcdef = _mm512_permutex2var_epi64(a, upper_256, i); + let e = _mm512_loadu_si512(input.add(0x4 * CHUNK_LEN) as *const i32); + let m = _mm512_loadu_si512(input.add(0xc * CHUNK_LEN) as *const i32); + let em_01234567 = _mm512_permutex2var_epi64(e, lower_256, m); + let em_89abcdef = _mm512_permutex2var_epi64(e, upper_256, m); + let aeim_0123 = _mm512_permutex2var_epi64(ai_01234567, lower_128, em_01234567); + let aeim_4567 = _mm512_permutex2var_epi64(ai_01234567, upper_128, em_01234567); + let aeim_89ab = _mm512_permutex2var_epi64(ai_89abcdef, lower_128, em_89abcdef); + let aeim_cdef = _mm512_permutex2var_epi64(ai_89abcdef, upper_128, em_89abcdef); + + let b = _mm512_loadu_si512(input.add(0x1 * CHUNK_LEN) as *const i32); + let j = _mm512_loadu_si512(input.add(0x9 * CHUNK_LEN) as *const i32); + let bj_01234567 = _mm512_permutex2var_epi64(b, lower_256, j); + let bj_89abcdef = _mm512_permutex2var_epi64(b, upper_256, j); + let f = _mm512_loadu_si512(input.add(0x5 * CHUNK_LEN) as *const i32); + let n = _mm512_loadu_si512(input.add(0xd * CHUNK_LEN) as *const i32); + let fn_01234567 = _mm512_permutex2var_epi64(f, lower_256, n); + let fn_89abcdef = _mm512_permutex2var_epi64(f, upper_256, n); + let bfjn_0123 = _mm512_permutex2var_epi64(bj_01234567, lower_128, fn_01234567); + let bfjn_4567 = _mm512_permutex2var_epi64(bj_01234567, upper_128, fn_01234567); + let bfjn_89ab = _mm512_permutex2var_epi64(bj_89abcdef, lower_128, fn_89abcdef); + let bfjn_cdef = _mm512_permutex2var_epi64(bj_89abcdef, upper_128, fn_89abcdef); // Interleave 32-bit words. This results in vectors like: // a0, b0, a1, b1, e0, f0, e1, f1, i0, j0, i1, j1, m0, n0, m1, n1 @@ -887,6 +871,35 @@ unsafe fn load_transposed_16(input: *const u8) -> [__m512i; 16] { let abefijmn_ab = _mm512_unpackhi_epi32(aeim_89ab, bfjn_89ab); let abefijmn_cd = _mm512_unpacklo_epi32(aeim_cdef, bfjn_cdef); let abefijmn_ef = _mm512_unpackhi_epi32(aeim_cdef, bfjn_cdef); + + let c = _mm512_loadu_si512(input.add(0x2 * CHUNK_LEN) as *const i32); + let k = _mm512_loadu_si512(input.add(0xa * CHUNK_LEN) as *const i32); + let ck_01234567 = _mm512_permutex2var_epi64(c, lower_256, k); + let ck_89abcdef = _mm512_permutex2var_epi64(c, upper_256, k); + let g = _mm512_loadu_si512(input.add(0x6 * CHUNK_LEN) as *const i32); + let o = _mm512_loadu_si512(input.add(0xe * CHUNK_LEN) as *const i32); + let go_01234567 = _mm512_permutex2var_epi64(g, lower_256, o); + let go_89abcdef = _mm512_permutex2var_epi64(g, upper_256, o); + let cgko_0123 = _mm512_permutex2var_epi64(ck_01234567, lower_128, go_01234567); + let cgko_4567 = _mm512_permutex2var_epi64(ck_01234567, upper_128, go_01234567); + let cgko_89ab = _mm512_permutex2var_epi64(ck_89abcdef, lower_128, go_89abcdef); + let cgko_cdef = _mm512_permutex2var_epi64(ck_89abcdef, upper_128, go_89abcdef); + + let d = _mm512_loadu_si512(input.add(0x3 * CHUNK_LEN) as *const i32); + let l = _mm512_loadu_si512(input.add(0xb * CHUNK_LEN) as *const i32); + let dl_01234567 = _mm512_permutex2var_epi64(d, lower_256, l); + let dl_89abcdef = _mm512_permutex2var_epi64(d, upper_256, l); + let h = _mm512_loadu_si512(input.add(0x7 * CHUNK_LEN) as *const i32); + let p = _mm512_loadu_si512(input.add(0xf * CHUNK_LEN) as *const i32); + let hp_01234567 = _mm512_permutex2var_epi64(h, lower_256, p); + let hp_89abcdef = _mm512_permutex2var_epi64(h, upper_256, p); + let dhlp_0123 = _mm512_permutex2var_epi64(dl_01234567, lower_128, hp_01234567); + let dhlp_4567 = _mm512_permutex2var_epi64(dl_01234567, upper_128, hp_01234567); + let dhlp_89ab = _mm512_permutex2var_epi64(dl_89abcdef, lower_128, hp_89abcdef); + let dhlp_cdef = _mm512_permutex2var_epi64(dl_89abcdef, upper_128, hp_89abcdef); + + // Interleave 32-bit words. This results in vectors like: + // a0, b0, a1, b1, e0, f0, e1, f1, i0, j0, i1, j1, m0, n0, m1, n1 let cdghklop_01 = _mm512_unpacklo_epi32(cgko_0123, dhlp_0123); let cdghklop_23 = _mm512_unpackhi_epi32(cgko_0123, dhlp_0123); let cdghklop_45 = _mm512_unpacklo_epi32(cgko_4567, dhlp_4567); |
