diff options
Diffstat (limited to 'tools/mgen/match.ml')
| -rw-r--r-- | tools/mgen/match.ml | 651 |
1 files changed, 651 insertions, 0 deletions
diff --git a/tools/mgen/match.ml b/tools/mgen/match.ml new file mode 100644 index 0000000..9c02ca4 --- /dev/null +++ b/tools/mgen/match.ml @@ -0,0 +1,651 @@ +type cls = Kw | Kl | Ks | Kd +type op_base = + | Oadd + | Osub + | Omul + | Oor + | Oshl + | Oshr +type op = cls * op_base + +let op_bases = + [Oadd; Osub; Omul; Oor; Oshl; Oshr] + +let commutative = function + | (_, (Oadd | Omul | Oor)) -> true + | (_, _) -> false + +let associative = function + | (_, (Oadd | Omul | Oor)) -> true + | (_, _) -> false + +type atomic_pattern = + | Tmp + | AnyCon + | Con of int64 +(* Tmp < AnyCon < Con k *) + +type pattern = + | Bnr of op * pattern * pattern + | Atm of atomic_pattern + | Var of string * atomic_pattern + +let is_atomic = function + | (Atm _ | Var _) -> true + | _ -> false + +let show_op_base o = + match o with + | Oadd -> "add" + | Osub -> "sub" + | Omul -> "mul" + | Oor -> "or" + | Oshl -> "shl" + | Oshr -> "shr" + +let show_op (k, o) = + show_op_base o ^ + (match k with + | Kw -> "w" + | Kl -> "l" + | Ks -> "s" + | Kd -> "d") + +let rec show_pattern p = + match p with + | Atm Tmp -> "%" + | Atm AnyCon -> "$" + | Atm (Con n) -> Int64.to_string n + | Var (v, p) -> + show_pattern (Atm p) ^ "'" ^ v + | Bnr (o, pl, pr) -> + "(" ^ show_op o ^ + " " ^ show_pattern pl ^ + " " ^ show_pattern pr ^ ")" + +let get_atomic p = + match p with + | (Atm a | Var (_, a)) -> Some a + | _ -> None + +let rec pattern_match p w = + match p with + | Var (_, p) -> + pattern_match (Atm p) w + | Atm Tmp -> + begin match get_atomic w with + | Some (Con _ | AnyCon) -> false + | _ -> true + end + | Atm (Con _) -> w = p + | Atm (AnyCon) -> + not (pattern_match (Atm Tmp) w) + | Bnr (o, pl, pr) -> + begin match w with + | Bnr (o', wl, wr) -> + o' = o && + pattern_match pl wl && + pattern_match pr wr + | _ -> false + end + +type +'a cursor = (* a position inside a pattern *) + | Bnrl of op * 'a cursor * pattern + | Bnrr of op * pattern * 'a cursor + | Top of 'a + +let rec fold_cursor c p = + match c with + | Bnrl (o, c', p') -> fold_cursor c' (Bnr (o, p, p')) + | Bnrr (o, p', c') -> fold_cursor c' (Bnr (o, p', p)) + | Top _ -> p + +let peel p x = + let once out (p, c) = + match p with + | Var (_, p) -> (Atm p, c) :: out + | Atm _ -> (p, c) :: out + | Bnr (o, pl, pr) -> + (pl, Bnrl (o, c, pr)) :: + (pr, Bnrr (o, pl, c)) :: out + in + let rec go l = + let l' = List.fold_left once [] l in + if List.length l' = List.length l + then l' + else go l' + in go [(p, Top x)] + +let fold_pairs l1 l2 ini f = + let rec go acc = function + | [] -> acc + | a :: l1' -> + go (List.fold_left + (fun acc b -> f (a, b) acc) + acc l2) l1' + in go ini l1 + +let iter_pairs l f = + fold_pairs l l () (fun x () -> f x) + +let inverse l = + List.map (fun (a, b) -> (b, a)) l + +type 'a state = + { id: int + ; seen: pattern + ; point: ('a cursor) list } + +let rec binops side {point; _} = + List.filter_map (fun c -> + match c, side with + | Bnrl (o, c, r), `L -> Some ((o, c), r) + | Bnrr (o, l, c), `R -> Some ((o, c), l) + | _ -> None) + point + +let group_by_fst l = + List.fast_sort (fun (a, _) (b, _) -> + compare a b) l |> + List.fold_left (fun (oo, l, res) (o', c) -> + match oo with + | None -> (Some o', [c], []) + | Some o when o = o' -> (oo, c :: l, res) + | Some o -> (Some o', [c], (o, l) :: res)) + (None, [], []) |> + (function + | (None, _, _) -> [] + | (Some o, l, res) -> (o, l) :: res) + +let sort_uniq cmp l = + List.fast_sort cmp l |> + List.fold_left (fun (eo, l) e' -> + match eo with + | None -> (Some e', l) + | Some e when cmp e e' = 0 -> (eo, l) + | Some e -> (Some e', e :: l)) + (None, []) |> + (function + | (None, _) -> [] + | (Some e, l) -> List.rev (e :: l)) + +let setify l = + sort_uniq compare l + +let normalize (point: ('a cursor) list) = + setify point + +let next_binary tmp s1 s2 = + let pm w (_, p) = pattern_match p w in + let o1 = binops `L s1 |> + List.filter (pm s2.seen) |> + List.map fst in + let o2 = binops `R s2 |> + List.filter (pm s1.seen) |> + List.map fst in + List.map (fun (o, l) -> + o, + { id = -1 + ; seen = Bnr (o, s1.seen, s2.seen) + ; point = normalize (l @ tmp) }) + (group_by_fst (o1 @ o2)) + +type p = string + +module StateSet : sig + type t + val create: unit -> t + val add: t -> p state -> + [> `Added | `Found ] * p state + val iter: t -> (p state -> unit) -> unit + val elems: t -> (p state) list +end = struct + open Hashtbl.Make(struct + type t = p state + let equal s1 s2 = s1.point = s2.point + let hash s = Hashtbl.hash s.point + end) + type nonrec t = + { h: int t + ; mutable next_id: int } + let create () = + { h = create 500; next_id = 0 } + let add set s = + assert (s.point = normalize s.point); + try + let id = find set.h s in + `Found, {s with id} + with Not_found -> begin + let id = set.next_id in + set.next_id <- id + 1; + add set.h s id; + `Added, {s with id} + end + let iter set f = + let f s id = f {s with id} in + iter f set.h + let elems set = + let res = ref [] in + iter set (fun s -> res := s :: !res); + !res +end + +type table_key = + | K of op * p state * p state + +module StateMap = struct + include Map.Make(struct + type t = table_key + let compare ka kb = + match ka, kb with + | K (o, sl, sr), K (o', sl', sr') -> + compare (o, sl.id, sr.id) + (o', sl'.id, sr'.id) + end) + let invert n sm = + let rmap = Array.make n [] in + iter (fun k {id; _} -> + match k with + | K (o, sl, sr) -> + rmap.(id) <- + (o, (sl.id, sr.id)) :: rmap.(id) + ) sm; + Array.map group_by_fst rmap + let by_ops sm = + fold (fun tk s ops -> + match tk with + | K (op, l, r) -> + (op, ((l.id, r.id), s.id)) :: ops) + sm [] |> group_by_fst +end + +type rule = + { name: string + ; vars: string list + ; pattern: pattern } + +let generate_table rl = + let states = StateSet.create () in + let rl = + (* these atomic patterns must occur in + * rules so that we are able to number + * all possible refs *) + [ { name = "$"; vars = [] + ; pattern = Atm AnyCon } + ; { name = "%"; vars = [] + ; pattern = Atm Tmp } ] @ rl + in + (* initialize states *) + let ground = + List.concat_map + (fun r -> peel r.pattern r.name) rl |> + group_by_fst + in + let tmp = List.assoc (Atm Tmp) ground in + let con = List.assoc (Atm AnyCon) ground in + let atoms = ref [] in + let () = + List.iter (fun (seen, l) -> + let point = + if pattern_match (Atm Tmp) seen + then normalize (tmp @ l) + else normalize (con @ l) + in + let s = {id = -1; seen; point} in + let _, s = StateSet.add states s in + match get_atomic seen with + | Some atm -> atoms := (atm, s) :: !atoms + | None -> () + ) ground + in + (* setup loop state *) + let map = ref StateMap.empty in + let map_add k s' = + map := StateMap.add k s' !map + in + let flag = ref `Added in + let flagmerge = function + | `Added -> flag := `Added + | _ -> () + in + (* iterate until fixpoint *) + while !flag = `Added do + flag := `Stop; + let statel = StateSet.elems states in + iter_pairs statel (fun (sl, sr) -> + next_binary tmp sl sr |> + List.iter (fun (o, s') -> + let flag', s' = + StateSet.add states s' in + flagmerge flag'; + map_add (K (o, sl, sr)) s'; + )); + done; + let states = + StateSet.elems states |> + List.sort (fun s s' -> compare s.id s'.id) |> + Array.of_list + in + (states, !atoms, !map) + +let intersperse x l = + let rec go left right out = + let out = + (List.rev left @ [x] @ right) :: + out in + match right with + | x :: right' -> + go (x :: left) right' out + | [] -> out + in go [] l [] + +let rec permute = function + | [] -> [[]] + | x :: l -> + List.concat (List.map + (intersperse x) (permute l)) + +(* build all binary trees with ordered + * leaves l *) +let rec bins build l = + let rec go l r out = + match r with + | [] -> out + | x :: r' -> + go (l @ [x]) r' + (fold_pairs + (bins build l) + (bins build r) + out (fun (l, r) out -> + build l r :: out)) + in + match l with + | [] -> [] + | [x] -> [x] + | x :: l -> go [x] l [] + +let products l ini f = + let rec go acc la = function + | [] -> f (List.rev la) acc + | xs :: l -> + List.fold_left (fun acc x -> + go acc (x :: la) l) + acc xs + in go ini [] l + +(* combinatorial nuke... *) +let rec ac_equiv = + let rec alevel o = function + | Bnr (o', l, r) when o' = o -> + alevel o l @ alevel o r + | x -> [x] + in function + | Bnr (o, _, _) as p + when associative o -> + products + (List.map ac_equiv (alevel o p)) [] + (fun choice out -> + List.concat_map + (bins (fun l r -> Bnr (o, l, r))) + (if commutative o + then permute choice + else [choice]) @ out) + | Bnr (o, l, r) + when commutative o -> + fold_pairs + (ac_equiv l) (ac_equiv r) [] + (fun (l, r) out -> + Bnr (o, l, r) :: + Bnr (o, r, l) :: out) + | Bnr (o, l, r) -> + fold_pairs + (ac_equiv l) (ac_equiv r) [] + (fun (l, r) out -> + Bnr (o, l, r) :: out) + | x -> [x] + +module Action: sig + type node = + | Switch of (int * t) list + | Push of bool * t + | Pop of t + | Set of string * t + | Stop + and t = private + { id: int; node: node } + val equal: t -> t -> bool + val size: t -> int + val stop: t + val mk_push: sym:bool -> t -> t + val mk_pop: t -> t + val mk_set: string -> t -> t + val mk_switch: int list -> (int -> t) -> t + val pp: Format.formatter -> t -> unit +end = struct + type node = + | Switch of (int * t) list + | Push of bool * t + | Pop of t + | Set of string * t + | Stop + and t = + { id: int; node: node } + + let equal a a' = a.id = a'.id + let size a = + let seen = Hashtbl.create 10 in + let rec node_size = function + | Switch l -> + List.fold_left + (fun n (_, a) -> n + size a) 0 l + | (Push (_, a) | Pop a | Set (_, a)) -> + size a + | Stop -> 0 + and size {id; node} = + if Hashtbl.mem seen id + then 0 + else begin + Hashtbl.add seen id (); + 1 + node_size node + end + in + size a + + let mk = + let hcons = Hashtbl.create 100 in + let fresh = ref 0 in + fun node -> + let id = + try Hashtbl.find hcons node + with Not_found -> + let id = !fresh in + Hashtbl.add hcons node id; + fresh := id + 1; + id + in + {id; node} + let stop = mk Stop + let mk_push ~sym a = mk (Push (sym, a)) + let mk_pop a = + match a.node with + | Stop -> a + | _ -> mk (Pop a) + let mk_set v a = mk (Set (v, a)) + let mk_switch ids f = + match List.map f ids with + | [] -> failwith "empty switch"; + | c :: cs as cases -> + if List.for_all (equal c) cs then c + else + let cases = List.combine ids cases in + mk (Switch cases) + + open Format + let rec pp_node fmt = function + | Switch l -> + fprintf fmt "@[<v>@[<v2>switch{"; + let pp_case (c, a) = + let pp_sep fmt () = fprintf fmt "," in + fprintf fmt "@,@[<2>→%a:@ @[%a@]@]" + (pp_print_list ~pp_sep pp_print_int) + c pp a + in + inverse l |> group_by_fst |> inverse |> + List.iter pp_case; + fprintf fmt "@]@,}@]" + | Push (true, a) -> fprintf fmt "pushsym@ %a" pp a + | Push (false, a) -> fprintf fmt "push@ %a" pp a + | Pop a -> fprintf fmt "pop@ %a" pp a + | Set (v, a) -> fprintf fmt "set(%s)@ %a" v pp a + | Stop -> fprintf fmt "•" + and pp fmt a = pp_node fmt a.node +end + +(* a state is commutative if (a op b) enters + * it iff (b op a) enters it as well *) +let symmetric rmap id = + List.for_all (fun (_, l) -> + let l1, l2 = + List.filter (fun (a, b) -> a <> b) l |> + List.partition (fun (a, b) -> a < b) + in + setify l1 = setify (inverse l2)) + rmap.(id) + +(* left-to-right matching of a set of patterns; + * may raise if there is no lr matcher for the + * input rule *) +let lr_matcher statemap states rules name = + let rmap = + let nstates = Array.length states in + StateMap.invert nstates statemap + in + let exception Stuck in + (* the list of ids represents a class of terms + * whose root ends up being labelled with one + * such id; the gen function generates a matcher + * that will, given any such term, assign values + * for the Var nodes of one pattern in pats *) + let rec gen + : 'a. int list -> (pattern * 'a) list + -> (int -> (pattern * 'a) list -> Action.t) + -> Action.t + = fun ids pats k -> + Action.mk_switch (setify ids) @@ fun id_top -> + let sym = symmetric rmap id_top in + let id_ops = + if sym then + let ordered (a, b) = a <= b in + List.map (fun (o, l) -> + (o, List.filter ordered l)) + rmap.(id_top) + else rmap.(id_top) + in + (* consider only the patterns that are + * compatible with the current id *) + let atm_pats, bin_pats = + List.filter (function + | Bnr (o, _, _), _ -> + List.exists + (fun (o', _) -> o' = o) + id_ops + | _ -> true) pats |> + List.partition + (fun (pat, _) -> is_atomic pat) + in + try + if bin_pats = [] then raise Stuck; + let pats_l = + List.map (function + | (Bnr (o, l, r), x) -> + (l, (o, x, r)) + | _ -> assert false) + bin_pats + and pats_r = + List.map (fun (l, (o, x, r)) -> + (r, (o, l, x))) + and patstop = + List.map (fun (r, (o, l, x)) -> + (Bnr (o, l, r), x)) + in + let id_pairs = List.concat_map snd id_ops in + let ids_l = List.map fst id_pairs + and ids_r id_left = + List.filter_map (fun (l, r) -> + if l = id_left then Some r else None) + id_pairs + in + (* match the left arm *) + Action.mk_push ~sym + (gen ids_l pats_l + @@ fun lid pats -> + (* then the right arm, considering + * only the remaining possible + * patterns and knowing that the + * left arm was numbered 'lid' *) + Action.mk_pop + (gen (ids_r lid) (pats_r pats) + @@ fun _rid pats -> + (* continue with the parent *) + k id_top (patstop pats))) + with Stuck -> + let atm_pats = + let seen = states.(id_top).seen in + List.filter (fun (pat, _) -> + pattern_match pat seen) atm_pats + in + if atm_pats = [] then raise Stuck else + let vars = + List.filter_map (function + | (Var (v, _), _) -> Some v + | _ -> None) atm_pats |> setify + in + match vars with + | [] -> k id_top atm_pats + | [v] -> Action.mk_set v (k id_top atm_pats) + | _ -> failwith "ambiguous var match" + in + (* generate a matcher for the rule *) + let ids_top = + Array.to_list states |> + List.filter_map (fun {id; point = p; _} -> + if List.exists ((=) (Top name)) p then + Some id + else None) + in + let rec filter_dups pats = + match pats with + | p :: pats -> + if List.exists (pattern_match p) pats + then filter_dups pats + else p :: filter_dups pats + | [] -> [] + in + let pats_top = + List.filter_map (fun r -> + if r.name = name then + Some r.pattern + else None) rules |> + filter_dups |> + List.map (fun p -> (p, ())) + in + gen ids_top pats_top (fun _ pats -> + assert (pats <> []); + Action.stop) + +type numberer = + { atoms: (atomic_pattern * p state) list + ; statemap: p state StateMap.t + ; states: p state array + ; mutable ops: op list + (* memoizes the list of possible operations + * according to the statemap *) } + +let make_numberer sa am sm = + { atoms = am + ; states = sa + ; statemap = sm + ; ops = [] } + +let atom_state n atm = + List.assoc atm n.atoms |
