Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
252 changes: 236 additions & 16 deletions bi/src/bi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

use crate::proto_trace::*;
use crate::signal_trace::*;
use baa::{BitVecOps, BitVecValue};
use baa::{BitVecMutOps, BitVecOps, BitVecValue, WidthInt};
use protocols::ir::*;
use rustc_hash::{FxHashMap, FxHashSet};

Expand Down Expand Up @@ -36,11 +36,12 @@ impl<T: SignalTrace> BackwardsInterpreter<T> {
) -> Self {
let transactions: Vec<_> = transactions_and_symbols
.enumerate()
.map(|(proto_id, (t, _))| {
.map(|(proto_id, (t, sym))| {
let next_stmt = t.next_stmt_mapping();
ProtoInfo {
proto_id,
proto: t.clone(),
sym: sym.clone(),
next_stmt,
}
})
Expand Down Expand Up @@ -103,6 +104,18 @@ impl<T: SignalTrace> BackwardsInterpreter<T> {
debug_assert!(new_paths.iter().all(|p| !p.failed()));
self.active.append(&mut new_paths);
}
PathResult::Branch(new_threads) => {
self.active
.extend(new_threads.into_iter().enumerate().map(|(id, new_t)| {
let mut p = path.clone();
p.active.push(new_t);
if id > 0 {
// ensure every new path has a unique trace
p.trace_id = self.traces.fork(p.trace_id);
}
p
}))
}
PathResult::FinishedStep => {
debug_assert!(
!path.active.is_empty()
Expand Down Expand Up @@ -130,6 +143,10 @@ impl<T: SignalTrace> BackwardsInterpreter<T> {
}
BIResult::Ok
} else {
// println!("BI: Finished step {}. Traces failed={}, active={}", self.step, self.failed.len(), self.next.len());
// for f in self.failed.iter().chain(self.next.iter()) {
// println!(" - {}", f.thread_string());
// }
// otherwise all paths must be finished
debug_assert!(self.active.is_empty());
debug_assert!(self.failed.iter().all(|p| p.failed()));
Expand Down Expand Up @@ -198,6 +215,7 @@ struct Path {
#[derive(Debug, Clone)]
enum PathResult {
Ok,
Branch(Vec<Thread>),
Fork,
Failed,
FinishedStep,
Expand All @@ -223,7 +241,7 @@ impl Path {
let active = self
.active
.iter()
.map(|t| t.name.as_str())
.map(|t| format!("{}@{:?}", t.name.as_str(), t.next_stmt))
.collect::<Vec<_>>()
.join(", ");
let next = self
Expand Down Expand Up @@ -263,11 +281,18 @@ impl Path {
p.trace_id = traces.fork(p.trace_id);
}

let arg_values = t
.proto
.args
.iter()
.map(|a| ArgValue::unknown(&t.sym, a))
.collect();

let t = Thread {
name: format!("{}@{}", t.proto.name, self.step),
transaction_id: id,
next_stmt: Some(t.proto.body),
arg_values: vec![None; t.proto.args.len()],
arg_values,
has_forked: false,
step: 0,
pin_assignments: vec![],
Expand Down Expand Up @@ -305,7 +330,11 @@ impl Path {
(PathResult::Ok, None)
}
ThreadResult::Fork => {
assert!(!self.has_forked_this_step && !self.fork_next_step);
assert!(
!self.has_forked_this_step,
"another thread already forked in the same thread!"
);
assert!(!self.fork_next_step);
self.has_forked_this_step = true;
self.active.push(thread);
(PathResult::Fork, None)
Expand All @@ -314,6 +343,11 @@ impl Path {
self.next.push(thread);
(PathResult::Ok, None)
}
ThreadResult::RepeatLoop => {
let tid = thread.transaction_id;
let (a, b) = thread.exec_repeat_loop_branch(&tis[tid]);
(PathResult::Branch(vec![a, b]), None)
}
ThreadResult::FinalStep => (
PathResult::Ok,
Some(thread_to_call(tis, thread, Some(self.step))),
Expand Down Expand Up @@ -346,7 +380,7 @@ impl Path {
fn thread_to_call(tis: &[ProtoInfo], thread: Thread, end: Option<u32>) -> ProtoCall {
assert!(end.is_none() || thread.next_stmt.is_none());
let name = tis[thread.transaction_id].proto.name.clone();
let args = thread.arg_values;
let args = thread.arg_values.iter().map(|a| a.get_known()).collect();
let start = thread.start_step;
ProtoCall {
name,
Expand All @@ -362,7 +396,7 @@ struct Thread {
name: String,
transaction_id: usize,
next_stmt: Option<StmtId>,
arg_values: Vec<Option<BitVecValue>>,
arg_values: Vec<ArgValue>,
pin_assignments: Vec<(StmtId, SymbolId, ExprId)>,
has_forked: bool,
step: u32,
Expand All @@ -373,6 +407,7 @@ struct Thread {
#[derive(Debug, Clone)]
enum ThreadResult {
Ok,
RepeatLoop,
Step,
Fork,
FinalStepAndFork,
Expand All @@ -381,14 +416,52 @@ enum ThreadResult {
}

impl Thread {
fn exec_repeat_loop_branch(self, ti: &ProtoInfo) -> (Thread, Thread) {
let stmt = self.next_stmt.expect("");
debug_assert!(
matches!(&ti.proto[stmt], Stmt::RepeatLoop(_, _)),
"repeat loop!"
);
let (body, arg_id) = if let Stmt::RepeatLoop(arg, body) = &ti.proto[stmt] {
(*body, as_arg(&ti.proto, *arg).unwrap().0)
} else {
unreachable!(
"this function may only be called when executing a repeat loop statement!"
);
};

// if we take the repeat loop branch, we up the value
let mut taken = self.clone();
let arg_value = taken.arg_values[arg_id]
.as_repeat()
.expect("must be a uint arg");
arg_value.increment_current_value();
let value = arg_value.current_value();
taken.name = format!("{}{value}?", taken.name);
taken.next_stmt = Some(body);

// of we do not take the branch, we jump after the loop
let mut not_taken = self;
let arg_value = not_taken.arg_values[arg_id]
.as_repeat()
.expect("must be a uint arg");
let max_iter = arg_value.current_value();
*arg_value = RepeatValue::Exactly(max_iter, 0);
not_taken.name = format!("{}{max_iter}!", not_taken.name);
not_taken.next_stmt = ti.next_stmt[&stmt];

// we need to explore both versions of our thread
(taken, not_taken)
}

fn exec_stmt(
&mut self,
ti: &ProtoInfo,
get_value: &impl Fn(SymbolId) -> BitVecValue,
) -> ThreadResult {
use ThreadResult::*;
if let Some(stmt) = self.next_stmt {
// println!("{:?}", &ti.transaction[stmt]);
// println!("{:?}", &ti.proto[stmt]);
match &ti.proto[stmt] {
Stmt::Block(stmt_ids) => {
self.next_stmt = stmt_ids.first().cloned();
Expand Down Expand Up @@ -417,6 +490,7 @@ impl Thread {
}
}
Stmt::Fork => {
assert!(self.step > 0, "[{}] Cannot fork at step zero!", self.name);
self.has_forked = true;
self.next_stmt = ti.next_stmt[&stmt];
Fork
Expand All @@ -432,11 +506,32 @@ impl Thread {
};
Ok
}
Stmt::RepeatLoop(repetitions, _body) => {
let (_arg_id, arg) = as_arg(&ti.proto, *repetitions)
.expect("repeat loop repetition count needs to be an argument");

todo!("repeat {:?}", arg.symbol())
Stmt::RepeatLoop(repetitions, body) => {
let arg_id = as_arg(&ti.proto, *repetitions)
.expect("repeat loop repetition count needs to be an argument")
.0;

let arg_value = self.arg_values[arg_id]
.as_repeat()
.expect("must be a repeat arg");

// check to see if the number of iterations is known
// in this case we essentially just have a for(i=0; i < max_iters; i++) loop
if let Some(max_iters) = arg_value.num_iters() {
if arg_value.current_value() < max_iters {
arg_value.increment_current_value();
self.next_stmt = Some(*body);
} else {
// this will make the current value overflow and go back to zero
arg_value.increment_current_value();
self.next_stmt = ti.next_stmt[&stmt]; // exit loop
}
Ok
} else {
// we do not actually know the correct number of steps, and we need to explore both possibilities
// this must happen outside of this method since it involves cloning the thread
RepeatLoop
}
}
Stmt::IfElse(cond, tru, fals) => {
let cond_value = self
Expand Down Expand Up @@ -536,13 +631,16 @@ impl Thread {
}
} else {
if let Some((arg_id, _arg)) = as_arg(&ti.proto, rhs) {
debug_assert!(self.arg_values[arg_id].is_none());
if let ArgValue::Data(d) = &mut self.arg_values[arg_id] {
d.define_value(lhs_value);
} else {
unreachable!("assignments/assert_eq must involve data values (bit vectors)!")
}
// println!(
// "[{}] Discovered that ?? = {}",
// self.name,
// lhs_value.to_bit_str()
// );
self.arg_values[arg_id] = Some(lhs_value);
} else {
todo!()
}
Expand All @@ -557,7 +655,7 @@ impl Thread {
expr: ExprId,
) -> Option<BitVecValue> {
if let Some((arg_id, _)) = as_arg(transaction, expr) {
return self.arg_values[arg_id].clone();
return self.arg_values[arg_id].get_known();
}
match &transaction[expr] {
Expr::Const(value) => Some(value.clone()),
Expand Down Expand Up @@ -619,5 +717,127 @@ fn as_arg(transaction: &Transaction, expr: ExprId) -> Option<(usize, Arg)> {
struct ProtoInfo {
proto_id: usize,
proto: Transaction,
sym: SymbolTable,
next_stmt: FxHashMap<StmtId, Option<StmtId>>,
}

#[derive(Debug, Clone)]
enum ArgValue {
Data(DataValue),
Repeat(RepeatValue),
}

impl ArgValue {
fn unknown(sym: &SymbolTable, arg: &Arg) -> Self {
match sym[arg.symbol()].tpe() {
Type::UnsignedInt => Self::Repeat(RepeatValue::default()),
Type::BitVec(w) => Self::Data(DataValue::unknown(w)),
Type::Struct(_) => unreachable!("args cannot be structs"),
Type::Unknown => unreachable!("arg types are always known"),
}
}

fn get_known(&self) -> Option<BitVecValue> {
match self {
ArgValue::Data(d) => d.get_known(),
ArgValue::Repeat(RepeatValue::Exactly(v, _)) => {
Some(BitVecValue::from_u64(*v as u64, 32))
}
ArgValue::Repeat(RepeatValue::AtLeast(_)) => None,
}
}

fn as_repeat(&mut self) -> Option<&mut RepeatValue> {
if let Self::Repeat(v) = self {
Some(v)
} else {
None
}
}
}

#[derive(Debug, Clone)]
enum RepeatValue {
AtLeast(u32),
Exactly(u32, u32),
}

impl Default for RepeatValue {
fn default() -> Self {
Self::AtLeast(0)
}
}

impl RepeatValue {
fn current_value(&self) -> u32 {
match self {
RepeatValue::AtLeast(v) => *v,
RepeatValue::Exactly(_, v) => *v,
}
}

fn increment_current_value(&mut self) {
match self {
RepeatValue::AtLeast(v) => *v += 1,
RepeatValue::Exactly(b, v) => {
if *v + 1 == *b {
*v = 0;
} else {
*v += 1;
}
debug_assert!(*v < *b);
}
}
}

fn num_iters(&self) -> Option<u32> {
match self {
RepeatValue::AtLeast(_) => None,
RepeatValue::Exactly(b, _) => Some(*b),
}
}
}

#[derive(Debug, Clone)]
struct DataValue {
value: BitVecValue,
known: BitVecValue,
}

impl DataValue {
fn unknown(width: WidthInt) -> Self {
Self {
value: BitVecValue::zero(width),
known: BitVecValue::zero(width),
}
}

fn get_known(&self) -> Option<BitVecValue> {
if self.known.is_all_ones() {
Some(self.value.clone())
} else {
None
}
}

#[allow(dead_code)]
fn bit_is_known(&self, bit: WidthInt) -> bool {
self.known.is_bit_set(bit)
}

#[allow(dead_code)]
fn define_bit(&mut self, bit: WidthInt, value: u8) {
debug_assert!(!self.bit_is_known(bit));
debug_assert!(bit < self.value.width());
if value != 0 {
self.value.set_bit(bit);
}
self.known.set_bit(bit);
}

fn define_value(&mut self, value: BitVecValue) {
debug_assert_eq!(self.value.width(), value.width());
self.value = value;
self.known = BitVecValue::ones(self.value.width());
}
}
Loading