aboutsummaryrefslogtreecommitdiff
path: root/tools/mgen/cgen.ml
blob: 297265cc3086e07c2d5d9c6af5b1ce74c54c9b98 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
open Match

type options =
  { pfx: string
  ; static: bool
  ; oc: out_channel }

type side = L | R

type id_pred =
  | InBitSet of Int64.t
  | Ge of int
  | Eq of int

and id_test =
  | Pred of (side * id_pred)
  | And of id_test * id_test

type case_code =
  | Table of ((int * int) * int) list
  | IfThen of
      { test: id_test
      ; cif: case_code
      ; cthen: case_code option }
  | Return of int

type case =
  { swap: bool
  ; code: case_code }

let cgen_case tmp nstates map =
  let cgen_test ids =
    match ids with
    | [id] -> Eq id
    | _ ->
        let min_id =
          List.fold_left min max_int ids in
        if List.length ids = nstates - min_id
        then Ge min_id
        else begin
          assert (nstates <= 64);
          InBitSet
            (List.fold_left (fun bs id ->
                 Int64.logor bs
                   (Int64.shift_left 1L id))
                0L ids)
        end
  in
  let symmetric =
    let inverse ((l, r), x) = ((r, l), x) in
    setify map = setify (List.map inverse map) in
  let map =
    let ordered ((l, r), _) = r <= l in
    if symmetric then
      List.filter ordered map
    else map
  in
  let exception BailToTable in
  try
    let st =
      match setify (List.map snd map) with
      | [st] -> st
      | _ -> raise BailToTable
    in
    (* the operation considered can only
     * generate a single state *)
    let pairs = List.map fst map in
    let ls, rs = List.split pairs in
    let ls = setify ls and rs = setify rs in
    if List.length ls > 1 && List.length rs > 1 then
      raise BailToTable;
    { swap = symmetric
    ; code =
        let pl = Pred (L, cgen_test ls)
        and pr = Pred (R, cgen_test rs) in
        IfThen
          { test = And (pl, pr)
          ; cif = Return st
          ; cthen = Some (Return tmp) } }
  with BailToTable ->
    { swap = symmetric
    ; code = Table map }

let show_op (_cls, op) =
  "O" ^ show_op_base op

let indent oc i =
  Printf.fprintf oc "%s" (String.sub "\t\t\t\t\t" 0 i)

let emit_swap oc i =
  let pf m = Printf.fprintf oc m in
  let pfi n m = indent oc n; pf m in
  pfi i "if (l < r)\n";
  pfi (i+1) "t = l, l = r, r = t;\n"

let gen_tables oc tmp pfx nstates (op, c) =
  let i = 1 in
  let pf m = Printf.fprintf oc m in
  let pfi n m = indent oc n; pf m in
  let ntables = ref 0 in
  (* we must follow the order in which
   * we visit code in emit_case, or
   * else ntables goes out of sync *)
  let base = pfx ^ show_op op in
  let swap = c.swap in
  let rec gen c =
    match c with
    | Table map ->
        let name =
          if !ntables = 0 then base else
          base ^ string_of_int !ntables
        in
        assert (nstates <= 256);
        if swap then
          let n = nstates * (nstates + 1) / 2 in
          pfi i "static uchar %stbl[%d] = {\n" name n
        else
          pfi i "static uchar %stbl[%d][%d] = {\n"
            name nstates nstates;
        for l = 0 to nstates - 1 do
          pfi (i+1) "";
          for r = 0 to nstates - 1 do
            if not swap || r <= l then
              begin
                pf "%d"
                  (try List.assoc (l,r) map
                   with Not_found -> tmp);
                pf ",";
              end
          done;
          pf "\n";
        done;
        pfi i "};\n"
    | IfThen {cif; cthen} ->
        gen cif;
        Option.iter gen cthen
    | Return _ -> ()
  in
  gen c.code

let emit_case oc pfx no_swap (op, c) =
  let fpf = Printf.fprintf in
  let pf m = fpf oc m in
  let pfi n m = indent oc n; pf m in
  let rec side oc = function
    | L -> fpf oc "l"
    | R -> fpf oc "r"
  in
  let pred oc (s, pred) =
    match pred with
    | InBitSet bs -> fpf oc "BIT(%a) & %#Lx" side s bs
    | Eq id -> fpf oc "%a == %d" side s id
    | Ge id -> fpf oc "%d <= %a" id side s
  in
  let base = pfx ^ show_op op in
  let swap = c.swap in
  let ntables = ref 0 in
  let rec code i c =
    match c with
    | Return id -> pfi i "return %d;\n" id
    | Table map ->
        let name =
          if !ntables = 0 then base else
          base ^ string_of_int !ntables
        in
        incr ntables;
        if swap then
          pfi i "return %stbl[(l + l*l)/2 + r];\n" name
        else pfi i "return %stbl[l][r];\n" name
    | IfThen ({test = And (And (t1, t2), t3)} as r) ->
        code i @@ IfThen
          {r with test = And (t1, And (t2, t3))}
    | IfThen {test = And (Pred p, t); cif; cthen} ->
        pfi i "if (%a)\n" pred p;
        code i (IfThen {test = t; cif; cthen})
    | IfThen {test = Pred p; cif; cthen} ->
        pfi i "if (%a) {\n" pred p;
        code (i+1) cif;
        pfi i "}\n";
        Option.iter (code i) cthen
  in
  pfi 1 "case %s:\n" (show_op op);
  if not no_swap && c.swap then
    emit_swap oc 2;
  code 2 c.code

let emit_list
    ?(limit=60) ?(cut_before_sep=false)
    ~col ~indent:i ~sep ~f oc l =
  let sl = String.length sep in
  let rstripped_sep, rssl =
    if sep.[sl - 1] = ' ' then
      String.sub sep 0 (sl - 1), sl - 1
    else sep, sl
  in
  let lstripped_sep, lssl =
    if sep.[0] = ' ' then
      String.sub sep 1 (sl - 1), sl - 1
    else sep, sl
  in
  let rec line col acc = function
    | [] -> (List.rev acc, [])
    | s :: l ->
        let col = col + sl + String.length s in
        let no_space =
          if cut_before_sep || l = [] then
            col > limit
          else
            col + rssl > limit
        in
        if no_space then
          (List.rev acc, s :: l)
        else
          line col (s :: acc) l
  in
  let rec go col l =
    if l = [] then () else
    let ll, l = line col [] l in
    Printf.fprintf oc "%s" (String.concat sep ll);
    if l <> [] && cut_before_sep then begin
      Printf.fprintf oc "\n";
      indent oc i;
      Printf.fprintf oc "%s" lstripped_sep;
      go (8*i + lssl) l
    end else if l <> [] then begin
      Printf.fprintf oc "%s\n" rstripped_sep;
      indent oc i;
      go (8*i) l
    end else ()
  in
  go col (List.map f l)

let emit_numberer opts n =
  let pf m = Printf.fprintf opts.oc m in
  let tmp = (atom_state n Tmp).id in
  let con = (atom_state n AnyCon).id in
  let nst = Array.length n.states in
  let cases =
    StateMap.by_ops n.statemap |>
    List.map (fun (op, map) ->
        (op, cgen_case tmp nst map))
  in
  let all_swap =
    List.for_all (fun (_, c) -> c.swap) cases in
  (* opn() *)
  if opts.static then pf "static ";
  pf "int\n";
  pf "%sopn(int op, int l, int r)\n" opts.pfx;
  pf "{\n";
  cases |> List.iter
    (gen_tables opts.oc tmp opts.pfx nst);
  if List.exists (fun (_, c) -> c.swap) cases then
    pf "\tint t;\n\n";
  if all_swap then emit_swap opts.oc 1;
  pf "\tswitch (op) {\n";
  cases |> List.iter
    (emit_case opts.oc opts.pfx all_swap);
  pf "\tdefault:\n";
  pf "\t\treturn %d;\n" tmp;
  pf "\t}\n";
  pf "}\n\n";
  (* refn() *)
  if opts.static then pf "static ";
  pf "int\n";
  pf "%srefn(Ref r, Num *tn, Con *con)\n" opts.pfx;
  pf "{\n";
  let cons =
    List.filter_map (function
        | (Con c, s) -> Some (c, s.id)
        | _ -> None)
      n.atoms
  in
  if cons <> [] then
    pf "\tint64_t n;\n\n";
  pf "\tswitch (rtype(r)) {\n";
  pf "\tcase RTmp:\n";
  if tmp <> 0 then begin
    assert
      (List.exists (fun (_, s) ->
           s.id = 0
         ) n.atoms &&
       (* no temp should ever get state 0 *)
       List.for_all (fun (a, s) ->
           s.id <> 0 ||
           match a with
           | AnyCon | Con _ -> true
           | _ -> false
         ) n.atoms);
    pf "\t\tif (!tn[r.val].n)\n";
    pf "\t\t\ttn[r.val].n = %d;\n" tmp;
  end;
  pf "\t\treturn tn[r.val].n;\n";
  pf "\tcase RCon:\n";
  if cons <> [] then begin
    pf "\t\tif (con[r.val].type != CBits)\n";
    pf "\t\t\treturn %d;\n" con;
    pf "\t\tn = con[r.val].bits.i;\n";
    cons |> inverse |> group_by_fst
    |> List.iter (fun (id, cs) ->
        pf "\t\tif (";
        emit_list ~cut_before_sep:true
          ~col:20 ~indent:2 ~sep:" || "
          ~f:(fun c -> "n == " ^ Int64.to_string c)
          opts.oc cs;
        pf ")\n";
        pf "\t\t\treturn %d;\n" id
      );
  end;
  pf "\t\treturn %d;\n" con;
  pf "\tdefault:\n";
  pf "\t\treturn INT_MIN;\n";
  pf "\t}\n";
  pf "}\n\n";
  (* match[]: patterns per state *)
  if opts.static then pf "static ";
  pf "bits %smatch[%d] = {\n" opts.pfx nst;
  n.states |> Array.iteri (fun sn s ->
      let tops =
        List.filter_map (function
          | Top ("$" | "%") -> None
          | Top r -> Some ("BIT(P" ^ r ^ ")")
          | _ -> None) s.point |> setify
      in
      if tops <> [] then
        pf "\t[%d] = %s,\n"
          sn (String.concat " | " tops);
    );
  pf "};\n\n"

let var_id vars f =
  List.mapi (fun i x -> (x, i)) vars |>
  List.assoc f

let compile_action vars act =
  let pcs = Hashtbl.create 100 in
  let rec gen pc (act: Action.t) =
    try
      [10 + Hashtbl.find pcs act.id]
    with Not_found ->
      let code =
        match act.node with
        | Action.Stop ->
            [0]
        | Action.Push (sym, k) ->
            let c = if sym then 1 else 2 in
            [c] @ gen (pc + 1) k
        | Action.Set (v, {node = Action.Pop k; _})
        | Action.Set (v, ({node = Action.Stop; _} as k)) ->
            let v = var_id vars v in
            [3; v] @ gen (pc + 2) k
        | Action.Set _ ->
            (* for now, only atomic patterns can be
             * tied to a variable, so Set must be
             * followed by either Pop or Stop *)
            assert false
        | Action.Pop k ->
            [4] @ gen (pc + 1) k
        | Action.Switch cases ->
            let cases =
              inverse cases |> group_by_fst |>
              List.sort (fun (_, cs1) (_, cs2) ->
                  let n1 = List.length cs1
                  and n2 = List.length cs2 in
                  compare n2 n1)
            in
            (* the last case is the one with
             * the max number of entries *)
            let cases = List.rev (List.tl cases)
            and last = fst (List.hd cases) in
            let ncases =
              List.fold_left (fun n (_, cs) ->
                  List.length cs + n)
                0 cases
            in
            let body_off = 2 + 2 * ncases + 1 in
            let pc, tbl, body =
              List.fold_left
                (fun (pc, tbl, body) (a, cs) ->
                   let ofs = body_off + List.length body in
                   let case = gen pc a in
                   let pc = pc + List.length case in
                   let body = body @ case in
                   let tbl =
                     List.fold_left (fun tbl c ->
                         tbl @ [c; ofs]
                       ) tbl cs
                   in
                   (pc, tbl, body))
                (pc + body_off, [], [])
                cases
            in
            let ofs = body_off + List.length body in
            let tbl = tbl @ [ofs] in
            assert (2 + List.length tbl = body_off);
            [5; ncases] @ tbl @ body @ gen pc last
      in
      if act.node <> Action.Stop then
        Hashtbl.replace pcs act.id pc;
      code
  in
  gen 0 act

let emit_matchers opts ms =
  let pf m = Printf.fprintf opts.oc m in
  if opts.static then pf "static ";
  pf "uchar *%smatcher[] = {\n" opts.pfx;
  List.iter (fun (vars, pname, m) ->
      pf "\t[P%s] = (uchar[]){\n" pname;
      pf "\t\t";
      let bytes = compile_action vars m in
      emit_list
        ~col:16 ~indent:2 ~sep:","
        ~f:string_of_int opts.oc bytes;
      pf "\n";
      pf "\t},\n")
    ms;
  pf "};\n\n"

let emit_c opts n =
  emit_numberer opts n