From ce10d5cc44c5f028f406edb3899628239452ae25 Mon Sep 17 00:00:00 2001 From: Andreas Lindner Date: Thu, 29 Aug 2024 15:33:10 +0200 Subject: [PATCH] Add foundation for faster z3 model importing --- src/shared/examples/test-z3_wrapper.sml | 23 +++++++ src/shared/smt/bir_smtLib.sml | 19 +++++- src/shared/smt/bir_smtlibLib.sml | 15 +++- src/shared/smt/holba_z3Lib.sml | 91 ++++++++++++++++++++++++- src/shared/smt/z3_wrapper.py | 79 +++++++++++++++++---- 5 files changed, 210 insertions(+), 17 deletions(-) diff --git a/src/shared/examples/test-z3_wrapper.sml b/src/shared/examples/test-z3_wrapper.sml index 1141e1616..f1997a499 100644 --- a/src/shared/examples/test-z3_wrapper.sml +++ b/src/shared/examples/test-z3_wrapper.sml @@ -93,6 +93,29 @@ in end; end; +(* +val use_holsmt = true; +val name = "simple addition"; +val query = + `` + ((x:word32) + y = 10w) + ``; +val name = "simple contradiction"; +val query = + `` + ((x:word32) + y = 10w) /\ + ((x:word32) + y = 11w) + ``; + +val use_holsmt = false; +open bslSyntax; +val name = "simple addition bir"; +val query = beq (bplus (bden (bvarimm32 "x"), bden (bvarimm32 "y")), bconstii 32 10); + +val name = "simple contradiction bir"; +val query = band (beq (bplus (bden (bvarimm32 "x"), bden (bvarimm32 "y")), bconstii 32 10), + beq (bplus (bden (bvarimm32 "x"), bden (bvarimm32 "y")), bconstii 32 11)); +*) val _ = List.map (fn (name, query) => let val _ = print ("\n\n=============== >>> RUNNING TEST CASE '" ^ name ^ "'\n"); diff --git a/src/shared/smt/bir_smtLib.sml b/src/shared/smt/bir_smtLib.sml index 2443c3df0..f8fd018fb 100644 --- a/src/shared/smt/bir_smtLib.sml +++ b/src/shared/smt/bir_smtLib.sml @@ -182,11 +182,28 @@ fun bir_smt_set_trace use_holsmt = else (fn _ => ()); +(* TODO: should not be operating on word expressions in this library, just bir expressions *) fun bir_smt_get_model use_holsmt = if use_holsmt then Z3_SAT_modelLib.Z3_GET_SAT_MODEL else - raise ERR "bir_smt_get_model" "not implemented"; + let + open holba_z3Lib; + open bir_smtlibLib; + in + (fn bexp => + let + val _ = if type_of bexp = bir_expSyntax.bir_exp_t_ty then () else + raise ERR "bir_smt_get_model" "need a bir expression"; + val exst = export_bexp bexp exst_empty; + val q = querysmt_mk_q (exst_to_querysmt exst); + val (res, model) = querysmt_getmodel q; + val _ = if res = BirSmtSat then () else + raise ERR "bir_smt_get_model" "unsatisfiable"; + in + smtmodel_to_wordfmap model + end) + end; (* ======================================= *) diff --git a/src/shared/smt/bir_smtlibLib.sml b/src/shared/smt/bir_smtlibLib.sml index 5e84405b8..69921a0dc 100644 --- a/src/shared/smt/bir_smtlibLib.sml +++ b/src/shared/smt/bir_smtlibLib.sml @@ -657,7 +657,20 @@ BExp_Store (BExp_Den (BVar "fr_269_MEM" (BType_Mem Bit32 Bit8))) exst end); -(* TODO: add a model importer *) +local + (* TODO: need to add conversion from word to bir: values are constant bitvector/imm or constant array/memory *) + (* TODO: also need variable name conversion, holv_ are going to be words, birv_ would be bir constant expressions *) + fun modellines_to_pairs [] acc = acc + | modellines_to_pairs [_] _ = raise ERR "modellines_to_pairs" "the returned model does not have an even number of lines" + | modellines_to_pairs (vname::holterm::lines) acc = + modellines_to_pairs lines ((vname, Parse.Term [QUOTE holterm])::acc); + open wordsSyntax; + open finite_mapSyntax; +in + fun smtmodel_to_wordfmap model = + rev (modellines_to_pairs model []); + (*fun smtmodel_to_bexp model = ;*) +end end (* local *) diff --git a/src/shared/smt/holba_z3Lib.sml b/src/shared/smt/holba_z3Lib.sml index ffce4152f..f63fbe14a 100644 --- a/src/shared/smt/holba_z3Lib.sml +++ b/src/shared/smt/holba_z3Lib.sml @@ -18,13 +18,21 @@ val z3bin = "/home/andreas/data/hol/HolBA_opt/z3-4.8.4/bin/z3"; fun openz3 z3bin = (Unix.execute (z3bin, ["-in"])) : (TextIO.instream, TextIO.outstream) Unix.proc; +(* +val z3wrap = "/home/andreas/data/hol/HolBA_symbexec/src/shared/smt/z3_wrapper.py"; +val prelude_path = "/home/andreas/data/hol/HolBA_symbexec/src/shared/smt/holba_z3Lib_prelude.z3"; +*) +fun openz3wrap z3wrap prelude_path = + (Unix.execute (z3wrap, [prelude_path, "loop"])) : (TextIO.instream, TextIO.outstream) Unix.proc; + fun endmeexit p = Unix.fromStatus (Unix.reap p); fun get_streams p = Unix.streamsOf p; val z3proc_bin_o = ref (NONE : string option); val z3proc_o = ref (NONE : ((TextIO.instream, TextIO.outstream) Unix.proc) option); -val prelude_z3 = read_from_file (holpathdb.subst_pathvars "$(HOLBADIR)/src/shared/smt/holba_z3Lib_prelude.z3"); +val prelude_z3_path = holpathdb.subst_pathvars "$(HOLBADIR)/src/shared/smt/holba_z3Lib_prelude.z3"; +val prelude_z3 = read_from_file prelude_z3_path; val prelude_z3_n = prelude_z3 ^ "\n"; val use_stack = true; val debug_print = false; @@ -59,6 +67,25 @@ fun get_z3proc z3bin = p end; +val z3wrapproc_o = ref (NONE : ((TextIO.instream, TextIO.outstream) Unix.proc) option); +fun get_z3wrapproc () = + let + val z3wrapproc_ = !z3wrapproc_o; + val p = if isSome z3wrapproc_ then valOf z3wrapproc_ else + let + val z3wrap = case OS.Process.getEnv "HOL4_Z3_WRAPPED_EXECUTABLE" of + SOME x => x + | NONE => raise ERR "get_z3wrapproc" "variable HOL4_Z3_WRAPPED_EXECUTABLE not defined"; + val _ = if not debug_print then () else + print ("starting: " ^ z3wrap ^ "\n"); + val p = openz3wrap z3wrap prelude_z3_path; + in (z3wrapproc_o := SOME p; p) end; + in + p + end; + +(* =========================================================== *) + fun inputLines_until m ins acc = let val line_o = TextIO.inputLine ins; @@ -102,6 +129,34 @@ fun sendreceive_query z3bin q = in out_lines end; + +fun sendreceive_wrap_query q = + let + val p = get_z3wrapproc (); + val (s_in,s_out) = get_streams p; + + val q_fixed = String.concat (List.map (fn c => if c = #"\n" then "\\n" else str c) (String.explode q)); + val _ = if not debug_print then () else + (print "sending: "; print q_fixed; print "\n"); + + val timer = holba_miscLib.timer_start 0; + val z3wrap_done_marker = "z3_wrapper query done"; + val () = TextIO.output (s_out, q_fixed ^ "\n"); + val out_lines = inputLines_until (z3wrap_done_marker ^ "\n") s_in []; + val _ = if debug_print then holba_miscLib.timer_stop + (fn delta_s => print (" wrapped query took " ^ delta_s ^ "\n")) timer else (); + + val _ = if not debug_print then () else + (map print out_lines; print "\n\n"); + in + out_lines + end; +(* + val q = "(declare-const x (_ BitVec 8))\n(assert (= x #xFF))\n"; + val q = "(declare-const x (_ BitVec 8))\n(assert (= x #xAA))\n(assert (= x #xFF))\n"; + + sendreceive_wrap_query q; +*) (* =========================================================== *) datatype bir_smt_result = @@ -134,6 +189,9 @@ fun sendreceive_query z3bin q = out_lines end; + fun querysmt_prepare_getmodel z3bin_o = + querysmt_raw z3bin_o NONE "(set-option :model.compact false)\n"; + (* querysmt_raw NONE NONE "(simplify ((_ extract 3 2) #xFC))"; @@ -155,6 +213,24 @@ querysmt_raw NONE NONE "(display (_ bv20 16))" print "\n============================\n"; raise ERR "querysmt_parse_checksat" "unknown output from z3"); + fun querysmt_parse_getmodel out_lines = + if hd out_lines = "sat\n" then + let + val model_lines = tl out_lines; + val model_lines_fix = map (fn line => if (hd o rev o explode) line = #"\n" then (implode o rev o tl o rev o explode) line else line) model_lines; + in + (BirSmtSat, model_lines_fix) + end + else if hd out_lines = "unsat\n" then + (BirSmtUnsat, []) + else if hd out_lines = "unknown\n" then + (BirSmtUnknown, []) + else + (print "\n============================\n"; + map print out_lines; + print "\n============================\n"; + raise ERR "querysmt_parse_getmodel" "unknown output from z3"); + (* https://rise4fun.com/z3/tutorial *) (* val q = "(declare-const a Int)\n" ^ @@ -182,13 +258,23 @@ querysmt_raw NONE NONE "(display (_ bv20 16))" val q = "(check-sat)\n"; val result = querysmt_parse_checksat (querysmt_raw NONE NONE q); + val result = (querysmt_raw NONE NONE (q^"(get-model)\n")); *) fun querysmt_checksat_gen z3bin_o timeout_o q = querysmt_parse_checksat (querysmt_raw z3bin_o timeout_o (q ^ "(check-sat)\n")); val querysmt_checksat = querysmt_checksat_gen NONE; - (* TODO: add querysmt_getmodel *) + fun querysmt_getmodel q = + querysmt_parse_getmodel (sendreceive_wrap_query q); + + (* + val q = "(declare-const x (_ BitVec 8))\n(assert (= x #xFF))\n"; + val q = "(declare-const x (_ BitVec 8))\n(assert (= x #xAA))\n(assert (= x #xFF))\n"; + + querysmt_checksat NONE q + querysmt_getmodel q + *) (* ------------------------------------------------------------------------ *) @@ -387,6 +473,7 @@ fun gen_smt_store_as_funcall valm valad valv opparam = [("(= x #xFF)", SMTTY_Bool), ("(= x #xAA)", SMTTY_Bool)]); querysmt_checksat NONE q + querysmt_getmodel q *) end (* local *) diff --git a/src/shared/smt/z3_wrapper.py b/src/shared/smt/z3_wrapper.py index 783e3c266..94750d8a1 100755 --- a/src/shared/smt/z3_wrapper.py +++ b/src/shared/smt/z3_wrapper.py @@ -210,9 +210,10 @@ def strip_z3_name(x): return len(x.split('_', maxsplit=1)) > 1 and x.split('_', maxsplit=1)[1] or x.split('_', maxsplit=1)[0] # create list of string pairs from model: (varname, holterm) -def model_to_list(model): +def model_to_list(model, strip_names): # map to pair (model variables (stripped hol name), variables value) and filter auxiliary assignments - assigns_pre = filter (lambda x: not ("!" in x[0]), map(lambda x: (strip_z3_name(str(x.name())), model[x]), model)) + stripfun = strip_z3_name if strip_names else (lambda x: x) + assigns_pre = filter (lambda x: not ("!" in x[0]), map(lambda x: (stripfun(str(x.name())), model[x]), model)) # partition, sort individually, put together again assign_ast = [] @@ -234,21 +235,79 @@ def model_to_list(model): # return the collected hol assignments return sml_list +def print_model_for_holba(model, strip_names = True): + hol_list = model_to_list(model, strip_names) + + for (varname, term) in hol_list: + print(varname) + print(term) + #print("on stdout: {}".format(line), file=sys.stderr) + +def send_query(s): + r = s.check() + model = [] + if r == sat: + model = s.model() + return (r, model) + +s = Solver() + +# from z3_wrapper import * +# load_prelude("holba_z3Lib_prelude.z3") +# q = "(declare-const x (_ BitVec 8))\n(assert (= x #xFF))\n" +# preluded_query(q) +def load_prelude(filename): + with open(filename, "r") as f: + pre = f.read() + s.from_string(pre) + s.push() + +def preluded_query(q): + s.from_string(q) + (r, model) = send_query(s) + s.pop() + s.push() + if r == unsat: + print("unsat") + elif r == unknown: + print("unknown") + else: + print("sat") + print_model_for_holba(model, strip_names = False) + +# python3 z3_wrapper.py holba_z3Lib_prelude.z3 loop # script entry point def main(): - use_files = len(sys.argv) > 1 - s = Solver() + use_files = False + preluded_loop = False + if len(sys.argv) > 1: + filename = sys.argv[1] + if len(sys.argv) > 2: + preluded_loop = True + else: + use_files = True + + if preluded_loop: + load_prelude(filename) + while True: + #print("waiting for input", file=sys.stderr) + q = sys.stdin.readline().replace("\\n", "\n") + #print("sending input to query", file=sys.stderr) + preluded_query(q) + print("z3_wrapper query done", flush=True) + + exit(-1) do_debug = False if do_debug: debug_input(s) elif use_files: - s.from_file(sys.argv[1]) + s.from_file(filename) else: stdin = "\n".join(sys.stdin.readlines()) s.from_string(stdin) - r = s.check() + (r, model) = send_query(s) if r == unsat: print("unsat") exit(0) @@ -261,13 +320,7 @@ def main(): print("sat") #print(s.model(), file=sys.stderr) - model = s.model() - hol_list = model_to_list(model) - - for (varname, term) in hol_list: - print(varname) - print(term) - #print("on stdout: {}".format(line), file=sys.stderr) + print_model_for_holba(model) if __name__ == '__main__':