From 92a02c34628343153b33602eae00cef46e28d191 Mon Sep 17 00:00:00 2001 From: Minijackson Date: Thu, 22 Dec 2022 12:19:59 +0100 Subject: WIP --- z3-solver/src/lib.rs | 691 ++++++++++++++++++++++++++++++++++++++++ z3-solver/src/solving/mod.rs | 9 + z3-solver/src/solving/z3/mod.rs | 44 +++ 3 files changed, 744 insertions(+) create mode 100644 z3-solver/src/lib.rs create mode 100644 z3-solver/src/solving/mod.rs create mode 100644 z3-solver/src/solving/z3/mod.rs (limited to 'z3-solver/src') diff --git a/z3-solver/src/lib.rs b/z3-solver/src/lib.rs new file mode 100644 index 0000000..04d77eb --- /dev/null +++ b/z3-solver/src/lib.rs @@ -0,0 +1,691 @@ +pub use z3; + +use diaphragm_core::{ + solving::VariableHandle, + types::{Bool, Float}, + SolverContext, SolverModel, +}; + +use z3::ast::Ast; + +use std::collections::HashMap; + +#[derive(Debug)] +pub struct Z3Context<'z3> { + ctx: &'z3 z3::Context, + solver: z3::Solver<'z3>, + + floats: HashMap>, + max_float_id: u32, + + bools: HashMap>, + pub max_bool_id: u32, +} + +impl Drop for Z3Context<'_> { + fn drop(&mut self) { + eprintln!("bool: {}", self.max_bool_id); + eprintln!("float: {}", self.max_float_id); + } +} + +fn value_to_num_den(value: f64) -> (i32, i32) { + let fract = value.fract(); + let number_of_fract_digits = -fract.log10().floor(); + + if number_of_fract_digits >= 1. && !number_of_fract_digits.is_infinite() { + let den = 10f64.powf(number_of_fract_digits); + ((value * den) as i32, den as i32) + } else { + (value as i32, 1) + } +} + +impl<'z3> Z3Context<'z3> { + pub fn new(ctx: &'z3 z3::Context) -> Self { + Self { + ctx, + solver: z3::Solver::new(&ctx), + floats: HashMap::new(), + max_float_id: 0, + bools: HashMap::new(), + max_bool_id: 0, + } + } + + fn anon_float(&mut self, f: z3::ast::Real<'z3>) -> Float { + self.max_float_id += 1; + let id = self.max_float_id; + let handle = VariableHandle::new(id as usize); + self.floats.insert(handle, f); + Float::from_handle(handle) + } + + fn anon_bool(&mut self, f: z3::ast::Bool<'z3>) -> Bool { + self.max_bool_id += 1; + let id = self.max_bool_id; + let handle = VariableHandle::new(id as usize); + self.bools.insert(handle, dbg!(f)); + Bool::new(handle) + } + + /* + fn float_handle(&mut self, f: Float) -> VariableHandle { + match f { + Float::Fixed(value) => { + let new_float = self.new_fixed_float(value); + self.float_handle(new_float) + } + Float::Variable(handle) => handle, + } + } + */ + + fn float(&self, f: Float) -> z3::ast::Real<'z3> { + let handle = match f { + Float::Fixed(value) => { + let (num, den) = value_to_num_den(value); + return z3::ast::Real::from_real(&self.ctx, num, den); + } + Float::Variable(handle) => handle, + }; + + self.floats + .get(&handle) + .expect("Couldn't get float") + .clone() + } + + fn bool(&self, f: Bool) -> &z3::ast::Bool<'z3> { + self.bools.get(&f.handle()).expect("Couldn't get float") + } +} + +impl<'z3> SolverContext for Z3Context<'z3> { + fn solve<'a>(&'a self) -> Box { + match self.solver.check() { + z3::SatResult::Unsat | z3::SatResult::Unknown => panic!("Failed solving"), + z3::SatResult::Sat => {} + } + Box::new(Z3Model { + ctx: self, + model: self.solver.get_model().unwrap(), + }) + } + + fn constrain(&mut self, assertion: Bool) { + self.solver.assert( + self.bools + .get(&assertion.handle()) + .expect("Couldn't get bool"), + ); + } + + // Floats + + fn new_free_float(&mut self) -> Float { + self.max_float_id += 1; + let id = self.max_float_id; + let handle = VariableHandle::new(id as usize); + self.floats + .insert(handle, z3::ast::Real::new_const(&self.ctx, id)); + Float::from_handle(handle) + } + + fn new_fixed_float(&mut self, value: f64) -> Float { + self.max_float_id += 1; + let id = self.max_float_id; + let handle = VariableHandle::new(id as usize); + + let (num, den) = value_to_num_den(value); + + self.floats + .insert(handle, z3::ast::Real::from_real(&self.ctx, num, den)); + Float::from_handle(handle) + } + + // TODO: that's a lot of copying things + + fn float_add(&mut self, values: &[Float]) -> Float { + let values = values.iter().map(|f| self.float(*f)).collect::>(); + let result = z3::ast::Real::add( + self.ctx, + &values.iter().collect::>() + ); + + self.anon_float(result) + } + + fn float_sub(&mut self, values: &[Float]) -> Float { + let values = values.iter().map(|f| self.float(*f)).collect::>(); + let result = z3::ast::Real::sub( + self.ctx, + &values.iter().collect::>() + ); + + self.anon_float(result) + } + + fn float_mul(&mut self, values: &[Float]) -> Float { + let values = values.iter().map(|f| self.float(*f)).collect::>(); + let result = z3::ast::Real::mul( + self.ctx, + &values.iter().collect::>() + ); + + self.anon_float(result) + } + + fn float_div(&mut self, lhs: Float, rhs: Float) -> Float { + let lhs = self.float(lhs); + let rhs = self.float(rhs); + let result = lhs.div(&rhs); + self.anon_float(result) + } + + fn float_neg(&mut self, value: Float) -> Float { + self.anon_float(self.float(value).unary_minus()) + } + + fn float_eq(&mut self, lhs: Float, rhs: Float) -> Bool { + let lhs = self.float(lhs); + let rhs = self.float(rhs); + let result = lhs._eq(&rhs); + self.anon_bool(result) + } + + fn float_ne(&mut self, lhs: Float, rhs: Float) -> Bool { + let lhs = self.float(lhs); + let rhs = self.float(rhs); + let result = lhs._eq(&rhs).not(); + self.anon_bool(result) + } + + fn float_gt(&mut self, lhs: Float, rhs: Float) -> Bool { + let lhs = self.float(lhs); + let rhs = self.float(rhs); + let result = lhs.gt(&rhs); + self.anon_bool(result) + } + + fn float_ge(&mut self, lhs: Float, rhs: Float) -> Bool { + let lhs = self.float(lhs); + let rhs = self.float(rhs); + let result = lhs.ge(&rhs); + self.anon_bool(result) + } + + fn float_lt(&mut self, lhs: Float, rhs: Float) -> Bool { + let lhs = self.float(lhs); + let rhs = self.float(rhs); + let result = lhs.lt(&rhs); + self.anon_bool(result) + } + + fn float_le(&mut self, lhs: Float, rhs: Float) -> Bool { + let lhs = self.float(lhs); + let rhs = self.float(rhs); + let result = lhs.le(&rhs); + self.anon_bool(result) + } + + // Bools + + fn new_free_bool(&mut self) -> Bool { + self.max_bool_id += 1; + let id = self.max_bool_id; + let handle = VariableHandle::new(id as usize); + self.bools + .insert(handle, z3::ast::Bool::new_const(&self.ctx, id)); + Bool::new(handle) + } + + fn new_fixed_bool(&mut self, value: bool) -> Bool { + self.max_bool_id += 1; + let id = self.max_bool_id; + let handle = VariableHandle::new(id as usize); + + self.bools + .insert(handle, z3::ast::Bool::from_bool(&self.ctx, value)); + Bool::new(handle) + } + + fn bool_eq(&mut self, lhs: Bool, rhs: Bool) -> Bool { + let lhs = self.bool(lhs); + let rhs = self.bool(rhs); + let result = lhs._eq(&rhs); + self.anon_bool(result) + } + + fn bool_ne(&mut self, lhs: Bool, rhs: Bool) -> Bool { + let lhs = self.bool(lhs); + let rhs = self.bool(rhs); + let result = lhs._eq(&rhs).not(); + self.anon_bool(result) + } + + fn bool_and(&mut self, values: &[Bool]) -> Bool { + let result = z3::ast::Bool::and( + self.ctx, + &values.iter().map(|b| self.bool(*b)).collect::>(), + ); + + self.anon_bool(result) + } + + fn bool_or(&mut self, values: &[Bool]) -> Bool { + let result = z3::ast::Bool::or( + self.ctx, + &values.iter().map(|b| self.bool(*b)).collect::>(), + ); + + self.anon_bool(result) + } + + fn bool_not(&mut self, value: Bool) -> Bool { + self.anon_bool(self.bool(value).not()) + } + + fn bool_implies(&mut self, lhs: Bool, rhs: Bool) -> Bool { + let lhs = self.bool(lhs); + let rhs = self.bool(rhs); + let result = lhs.implies(rhs); + self.anon_bool(result) + } +} + +pub struct Z3Model<'z3> { + ctx: &'z3 Z3Context<'z3>, + model: z3::Model<'z3>, +} + +impl SolverModel for Z3Model<'_> { + fn eval_float(&self, f: Float) -> Option { + let handle = match f { + Float::Fixed(value) => return Some(value), + Float::Variable(handle) => handle, + }; + + let (num, den) = self + .model + .eval::(self.ctx.floats.get(&handle).expect("Couldn't find float")) + .unwrap() + .as_real() + .unwrap(); + // TODO: handle errors + Some(num as f64 / den as f64) + } + + fn eval_bool(&self, f: Bool) -> Option { + Some( + self.model + .eval::(&self.ctx.bool(f)) + .unwrap() + .as_bool() + .unwrap(), + ) + // TODO: handle errors + } +} + +/* +use std::cell::RefCell; +use std::collections::HashMap; +use std::rc::Rc; + +#[derive(Clone, Eq, PartialEq, Debug)] +pub struct Z3Float<'z3> { + id: u32, + real: z3::ast::Real<'z3>, + ctx: Z3Context<'z3>, +} + +// TODO: try to remove this clone? +fn get_real<'z3>(f: &dyn VariableFloat, ctx: &Z3Context<'z3>) -> z3::ast::Real<'z3> { + let id = f.id() as u32; + ctx.0 + .borrow() + .floats + .get(&id) + .expect("Couldn't find float") + .clone() +} + +impl<'z3> VariableFloat<'z3> for Z3Float<'z3> { + fn id(&self) -> usize { + self.id as usize + } + + fn dyn_clone(&self) -> Box + 'z3> { + Box::new(Z3Float { + id: self.id, + real: self.real.clone(), + ctx: self.ctx.clone(), + }) + } + + fn eq(&self, other: &dyn VariableFloat) -> Bool<'z3> { + let other = get_real(other, &self.ctx); + let result = self.real._eq(&other); + let id = self.ctx.anon_bool(&result); + + Bool::new(Box::new(Z3Bool { + id, + real: result, + ctx: self.ctx.clone(), + })) + } + + fn neq(&self, other: &dyn VariableFloat) -> Bool<'z3> { + let other = get_real(other, &self.ctx); + let result = self.real._eq(&other).not(); + let id = self.ctx.anon_bool(&result); + + Bool::new(Box::new(Z3Bool { + id, + real: result, + ctx: self.ctx.clone(), + })) + } + + fn gt(&self, other: &dyn VariableFloat) -> Bool<'z3> { + let other = get_real(other, &self.ctx); + let result = self.real.gt(&other); + let id = self.ctx.anon_bool(&result); + + Bool::new(Box::new(Z3Bool { + id, + real: result, + ctx: self.ctx.clone(), + })) + } + + fn ge(&self, other: &dyn VariableFloat) -> Bool<'z3> { + let other = get_real(other, &self.ctx); + let result = self.real.ge(&other); + let id = self.ctx.anon_bool(&result); + + Bool::new(Box::new(Z3Bool { + id, + real: result, + ctx: self.ctx.clone(), + })) + } + + fn lt(&self, other: &dyn VariableFloat) -> Bool<'z3> { + let other = get_real(other, &self.ctx); + let result = self.real.lt(&other); + let id = self.ctx.anon_bool(&result); + + Bool::new(Box::new(Z3Bool { + id, + real: result, + ctx: self.ctx.clone(), + })) + } + + fn le(&self, other: &dyn VariableFloat) -> Bool<'z3> { + let other = get_real(other, &self.ctx); + let result = self.real.le(&other); + let id = self.ctx.anon_bool(&result); + + Bool::new(Box::new(Z3Bool { + id, + real: result, + ctx: self.ctx.clone(), + })) + } + + fn add(&self, other: &dyn VariableFloat) -> Float<'z3> { + let other = get_real(other, &self.ctx); + let result = &self.real + &other; + let id = self.ctx.anon_float(&result); + + Float::new(Box::new(Z3Float { + id, + real: result, + ctx: self.ctx.clone(), + })) + } + + fn sub(&self, other: &dyn VariableFloat) -> Float<'z3> { + let other = get_real(other, &self.ctx); + let result = &self.real - &other; + let id = self.ctx.anon_float(&result); + + Float::new(Box::new(Z3Float { + id, + real: result, + ctx: self.ctx.clone(), + })) + } + + fn mul(&self, other: &dyn VariableFloat) -> Float<'z3> { + let other = get_real(other, &self.ctx); + let result = &self.real * &other; + let id = self.ctx.anon_float(&result); + + Float::new(Box::new(Z3Float { + id, + real: result, + ctx: self.ctx.clone(), + })) + } + + fn div(&self, other: &dyn VariableFloat) -> Float<'z3> { + let other = get_real(other, &self.ctx); + let result = &self.real / &other; + let id = self.ctx.anon_float(&result); + + Float::new(Box::new(Z3Float { + id, + real: result, + ctx: self.ctx.clone(), + })) + } + + fn neg(&self) -> Float<'z3> { + let result = self.real.unary_minus(); + let id = self.ctx.anon_float(&result); + + Float::new(Box::new(Z3Float { + id, + real: result, + ctx: self.ctx.clone(), + })) + } +} + +#[derive(Clone, Eq, PartialEq, Debug)] +pub struct Z3Bool<'z3> { + id: u32, + real: z3::ast::Bool<'z3>, + ctx: Z3Context<'z3>, +} + +// TODO: try to remove this clone? +fn get_bool<'z3>(f: &dyn VariableBool, ctx: &Z3Context<'z3>) -> z3::ast::Bool<'z3> { + let id = f.id() as u32; + ctx.0 + .borrow() + .bools + .get(&id) + .expect("Couldn't find bool") + .clone() +} + +impl<'z3> VariableBool<'z3> for Z3Bool<'z3> { + fn id(&self) -> usize { + self.id as usize + } + + fn dyn_clone(&self) -> Box + 'z3> { + Box::new(Z3Bool { + id: self.id, + real: self.real.clone(), + ctx: self.ctx.clone(), + }) + } + + fn eq(&self, other: &dyn VariableBool) -> Bool<'z3> { + let other = get_bool(other, &self.ctx); + let result = self.real._eq(&other); + let id = self.ctx.anon_bool(&result); + + Bool::new(Box::new(Z3Bool { + id, + real: result, + ctx: self.ctx.clone(), + })) + } + + fn neq(&self, other: &dyn VariableBool) -> Bool<'z3> { + let other = get_bool(other, &self.ctx); + let result = self.real._eq(&other).not(); + let id = self.ctx.anon_bool(&result); + + Bool::new(Box::new(Z3Bool { + id, + real: result, + ctx: self.ctx.clone(), + })) + } + + fn and(&self, other: &dyn VariableBool) -> Bool<'z3> { + let other = get_bool(other, &self.ctx); + let result = &self.real & &other; + let id = self.ctx.anon_bool(&result); + + Bool::new(Box::new(Z3Bool { + id, + real: result, + ctx: self.ctx.clone(), + })) + } + + fn or(&self, other: &dyn VariableBool) -> Bool<'z3> { + let other = get_bool(other, &self.ctx); + let result = &self.real | &other; + let id = self.ctx.anon_bool(&result); + + Bool::new(Box::new(Z3Bool { + id, + real: result, + ctx: self.ctx.clone(), + })) + } + + fn not(&self) -> Bool<'z3> { + let result = self.real.not(); + let id = self.ctx.anon_bool(&result); + + Bool::new(Box::new(Z3Bool { + id, + real: result, + ctx: self.ctx.clone(), + })) + } +} + +#[derive(Clone, Eq, PartialEq, Debug)] +pub struct Z3Context<'z3>(Rc>>); + +#[derive(Eq, PartialEq, Debug)] +pub struct Z3ContextImpl<'z3> { + ctx: z3::Context, + floats: HashMap>, + max_float_id: u32, + bools: HashMap>, + max_bool_id: u32, +} + +impl<'z3> Z3Context<'z3> { + pub fn new() -> Self { + let conf = z3::Config::new(); + let ctx = z3::Context::new(&conf); + Z3Context(Rc::new(RefCell::new(Z3ContextImpl { + ctx, + floats: HashMap::new(), + max_float_id: 0, + bools: HashMap::new(), + max_bool_id: 0, + }))) + } + + fn anon_float(&self, f: &z3::ast::Real<'z3>) -> u32 { + let mut ctx = self.0.borrow_mut(); + + ctx.max_float_id += 1; + let id = ctx.max_float_id; + ctx.floats.insert(id, f.clone()); + id + } + + fn anon_bool(&self, f: &z3::ast::Bool<'z3>) -> u32 { + let mut ctx = self.0.borrow_mut(); + + ctx.max_float_id += 1; + let id = ctx.max_float_id; + ctx.bools.insert(id, f.clone()); + id + } +} + +impl<'z3> SolverContext<'z3> for Z3Context<'z3> { + fn new_float(&'z3 mut self) -> Box + 'z3> { + let mut ctx: std::cell::RefMut<'z3, _> = self.0.borrow_mut(); + + ctx.max_float_id += 1; + let id = ctx.max_float_id; + // :( + // Should be safe since the z3 context inside the RefCell of self cannot be dropped while + // the Box is still alive + let f = z3::ast::Real::new_const(unsafe { std::mem::transmute(&ctx.ctx) }, id); + ctx.floats.insert(id, f.clone()); + Box::new(Z3Float { + id, + real: f, + ctx: self.clone(), + }) + } + + fn new_bool(&'z3 mut self) -> Box + 'z3> { + let mut ctx = self.0.borrow_mut(); + + ctx.max_bool_id += 1; + let id = ctx.max_float_id; + // :( + // Should be safe since the z3 context inside the RefCell of self cannot be dropped while + // the Box is still alive + let b = z3::ast::Bool::new_const(unsafe { std::mem::transmute(&ctx.ctx) }, id); + ctx.bools.insert(id, b.clone()); + Box::new(Z3Bool { + id: id, + real: b, + ctx: self.clone(), + }) + } +} + +pub struct Z3Solver<'z3>(z3::Solver<'z3>); + +impl<'z3> Z3Solver<'z3> { + pub fn new(ctx: &'z3 Z3Context) -> Self { + Self(z3::Solver::new(unsafe { + // :( + // again + std::mem::transmute(&ctx.0.borrow().ctx) + })) + } +} + +impl<'z3> Solver for Z3Solver<'z3> { + fn constrain(&mut self, assertion: &Bool) { + todo!() + } + + fn solve(&self) -> Box { + todo!() + } +} +*/ diff --git a/z3-solver/src/solving/mod.rs b/z3-solver/src/solving/mod.rs new file mode 100644 index 0000000..f3f6673 --- /dev/null +++ b/z3-solver/src/solving/mod.rs @@ -0,0 +1,9 @@ +pub mod z3; + +#[derive(Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Clone, Copy)] +pub struct FloatHandle(u32); + +pub trait Solver { + // TODO: make handles generic? + fn new_float<'a>(&'a mut self, handles: &mut z3::Handles<'a>) -> FloatHandle; +} diff --git a/z3-solver/src/solving/z3/mod.rs b/z3-solver/src/solving/z3/mod.rs new file mode 100644 index 0000000..2f29073 --- /dev/null +++ b/z3-solver/src/solving/z3/mod.rs @@ -0,0 +1,44 @@ +use super::{Solver, FloatHandle}; + +use std::collections::HashMap; + +pub struct Z3Solver { + ctx: z3::Context, +} + +pub struct Handles<'a> { + float: HashMap>, + float_max_id: u32, +} + +impl<'a> Handles<'a> { + pub fn new() -> Self { + Self { + float: HashMap::new(), + float_max_id: 0, + } + } +} + +impl Z3Solver { + pub fn new<'a>() -> (Self, Handles<'a>) { + let config = z3::Config::new(); + + ( + Self { + ctx: z3::Context::new(&config), + }, + Handles::new(), + ) + } +} + +impl Solver for Z3Solver { + fn new_float<'a>(&'a mut self, handles: &mut Handles<'a>) -> FloatHandle { + let id = handles.float_max_id; + let float = z3::ast::Real::new_const(&self.ctx, id); + handles.float_max_id += 1; + handles.float.insert(id, float); + FloatHandle(id) + } +} -- cgit v1.2.3