pub use z3; use diaphragm_core::{ solving::VariableHandle, types::{Bool, Float}, SolverContext, SolverModel, }; use z3::ast::Ast; use std::collections::HashMap; #[derive(Debug)] pub struct Z3Context<'z3> { ctx: &'z3 z3::Context, solver: z3::Optimize<'z3>, floats: HashMap>, max_float_id: u32, bools: HashMap>, pub max_bool_id: u32, } impl Drop for Z3Context<'_> { fn drop(&mut self) { eprintln!("bool: {}", self.max_bool_id); eprintln!("float: {}", self.max_float_id); } } fn value_to_num_den(value: f64) -> (i32, i32) { // TODO: FIXME: so hacky, because I'm so lazy... const FACTOR: f64 = 524_288.; const FACTOR_I32: i32 = FACTOR as _; const LIMIT: f64 = i32::MAX as f64 / FACTOR; if value < LIMIT { ((value * FACTOR) as _, FACTOR_I32) } else { (value as _, 1) } // let fract = value.fract(); // let number_of_fract_digits = -fract.log10().floor(); // // if number_of_fract_digits >= 1. && !number_of_fract_digits.is_infinite() { // let den = 10f64.powf(number_of_fract_digits); // ((value * den) as i32, den as i32) // } else { // (value as i32, 1) // } } impl<'z3> Z3Context<'z3> { pub fn new(ctx: &'z3 z3::Context) -> Self { Self { ctx, solver: z3::Optimize::new(ctx), floats: HashMap::new(), max_float_id: 0, bools: HashMap::new(), max_bool_id: 0, } } fn anon_float(&mut self, f: z3::ast::Real<'z3>) -> Float { self.max_float_id += 1; let id = self.max_float_id; let handle = VariableHandle::new(id as usize); self.floats.insert(handle, f); Float::from_handle(handle) } fn anon_bool(&mut self, f: z3::ast::Bool<'z3>) -> Bool { self.max_bool_id += 1; let id = self.max_bool_id; let handle = VariableHandle::new(id as usize); self.bools.insert(handle, f); Bool::new(handle) } /* fn float_handle(&mut self, f: Float) -> VariableHandle { match f { Float::Fixed(value) => { let new_float = self.new_fixed_float(value); self.float_handle(new_float) } Float::Variable(handle) => handle, } } */ fn float(&self, f: Float) -> z3::ast::Real<'z3> { let handle = match f { Float::Fixed(value) => { let (num, den) = value_to_num_den(value); return z3::ast::Real::from_real(self.ctx, num, den); } Float::Variable(handle) => handle, }; self.floats .get(&handle) .expect("Couldn't get float") .clone() } fn bool(&self, f: Bool) -> &z3::ast::Bool<'z3> { self.bools.get(&f.handle()).expect("Couldn't get float") } } impl<'z3> SolverContext for Z3Context<'z3> { fn solve<'a>(&'a self) -> Box { match self.solver.check(&[]) { z3::SatResult::Unsat | z3::SatResult::Unknown => panic!("Failed solving"), z3::SatResult::Sat => {} } Box::new(Z3Model { ctx: self, model: self.solver.get_model().unwrap(), }) } fn constrain(&mut self, assertion: Bool) { self.solver.assert( self.bools .get(&assertion.handle()) .expect("Couldn't get bool"), ); } // Floats fn new_free_float(&mut self) -> Float { self.max_float_id += 1; let id = self.max_float_id; let handle = VariableHandle::new(id as usize); self.floats .insert(handle, z3::ast::Real::new_const(self.ctx, id)); Float::from_handle(handle) } fn new_fixed_float(&mut self, value: f64) -> Float { self.max_float_id += 1; let id = self.max_float_id; let handle = VariableHandle::new(id as usize); let (num, den) = value_to_num_den(value); self.floats .insert(handle, z3::ast::Real::from_real(self.ctx, num, den)); Float::from_handle(handle) } // TODO: that's a lot of copying things fn float_add(&mut self, values: &[Float]) -> Float { let values = values.iter().map(|f| self.float(*f)).collect::>(); let result = z3::ast::Real::add( self.ctx, &values.iter().collect::>() ); self.anon_float(result) } fn float_sub(&mut self, values: &[Float]) -> Float { let values = values.iter().map(|f| self.float(*f)).collect::>(); let result = z3::ast::Real::sub( self.ctx, &values.iter().collect::>() ); self.anon_float(result) } fn float_mul(&mut self, values: &[Float]) -> Float { let values = values.iter().map(|f| self.float(*f)).collect::>(); let result = z3::ast::Real::mul( self.ctx, &values.iter().collect::>() ); self.anon_float(result) } fn float_div(&mut self, lhs: Float, rhs: Float) -> Float { let lhs = self.float(lhs); let rhs = self.float(rhs); let result = lhs.div(&rhs); self.anon_float(result) } fn float_neg(&mut self, value: Float) -> Float { self.anon_float(self.float(value).unary_minus()) } fn float_eq(&mut self, lhs: Float, rhs: Float) -> Bool { let lhs = self.float(lhs); let rhs = self.float(rhs); let result = lhs._eq(&rhs); self.anon_bool(result) } fn float_ne(&mut self, lhs: Float, rhs: Float) -> Bool { let lhs = self.float(lhs); let rhs = self.float(rhs); let result = lhs._eq(&rhs).not(); self.anon_bool(result) } fn float_gt(&mut self, lhs: Float, rhs: Float) -> Bool { let lhs = self.float(lhs); let rhs = self.float(rhs); let result = lhs.gt(&rhs); self.anon_bool(result) } fn float_ge(&mut self, lhs: Float, rhs: Float) -> Bool { let lhs = self.float(lhs); let rhs = self.float(rhs); let result = lhs.ge(&rhs); self.anon_bool(result) } fn float_lt(&mut self, lhs: Float, rhs: Float) -> Bool { let lhs = self.float(lhs); let rhs = self.float(rhs); let result = lhs.lt(&rhs); self.anon_bool(result) } fn float_le(&mut self, lhs: Float, rhs: Float) -> Bool { let lhs = self.float(lhs); let rhs = self.float(rhs); let result = lhs.le(&rhs); self.anon_bool(result) } fn float_maximize(&mut self, value: Float) { self.solver.maximize(&self.float(value)); } fn float_minimize(&mut self, value: Float) { self.solver.minimize(&self.float(value)); } // Bools fn new_free_bool(&mut self) -> Bool { self.max_bool_id += 1; let id = self.max_bool_id; let handle = VariableHandle::new(id as usize); self.bools .insert(handle, z3::ast::Bool::new_const(self.ctx, id)); Bool::new(handle) } fn new_fixed_bool(&mut self, value: bool) -> Bool { self.max_bool_id += 1; let id = self.max_bool_id; let handle = VariableHandle::new(id as usize); self.bools .insert(handle, z3::ast::Bool::from_bool(self.ctx, value)); Bool::new(handle) } fn bool_eq(&mut self, lhs: Bool, rhs: Bool) -> Bool { let lhs = self.bool(lhs); let rhs = self.bool(rhs); let result = lhs._eq(rhs); self.anon_bool(result) } fn bool_ne(&mut self, lhs: Bool, rhs: Bool) -> Bool { let lhs = self.bool(lhs); let rhs = self.bool(rhs); let result = lhs._eq(rhs).not(); self.anon_bool(result) } fn bool_and(&mut self, values: &[Bool]) -> Bool { let result = z3::ast::Bool::and( self.ctx, &values.iter().map(|b| self.bool(*b)).collect::>(), ); self.anon_bool(result) } fn bool_or(&mut self, values: &[Bool]) -> Bool { let result = z3::ast::Bool::or( self.ctx, &values.iter().map(|b| self.bool(*b)).collect::>(), ); self.anon_bool(result) } fn bool_not(&mut self, value: Bool) -> Bool { self.anon_bool(self.bool(value).not()) } fn bool_implies(&mut self, lhs: Bool, rhs: Bool) -> Bool { let lhs = self.bool(lhs); let rhs = self.bool(rhs); let result = lhs.implies(rhs); self.anon_bool(result) } } pub struct Z3Model<'z3> { ctx: &'z3 Z3Context<'z3>, model: z3::Model<'z3>, } impl SolverModel for Z3Model<'_> { fn eval_float(&self, f: Float) -> Option { let handle = match f { Float::Fixed(value) => return Some(value), Float::Variable(handle) => handle, }; let (num, den) = self .model .eval::(self.ctx.floats.get(&handle).expect("Couldn't find float"), true) .unwrap() .as_real() .unwrap(); // TODO: handle errors Some(num as f64 / den as f64) } fn eval_bool(&self, f: Bool) -> Option { Some( self.model .eval::(self.ctx.bool(f), true) .unwrap() .as_bool() .unwrap(), ) // TODO: handle errors } } /* use std::cell::RefCell; use std::collections::HashMap; use std::rc::Rc; #[derive(Clone, Eq, PartialEq, Debug)] pub struct Z3Float<'z3> { id: u32, real: z3::ast::Real<'z3>, ctx: Z3Context<'z3>, } // TODO: try to remove this clone? fn get_real<'z3>(f: &dyn VariableFloat, ctx: &Z3Context<'z3>) -> z3::ast::Real<'z3> { let id = f.id() as u32; ctx.0 .borrow() .floats .get(&id) .expect("Couldn't find float") .clone() } impl<'z3> VariableFloat<'z3> for Z3Float<'z3> { fn id(&self) -> usize { self.id as usize } fn dyn_clone(&self) -> Box + 'z3> { Box::new(Z3Float { id: self.id, real: self.real.clone(), ctx: self.ctx.clone(), }) } fn eq(&self, other: &dyn VariableFloat) -> Bool<'z3> { let other = get_real(other, &self.ctx); let result = self.real._eq(&other); let id = self.ctx.anon_bool(&result); Bool::new(Box::new(Z3Bool { id, real: result, ctx: self.ctx.clone(), })) } fn neq(&self, other: &dyn VariableFloat) -> Bool<'z3> { let other = get_real(other, &self.ctx); let result = self.real._eq(&other).not(); let id = self.ctx.anon_bool(&result); Bool::new(Box::new(Z3Bool { id, real: result, ctx: self.ctx.clone(), })) } fn gt(&self, other: &dyn VariableFloat) -> Bool<'z3> { let other = get_real(other, &self.ctx); let result = self.real.gt(&other); let id = self.ctx.anon_bool(&result); Bool::new(Box::new(Z3Bool { id, real: result, ctx: self.ctx.clone(), })) } fn ge(&self, other: &dyn VariableFloat) -> Bool<'z3> { let other = get_real(other, &self.ctx); let result = self.real.ge(&other); let id = self.ctx.anon_bool(&result); Bool::new(Box::new(Z3Bool { id, real: result, ctx: self.ctx.clone(), })) } fn lt(&self, other: &dyn VariableFloat) -> Bool<'z3> { let other = get_real(other, &self.ctx); let result = self.real.lt(&other); let id = self.ctx.anon_bool(&result); Bool::new(Box::new(Z3Bool { id, real: result, ctx: self.ctx.clone(), })) } fn le(&self, other: &dyn VariableFloat) -> Bool<'z3> { let other = get_real(other, &self.ctx); let result = self.real.le(&other); let id = self.ctx.anon_bool(&result); Bool::new(Box::new(Z3Bool { id, real: result, ctx: self.ctx.clone(), })) } fn add(&self, other: &dyn VariableFloat) -> Float<'z3> { let other = get_real(other, &self.ctx); let result = &self.real + &other; let id = self.ctx.anon_float(&result); Float::new(Box::new(Z3Float { id, real: result, ctx: self.ctx.clone(), })) } fn sub(&self, other: &dyn VariableFloat) -> Float<'z3> { let other = get_real(other, &self.ctx); let result = &self.real - &other; let id = self.ctx.anon_float(&result); Float::new(Box::new(Z3Float { id, real: result, ctx: self.ctx.clone(), })) } fn mul(&self, other: &dyn VariableFloat) -> Float<'z3> { let other = get_real(other, &self.ctx); let result = &self.real * &other; let id = self.ctx.anon_float(&result); Float::new(Box::new(Z3Float { id, real: result, ctx: self.ctx.clone(), })) } fn div(&self, other: &dyn VariableFloat) -> Float<'z3> { let other = get_real(other, &self.ctx); let result = &self.real / &other; let id = self.ctx.anon_float(&result); Float::new(Box::new(Z3Float { id, real: result, ctx: self.ctx.clone(), })) } fn neg(&self) -> Float<'z3> { let result = self.real.unary_minus(); let id = self.ctx.anon_float(&result); Float::new(Box::new(Z3Float { id, real: result, ctx: self.ctx.clone(), })) } } #[derive(Clone, Eq, PartialEq, Debug)] pub struct Z3Bool<'z3> { id: u32, real: z3::ast::Bool<'z3>, ctx: Z3Context<'z3>, } // TODO: try to remove this clone? fn get_bool<'z3>(f: &dyn VariableBool, ctx: &Z3Context<'z3>) -> z3::ast::Bool<'z3> { let id = f.id() as u32; ctx.0 .borrow() .bools .get(&id) .expect("Couldn't find bool") .clone() } impl<'z3> VariableBool<'z3> for Z3Bool<'z3> { fn id(&self) -> usize { self.id as usize } fn dyn_clone(&self) -> Box + 'z3> { Box::new(Z3Bool { id: self.id, real: self.real.clone(), ctx: self.ctx.clone(), }) } fn eq(&self, other: &dyn VariableBool) -> Bool<'z3> { let other = get_bool(other, &self.ctx); let result = self.real._eq(&other); let id = self.ctx.anon_bool(&result); Bool::new(Box::new(Z3Bool { id, real: result, ctx: self.ctx.clone(), })) } fn neq(&self, other: &dyn VariableBool) -> Bool<'z3> { let other = get_bool(other, &self.ctx); let result = self.real._eq(&other).not(); let id = self.ctx.anon_bool(&result); Bool::new(Box::new(Z3Bool { id, real: result, ctx: self.ctx.clone(), })) } fn and(&self, other: &dyn VariableBool) -> Bool<'z3> { let other = get_bool(other, &self.ctx); let result = &self.real & &other; let id = self.ctx.anon_bool(&result); Bool::new(Box::new(Z3Bool { id, real: result, ctx: self.ctx.clone(), })) } fn or(&self, other: &dyn VariableBool) -> Bool<'z3> { let other = get_bool(other, &self.ctx); let result = &self.real | &other; let id = self.ctx.anon_bool(&result); Bool::new(Box::new(Z3Bool { id, real: result, ctx: self.ctx.clone(), })) } fn not(&self) -> Bool<'z3> { let result = self.real.not(); let id = self.ctx.anon_bool(&result); Bool::new(Box::new(Z3Bool { id, real: result, ctx: self.ctx.clone(), })) } } #[derive(Clone, Eq, PartialEq, Debug)] pub struct Z3Context<'z3>(Rc>>); #[derive(Eq, PartialEq, Debug)] pub struct Z3ContextImpl<'z3> { ctx: z3::Context, floats: HashMap>, max_float_id: u32, bools: HashMap>, max_bool_id: u32, } impl<'z3> Z3Context<'z3> { pub fn new() -> Self { let conf = z3::Config::new(); let ctx = z3::Context::new(&conf); Z3Context(Rc::new(RefCell::new(Z3ContextImpl { ctx, floats: HashMap::new(), max_float_id: 0, bools: HashMap::new(), max_bool_id: 0, }))) } fn anon_float(&self, f: &z3::ast::Real<'z3>) -> u32 { let mut ctx = self.0.borrow_mut(); ctx.max_float_id += 1; let id = ctx.max_float_id; ctx.floats.insert(id, f.clone()); id } fn anon_bool(&self, f: &z3::ast::Bool<'z3>) -> u32 { let mut ctx = self.0.borrow_mut(); ctx.max_float_id += 1; let id = ctx.max_float_id; ctx.bools.insert(id, f.clone()); id } } impl<'z3> SolverContext<'z3> for Z3Context<'z3> { fn new_float(&'z3 mut self) -> Box + 'z3> { let mut ctx: std::cell::RefMut<'z3, _> = self.0.borrow_mut(); ctx.max_float_id += 1; let id = ctx.max_float_id; // :( // Should be safe since the z3 context inside the RefCell of self cannot be dropped while // the Box is still alive let f = z3::ast::Real::new_const(unsafe { std::mem::transmute(&ctx.ctx) }, id); ctx.floats.insert(id, f.clone()); Box::new(Z3Float { id, real: f, ctx: self.clone(), }) } fn new_bool(&'z3 mut self) -> Box + 'z3> { let mut ctx = self.0.borrow_mut(); ctx.max_bool_id += 1; let id = ctx.max_float_id; // :( // Should be safe since the z3 context inside the RefCell of self cannot be dropped while // the Box is still alive let b = z3::ast::Bool::new_const(unsafe { std::mem::transmute(&ctx.ctx) }, id); ctx.bools.insert(id, b.clone()); Box::new(Z3Bool { id: id, real: b, ctx: self.clone(), }) } } pub struct Z3Solver<'z3>(z3::Solver<'z3>); impl<'z3> Z3Solver<'z3> { pub fn new(ctx: &'z3 Z3Context) -> Self { Self(z3::Solver::new(unsafe { // :( // again std::mem::transmute(&ctx.0.borrow().ctx) })) } } impl<'z3> Solver for Z3Solver<'z3> { fn constrain(&mut self, assertion: &Bool) { todo!() } fn solve(&self) -> Box { todo!() } } */