mirror of https://github.com/geohot/qira
Add byteweight implementation into extra
This commit is contained in:
parent
f90584139a
commit
8d0206023b
|
@ -0,0 +1,6 @@
|
|||
ByteWeight is a function identification project for cross platforms. It is based
|
||||
on CMU Binary Analysis Platform.
|
||||
|
||||
Current code contains both training and testing implementation. To train out
|
||||
signature file, one should use train by:
|
||||
./train -bin-dir [training binary] -sig [output signature file]
|
|
@ -0,0 +1,28 @@
|
|||
OASISFormat: 0.4
|
||||
Name: byteweight
|
||||
Version: 0.1
|
||||
Synopsis: Function Identification Tool
|
||||
Authors: Tiffany Bao
|
||||
License: MIT
|
||||
Plugins: META (0.4), DevFiles (0.4)
|
||||
BuildTools: ocamlbuild
|
||||
BuildDepends: str,
|
||||
bap-lifter,
|
||||
bap-types,
|
||||
bap-container
|
||||
|
||||
Library "byteweight"
|
||||
Path: .
|
||||
BuildTools: ocamlbuild
|
||||
Modules: Byteweight
|
||||
CompiledObject: best
|
||||
|
||||
Executable "byteweight"
|
||||
Path: .
|
||||
MainIs: byteweight.ml
|
||||
CompiledObject: best
|
||||
|
||||
Executable "train"
|
||||
Path: .
|
||||
MainIs: train.ml
|
||||
CompiledObject: best
|
|
@ -0,0 +1,21 @@
|
|||
|
||||
let generate_keys bytes =
|
||||
let rec rec_g res bytes =
|
||||
let len = String.length bytes in
|
||||
if len == 0 then res
|
||||
else
|
||||
let tl_bytes =
|
||||
String.sub bytes 0 (len-1)
|
||||
in
|
||||
rec_g (bytes :: res) tl_bytes
|
||||
in
|
||||
rec_g [] bytes
|
||||
|
||||
(*consecutive : addr -> addr -> int -> Exec_container.t -> string *)
|
||||
let consecutive addr end_addr len container =
|
||||
let real_end =
|
||||
let max_addr = Bitvector.plus addr (Bitvector.lit len 32) in
|
||||
if Bitvector.bool_of (Bitvector.lt end_addr max_addr) then end_addr
|
||||
else max_addr
|
||||
in
|
||||
Exec_container.Reader.get_bytes container addr real_end
|
|
@ -0,0 +1,126 @@
|
|||
(* TODO: convert to Core_list? *)
|
||||
module DismTrie = Dism.Trie
|
||||
let usage = "./match [binary file]"
|
||||
let k = 10
|
||||
|
||||
let g_wpt = Filename.concat (Filename.dirname Sys.executable_name) "signatures/sig_arm"
|
||||
|
||||
let threshold = ref 0.5
|
||||
let d_bin = ref None
|
||||
let d_out = ref None
|
||||
let bin = ref None
|
||||
let out = ref stdout
|
||||
let f_wpt = ref g_wpt
|
||||
(* let arch = ref None *)
|
||||
|
||||
let arg_specs =
|
||||
("-wpt", Arg.String(fun s -> f_wpt := s), "weighted prefix tree file")
|
||||
:: ("-bin-dir", Arg.String(fun s -> d_bin := Some s), "test binary directory")
|
||||
(* :: ("-bin", Arg.String(fun s -> bin := Some s), "test binary") *)
|
||||
:: ("-o-dir", Arg.String(fun s -> d_out := Some s; try Unix.mkdir s 0o755 with _ -> ()), "output directory")
|
||||
:: ("-o", Arg.String(fun s -> out := open_out s), "output file")
|
||||
:: ("-t", Arg.Float(fun f -> threshold := f), "threshold")
|
||||
(* Question: can BAP infer architecture from binaries? *)
|
||||
:: []
|
||||
|
||||
let anon_fun s = bin := Some s
|
||||
|
||||
|
||||
(* fsi_container : Trie.t -> Container.exec_container -> addr list -> addr list *)
|
||||
let fsi_container trie container codes =
|
||||
let fs_sec = List.map (fun (start_addr, end_addr) ->
|
||||
(* score : Trie.t -> addr -> float *)
|
||||
let score addr =
|
||||
let disms = Dism.consecutive addr end_addr k container in
|
||||
DismTrie.find trie disms
|
||||
in
|
||||
|
||||
let rec rec_score addr fs =
|
||||
if (Bitvector.bool_of (Bitvector.lt end_addr addr)) then fs
|
||||
else
|
||||
let s = score addr in
|
||||
if s > !threshold then
|
||||
rec_score (Bitvector.incr addr) (addr :: fs)
|
||||
else
|
||||
rec_score (Bitvector.incr addr) fs
|
||||
in
|
||||
|
||||
rec_score start_addr []
|
||||
) codes in
|
||||
(* List.concat fs_sec *)
|
||||
(* List.concat is not tail-recursive, so I use List.fold_left instead *)
|
||||
List.sort
|
||||
Bitvector.compare
|
||||
(List.fold_left (fun res l -> List.rev_append (List.rev res) l) [] fs_sec)
|
||||
|
||||
|
||||
(* output: out_channel -> addr list -> unit *)
|
||||
let output oc fsi =
|
||||
List.iter (fun addr ->
|
||||
Printf.fprintf oc "%s\n" (Bitvector.to_hex addr)
|
||||
) fsi;
|
||||
close_out oc
|
||||
|
||||
let get_code_segments container =
|
||||
let sections = Exec_container.Reader.get_sections container in
|
||||
List.fold_left (
|
||||
fun res {Exec_container.start_addr=start_addr;
|
||||
Exec_container.end_addr=end_addr;
|
||||
Exec_container.permissions=permissions} ->
|
||||
if List.mem Exec_container.X permissions then
|
||||
(start_addr, end_addr) :: res
|
||||
else
|
||||
res
|
||||
) [] sections
|
||||
|
||||
(* fsi_bin : string -> Trie.t -> addr list *)
|
||||
let fsi_bin bin trie =
|
||||
let exec_container = Dism.get_container bin in
|
||||
match exec_container with
|
||||
| None -> failwith (Printf.sprintf "Binary Load Error %s" bin)
|
||||
| Some container ->
|
||||
(* codes: (addr * addr) list *)
|
||||
let codes = get_code_segments container in
|
||||
(* List.iter (fun (st, en) -> Printf.printf "%s %s\n%!"
|
||||
(Bitvector.to_hex st) (Bitvector.to_hex en)) codes; *)
|
||||
fsi_container trie container codes
|
||||
|
||||
|
||||
(* main *)
|
||||
let () =
|
||||
let () = Arg.parse arg_specs anon_fun usage in
|
||||
match !bin, !d_bin with
|
||||
| None, Some d_i -> (
|
||||
match !d_out with
|
||||
| None ->
|
||||
let err =
|
||||
Printf.sprintf "Output directory is required.\n" ^ usage
|
||||
in
|
||||
raise (Arg.Bad err)
|
||||
| Some d_o ->
|
||||
let trie = Dism.load !f_wpt in
|
||||
let bins = List.map
|
||||
(Filename.concat d_i)
|
||||
(Array.to_list (Sys.readdir d_i))
|
||||
in
|
||||
List.iter (fun bin ->
|
||||
let fs = fsi_bin bin trie in
|
||||
let oc =
|
||||
let bin_out = Filename.concat d_o (Filename.basename bin) in
|
||||
open_out bin_out
|
||||
in
|
||||
output oc fs
|
||||
) bins
|
||||
)
|
||||
| Some i, None ->
|
||||
let trie = Dism.load !f_wpt in
|
||||
let fs = fsi_bin i trie in
|
||||
output !out fs
|
||||
| _ -> raise (Arg.Bad usage)
|
||||
|
||||
|
||||
(* get_functions: exec_container -> addr list *)
|
||||
let get_functions container =
|
||||
let trie = Dism.load g_wpt in
|
||||
let codes = get_code_segments container in
|
||||
fsi_container trie container codes
|
|
@ -0,0 +1 @@
|
|||
val get_functions : Exec_container.t -> Exec_container.addr list
|
|
@ -0,0 +1,103 @@
|
|||
exception No_Dism
|
||||
module Instr = struct
|
||||
type t = string
|
||||
let equal i j = i = j
|
||||
let hash = Hashtbl.hash
|
||||
end
|
||||
|
||||
module M = struct
|
||||
module D = Hashtbl.Make(Instr)
|
||||
include D
|
||||
let format a = a
|
||||
end
|
||||
|
||||
module Trie = Trie.Make(M)
|
||||
|
||||
let sep = ";"
|
||||
|
||||
(* read_from_ic : in_channel -> string list * float *)
|
||||
let read_from_ic ic =
|
||||
let to_dism_score line =
|
||||
let words = Str.split (Str.regexp "->") line in
|
||||
match words with
|
||||
| [disms_str; counts] ->
|
||||
let disms = Str.split (Str.regexp ";") disms_str in
|
||||
let p, n =
|
||||
let p_n = Str.split (Str.regexp ",") counts in
|
||||
match p_n with
|
||||
| [p;n] -> float_of_string p, float_of_string n
|
||||
| _ -> failwith "WPT File Format error"
|
||||
in
|
||||
disms, (p /. (p +. n))
|
||||
| _ -> failwith "WPT File Format error"
|
||||
in
|
||||
let sigs = ref [] in
|
||||
try
|
||||
while true; do
|
||||
let line = input_line ic in
|
||||
let disms, score = to_dism_score line in
|
||||
sigs := (disms, score) :: !sigs
|
||||
done;
|
||||
[]
|
||||
with End_of_file ->
|
||||
close_in ic;
|
||||
!sigs
|
||||
|
||||
let load file =
|
||||
let ic = open_in file in
|
||||
(* sigs : string list * float *)
|
||||
let sigs = read_from_ic ic in
|
||||
let trie = Trie.init 0.0 in
|
||||
List.iter (fun (k, v) ->
|
||||
Trie.add trie k v
|
||||
) sigs;
|
||||
trie
|
||||
|
||||
let get_disasm container addr =
|
||||
let module ARM = Arch_arm.ARM in
|
||||
(* Printf.printf "%s\n" (Bitvector.to_hex addr); *)
|
||||
let _, _, fallthrough, dism =
|
||||
ARM.disasm ARM.init_state (fun addr ->
|
||||
String.get
|
||||
(Exec_container.Reader.get_bytes container addr (Bitvector.incr addr)) 0
|
||||
) addr
|
||||
in
|
||||
match dism with
|
||||
| None -> raise No_Dism
|
||||
| Some d -> d, fallthrough
|
||||
|
||||
|
||||
(* consecutive: addr -> addr -> int -> Container.exec_container -> asm list *)
|
||||
let consecutive addr end_addr len container =
|
||||
let rec rec_consecutive addr i disms =
|
||||
if (i >= len) || (Bitvector.bool_of (Bitvector.lt end_addr addr)) then
|
||||
List.rev disms
|
||||
else try (
|
||||
let dism, fallthrough = get_disasm container addr in
|
||||
rec_consecutive fallthrough (i + 1) (Normalize.normalize dism :: disms)
|
||||
)
|
||||
with _ -> List.rev disms
|
||||
in
|
||||
rec_consecutive addr 0 []
|
||||
|
||||
|
||||
let get_container bin =
|
||||
let ic = open_in_bin bin in
|
||||
let buf = String.create (in_channel_length ic) in
|
||||
let () = really_input ic buf 0 (String.length buf) in
|
||||
let () = close_in ic in
|
||||
Elf_container.load_executable buf
|
||||
|
||||
|
||||
let generate_keys disms =
|
||||
let rec rec_g res prefix = function
|
||||
| [] -> res
|
||||
| hd :: tl ->
|
||||
let new_key =
|
||||
if prefix = "" then hd
|
||||
else Printf.sprintf "%s%s%s" prefix sep hd
|
||||
in
|
||||
rec_g (new_key :: res) new_key tl
|
||||
in
|
||||
rec_g [] "" disms
|
||||
|
|
@ -0,0 +1,106 @@
|
|||
open Core_kernel.Std
|
||||
open Or_error
|
||||
open Dwarf
|
||||
|
||||
module Buffer = Dwarf_data.Buffer
|
||||
(*
|
||||
let fb bin =
|
||||
let tmp_file = Filename.temp_file bin ".tmp" in
|
||||
let command =
|
||||
Printf.sprintf
|
||||
"arm-linux-gnueabi-objdump -t %s | grep \"F .text\" | \
|
||||
gawk `{ \
|
||||
start=strtonum(\"0x\"$1); \
|
||||
size=strtonum(\"0x\"$5); \
|
||||
printf(\"%%x %%x\\n\", start, start+size)}` | sort -u > %s"
|
||||
bin tmp_file
|
||||
in
|
||||
let () = Sys.command command in
|
||||
read_fb tmp_file
|
||||
*)
|
||||
|
||||
let read_fs filename =
|
||||
let lines = In_channel.read_lines filename in
|
||||
List.map lines ~f: (fun line ->
|
||||
let addr_int = int_of_string (Printf.sprintf "0x%s" line) in
|
||||
Bitvector.lit addr_int 32
|
||||
)
|
||||
|
||||
let fs bin =
|
||||
let tmp_file = Filename.temp_file (Filename.basename bin) ".tmp" in
|
||||
let command =
|
||||
Printf.sprintf
|
||||
"arm-linux-gnueabi-objdump -t %s | grep \"F .text\" | \
|
||||
awk '{print $1}' | sort -u > %s"
|
||||
bin tmp_file
|
||||
in
|
||||
(* Printf.printf "%s%!" command; *)
|
||||
let _ = Sys.command command in
|
||||
read_fs tmp_file
|
||||
(*
|
||||
let fs_dwarf filename =
|
||||
let filedata = In_channel.read_all filename in
|
||||
match Elf.parse filedata with
|
||||
| None -> (*errorf "%s is not an elf file\n" filename *) []
|
||||
| Some elf ->
|
||||
let open Elf in
|
||||
let endian = match elf.e_data with
|
||||
| ELFDATA2LSB -> LittleEndian
|
||||
| ELFDATA2MSB -> BigEndian in
|
||||
let create name s = Some (name, Buffer.create s.sh_data) in
|
||||
let sections = List.filter_map elf.e_sections ~f:(fun s ->
|
||||
match s.sh_name with
|
||||
| ".debug_info" -> create Section.Info s
|
||||
| ".debug_abbrev" -> create Section.Abbrev s
|
||||
| ".debug_str" -> create Section.Str s
|
||||
| _ -> None) in
|
||||
match Dwarf_data.create endian sections with
|
||||
| Ok data ->
|
||||
(match Dff.create data with
|
||||
| Ok dff ->
|
||||
let seq = Sequence.map (Dff.functions dff) ~f:(fun (_, fn) ->
|
||||
match Dff.Fn.pc_lo fn with
|
||||
| Dwarf.Addr.Int64 x -> Bitvector.litz (Z.of_int64 x) 64
|
||||
| Dwarf.Addr.Int32 x -> Bitvector.litz (Z.of_int32 x) 32
|
||||
) in
|
||||
Sequence.to_list seq
|
||||
| Error err -> (* eprintf "error" @@ Error.to_string_hum err; *) [])
|
||||
| _ -> []
|
||||
*)
|
||||
|
||||
let dwarf_to_bitvector = function
|
||||
| Dwarf.Addr.Int64 x -> Bitvector.litz (Z.of_int64 x) 64
|
||||
| Dwarf.Addr.Int32 x -> Bitvector.litz (Z.of_int32 x) 32
|
||||
|
||||
let fs_dwarf filename =
|
||||
let filedata = In_channel.read_all filename in
|
||||
let res =
|
||||
match Elf.parse filedata with
|
||||
| None -> errorf "%s is not an elf file\n" filename
|
||||
| Some elf ->
|
||||
let open Elf in
|
||||
let endian = match elf.e_data with
|
||||
| ELFDATA2LSB -> LittleEndian
|
||||
| ELFDATA2MSB -> BigEndian in
|
||||
let create name s = Some (name, Buffer.create s.sh_data) in
|
||||
let sections = List.filter_map elf.e_sections ~f:(fun s ->
|
||||
match s.sh_name with
|
||||
| ".debug_info" -> create Section.Info s
|
||||
| ".debug_abbrev" -> create Section.Abbrev s
|
||||
| ".debug_str" -> create Section.Str s
|
||||
| _ -> None) in
|
||||
Dwarf_data.create endian sections >>= fun data ->
|
||||
Dff.create data >>| fun dff ->
|
||||
let seq = Sequence.map (Dff.functions dff) ~f:(fun (_, fn) ->
|
||||
dwarf_to_bitvector (Dff.Fn.pc_lo fn)
|
||||
) in
|
||||
Sequence.to_list seq
|
||||
in match res with
|
||||
| Ok x ->
|
||||
(*
|
||||
let gt = Filename.concat "gt_dwarf" (Filename.basename filename) in
|
||||
Out_channel.write_lines gt (List.map x ~f:Bitvector.to_hex);
|
||||
*)
|
||||
x
|
||||
| Error err -> Printf.printf "dwarf error %s: %s\n" filename
|
||||
(Error.to_string_hum err); []
|
|
@ -0,0 +1,46 @@
|
|||
type t_arch =
|
||||
| Arm
|
||||
| X86
|
||||
| X86_64
|
||||
|
||||
|
||||
let replace_patt str patt =
|
||||
Str.global_replace (Str.regexp patt) patt str
|
||||
|
||||
let normalize ?(arch=None) s =
|
||||
let s_trimmed = String.trim s in
|
||||
let s_normalized =
|
||||
match arch with
|
||||
| None | Some Arm ->
|
||||
let norm_const =
|
||||
let neg = "#-[1-9a-f][0-9a-f]*"
|
||||
and pos = "#[1-9a-f][0-9a-f]*"
|
||||
and zero = "#0+" in
|
||||
List.fold_left replace_patt s_trimmed [neg;pos;zero]
|
||||
in
|
||||
let norm_branch =
|
||||
let bl = "^b\\(l\\)?[ \t]+[1-9a-f]+" in
|
||||
replace_patt norm_const bl
|
||||
in
|
||||
norm_branch
|
||||
| Some X86
|
||||
| Some X86_64 ->
|
||||
let norm_const =
|
||||
let neg = "-\\(\\$\\)?\\(0x\\)?[0-9a-f]+"
|
||||
and pos = "\\(\\$\\)?\\(0x\\)?[0-9a-f]+"
|
||||
and zero = "\\(\\$\\)?\\(0x\\)?0+" in
|
||||
List.fold_left replace_patt s_trimmed [neg;pos;zero]
|
||||
in
|
||||
let norm_branch =
|
||||
let jump = "^j[a-z]+[ \t]+\\(\\*\\)?[0-9a-f]+"
|
||||
and call = "^call[a-z]+[ \t]+\\(\\*\\)?[0-9a-f]+" in
|
||||
List.fold_left replace_patt norm_const [jump;call]
|
||||
in
|
||||
norm_branch
|
||||
in
|
||||
let s_stripped =
|
||||
let s_splitted = Str.split (Str.regexp "[ \t]+") s_normalized in
|
||||
String.concat "" s_splitted
|
||||
in
|
||||
(* Printf.printf "====\n%s\n%s\n====\n%!" s s_stripped; *)
|
||||
s_stripped
|
|
@ -0,0 +1,137 @@
|
|||
module I = Dism.Instr
|
||||
module Sigs = Hashtbl.Make(I)
|
||||
let usage = "Train: ./train -bin-dir [test binary directory] -sig [output signature file]"
|
||||
let d_bin = ref None
|
||||
let sig_out = ref None
|
||||
let k = 20
|
||||
|
||||
let arg_specs =
|
||||
("-bin-dir", Arg.String(fun s -> d_bin := Some s), "train binary directory")
|
||||
:: ("-sig", Arg.String(fun s -> sig_out := Some s), "output signature file")
|
||||
:: []
|
||||
|
||||
let anon_fun _ = raise (Arg.Bad usage)
|
||||
|
||||
let parse_command =
|
||||
Arg.parse arg_specs anon_fun usage;
|
||||
match !d_bin, !sig_out with
|
||||
| Some d, Some out -> d, out
|
||||
| _ -> raise (Arg.Bad usage)
|
||||
|
||||
(* build_sigs : (addr list * Container.exec_container) list -> string Hashtbl.t *)
|
||||
let build_sigs info =
|
||||
let sigs = Sigs.create 1000 in
|
||||
List.iter (fun (fs, sections, container) ->
|
||||
List.iter (fun addr ->
|
||||
let keys =
|
||||
let sec_end =
|
||||
let rec rec_sec_end addr = function
|
||||
| (st, nd) :: tl ->
|
||||
if (Bitvector.bool_of (Bitvector.le addr nd))
|
||||
&& (Bitvector.bool_of (Bitvector.le st addr))
|
||||
then nd
|
||||
else rec_sec_end addr tl
|
||||
| [] ->
|
||||
failwith (
|
||||
Printf.sprintf "Function %s is not in executable segment"
|
||||
(Bitvector.to_hex addr))
|
||||
in
|
||||
rec_sec_end addr sections
|
||||
in
|
||||
(* let disms = Dism.consecutive addr sec_end k container in
|
||||
Dism.generate_keys disms *)
|
||||
let bytes = Byte.consecutive addr sec_end k container in
|
||||
Byte.generate_keys bytes
|
||||
in
|
||||
List.iter (fun key ->
|
||||
try
|
||||
(* Printf.printf "%s\n%!" key; *)
|
||||
let (p, n) = Sigs.find sigs key in
|
||||
Sigs.replace sigs key (p + 1, n)
|
||||
with Not_found ->
|
||||
Sigs.add sigs key (1, 0)
|
||||
) keys
|
||||
) fs
|
||||
) info;
|
||||
sigs
|
||||
|
||||
(* update_sigs :
|
||||
* string Hashtbl.t -> (addr list * Container.section) list -> unit
|
||||
*)
|
||||
let update_sigs sigs info =
|
||||
List.iter (fun (fs, sections, container) ->
|
||||
List.iter (fun (start_addr, end_addr) ->
|
||||
let rec rec_update addr =
|
||||
if addr > end_addr then ()
|
||||
else if List.mem addr fs then
|
||||
rec_update (Bitvector.incr addr)
|
||||
else (
|
||||
let keys =
|
||||
(* let disms = Dism.consecutive addr end_addr k container in
|
||||
Dism.generate_keys disms *)
|
||||
let bytes = Byte.consecutive addr end_addr k container in
|
||||
Byte.generate_keys bytes
|
||||
in
|
||||
List.iter (fun key ->
|
||||
try
|
||||
let (p, n) = Sigs.find sigs key in
|
||||
Sigs.replace sigs key (p, n + 1)
|
||||
with Not_found -> ()
|
||||
) keys;
|
||||
rec_update (Bitvector.incr addr)
|
||||
)
|
||||
in
|
||||
rec_update start_addr
|
||||
) sections
|
||||
) info
|
||||
|
||||
let train d =
|
||||
let bins =
|
||||
List.map (Filename.concat d) (Array.to_list (Sys.readdir d))
|
||||
in
|
||||
let info = List.rev_map (fun bin ->
|
||||
Printf.printf "%s\n%!" bin;
|
||||
let exec_container = Dism.get_container bin in
|
||||
match exec_container with
|
||||
| None -> failwith "Binary Load Error"
|
||||
| Some container ->
|
||||
(* let fs = Mock.fs bin *)
|
||||
let fs = Mock.fs_dwarf bin
|
||||
and codes =
|
||||
(* TODO: currently they are segments, not sections
|
||||
* What we really want is sections *)
|
||||
let sections = Exec_container.Reader.get_sections container in
|
||||
List.fold_left (
|
||||
fun res {Exec_container.start_addr=start_addr;
|
||||
Exec_container.end_addr=end_addr;
|
||||
Exec_container.permissions=permissions} ->
|
||||
if List.mem Exec_container.X permissions then
|
||||
(start_addr, end_addr) :: res
|
||||
else
|
||||
res
|
||||
) [] sections
|
||||
in
|
||||
(*
|
||||
List.iter (fun (st, en) ->
|
||||
Printf.printf "%s %s\n%!" (Bitvector.to_hex st) (Bitvector.to_hex en))
|
||||
codes;
|
||||
List.iter (fun addr ->
|
||||
Printf.printf "%s\n%!" (Bitvector.to_hex addr)
|
||||
) fs; *)
|
||||
fs, codes, container
|
||||
) bins in
|
||||
let sigs = build_sigs info in
|
||||
update_sigs sigs info;
|
||||
sigs
|
||||
|
||||
let output sigs file =
|
||||
let oc = open_out_bin file in
|
||||
Sigs.iter (fun k (p, n) ->
|
||||
Printf.fprintf oc "%s->%d,%d\n" k p n
|
||||
) sigs;
|
||||
close_out oc
|
||||
|
||||
let () =
|
||||
let d, out = parse_command in
|
||||
let sigs = train d in
|
||||
output sigs out
|
|
@ -0,0 +1,113 @@
|
|||
module type M = sig
|
||||
type 'a t
|
||||
type key
|
||||
val create : int -> 'a t
|
||||
val iter : (key -> 'a -> unit) -> 'a t -> unit
|
||||
val add : 'a t -> key -> 'a -> unit
|
||||
val replace : 'a t -> key -> 'a -> unit
|
||||
val find : 'a t -> key -> 'a
|
||||
val format : key -> string
|
||||
end
|
||||
|
||||
module type TRIE = sig
|
||||
type 'a t
|
||||
type key
|
||||
val init : 'a -> 'a t
|
||||
val add : 'a t -> key -> 'a -> unit
|
||||
val find : 'a t -> key -> 'a
|
||||
val output : 'a t -> string -> ('a -> string) -> unit
|
||||
end
|
||||
|
||||
|
||||
module Make (M : M) : (TRIE with type key = M.key list) = struct
|
||||
type key = M.key list
|
||||
type 'a t = Node of 'a * 'a t M.t
|
||||
|
||||
let init v = Node (v, M.create 10)
|
||||
|
||||
(* add : 'a t -> key -> 'a -> unit *)
|
||||
let rec add trie k v = match k with
|
||||
| [] -> ()
|
||||
| hd :: [] -> (
|
||||
match trie with
|
||||
| Node (_, m) ->
|
||||
try
|
||||
match (M.find m hd) with
|
||||
| Node (_, sub) -> M.replace m hd (Node (v, sub))
|
||||
with Not_found ->
|
||||
(* If this is a new node, add to its father node's map *)
|
||||
let subtrie_init = init v in
|
||||
M.add m hd subtrie_init
|
||||
)
|
||||
| hd :: tl ->
|
||||
match trie with
|
||||
| Node (_, m) ->
|
||||
let subtrie =
|
||||
try
|
||||
M.find m hd
|
||||
with Not_found -> (
|
||||
(* If this is a new node, add to its father node's map *)
|
||||
let subtrie_init = init v in
|
||||
M.add m hd subtrie_init;
|
||||
subtrie_init
|
||||
)
|
||||
in
|
||||
add subtrie tl v
|
||||
|
||||
(* find : 'a t -> key -> 'a -> 'a *)
|
||||
(* find : return the longest match *)
|
||||
let find trie k =
|
||||
let rec rec_find trie k t_v = match k with
|
||||
| [] -> t_v
|
||||
| hd :: tl ->
|
||||
match trie with
|
||||
| Node (_, m) ->
|
||||
try
|
||||
(* let subtrie = M.find m hd in
|
||||
match subtrie with *)
|
||||
match M.find m hd with
|
||||
| Node (v, _) as subtrie ->
|
||||
rec_find subtrie tl v
|
||||
(* Not_found means reach the longest match, so return t_v *)
|
||||
with Not_found ->
|
||||
t_v
|
||||
in
|
||||
let root_v = match trie with
|
||||
| Node (v, _) -> v
|
||||
in
|
||||
rec_find trie k root_v
|
||||
|
||||
(* output : 'a t -> string -> ('a -> string) -> unit *)
|
||||
let output trie file format_v =
|
||||
let oc = open_out file in
|
||||
let rec rec_output prefix = function
|
||||
| Node (v, m) ->
|
||||
Printf.fprintf oc "%s->%s\n" (String.concat ";" (List.rev prefix)) (format_v v);
|
||||
M.iter (fun k v ->
|
||||
rec_output (M.format k :: prefix) v
|
||||
) m
|
||||
in
|
||||
rec_output [] trie;
|
||||
close_out oc
|
||||
end
|
||||
|
||||
(*
|
||||
module type DISM = sig
|
||||
type t
|
||||
val equal : t -> t -> bool
|
||||
val hash : t -> int
|
||||
end
|
||||
|
||||
module Dism : DISM = struct
|
||||
type t = string
|
||||
let equal i j = i = j
|
||||
let hash = Hashtbl.hash
|
||||
end
|
||||
|
||||
module M = struct
|
||||
module D = Hashtbl.Make(Dism)
|
||||
include D
|
||||
let format a = a
|
||||
end
|
||||
|
||||
module DismTrie = Make(M) *)
|
|
@ -0,0 +1,21 @@
|
|||
module type M = sig
|
||||
type 'a t
|
||||
type key
|
||||
val create : int -> 'a t
|
||||
val iter : (key -> 'a -> unit) -> 'a t -> unit
|
||||
val add : 'a t -> key -> 'a -> unit
|
||||
val replace : 'a t -> key -> 'a -> unit
|
||||
val find : 'a t -> key -> 'a
|
||||
val format : key -> string
|
||||
end
|
||||
|
||||
module type TRIE = sig
|
||||
type 'a t
|
||||
type key
|
||||
val init : 'a -> 'a t
|
||||
val add : 'a t -> key -> 'a -> unit
|
||||
val find : 'a t -> key -> 'a
|
||||
val output : 'a t -> string -> ('a -> string) -> unit
|
||||
end
|
||||
|
||||
module Make (M : M) : (TRIE with type key = M.key list)
|
Loading…
Reference in New Issue