aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorQuentin Carbonneaux <[email protected]>2022-02-11 08:42:28 +0100
committerQuentin Carbonneaux <[email protected]>2024-04-09 21:34:57 +0200
commit56e2263ca46166ffffb814ae225faf08fd52248c (patch)
treebb14b115690c2f3c174d43d0aa14565e31a05f37
parent8a5e1c3a2359054af8bbcd7fa53785e0bb054390 (diff)
fuse ac rules in ins-tree matching
The initial plan was to have one matcher per ac-variant, but that leads to way too much generated code. Instead, we can fuse ac variants of the rules and have a smarter matching algorithm to recover bound variables.
-rw-r--r--tools/match.ml80
-rw-r--r--tools/match_test.ml78
2 files changed, 86 insertions, 72 deletions
diff --git a/tools/match.ml b/tools/match.ml
index 4aeeae0..5de356b 100644
--- a/tools/match.ml
+++ b/tools/match.ml
@@ -23,11 +23,32 @@ type pattern =
| Atm of atomic_pattern
| Var of string * atomic_pattern
+let show_op (k, o) =
+ (match o with
+ | Oadd -> "add"
+ | Osub -> "sub"
+ | Omul -> "mul") ^
+ (match k with
+ | Kw -> "w"
+ | Kl -> "l"
+ | Ks -> "s"
+ | Kd -> "d")
+
+let rec show_pattern p =
+ match p with
+ | Var _ -> failwith "variable not allowed"
+ | Atm Tmp -> "%"
+ | Atm AnyCon -> "$"
+ | Atm (Con n) -> Int64.to_string n
+ | Bnr (o, pl, pr) ->
+ "(" ^ show_op o ^
+ " " ^ show_pattern pl ^
+ " " ^ show_pattern pr ^ ")"
+
let rec pattern_match p w =
match p with
- | Var _ ->
- failwith "variable not allowed"
- | Atm (Tmp) ->
+ | Var _ -> failwith "variable not allowed"
+ | Atm Tmp ->
begin match w with
| Atm (Con _ | AnyCon) -> false
| _ -> true
@@ -89,12 +110,12 @@ type 'a state =
; point: ('a cursor) list }
let rec binops side {point; _} =
- List.fold_left (fun res c ->
+ List.filter_map (fun c ->
match c, side with
- | Bnrl (o, c, r), `L -> ((o, c), r) :: res
- | Bnrr (o, l, c), `R -> ((o, c), l) :: res
- | _ -> res)
- [] point
+ | Bnrl (o, c, r), `L -> Some ((o, c), r)
+ | Bnrr (o, l, c), `R -> Some ((o, c), l)
+ | _ -> None)
+ point
let group_by_fst l =
List.fast_sort (fun (a, _) (b, _) ->
@@ -114,11 +135,9 @@ let sort_uniq cmp l =
List.fold_left (fun (eo, l) e' ->
match eo with
| None -> (Some e', l)
- | Some e ->
- if cmp e e' = 0
- then (eo, l)
- else (Some e', e :: l)
- ) (None, []) |>
+ | Some e when cmp e e' = 0 -> (eo, l)
+ | Some e -> (Some e', e :: l))
+ (None, []) |>
(function
| (None, _) -> []
| (Some e, l) -> List.rev (e :: l))
@@ -126,15 +145,14 @@ let sort_uniq cmp l =
let normalize (point: ('a cursor) list) =
sort_uniq compare point
-let nextbnr tmp s1 s2 =
+let next_binary tmp s1 s2 =
let pm w (_, p) = pattern_match p w in
let o1 = binops `L s1 |>
List.filter (pm s2.seen) |>
- List.map fst
- and o2 = binops `R s2 |>
+ List.map fst in
+ let o2 = binops `R s2 |>
List.filter (pm s1.seen) |>
- List.map fst
- in
+ List.map fst in
List.map (fun (o, l) ->
o,
{ id = 0
@@ -145,25 +163,24 @@ let nextbnr tmp s1 s2 =
type p = string
module StateSet : sig
- type set
- val create: unit -> set
- val add: set -> p state ->
+ type t
+ val create: unit -> t
+ val add: t -> p state ->
[> `Added | `Found ] * p state
- val iter: set -> (p state -> unit) -> unit
- val elems: set -> (p state) list
+ val iter: t -> (p state -> unit) -> unit
+ val elems: t -> (p state) list
end = struct
- include Hashtbl.Make(struct
+ open Hashtbl.Make(struct
type t = p state
let equal s1 s2 = s1.point = s2.point
let hash s = Hashtbl.hash s.point
end)
- type set =
+ type nonrec t =
{ h: int t
; mutable next_id: int }
let create () =
{ h = create 500; next_id = 1 }
let add set s =
- (* delete the check later *)
assert (s.point = normalize s.point);
try
let id = find set.h s in
@@ -171,6 +188,8 @@ end = struct
with Not_found -> begin
let id = set.next_id in
set.next_id <- id + 1;
+ Printf.printf "adding: %d [%s]\n"
+ id (show_pattern s.seen);
add set.h s id;
`Added, {s with id}
end
@@ -198,17 +217,14 @@ end)
type rule =
{ name: string
; pattern: pattern
- (* TODO access pattern *)
}
let generate_table rl =
let states = StateSet.create () in
(* initialize states *)
let ground =
- List.fold_left
- (fun ini r ->
- peel r.pattern r.name @ ini)
- [] rl |>
+ List.concat_map
+ (fun r -> peel r.pattern r.name) rl |>
group_by_fst
in
let find x d l =
@@ -242,7 +258,7 @@ let generate_table rl =
flag := `Stop;
let statel = StateSet.elems states in
iter_pairs statel (fun (sl, sr) ->
- nextbnr tmp sl sr |>
+ next_binary tmp sl sr |>
List.iter (fun (o, s') ->
let flag', s' =
StateSet.add states s' in
diff --git a/tools/match_test.ml b/tools/match_test.ml
index fe740c5..da63666 100644
--- a/tools/match_test.ml
+++ b/tools/match_test.ml
@@ -46,54 +46,52 @@ let ts =
}
let print_sm =
- let op_str (k, o) =
- Printf.sprintf "%s%s"
- (match o with
- | Oadd -> "add"
- | Osub -> "sub"
- | Omul -> "mul")
- (match k with
- | Kw -> "w"
- | Kl -> "l"
- | Ks -> "s"
- | Kd -> "d")
- in
StateMap.iter (fun k s' ->
match k with
| K (o, sl, sr) ->
+ let top =
+ List.fold_left (fun top c ->
+ match c with
+ | Top r -> top ^ " " ^ r
+ | _ -> top) "" s'.point
+ in
Printf.printf
- "(%s %d %d) -> %d\n"
- (op_str o)
- sl.id sr.id s'.id
- )
+ "(%s %d %d) -> %d%s\n"
+ (show_op o)
+ sl.id sr.id s'.id top)
-let address_rules =
+let rules =
let oa = Kl, Oadd in
let om = Kl, Omul in
- let rule name pattern =
- List.mapi (fun i pattern ->
- { name = Printf.sprintf "%s%d" name (i+1)
- ; pattern; })
- (ac_equiv pattern) in
-
+ match `X64Addr with
+ (* ------------------------------- *)
+ | `X64Addr ->
+ let rule name pattern =
+ List.mapi (fun i pattern ->
+ { name (* = Printf.sprintf "%s%d" name (i+1) *)
+ ; pattern })
+ (ac_equiv pattern) in
(* o + b *)
- rule "ob" (Bnr (oa, Atm Tmp, Atm AnyCon))
- @ (* b + s * i *)
- rule "bs" (Bnr (oa, Atm Tmp, Bnr (om, Atm AnyCon, Atm Tmp)))
- @ (* o + s * i *)
- rule "os" (Bnr (oa, Atm AnyCon, Bnr (om, Atm AnyCon, Atm Tmp)))
- @ (* b + o + s * i *)
- rule "bos" (Bnr (oa, Bnr (oa, Atm AnyCon, Atm Tmp), Bnr (om, Atm AnyCon, Atm Tmp)))
+ rule "ob" (Bnr (oa, Atm Tmp, Atm AnyCon))
+ @ (* b + s * i *)
+ rule "bs" (Bnr (oa, Atm Tmp, Bnr (om, Atm (Con 4L), Atm Tmp)))
+ @ (* o + s * i *)
+ rule "os" (Bnr (oa, Atm AnyCon, Bnr (om, Atm (Con 4L), Atm Tmp)))
+ @ (* b + o + s * i *)
+ rule "bos" (Bnr (oa, Bnr (oa, Atm AnyCon, Atm Tmp), Bnr (om, Atm (Con 4L), Atm Tmp)))
+ (* ------------------------------- *)
+ | `Add3 ->
+ [ { name = "add"
+ ; pattern = Bnr (oa, Atm Tmp, Bnr (oa, Atm Tmp, Atm Tmp)) } ] @
+ [ { name = "add"
+ ; pattern = Bnr (oa, Bnr (oa, Atm Tmp, Atm Tmp), Atm Tmp) } ] @
+ [ { name = "mul"
+ ; pattern = Bnr (om, Bnr (oa, Bnr (oa, Atm Tmp, Atm Tmp),
+ Atm Tmp),
+ Bnr (oa, Atm Tmp,
+ Bnr (oa, Atm Tmp, Atm Tmp))) } ]
+
-let sl, sm = generate_table address_rules
+let sl, sm = generate_table rules
let s n = List.find (fun {id; _} -> id = n) sl
let () = print_sm sm
-
-(*
-let tp0 =
- let o = Kw, Oadd in
- Bnr (o, Atm Tmp, Atm (Con 0L))
-let tp1 =
- let o = Kw, Oadd in
- Bnr (o, tp0, Atm (Con 1L))
-*)