aboutsummaryrefslogtreecommitdiff
path: root/tools/mgen/match.ml
diff options
context:
space:
mode:
Diffstat (limited to 'tools/mgen/match.ml')
-rw-r--r--tools/mgen/match.ml651
1 files changed, 651 insertions, 0 deletions
diff --git a/tools/mgen/match.ml b/tools/mgen/match.ml
new file mode 100644
index 0000000..9c02ca4
--- /dev/null
+++ b/tools/mgen/match.ml
@@ -0,0 +1,651 @@
+type cls = Kw | Kl | Ks | Kd
+type op_base =
+ | Oadd
+ | Osub
+ | Omul
+ | Oor
+ | Oshl
+ | Oshr
+type op = cls * op_base
+
+let op_bases =
+ [Oadd; Osub; Omul; Oor; Oshl; Oshr]
+
+let commutative = function
+ | (_, (Oadd | Omul | Oor)) -> true
+ | (_, _) -> false
+
+let associative = function
+ | (_, (Oadd | Omul | Oor)) -> true
+ | (_, _) -> false
+
+type atomic_pattern =
+ | Tmp
+ | AnyCon
+ | Con of int64
+(* Tmp < AnyCon < Con k *)
+
+type pattern =
+ | Bnr of op * pattern * pattern
+ | Atm of atomic_pattern
+ | Var of string * atomic_pattern
+
+let is_atomic = function
+ | (Atm _ | Var _) -> true
+ | _ -> false
+
+let show_op_base o =
+ match o with
+ | Oadd -> "add"
+ | Osub -> "sub"
+ | Omul -> "mul"
+ | Oor -> "or"
+ | Oshl -> "shl"
+ | Oshr -> "shr"
+
+let show_op (k, o) =
+ show_op_base o ^
+ (match k with
+ | Kw -> "w"
+ | Kl -> "l"
+ | Ks -> "s"
+ | Kd -> "d")
+
+let rec show_pattern p =
+ match p with
+ | Atm Tmp -> "%"
+ | Atm AnyCon -> "$"
+ | Atm (Con n) -> Int64.to_string n
+ | Var (v, p) ->
+ show_pattern (Atm p) ^ "'" ^ v
+ | Bnr (o, pl, pr) ->
+ "(" ^ show_op o ^
+ " " ^ show_pattern pl ^
+ " " ^ show_pattern pr ^ ")"
+
+let get_atomic p =
+ match p with
+ | (Atm a | Var (_, a)) -> Some a
+ | _ -> None
+
+let rec pattern_match p w =
+ match p with
+ | Var (_, p) ->
+ pattern_match (Atm p) w
+ | Atm Tmp ->
+ begin match get_atomic w with
+ | Some (Con _ | AnyCon) -> false
+ | _ -> true
+ end
+ | Atm (Con _) -> w = p
+ | Atm (AnyCon) ->
+ not (pattern_match (Atm Tmp) w)
+ | Bnr (o, pl, pr) ->
+ begin match w with
+ | Bnr (o', wl, wr) ->
+ o' = o &&
+ pattern_match pl wl &&
+ pattern_match pr wr
+ | _ -> false
+ end
+
+type +'a cursor = (* a position inside a pattern *)
+ | Bnrl of op * 'a cursor * pattern
+ | Bnrr of op * pattern * 'a cursor
+ | Top of 'a
+
+let rec fold_cursor c p =
+ match c with
+ | Bnrl (o, c', p') -> fold_cursor c' (Bnr (o, p, p'))
+ | Bnrr (o, p', c') -> fold_cursor c' (Bnr (o, p', p))
+ | Top _ -> p
+
+let peel p x =
+ let once out (p, c) =
+ match p with
+ | Var (_, p) -> (Atm p, c) :: out
+ | Atm _ -> (p, c) :: out
+ | Bnr (o, pl, pr) ->
+ (pl, Bnrl (o, c, pr)) ::
+ (pr, Bnrr (o, pl, c)) :: out
+ in
+ let rec go l =
+ let l' = List.fold_left once [] l in
+ if List.length l' = List.length l
+ then l'
+ else go l'
+ in go [(p, Top x)]
+
+let fold_pairs l1 l2 ini f =
+ let rec go acc = function
+ | [] -> acc
+ | a :: l1' ->
+ go (List.fold_left
+ (fun acc b -> f (a, b) acc)
+ acc l2) l1'
+ in go ini l1
+
+let iter_pairs l f =
+ fold_pairs l l () (fun x () -> f x)
+
+let inverse l =
+ List.map (fun (a, b) -> (b, a)) l
+
+type 'a state =
+ { id: int
+ ; seen: pattern
+ ; point: ('a cursor) list }
+
+let rec binops side {point; _} =
+ List.filter_map (fun c ->
+ match c, side with
+ | Bnrl (o, c, r), `L -> Some ((o, c), r)
+ | Bnrr (o, l, c), `R -> Some ((o, c), l)
+ | _ -> None)
+ point
+
+let group_by_fst l =
+ List.fast_sort (fun (a, _) (b, _) ->
+ compare a b) l |>
+ List.fold_left (fun (oo, l, res) (o', c) ->
+ match oo with
+ | None -> (Some o', [c], [])
+ | Some o when o = o' -> (oo, c :: l, res)
+ | Some o -> (Some o', [c], (o, l) :: res))
+ (None, [], []) |>
+ (function
+ | (None, _, _) -> []
+ | (Some o, l, res) -> (o, l) :: res)
+
+let sort_uniq cmp l =
+ List.fast_sort cmp l |>
+ List.fold_left (fun (eo, l) e' ->
+ match eo with
+ | None -> (Some e', l)
+ | Some e when cmp e e' = 0 -> (eo, l)
+ | Some e -> (Some e', e :: l))
+ (None, []) |>
+ (function
+ | (None, _) -> []
+ | (Some e, l) -> List.rev (e :: l))
+
+let setify l =
+ sort_uniq compare l
+
+let normalize (point: ('a cursor) list) =
+ setify point
+
+let next_binary tmp s1 s2 =
+ let pm w (_, p) = pattern_match p w in
+ let o1 = binops `L s1 |>
+ List.filter (pm s2.seen) |>
+ List.map fst in
+ let o2 = binops `R s2 |>
+ List.filter (pm s1.seen) |>
+ List.map fst in
+ List.map (fun (o, l) ->
+ o,
+ { id = -1
+ ; seen = Bnr (o, s1.seen, s2.seen)
+ ; point = normalize (l @ tmp) })
+ (group_by_fst (o1 @ o2))
+
+type p = string
+
+module StateSet : sig
+ type t
+ val create: unit -> t
+ val add: t -> p state ->
+ [> `Added | `Found ] * p state
+ val iter: t -> (p state -> unit) -> unit
+ val elems: t -> (p state) list
+end = struct
+ open Hashtbl.Make(struct
+ type t = p state
+ let equal s1 s2 = s1.point = s2.point
+ let hash s = Hashtbl.hash s.point
+ end)
+ type nonrec t =
+ { h: int t
+ ; mutable next_id: int }
+ let create () =
+ { h = create 500; next_id = 0 }
+ let add set s =
+ assert (s.point = normalize s.point);
+ try
+ let id = find set.h s in
+ `Found, {s with id}
+ with Not_found -> begin
+ let id = set.next_id in
+ set.next_id <- id + 1;
+ add set.h s id;
+ `Added, {s with id}
+ end
+ let iter set f =
+ let f s id = f {s with id} in
+ iter f set.h
+ let elems set =
+ let res = ref [] in
+ iter set (fun s -> res := s :: !res);
+ !res
+end
+
+type table_key =
+ | K of op * p state * p state
+
+module StateMap = struct
+ include Map.Make(struct
+ type t = table_key
+ let compare ka kb =
+ match ka, kb with
+ | K (o, sl, sr), K (o', sl', sr') ->
+ compare (o, sl.id, sr.id)
+ (o', sl'.id, sr'.id)
+ end)
+ let invert n sm =
+ let rmap = Array.make n [] in
+ iter (fun k {id; _} ->
+ match k with
+ | K (o, sl, sr) ->
+ rmap.(id) <-
+ (o, (sl.id, sr.id)) :: rmap.(id)
+ ) sm;
+ Array.map group_by_fst rmap
+ let by_ops sm =
+ fold (fun tk s ops ->
+ match tk with
+ | K (op, l, r) ->
+ (op, ((l.id, r.id), s.id)) :: ops)
+ sm [] |> group_by_fst
+end
+
+type rule =
+ { name: string
+ ; vars: string list
+ ; pattern: pattern }
+
+let generate_table rl =
+ let states = StateSet.create () in
+ let rl =
+ (* these atomic patterns must occur in
+ * rules so that we are able to number
+ * all possible refs *)
+ [ { name = "$"; vars = []
+ ; pattern = Atm AnyCon }
+ ; { name = "%"; vars = []
+ ; pattern = Atm Tmp } ] @ rl
+ in
+ (* initialize states *)
+ let ground =
+ List.concat_map
+ (fun r -> peel r.pattern r.name) rl |>
+ group_by_fst
+ in
+ let tmp = List.assoc (Atm Tmp) ground in
+ let con = List.assoc (Atm AnyCon) ground in
+ let atoms = ref [] in
+ let () =
+ List.iter (fun (seen, l) ->
+ let point =
+ if pattern_match (Atm Tmp) seen
+ then normalize (tmp @ l)
+ else normalize (con @ l)
+ in
+ let s = {id = -1; seen; point} in
+ let _, s = StateSet.add states s in
+ match get_atomic seen with
+ | Some atm -> atoms := (atm, s) :: !atoms
+ | None -> ()
+ ) ground
+ in
+ (* setup loop state *)
+ let map = ref StateMap.empty in
+ let map_add k s' =
+ map := StateMap.add k s' !map
+ in
+ let flag = ref `Added in
+ let flagmerge = function
+ | `Added -> flag := `Added
+ | _ -> ()
+ in
+ (* iterate until fixpoint *)
+ while !flag = `Added do
+ flag := `Stop;
+ let statel = StateSet.elems states in
+ iter_pairs statel (fun (sl, sr) ->
+ next_binary tmp sl sr |>
+ List.iter (fun (o, s') ->
+ let flag', s' =
+ StateSet.add states s' in
+ flagmerge flag';
+ map_add (K (o, sl, sr)) s';
+ ));
+ done;
+ let states =
+ StateSet.elems states |>
+ List.sort (fun s s' -> compare s.id s'.id) |>
+ Array.of_list
+ in
+ (states, !atoms, !map)
+
+let intersperse x l =
+ let rec go left right out =
+ let out =
+ (List.rev left @ [x] @ right) ::
+ out in
+ match right with
+ | x :: right' ->
+ go (x :: left) right' out
+ | [] -> out
+ in go [] l []
+
+let rec permute = function
+ | [] -> [[]]
+ | x :: l ->
+ List.concat (List.map
+ (intersperse x) (permute l))
+
+(* build all binary trees with ordered
+ * leaves l *)
+let rec bins build l =
+ let rec go l r out =
+ match r with
+ | [] -> out
+ | x :: r' ->
+ go (l @ [x]) r'
+ (fold_pairs
+ (bins build l)
+ (bins build r)
+ out (fun (l, r) out ->
+ build l r :: out))
+ in
+ match l with
+ | [] -> []
+ | [x] -> [x]
+ | x :: l -> go [x] l []
+
+let products l ini f =
+ let rec go acc la = function
+ | [] -> f (List.rev la) acc
+ | xs :: l ->
+ List.fold_left (fun acc x ->
+ go acc (x :: la) l)
+ acc xs
+ in go ini [] l
+
+(* combinatorial nuke... *)
+let rec ac_equiv =
+ let rec alevel o = function
+ | Bnr (o', l, r) when o' = o ->
+ alevel o l @ alevel o r
+ | x -> [x]
+ in function
+ | Bnr (o, _, _) as p
+ when associative o ->
+ products
+ (List.map ac_equiv (alevel o p)) []
+ (fun choice out ->
+ List.concat_map
+ (bins (fun l r -> Bnr (o, l, r)))
+ (if commutative o
+ then permute choice
+ else [choice]) @ out)
+ | Bnr (o, l, r)
+ when commutative o ->
+ fold_pairs
+ (ac_equiv l) (ac_equiv r) []
+ (fun (l, r) out ->
+ Bnr (o, l, r) ::
+ Bnr (o, r, l) :: out)
+ | Bnr (o, l, r) ->
+ fold_pairs
+ (ac_equiv l) (ac_equiv r) []
+ (fun (l, r) out ->
+ Bnr (o, l, r) :: out)
+ | x -> [x]
+
+module Action: sig
+ type node =
+ | Switch of (int * t) list
+ | Push of bool * t
+ | Pop of t
+ | Set of string * t
+ | Stop
+ and t = private
+ { id: int; node: node }
+ val equal: t -> t -> bool
+ val size: t -> int
+ val stop: t
+ val mk_push: sym:bool -> t -> t
+ val mk_pop: t -> t
+ val mk_set: string -> t -> t
+ val mk_switch: int list -> (int -> t) -> t
+ val pp: Format.formatter -> t -> unit
+end = struct
+ type node =
+ | Switch of (int * t) list
+ | Push of bool * t
+ | Pop of t
+ | Set of string * t
+ | Stop
+ and t =
+ { id: int; node: node }
+
+ let equal a a' = a.id = a'.id
+ let size a =
+ let seen = Hashtbl.create 10 in
+ let rec node_size = function
+ | Switch l ->
+ List.fold_left
+ (fun n (_, a) -> n + size a) 0 l
+ | (Push (_, a) | Pop a | Set (_, a)) ->
+ size a
+ | Stop -> 0
+ and size {id; node} =
+ if Hashtbl.mem seen id
+ then 0
+ else begin
+ Hashtbl.add seen id ();
+ 1 + node_size node
+ end
+ in
+ size a
+
+ let mk =
+ let hcons = Hashtbl.create 100 in
+ let fresh = ref 0 in
+ fun node ->
+ let id =
+ try Hashtbl.find hcons node
+ with Not_found ->
+ let id = !fresh in
+ Hashtbl.add hcons node id;
+ fresh := id + 1;
+ id
+ in
+ {id; node}
+ let stop = mk Stop
+ let mk_push ~sym a = mk (Push (sym, a))
+ let mk_pop a =
+ match a.node with
+ | Stop -> a
+ | _ -> mk (Pop a)
+ let mk_set v a = mk (Set (v, a))
+ let mk_switch ids f =
+ match List.map f ids with
+ | [] -> failwith "empty switch";
+ | c :: cs as cases ->
+ if List.for_all (equal c) cs then c
+ else
+ let cases = List.combine ids cases in
+ mk (Switch cases)
+
+ open Format
+ let rec pp_node fmt = function
+ | Switch l ->
+ fprintf fmt "@[<v>@[<v2>switch{";
+ let pp_case (c, a) =
+ let pp_sep fmt () = fprintf fmt "," in
+ fprintf fmt "@,@[<2>→%a:@ @[%a@]@]"
+ (pp_print_list ~pp_sep pp_print_int)
+ c pp a
+ in
+ inverse l |> group_by_fst |> inverse |>
+ List.iter pp_case;
+ fprintf fmt "@]@,}@]"
+ | Push (true, a) -> fprintf fmt "pushsym@ %a" pp a
+ | Push (false, a) -> fprintf fmt "push@ %a" pp a
+ | Pop a -> fprintf fmt "pop@ %a" pp a
+ | Set (v, a) -> fprintf fmt "set(%s)@ %a" v pp a
+ | Stop -> fprintf fmt "•"
+ and pp fmt a = pp_node fmt a.node
+end
+
+(* a state is commutative if (a op b) enters
+ * it iff (b op a) enters it as well *)
+let symmetric rmap id =
+ List.for_all (fun (_, l) ->
+ let l1, l2 =
+ List.filter (fun (a, b) -> a <> b) l |>
+ List.partition (fun (a, b) -> a < b)
+ in
+ setify l1 = setify (inverse l2))
+ rmap.(id)
+
+(* left-to-right matching of a set of patterns;
+ * may raise if there is no lr matcher for the
+ * input rule *)
+let lr_matcher statemap states rules name =
+ let rmap =
+ let nstates = Array.length states in
+ StateMap.invert nstates statemap
+ in
+ let exception Stuck in
+ (* the list of ids represents a class of terms
+ * whose root ends up being labelled with one
+ * such id; the gen function generates a matcher
+ * that will, given any such term, assign values
+ * for the Var nodes of one pattern in pats *)
+ let rec gen
+ : 'a. int list -> (pattern * 'a) list
+ -> (int -> (pattern * 'a) list -> Action.t)
+ -> Action.t
+ = fun ids pats k ->
+ Action.mk_switch (setify ids) @@ fun id_top ->
+ let sym = symmetric rmap id_top in
+ let id_ops =
+ if sym then
+ let ordered (a, b) = a <= b in
+ List.map (fun (o, l) ->
+ (o, List.filter ordered l))
+ rmap.(id_top)
+ else rmap.(id_top)
+ in
+ (* consider only the patterns that are
+ * compatible with the current id *)
+ let atm_pats, bin_pats =
+ List.filter (function
+ | Bnr (o, _, _), _ ->
+ List.exists
+ (fun (o', _) -> o' = o)
+ id_ops
+ | _ -> true) pats |>
+ List.partition
+ (fun (pat, _) -> is_atomic pat)
+ in
+ try
+ if bin_pats = [] then raise Stuck;
+ let pats_l =
+ List.map (function
+ | (Bnr (o, l, r), x) ->
+ (l, (o, x, r))
+ | _ -> assert false)
+ bin_pats
+ and pats_r =
+ List.map (fun (l, (o, x, r)) ->
+ (r, (o, l, x)))
+ and patstop =
+ List.map (fun (r, (o, l, x)) ->
+ (Bnr (o, l, r), x))
+ in
+ let id_pairs = List.concat_map snd id_ops in
+ let ids_l = List.map fst id_pairs
+ and ids_r id_left =
+ List.filter_map (fun (l, r) ->
+ if l = id_left then Some r else None)
+ id_pairs
+ in
+ (* match the left arm *)
+ Action.mk_push ~sym
+ (gen ids_l pats_l
+ @@ fun lid pats ->
+ (* then the right arm, considering
+ * only the remaining possible
+ * patterns and knowing that the
+ * left arm was numbered 'lid' *)
+ Action.mk_pop
+ (gen (ids_r lid) (pats_r pats)
+ @@ fun _rid pats ->
+ (* continue with the parent *)
+ k id_top (patstop pats)))
+ with Stuck ->
+ let atm_pats =
+ let seen = states.(id_top).seen in
+ List.filter (fun (pat, _) ->
+ pattern_match pat seen) atm_pats
+ in
+ if atm_pats = [] then raise Stuck else
+ let vars =
+ List.filter_map (function
+ | (Var (v, _), _) -> Some v
+ | _ -> None) atm_pats |> setify
+ in
+ match vars with
+ | [] -> k id_top atm_pats
+ | [v] -> Action.mk_set v (k id_top atm_pats)
+ | _ -> failwith "ambiguous var match"
+ in
+ (* generate a matcher for the rule *)
+ let ids_top =
+ Array.to_list states |>
+ List.filter_map (fun {id; point = p; _} ->
+ if List.exists ((=) (Top name)) p then
+ Some id
+ else None)
+ in
+ let rec filter_dups pats =
+ match pats with
+ | p :: pats ->
+ if List.exists (pattern_match p) pats
+ then filter_dups pats
+ else p :: filter_dups pats
+ | [] -> []
+ in
+ let pats_top =
+ List.filter_map (fun r ->
+ if r.name = name then
+ Some r.pattern
+ else None) rules |>
+ filter_dups |>
+ List.map (fun p -> (p, ()))
+ in
+ gen ids_top pats_top (fun _ pats ->
+ assert (pats <> []);
+ Action.stop)
+
+type numberer =
+ { atoms: (atomic_pattern * p state) list
+ ; statemap: p state StateMap.t
+ ; states: p state array
+ ; mutable ops: op list
+ (* memoizes the list of possible operations
+ * according to the statemap *) }
+
+let make_numberer sa am sm =
+ { atoms = am
+ ; states = sa
+ ; statemap = sm
+ ; ops = [] }
+
+let atom_state n atm =
+ List.assoc atm n.atoms