aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorQuentin Carbonneaux <[email protected]>2017-12-14 22:35:30 +0100
committerQuentin Carbonneaux <[email protected]>2024-04-09 21:32:49 +0200
commita374da3c2e205bb8c8548c1a63f186b6d9188e9c (patch)
treefda4fc786fccb35b615632fc7abace320a6622f8
parent24d132442411804f140b2aacb20f41139deb20e4 (diff)
modulo ac matching and more tests
-rw-r--r--tools/match.ml357
-rw-r--r--tools/match_test.ml102
2 files changed, 392 insertions, 67 deletions
diff --git a/tools/match.ml b/tools/match.ml
index 0eaa244..4aeeae0 100644
--- a/tools/match.ml
+++ b/tools/match.ml
@@ -5,26 +5,36 @@ type op_base =
| Omul
type op = cls * op_base
+let commutative = function
+ | (_, (Oadd | Omul)) -> true
+ | (_, _) -> false
+
+let associative = function
+ | (_, (Oadd | Omul)) -> true
+ | (_, _) -> false
+
type atomic_pattern =
- | Any
+ | Tmp
+ | AnyCon
| Con of int64
type pattern =
| Bnr of op * pattern * pattern
- | Unr of op * pattern
| Atm of atomic_pattern
+ | Var of string * atomic_pattern
let rec pattern_match p w =
match p with
- | Atm (Any) -> true
- | Atm (Con _) -> w = p
- | Unr (o, pa) ->
+ | Var _ ->
+ failwith "variable not allowed"
+ | Atm (Tmp) ->
begin match w with
- | Unr (o', wa) ->
- o' = o &&
- pattern_match pa wa
- | _ -> false
+ | Atm (Con _ | AnyCon) -> false
+ | _ -> true
end
+ | Atm (Con _) -> w = p
+ | Atm (AnyCon) ->
+ not (pattern_match (Atm Tmp) w)
| Bnr (o, pl, pr) ->
begin match w with
| Bnr (o', wl, wr) ->
@@ -34,75 +44,288 @@ let rec pattern_match p w =
| _ -> false
end
-let test_pattern_match =
- let pm = pattern_match
- and nm = fun x y -> not (pattern_match x y)
- and o = (Kw, Oadd) in
- begin
- assert (pm (Atm Any) (Atm (Con 42L)));
- assert (pm (Atm Any) (Unr (o, Atm Any)));
- assert (nm (Atm (Con 42L)) (Atm Any));
- assert (pm (Unr (o, Atm Any))
- (Unr (o, Atm (Con 42L))));
- assert (nm (Unr (o, Atm Any))
- (Unr ((Kl, Oadd), Atm (Con 42L))));
- assert (nm (Unr (o, Atm Any))
- (Bnr (o, Atm (Con 42L), Atm Any)));
- end
-
-type cursor = (* a position inside a pattern *)
- | Bnrl of op * cursor * pattern
- | Bnrr of op * pattern * cursor
- | Unra of op * cursor
- | Top
+type 'a cursor = (* a position inside a pattern *)
+ | Bnrl of op * 'a cursor * pattern
+ | Bnrr of op * pattern * 'a cursor
+ | Top of 'a
let rec fold_cursor c p =
match c with
| Bnrl (o, c', p') -> fold_cursor c' (Bnr (o, p, p'))
| Bnrr (o, p', c') -> fold_cursor c' (Bnr (o, p', p))
- | Unra (o, c') -> fold_cursor c' (Unr (o, p))
- | Top -> p
+ | Top _ -> p
-let peel p =
- let once out (c, p) =
+let peel p x =
+ let once out (p, c) =
match p with
- | Atm _ -> (c, p) :: out
- | Unr (o, pa) ->
- (Unra (o, c), pa) :: out
+ | Var _ -> failwith "variable not allowed"
+ | Atm _ -> (p, c) :: out
| Bnr (o, pl, pr) ->
- (Bnrl (o, c, pr), pl) ::
- (Bnrr (o, pl, c), pr) :: out
+ (pl, Bnrl (o, c, pr)) ::
+ (pr, Bnrr (o, pl, c)) :: out
in
let rec go l =
let l' = List.fold_left once [] l in
if List.length l' = List.length l
then l
else go l'
- in go [(Top, p)]
-
-let test_peel =
- let o = Kw, Oadd in
- let p = Bnr (o, Bnr (o, Atm Any, Atm Any),
- Atm (Con 42L)) in
- let l = peel p in
- let () = assert (List.length l = 3) in
- let atomic_p (_, p) =
- match p with Atm _ -> true | _ -> false in
- let () = assert (List.for_all atomic_p l) in
- let l = List.map (fun (c, p) -> fold_cursor c p) l in
- let () = assert (List.for_all ((=) p) l) in
- ()
-
-(* we want to compute all the configurations we could
- * possibly be in when processing a block of instructions;
- * to do so, we start with all the possible cursors for
- * the list of patterns we are given, this will be our
- * main "initial state"; each constant (used in the
- * patterns) also generates a state of its own
- *
- * to create new states we can take pairs of states, and
- * combine them with binary operations, we keep the
- * result if it is non-trivial (non-empty) and new (we
- * have not seen this cursor combination yet); we can
- * also do the same with unary operations
- * *)
+ in go [(p, Top x)]
+
+let fold_pairs l1 l2 ini f =
+ let rec go acc = function
+ | [] -> acc
+ | a :: l1' ->
+ go (List.fold_left
+ (fun acc b -> f (a, b) acc)
+ acc l2) l1'
+ in go ini l1
+
+let iter_pairs l f =
+ fold_pairs l l () (fun x () -> f x)
+
+type 'a state =
+ { id: int
+ ; seen: pattern
+ ; point: ('a cursor) list }
+
+let rec binops side {point; _} =
+ List.fold_left (fun res c ->
+ match c, side with
+ | Bnrl (o, c, r), `L -> ((o, c), r) :: res
+ | Bnrr (o, l, c), `R -> ((o, c), l) :: res
+ | _ -> res)
+ [] point
+
+let group_by_fst l =
+ List.fast_sort (fun (a, _) (b, _) ->
+ compare a b) l |>
+ List.fold_left (fun (oo, l, res) (o', c) ->
+ match oo with
+ | None -> (Some o', [c], [])
+ | Some o when o = o' -> (oo, c :: l, res)
+ | Some o -> (Some o', [c], (o, l) :: res))
+ (None, [], []) |>
+ (function
+ | (None, _, _) -> []
+ | (Some o, l, res) -> (o, l) :: res)
+
+let sort_uniq cmp l =
+ List.fast_sort cmp l |>
+ List.fold_left (fun (eo, l) e' ->
+ match eo with
+ | None -> (Some e', l)
+ | Some e ->
+ if cmp e e' = 0
+ then (eo, l)
+ else (Some e', e :: l)
+ ) (None, []) |>
+ (function
+ | (None, _) -> []
+ | (Some e, l) -> List.rev (e :: l))
+
+let normalize (point: ('a cursor) list) =
+ sort_uniq compare point
+
+let nextbnr tmp s1 s2 =
+ let pm w (_, p) = pattern_match p w in
+ let o1 = binops `L s1 |>
+ List.filter (pm s2.seen) |>
+ List.map fst
+ and o2 = binops `R s2 |>
+ List.filter (pm s1.seen) |>
+ List.map fst
+ in
+ List.map (fun (o, l) ->
+ o,
+ { id = 0
+ ; seen = Bnr (o, s1.seen, s2.seen)
+ ; point = normalize (l @ tmp)
+ }) (group_by_fst (o1 @ o2))
+
+type p = string
+
+module StateSet : sig
+ type set
+ val create: unit -> set
+ val add: set -> p state ->
+ [> `Added | `Found ] * p state
+ val iter: set -> (p state -> unit) -> unit
+ val elems: set -> (p state) list
+end = struct
+ include Hashtbl.Make(struct
+ type t = p state
+ let equal s1 s2 = s1.point = s2.point
+ let hash s = Hashtbl.hash s.point
+ end)
+ type set =
+ { h: int t
+ ; mutable next_id: int }
+ let create () =
+ { h = create 500; next_id = 1 }
+ let add set s =
+ (* delete the check later *)
+ assert (s.point = normalize s.point);
+ try
+ let id = find set.h s in
+ `Found, {s with id}
+ with Not_found -> begin
+ let id = set.next_id in
+ set.next_id <- id + 1;
+ add set.h s id;
+ `Added, {s with id}
+ end
+ let iter set f =
+ let f s id = f {s with id} in
+ iter f set.h
+ let elems set =
+ let res = ref [] in
+ iter set (fun s -> res := s :: !res);
+ !res
+end
+
+type table_key =
+ | K of op * p state * p state
+
+module StateMap = Map.Make(struct
+ type t = table_key
+ let compare ka kb =
+ match ka, kb with
+ | K (o, sl, sr), K (o', sl', sr') ->
+ compare (o, sl.id, sr.id)
+ (o', sl'.id, sr'.id)
+end)
+
+type rule =
+ { name: string
+ ; pattern: pattern
+ (* TODO access pattern *)
+ }
+
+let generate_table rl =
+ let states = StateSet.create () in
+ (* initialize states *)
+ let ground =
+ List.fold_left
+ (fun ini r ->
+ peel r.pattern r.name @ ini)
+ [] rl |>
+ group_by_fst
+ in
+ let find x d l =
+ try List.assoc x l with Not_found -> d in
+ let tmp = find (Atm Tmp) [] ground in
+ let con = find (Atm AnyCon) [] ground in
+ let () =
+ List.iter (fun (seen, l) ->
+ let point =
+ if pattern_match (Atm Tmp) seen
+ then normalize (tmp @ l)
+ else normalize (con @ l)
+ in
+ let s = {id = 0; seen; point} in
+ let flag, _ = StateSet.add states s in
+ assert (flag = `Added)
+ ) ground
+ in
+ (* setup loop state *)
+ let map = ref StateMap.empty in
+ let map_add k s' =
+ map := StateMap.add k s' !map
+ in
+ let flag = ref `Added in
+ let flagmerge = function
+ | `Added -> flag := `Added
+ | _ -> ()
+ in
+ (* iterate until fixpoint *)
+ while !flag = `Added do
+ flag := `Stop;
+ let statel = StateSet.elems states in
+ iter_pairs statel (fun (sl, sr) ->
+ nextbnr tmp sl sr |>
+ List.iter (fun (o, s') ->
+ let flag', s' =
+ StateSet.add states s' in
+ flagmerge flag';
+ map_add (K (o, sl, sr)) s';
+ ));
+ done;
+ (StateSet.elems states, !map)
+
+let intersperse x l =
+ let rec go left right out =
+ let out =
+ (List.rev left @ [x] @ right) ::
+ out in
+ match right with
+ | x :: right' ->
+ go (x :: left) right' out
+ | [] -> out
+ in go [] l []
+
+let rec permute = function
+ | [] -> [[]]
+ | x :: l ->
+ List.concat (List.map
+ (intersperse x) (permute l))
+
+(* build all binary trees with ordered
+ * leaves l *)
+let rec bins build l =
+ let rec go l r out =
+ match r with
+ | [] -> out
+ | x :: r' ->
+ go (l @ [x]) r'
+ (fold_pairs
+ (bins build l)
+ (bins build r)
+ out (fun (l, r) out ->
+ build l r :: out))
+ in
+ match l with
+ | [] -> []
+ | [x] -> [x]
+ | x :: l -> go [x] l []
+
+let products l ini f =
+ let rec go acc la = function
+ | [] -> f (List.rev la) acc
+ | xs :: l ->
+ List.fold_left (fun acc x ->
+ go acc (x :: la) l)
+ acc xs
+ in go ini [] l
+
+(* combinatorial nuke... *)
+let rec ac_equiv =
+ let rec alevel o = function
+ | Bnr (o', l, r) when o' = o ->
+ alevel o l @ alevel o r
+ | x -> [x]
+ in function
+ | Bnr (o, _, _) as p
+ when associative o ->
+ products
+ (List.map ac_equiv (alevel o p)) []
+ (fun choice out ->
+ List.map
+ (bins (fun l r -> Bnr (o, l, r)))
+ (if commutative o
+ then permute choice
+ else [choice]) |>
+ List.concat |>
+ (fun l -> List.rev_append l out))
+ | Bnr (o, l, r)
+ when commutative o ->
+ fold_pairs
+ (ac_equiv l) (ac_equiv r) []
+ (fun (l, r) out ->
+ Bnr (o, l, r) ::
+ Bnr (o, r, l) :: out)
+ | Bnr (o, l, r) ->
+ fold_pairs
+ (ac_equiv l) (ac_equiv r) []
+ (fun (l, r) out ->
+ Bnr (o, l, r) :: out)
+ | x -> [x]
diff --git a/tools/match_test.ml b/tools/match_test.ml
new file mode 100644
index 0000000..75e2005
--- /dev/null
+++ b/tools/match_test.ml
@@ -0,0 +1,102 @@
+#use "match.ml"
+
+let test_pattern_match =
+ let pm = pattern_match
+ and nm = fun x y -> not (pattern_match x y) in
+ begin
+ assert (nm (Atm Tmp) (Atm (Con 42L)));
+ assert (pm (Atm AnyCon) (Atm (Con 42L)));
+ assert (nm (Atm (Con 42L)) (Atm AnyCon));
+ assert (nm (Atm (Con 42L)) (Atm Tmp));
+ end
+
+let test_peel =
+ let o = Kw, Oadd in
+ let p = Bnr (o, Bnr (o, Atm Tmp, Atm Tmp),
+ Atm (Con 42L)) in
+ let l = peel p () in
+ let () = assert (List.length l = 3) in
+ let atomic_p (p, _) =
+ match p with Atm _ -> true | _ -> false in
+ let () = assert (List.for_all atomic_p l) in
+ let l = List.map (fun (p, c) -> fold_cursor c p) l in
+ let () = assert (List.for_all ((=) p) l) in
+ ()
+
+let test_fold_pairs =
+ let l = [1; 2; 3; 4; 5] in
+ let p = fold_pairs l l [] (fun a b -> a :: b) in
+ let () = assert (List.length p = 25) in
+ let p = sort_uniq compare p in
+ let () = assert (List.length p = 25) in
+ ()
+
+(* test pattern & state *)
+let tp =
+ let o = Kw, Oadd in
+ Bnr (o, Bnr (o, Atm Tmp, Atm Tmp),
+ Atm (Con 0L))
+let ts =
+ { id = 0
+ ; seen = Atm Tmp
+ ; point =
+ List.map snd
+ (List.filter (fun (p, _) -> p = Atm Tmp)
+ (peel tp ()))
+ }
+
+let print_sm =
+ let op_str (k, o) =
+ Printf.sprintf "%s%s"
+ (match o with
+ | Oadd -> "add"
+ | Osub -> "sub"
+ | Omul -> "mul")
+ (match k with
+ | Kw -> "w"
+ | Kl -> "l"
+ | Ks -> "s"
+ | Kd -> "d")
+ in
+ StateMap.iter (fun k s' ->
+ match k with
+ | K (o, sl, sr) ->
+ Printf.printf
+ "(%s %d %d) -> %d\n"
+ (op_str o)
+ sl.id sr.id s'.id
+ )
+
+let address_rules =
+ let oa = Kl, Oadd in
+ let om = Kl, Omul in
+ let rule name pattern = { name; pattern; } in
+ (* o + b *)
+ [ rule "ob1" (Bnr (oa, Atm Tmp, Atm AnyCon))
+ ; rule "ob2" (Bnr (oa, Atm AnyCon, Atm Tmp))
+
+ (* b + s * i *)
+ ; rule "bs1" (Bnr (oa, Atm Tmp, Bnr (om, Atm AnyCon, Atm Tmp)))
+ ; rule "bs2" (Bnr (oa, Atm Tmp, Bnr (om, Atm Tmp, Atm AnyCon)))
+ ; rule "bs3" (Bnr (oa, Bnr (om, Atm AnyCon, Atm Tmp), Atm Tmp))
+ ; rule "bs4" (Bnr (oa, Bnr (om, Atm Tmp, Atm AnyCon), Atm Tmp))
+
+ (* o + s * i *)
+ ; rule "os1" (Bnr (oa, Atm AnyCon, Bnr (om, Atm AnyCon, Atm Tmp)))
+ ; rule "os2" (Bnr (oa, Atm AnyCon, Bnr (om, Atm Tmp, Atm AnyCon)))
+ ; rule "os3" (Bnr (oa, Bnr (om, Atm AnyCon, Atm Tmp), Atm AnyCon))
+ ; rule "os4" (Bnr (oa, Bnr (om, Atm Tmp, Atm AnyCon), Atm AnyCon))
+ ]
+
+(*
+let sl, sm = generate_table address_rules
+let s n = List.find (fun {id; _} -> id = n) sl
+let () = print_sm sm
+*)
+
+let tp0 =
+ let o = Kw, Oadd in
+ Bnr (o, Atm Tmp, Atm (Con 0L))
+let tp1 =
+ let o = Kw, Oadd in
+ Bnr (o, tp0, Atm (Con 1L))