summaryrefslogtreecommitdiffstats
path: root/z3-solver/src/lib.rs
diff options
context:
space:
mode:
authorMinijackson <minijackson@riseup.net>2022-12-22 12:19:59 +0100
committerMinijackson <minijackson@riseup.net>2022-12-22 12:19:59 +0100
commit92a02c34628343153b33602eae00cef46e28d191 (patch)
tree8622ec528d24e456be22d984d93aa9bcafc97399 /z3-solver/src/lib.rs
downloaddiaphragm-92a02c34628343153b33602eae00cef46e28d191.tar.gz
diaphragm-92a02c34628343153b33602eae00cef46e28d191.zip
WIP
Diffstat (limited to 'z3-solver/src/lib.rs')
-rw-r--r--z3-solver/src/lib.rs691
1 files changed, 691 insertions, 0 deletions
diff --git a/z3-solver/src/lib.rs b/z3-solver/src/lib.rs
new file mode 100644
index 0000000..04d77eb
--- /dev/null
+++ b/z3-solver/src/lib.rs
@@ -0,0 +1,691 @@
1pub use z3;
2
3use diaphragm_core::{
4 solving::VariableHandle,
5 types::{Bool, Float},
6 SolverContext, SolverModel,
7};
8
9use z3::ast::Ast;
10
11use std::collections::HashMap;
12
13#[derive(Debug)]
14pub struct Z3Context<'z3> {
15 ctx: &'z3 z3::Context,
16 solver: z3::Solver<'z3>,
17
18 floats: HashMap<VariableHandle, z3::ast::Real<'z3>>,
19 max_float_id: u32,
20
21 bools: HashMap<VariableHandle, z3::ast::Bool<'z3>>,
22 pub max_bool_id: u32,
23}
24
25impl Drop for Z3Context<'_> {
26 fn drop(&mut self) {
27 eprintln!("bool: {}", self.max_bool_id);
28 eprintln!("float: {}", self.max_float_id);
29 }
30}
31
32fn value_to_num_den(value: f64) -> (i32, i32) {
33 let fract = value.fract();
34 let number_of_fract_digits = -fract.log10().floor();
35
36 if number_of_fract_digits >= 1. && !number_of_fract_digits.is_infinite() {
37 let den = 10f64.powf(number_of_fract_digits);
38 ((value * den) as i32, den as i32)
39 } else {
40 (value as i32, 1)
41 }
42}
43
44impl<'z3> Z3Context<'z3> {
45 pub fn new(ctx: &'z3 z3::Context) -> Self {
46 Self {
47 ctx,
48 solver: z3::Solver::new(&ctx),
49 floats: HashMap::new(),
50 max_float_id: 0,
51 bools: HashMap::new(),
52 max_bool_id: 0,
53 }
54 }
55
56 fn anon_float(&mut self, f: z3::ast::Real<'z3>) -> Float {
57 self.max_float_id += 1;
58 let id = self.max_float_id;
59 let handle = VariableHandle::new(id as usize);
60 self.floats.insert(handle, f);
61 Float::from_handle(handle)
62 }
63
64 fn anon_bool(&mut self, f: z3::ast::Bool<'z3>) -> Bool {
65 self.max_bool_id += 1;
66 let id = self.max_bool_id;
67 let handle = VariableHandle::new(id as usize);
68 self.bools.insert(handle, dbg!(f));
69 Bool::new(handle)
70 }
71
72 /*
73 fn float_handle(&mut self, f: Float) -> VariableHandle {
74 match f {
75 Float::Fixed(value) => {
76 let new_float = self.new_fixed_float(value);
77 self.float_handle(new_float)
78 }
79 Float::Variable(handle) => handle,
80 }
81 }
82 */
83
84 fn float(&self, f: Float) -> z3::ast::Real<'z3> {
85 let handle = match f {
86 Float::Fixed(value) => {
87 let (num, den) = value_to_num_den(value);
88 return z3::ast::Real::from_real(&self.ctx, num, den);
89 }
90 Float::Variable(handle) => handle,
91 };
92
93 self.floats
94 .get(&handle)
95 .expect("Couldn't get float")
96 .clone()
97 }
98
99 fn bool(&self, f: Bool) -> &z3::ast::Bool<'z3> {
100 self.bools.get(&f.handle()).expect("Couldn't get float")
101 }
102}
103
104impl<'z3> SolverContext for Z3Context<'z3> {
105 fn solve<'a>(&'a self) -> Box<dyn SolverModel + 'a> {
106 match self.solver.check() {
107 z3::SatResult::Unsat | z3::SatResult::Unknown => panic!("Failed solving"),
108 z3::SatResult::Sat => {}
109 }
110 Box::new(Z3Model {
111 ctx: self,
112 model: self.solver.get_model().unwrap(),
113 })
114 }
115
116 fn constrain(&mut self, assertion: Bool) {
117 self.solver.assert(
118 self.bools
119 .get(&assertion.handle())
120 .expect("Couldn't get bool"),
121 );
122 }
123
124 // Floats
125
126 fn new_free_float(&mut self) -> Float {
127 self.max_float_id += 1;
128 let id = self.max_float_id;
129 let handle = VariableHandle::new(id as usize);
130 self.floats
131 .insert(handle, z3::ast::Real::new_const(&self.ctx, id));
132 Float::from_handle(handle)
133 }
134
135 fn new_fixed_float(&mut self, value: f64) -> Float {
136 self.max_float_id += 1;
137 let id = self.max_float_id;
138 let handle = VariableHandle::new(id as usize);
139
140 let (num, den) = value_to_num_den(value);
141
142 self.floats
143 .insert(handle, z3::ast::Real::from_real(&self.ctx, num, den));
144 Float::from_handle(handle)
145 }
146
147 // TODO: that's a lot of copying things
148
149 fn float_add(&mut self, values: &[Float]) -> Float {
150 let values = values.iter().map(|f| self.float(*f)).collect::<Vec<_>>();
151 let result = z3::ast::Real::add(
152 self.ctx,
153 &values.iter().collect::<Vec<_>>()
154 );
155
156 self.anon_float(result)
157 }
158
159 fn float_sub(&mut self, values: &[Float]) -> Float {
160 let values = values.iter().map(|f| self.float(*f)).collect::<Vec<_>>();
161 let result = z3::ast::Real::sub(
162 self.ctx,
163 &values.iter().collect::<Vec<_>>()
164 );
165
166 self.anon_float(result)
167 }
168
169 fn float_mul(&mut self, values: &[Float]) -> Float {
170 let values = values.iter().map(|f| self.float(*f)).collect::<Vec<_>>();
171 let result = z3::ast::Real::mul(
172 self.ctx,
173 &values.iter().collect::<Vec<_>>()
174 );
175
176 self.anon_float(result)
177 }
178
179 fn float_div(&mut self, lhs: Float, rhs: Float) -> Float {
180 let lhs = self.float(lhs);
181 let rhs = self.float(rhs);
182 let result = lhs.div(&rhs);
183 self.anon_float(result)
184 }
185
186 fn float_neg(&mut self, value: Float) -> Float {
187 self.anon_float(self.float(value).unary_minus())
188 }
189
190 fn float_eq(&mut self, lhs: Float, rhs: Float) -> Bool {
191 let lhs = self.float(lhs);
192 let rhs = self.float(rhs);
193 let result = lhs._eq(&rhs);
194 self.anon_bool(result)
195 }
196
197 fn float_ne(&mut self, lhs: Float, rhs: Float) -> Bool {
198 let lhs = self.float(lhs);
199 let rhs = self.float(rhs);
200 let result = lhs._eq(&rhs).not();
201 self.anon_bool(result)
202 }
203
204 fn float_gt(&mut self, lhs: Float, rhs: Float) -> Bool {
205 let lhs = self.float(lhs);
206 let rhs = self.float(rhs);
207 let result = lhs.gt(&rhs);
208 self.anon_bool(result)
209 }
210
211 fn float_ge(&mut self, lhs: Float, rhs: Float) -> Bool {
212 let lhs = self.float(lhs);
213 let rhs = self.float(rhs);
214 let result = lhs.ge(&rhs);
215 self.anon_bool(result)
216 }
217
218 fn float_lt(&mut self, lhs: Float, rhs: Float) -> Bool {
219 let lhs = self.float(lhs);
220 let rhs = self.float(rhs);
221 let result = lhs.lt(&rhs);
222 self.anon_bool(result)
223 }
224
225 fn float_le(&mut self, lhs: Float, rhs: Float) -> Bool {
226 let lhs = self.float(lhs);
227 let rhs = self.float(rhs);
228 let result = lhs.le(&rhs);
229 self.anon_bool(result)
230 }
231
232 // Bools
233
234 fn new_free_bool(&mut self) -> Bool {
235 self.max_bool_id += 1;
236 let id = self.max_bool_id;
237 let handle = VariableHandle::new(id as usize);
238 self.bools
239 .insert(handle, z3::ast::Bool::new_const(&self.ctx, id));
240 Bool::new(handle)
241 }
242
243 fn new_fixed_bool(&mut self, value: bool) -> Bool {
244 self.max_bool_id += 1;
245 let id = self.max_bool_id;
246 let handle = VariableHandle::new(id as usize);
247
248 self.bools
249 .insert(handle, z3::ast::Bool::from_bool(&self.ctx, value));
250 Bool::new(handle)
251 }
252
253 fn bool_eq(&mut self, lhs: Bool, rhs: Bool) -> Bool {
254 let lhs = self.bool(lhs);
255 let rhs = self.bool(rhs);
256 let result = lhs._eq(&rhs);
257 self.anon_bool(result)
258 }
259
260 fn bool_ne(&mut self, lhs: Bool, rhs: Bool) -> Bool {
261 let lhs = self.bool(lhs);
262 let rhs = self.bool(rhs);
263 let result = lhs._eq(&rhs).not();
264 self.anon_bool(result)
265 }
266
267 fn bool_and(&mut self, values: &[Bool]) -> Bool {
268 let result = z3::ast::Bool::and(
269 self.ctx,
270 &values.iter().map(|b| self.bool(*b)).collect::<Vec<_>>(),
271 );
272
273 self.anon_bool(result)
274 }
275
276 fn bool_or(&mut self, values: &[Bool]) -> Bool {
277 let result = z3::ast::Bool::or(
278 self.ctx,
279 &values.iter().map(|b| self.bool(*b)).collect::<Vec<_>>(),
280 );
281
282 self.anon_bool(result)
283 }
284
285 fn bool_not(&mut self, value: Bool) -> Bool {
286 self.anon_bool(self.bool(value).not())
287 }
288
289 fn bool_implies(&mut self, lhs: Bool, rhs: Bool) -> Bool {
290 let lhs = self.bool(lhs);
291 let rhs = self.bool(rhs);
292 let result = lhs.implies(rhs);
293 self.anon_bool(result)
294 }
295}
296
297pub struct Z3Model<'z3> {
298 ctx: &'z3 Z3Context<'z3>,
299 model: z3::Model<'z3>,
300}
301
302impl SolverModel for Z3Model<'_> {
303 fn eval_float(&self, f: Float) -> Option<f64> {
304 let handle = match f {
305 Float::Fixed(value) => return Some(value),
306 Float::Variable(handle) => handle,
307 };
308
309 let (num, den) = self
310 .model
311 .eval::<z3::ast::Real>(self.ctx.floats.get(&handle).expect("Couldn't find float"))
312 .unwrap()
313 .as_real()
314 .unwrap();
315 // TODO: handle errors
316 Some(num as f64 / den as f64)
317 }
318
319 fn eval_bool(&self, f: Bool) -> Option<bool> {
320 Some(
321 self.model
322 .eval::<z3::ast::Bool>(&self.ctx.bool(f))
323 .unwrap()
324 .as_bool()
325 .unwrap(),
326 )
327 // TODO: handle errors
328 }
329}
330
331/*
332use std::cell::RefCell;
333use std::collections::HashMap;
334use std::rc::Rc;
335
336#[derive(Clone, Eq, PartialEq, Debug)]
337pub struct Z3Float<'z3> {
338 id: u32,
339 real: z3::ast::Real<'z3>,
340 ctx: Z3Context<'z3>,
341}
342
343// TODO: try to remove this clone?
344fn get_real<'z3>(f: &dyn VariableFloat, ctx: &Z3Context<'z3>) -> z3::ast::Real<'z3> {
345 let id = f.id() as u32;
346 ctx.0
347 .borrow()
348 .floats
349 .get(&id)
350 .expect("Couldn't find float")
351 .clone()
352}
353
354impl<'z3> VariableFloat<'z3> for Z3Float<'z3> {
355 fn id(&self) -> usize {
356 self.id as usize
357 }
358
359 fn dyn_clone(&self) -> Box<dyn VariableFloat<'z3> + 'z3> {
360 Box::new(Z3Float {
361 id: self.id,
362 real: self.real.clone(),
363 ctx: self.ctx.clone(),
364 })
365 }
366
367 fn eq(&self, other: &dyn VariableFloat) -> Bool<'z3> {
368 let other = get_real(other, &self.ctx);
369 let result = self.real._eq(&other);
370 let id = self.ctx.anon_bool(&result);
371
372 Bool::new(Box::new(Z3Bool {
373 id,
374 real: result,
375 ctx: self.ctx.clone(),
376 }))
377 }
378
379 fn neq(&self, other: &dyn VariableFloat) -> Bool<'z3> {
380 let other = get_real(other, &self.ctx);
381 let result = self.real._eq(&other).not();
382 let id = self.ctx.anon_bool(&result);
383
384 Bool::new(Box::new(Z3Bool {
385 id,
386 real: result,
387 ctx: self.ctx.clone(),
388 }))
389 }
390
391 fn gt(&self, other: &dyn VariableFloat) -> Bool<'z3> {
392 let other = get_real(other, &self.ctx);
393 let result = self.real.gt(&other);
394 let id = self.ctx.anon_bool(&result);
395
396 Bool::new(Box::new(Z3Bool {
397 id,
398 real: result,
399 ctx: self.ctx.clone(),
400 }))
401 }
402
403 fn ge(&self, other: &dyn VariableFloat) -> Bool<'z3> {
404 let other = get_real(other, &self.ctx);
405 let result = self.real.ge(&other);
406 let id = self.ctx.anon_bool(&result);
407
408 Bool::new(Box::new(Z3Bool {
409 id,
410 real: result,
411 ctx: self.ctx.clone(),
412 }))
413 }
414
415 fn lt(&self, other: &dyn VariableFloat) -> Bool<'z3> {
416 let other = get_real(other, &self.ctx);
417 let result = self.real.lt(&other);
418 let id = self.ctx.anon_bool(&result);
419
420 Bool::new(Box::new(Z3Bool {
421 id,
422 real: result,
423 ctx: self.ctx.clone(),
424 }))
425 }
426
427 fn le(&self, other: &dyn VariableFloat) -> Bool<'z3> {
428 let other = get_real(other, &self.ctx);
429 let result = self.real.le(&other);
430 let id = self.ctx.anon_bool(&result);
431
432 Bool::new(Box::new(Z3Bool {
433 id,
434 real: result,
435 ctx: self.ctx.clone(),
436 }))
437 }
438
439 fn add(&self, other: &dyn VariableFloat) -> Float<'z3> {
440 let other = get_real(other, &self.ctx);
441 let result = &self.real + &other;
442 let id = self.ctx.anon_float(&result);
443
444 Float::new(Box::new(Z3Float {
445 id,
446 real: result,
447 ctx: self.ctx.clone(),
448 }))
449 }
450
451 fn sub(&self, other: &dyn VariableFloat) -> Float<'z3> {
452 let other = get_real(other, &self.ctx);
453 let result = &self.real - &other;
454 let id = self.ctx.anon_float(&result);
455
456 Float::new(Box::new(Z3Float {
457 id,
458 real: result,
459 ctx: self.ctx.clone(),
460 }))
461 }
462
463 fn mul(&self, other: &dyn VariableFloat) -> Float<'z3> {
464 let other = get_real(other, &self.ctx);
465 let result = &self.real * &other;
466 let id = self.ctx.anon_float(&result);
467
468 Float::new(Box::new(Z3Float {
469 id,
470 real: result,
471 ctx: self.ctx.clone(),
472 }))
473 }
474
475 fn div(&self, other: &dyn VariableFloat) -> Float<'z3> {
476 let other = get_real(other, &self.ctx);
477 let result = &self.real / &other;
478 let id = self.ctx.anon_float(&result);
479
480 Float::new(Box::new(Z3Float {
481 id,
482 real: result,
483 ctx: self.ctx.clone(),
484 }))
485 }
486
487 fn neg(&self) -> Float<'z3> {
488 let result = self.real.unary_minus();
489 let id = self.ctx.anon_float(&result);
490
491 Float::new(Box::new(Z3Float {
492 id,
493 real: result,
494 ctx: self.ctx.clone(),
495 }))
496 }
497}
498
499#[derive(Clone, Eq, PartialEq, Debug)]
500pub struct Z3Bool<'z3> {
501 id: u32,
502 real: z3::ast::Bool<'z3>,
503 ctx: Z3Context<'z3>,
504}
505
506// TODO: try to remove this clone?
507fn get_bool<'z3>(f: &dyn VariableBool, ctx: &Z3Context<'z3>) -> z3::ast::Bool<'z3> {
508 let id = f.id() as u32;
509 ctx.0
510 .borrow()
511 .bools
512 .get(&id)
513 .expect("Couldn't find bool")
514 .clone()
515}
516
517impl<'z3> VariableBool<'z3> for Z3Bool<'z3> {
518 fn id(&self) -> usize {
519 self.id as usize
520 }
521
522 fn dyn_clone(&self) -> Box<dyn VariableBool<'z3> + 'z3> {
523 Box::new(Z3Bool {
524 id: self.id,
525 real: self.real.clone(),
526 ctx: self.ctx.clone(),
527 })
528 }
529
530 fn eq(&self, other: &dyn VariableBool) -> Bool<'z3> {
531 let other = get_bool(other, &self.ctx);
532 let result = self.real._eq(&other);
533 let id = self.ctx.anon_bool(&result);
534
535 Bool::new(Box::new(Z3Bool {
536 id,
537 real: result,
538 ctx: self.ctx.clone(),
539 }))
540 }
541
542 fn neq(&self, other: &dyn VariableBool) -> Bool<'z3> {
543 let other = get_bool(other, &self.ctx);
544 let result = self.real._eq(&other).not();
545 let id = self.ctx.anon_bool(&result);
546
547 Bool::new(Box::new(Z3Bool {
548 id,
549 real: result,
550 ctx: self.ctx.clone(),
551 }))
552 }
553
554 fn and(&self, other: &dyn VariableBool) -> Bool<'z3> {
555 let other = get_bool(other, &self.ctx);
556 let result = &self.real & &other;
557 let id = self.ctx.anon_bool(&result);
558
559 Bool::new(Box::new(Z3Bool {
560 id,
561 real: result,
562 ctx: self.ctx.clone(),
563 }))
564 }
565
566 fn or(&self, other: &dyn VariableBool) -> Bool<'z3> {
567 let other = get_bool(other, &self.ctx);
568 let result = &self.real | &other;
569 let id = self.ctx.anon_bool(&result);
570
571 Bool::new(Box::new(Z3Bool {
572 id,
573 real: result,
574 ctx: self.ctx.clone(),
575 }))
576 }
577
578 fn not(&self) -> Bool<'z3> {
579 let result = self.real.not();
580 let id = self.ctx.anon_bool(&result);
581
582 Bool::new(Box::new(Z3Bool {
583 id,
584 real: result,
585 ctx: self.ctx.clone(),
586 }))
587 }
588}
589
590#[derive(Clone, Eq, PartialEq, Debug)]
591pub struct Z3Context<'z3>(Rc<RefCell<Z3ContextImpl<'z3>>>);
592
593#[derive(Eq, PartialEq, Debug)]
594pub struct Z3ContextImpl<'z3> {
595 ctx: z3::Context,
596 floats: HashMap<u32, z3::ast::Real<'z3>>,
597 max_float_id: u32,
598 bools: HashMap<u32, z3::ast::Bool<'z3>>,
599 max_bool_id: u32,
600}
601
602impl<'z3> Z3Context<'z3> {
603 pub fn new() -> Self {
604 let conf = z3::Config::new();
605 let ctx = z3::Context::new(&conf);
606 Z3Context(Rc::new(RefCell::new(Z3ContextImpl {
607 ctx,
608 floats: HashMap::new(),
609 max_float_id: 0,
610 bools: HashMap::new(),
611 max_bool_id: 0,
612 })))
613 }
614
615 fn anon_float(&self, f: &z3::ast::Real<'z3>) -> u32 {
616 let mut ctx = self.0.borrow_mut();
617
618 ctx.max_float_id += 1;
619 let id = ctx.max_float_id;
620 ctx.floats.insert(id, f.clone());
621 id
622 }
623
624 fn anon_bool(&self, f: &z3::ast::Bool<'z3>) -> u32 {
625 let mut ctx = self.0.borrow_mut();
626
627 ctx.max_float_id += 1;
628 let id = ctx.max_float_id;
629 ctx.bools.insert(id, f.clone());
630 id
631 }
632}
633
634impl<'z3> SolverContext<'z3> for Z3Context<'z3> {
635 fn new_float(&'z3 mut self) -> Box<dyn VariableFloat<'z3> + 'z3> {
636 let mut ctx: std::cell::RefMut<'z3, _> = self.0.borrow_mut();
637
638 ctx.max_float_id += 1;
639 let id = ctx.max_float_id;
640 // :(
641 // Should be safe since the z3 context inside the RefCell of self cannot be dropped while
642 // the Box is still alive
643 let f = z3::ast::Real::new_const(unsafe { std::mem::transmute(&ctx.ctx) }, id);
644 ctx.floats.insert(id, f.clone());
645 Box::new(Z3Float {
646 id,
647 real: f,
648 ctx: self.clone(),
649 })
650 }
651
652 fn new_bool(&'z3 mut self) -> Box<dyn VariableBool<'z3> + 'z3> {
653 let mut ctx = self.0.borrow_mut();
654
655 ctx.max_bool_id += 1;
656 let id = ctx.max_float_id;
657 // :(
658 // Should be safe since the z3 context inside the RefCell of self cannot be dropped while
659 // the Box is still alive
660 let b = z3::ast::Bool::new_const(unsafe { std::mem::transmute(&ctx.ctx) }, id);
661 ctx.bools.insert(id, b.clone());
662 Box::new(Z3Bool {
663 id: id,
664 real: b,
665 ctx: self.clone(),
666 })
667 }
668}
669
670pub struct Z3Solver<'z3>(z3::Solver<'z3>);
671
672impl<'z3> Z3Solver<'z3> {
673 pub fn new(ctx: &'z3 Z3Context) -> Self {
674 Self(z3::Solver::new(unsafe {
675 // :(
676 // again
677 std::mem::transmute(&ctx.0.borrow().ctx)
678 }))
679 }
680}
681
682impl<'z3> Solver for Z3Solver<'z3> {
683 fn constrain(&mut self, assertion: &Bool) {
684 todo!()
685 }
686
687 fn solve(&self) -> Box<dyn SolverModel> {
688 todo!()
689 }
690}
691*/