aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJack O'Connor <[email protected]>2023-09-18 08:09:33 -0700
committerJack O'Connor <[email protected]>2023-09-18 08:09:40 -0700
commitf7b4c2bdc43b1ea2c6bbbd11146d8760fabc0888 (patch)
tree34d72d5764d88212147df60832cb59b107ee02ec
parentfd91b59473ddecc6006d49c61ff4a63c42d9a3c3 (diff)
riscv universal_hash passing all tests
-rw-r--r--rust/guts/src/riscv64gcv.S134
1 files changed, 105 insertions, 29 deletions
diff --git a/rust/guts/src/riscv64gcv.S b/rust/guts/src/riscv64gcv.S
index 2682619..424b95a 100644
--- a/rust/guts/src/riscv64gcv.S
+++ b/rust/guts/src/riscv64gcv.S
@@ -1618,17 +1618,7 @@ blake3_guts_riscv64gcv_xof_xor_partial_block:
.global blake3_guts_riscv64gcv_universal_hash
blake3_guts_riscv64gcv_universal_hash:
// t0 := full_blocks := input_len / 64
- // TODO: handle the partial block at the end
srli t0, a1, 6
- // Load the counter.
- vsetvli zero, t0, e64, m2, ta, ma
- vmv.v.x v8, a3
- vid.v v10
- vadd.vv v8, v8, v10
- vsetvli zero, t0, e32, m1, ta, ma
- vncvt.x.x.w v12, v8
- li t1, 32
- vnsrl.wx v13, v8, t1
// Load and transpose full message blocks. These are "strided segment
// loads". Each vlsseg8e32 instruction transposes 8 words from multiple
// message blocks into 8 registers, so we need two vlsseg8e32
@@ -1639,30 +1629,44 @@ blake3_guts_riscv64gcv_universal_hash:
// RISC-V ABI allows misaligned loads and stores. If we need to support
// an environment that doesn't allow them (or where they're
// unacceptably slow), we could add a fallback here.
+ vsetvli zero, t0, e32, m1, ta, ma
li t1, 64
addi t2, a0, 32
vlsseg8e32.v v16, (a0), t1
vlsseg8e32.v v24, (t2), t1
- // Broadcast the key to v0-7.
- lw t0, 0(a2)
- vmv.v.x v0, t0
- lw t0, 4(a2)
- vmv.v.x v1, t0
- lw t0, 8(a2)
- vmv.v.x v2, t0
- lw t0, 12(a2)
- vmv.v.x v3, t0
- lw t0, 16(a2)
- vmv.v.x v4, t0
- lw t0, 20(a2)
- vmv.v.x v5, t0
- lw t0, 24(a2)
- vmv.v.x v6, t0
- lw t0, 28(a2)
- vmv.v.x v7, t0
// Broadcast the block length.
li t1, 64
vmv.v.x v14, t1
+ // If there's a partial block, handle it in an out-of-line branch.
+ andi t1, a1, 63
+ bnez t1, universal_hash_handle_partial_block
+universal_hash_partial_block_finished:
+ // Broadcast the key to v0-7.
+ lw t1, 0(a2)
+ vmv.v.x v0, t1
+ lw t1, 4(a2)
+ vmv.v.x v1, t1
+ lw t1, 8(a2)
+ vmv.v.x v2, t1
+ lw t1, 12(a2)
+ vmv.v.x v3, t1
+ lw t1, 16(a2)
+ vmv.v.x v4, t1
+ lw t1, 20(a2)
+ vmv.v.x v5, t1
+ lw t1, 24(a2)
+ vmv.v.x v6, t1
+ lw t1, 28(a2)
+ vmv.v.x v7, t1
+ // Load the counter.
+ vsetvli zero, t0, e64, m2, ta, ma
+ vmv.v.x v8, a3
+ vid.v v10
+ vadd.vv v8, v8, v10
+ vsetvli zero, t0, e32, m1, ta, ma
+ vncvt.x.x.w v12, v8
+ li t1, 32
+ vnsrl.wx v13, v8, t1
// Broadcast the flags.
li t1, CHUNK_START | CHUNK_END | ROOT | KEYED_HASH
vmv.v.x v15, t1
@@ -1670,7 +1674,7 @@ blake3_guts_riscv64gcv_universal_hash:
mv t6, ra
call blake3_guts_riscv64gcv_kernel
mv ra, t6
- // XOR the first four words. The rest are dropped.
+ // Finish the first four state vectors. The rest are dropped.
vxor.vv v0, v0, v8
vxor.vv v1, v1, v9
vxor.vv v2, v2, v10
@@ -1690,5 +1694,77 @@ blake3_guts_riscv64gcv_universal_hash:
sw t0, 8(a4)
vmv.x.s t0, v3
sw t0, 12(a4)
-
ret
+universal_hash_handle_partial_block:
+ // Load the partial block into v8-v11. With LMUL=4, v8 is guaranteed to
+ // hold at least 64 bytes. Zero all 64 bytes first, for block padding.
+ // The block length is already in t1.
+ li t2, 64
+ vsetvli zero, t2, e8, m4, ta, ma
+ vmv.v.i v8, 0
+ vsetvli zero, t1, e8, m4, ta, ma
+ add t2, a0, a1
+ sub t2, t2, t1
+ vle8.v v8, (t2)
+ // If VLEN is longer than 128 bits (16 bytes), then half or all of the
+ // block bytes will be in v8. Make sure they're split evenly across
+ // v8-v11.
+ csrr t2, vlenb
+ li t3, 64
+ bltu t2, t3, universal_hash_vlenb_less_than_64
+ vsetivli zero, 8, e32, m1, ta, ma
+ vslidedown.vi v9, v8, 8
+universal_hash_vlenb_less_than_64:
+ li t3, 32
+ bltu t2, t3, universal_hash_vlenb_less_than_32
+ vsetivli zero, 4, e32, m1, ta, ma
+ vmv.v.v v10, v9
+ vslidedown.vi v11, v9, 4
+ vslidedown.vi v9, v8, 4
+universal_hash_vlenb_less_than_32:
+ // Shift each of the words of the padded partial block to the end of
+ // the corresponding message vector. t0 was previously the number of
+ // full blocks. Now we increment it, so that it's the number of all
+ // blocks (both full and partial).
+ mv t2, t0
+ addi t0, t0, 1
+ // Set vl to at least 4, because v8-v11 each have 4 message words.
+ // Setting vl shorter will make vslide1down clobber those words.
+ li t3, 4
+ maxu t3, t0, t3
+ vsetvli zero, t3, e32, m1, ta, ma
+ vslideup.vx v16, v8, t2
+ vslide1down.vx v8, v8, zero
+ vslideup.vx v17, v8, t2
+ vslide1down.vx v8, v8, zero
+ vslideup.vx v18, v8, t2
+ vslide1down.vx v8, v8, zero
+ vslideup.vx v19, v8, t2
+ vslideup.vx v20, v9, t2
+ vslide1down.vx v9, v9, zero
+ vslideup.vx v21, v9, t2
+ vslide1down.vx v9, v9, zero
+ vslideup.vx v22, v9, t2
+ vslide1down.vx v9, v9, zero
+ vslideup.vx v23, v9, t2
+ vslideup.vx v24, v10, t2
+ vslide1down.vx v10, v10, zero
+ vslideup.vx v25, v10, t2
+ vslide1down.vx v10, v10, zero
+ vslideup.vx v26, v10, t2
+ vslide1down.vx v10, v10, zero
+ vslideup.vx v27, v10, t2
+ vslideup.vx v28, v11, t2
+ vslide1down.vx v11, v11, zero
+ vslideup.vx v29, v11, t2
+ vslide1down.vx v11, v11, zero
+ vslideup.vx v30, v11, t2
+ vslide1down.vx v11, v11, zero
+ vslideup.vx v31, v11, t2
+ // Set the updated VL.
+ vsetvli zero, t0, e32, m1, ta, ma
+ // Append the final block length, still in t1.
+ vmv.v.x v8, t1
+ addi t2, t0, -1
+ vslideup.vx v14, v8, t2
+ j universal_hash_partial_block_finished