Skip to content

Commit cd14fbc

Browse files
committed
Adapt Bit Machine
1 parent ff8b92f commit cd14fbc

File tree

1 file changed

+129
-123
lines changed

1 file changed

+129
-123
lines changed

src/bit_machine/exec.rs

Lines changed: 129 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,13 @@
1919
//!
2020
2121
use super::frame::Frame;
22-
use crate::core::Value;
23-
use crate::jet::AppError;
24-
use std::error;
22+
use crate::core::node::NodeInner;
23+
use crate::core::types::TypeInner;
24+
use crate::core::{Node, Value};
25+
use crate::decode;
26+
use crate::jet::{AppError, Application};
2527
use std::fmt;
28+
use std::{cmp, error};
2629

2730
/// An execution context for a Simplicity program
2831
pub struct BitMachine {
@@ -38,21 +41,18 @@ pub struct BitMachine {
3841
}
3942

4043
impl BitMachine {
41-
/*
42-
/// Construct a Bit Machine with enough space to execute
43-
/// the given program
44-
pub fn for_program<App: Application>(program: &Program<App>) -> BitMachine {
45-
let prog = program.root();
46-
let io_width = prog.source_ty().bit_width + prog.target_ty().bit_width;
47-
BitMachine {
48-
data: vec![0; (io_width + prog.extra_cells_bound + 7) / 8],
44+
/// Construct a Bit Machine with enough space to execute the given program.
45+
pub fn for_program<App: Application>(program: &Node<Value, App>) -> Self {
46+
let io_width = program.ty.source.bit_width + program.ty.target.bit_width;
47+
48+
Self {
49+
data: vec![0; (io_width + program.bounds.extra_cells + 7) / 8],
4950
next_frame_start: 0,
50-
// +1's for input and output; these are used only for nontrivial
51-
read: Vec::with_capacity(prog.frame_count_bound + 1),
52-
write: Vec::with_capacity(prog.frame_count_bound + 1),
51+
// +1 for input and output frame
52+
read: Vec::with_capacity(program.bounds.frame_count + 1),
53+
write: Vec::with_capacity(program.bounds.frame_count + 1),
5354
}
5455
}
55-
*/
5656

5757
/// Push a new frame of given size onto the write frame stack
5858
fn new_frame(&mut self, len: usize) {
@@ -265,152 +265,160 @@ impl BitMachine {
265265
self.move_frame();
266266
}
267267

268-
/*
269-
/// Execute a program in the Bit Machine
270-
pub fn exec<'a, App: Application>(
268+
/// Execute the given program on the Bit Machine, using the given environment.
269+
///
270+
/// Make sure the Bit Machine has enough space by constructing it via [`Self::for_program()`].
271+
pub fn exec<'a, App: Application + std::fmt::Debug>(
271272
&mut self,
272-
program: &'a Program<App>,
273+
program: &'a Node<Value, App>,
273274
env: &App::Environment,
274275
) -> Result<Value, ExecutionError<'a>> {
275-
enum CallStack {
276-
Goto(usize),
276+
// Rust cannot use `App` from parent function
277+
enum CallStack<'a, App: Application> {
278+
Goto(&'a Node<Value, App>),
277279
MoveFrame,
278280
DropFrame,
279281
CopyFwd(usize),
280282
Back(usize),
281283
}
282284

283-
let mut ip = &program.root().typed;
285+
let mut ip = program;
284286
let mut call_stack = vec![];
285-
let mut iters = 0u64;
287+
let mut iterations = 0u64;
286288

287-
let input_width = ip.source_ty.bit_width;
289+
let input_width = ip.ty.source.bit_width;
290+
// TODO: convert into crate::Error
288291
if input_width > 0 && self.read.is_empty() {
289-
panic!(
290-
"Pleas call `Program::input` to add an input value for this program {}",
291-
ip
292-
);
292+
panic!("Program requires a non-empty input to execute");
293293
}
294-
let output_width = ip.target_ty.bit_width;
294+
let output_width = ip.ty.target.bit_width;
295295
if output_width > 0 {
296296
self.new_frame(output_width);
297297
}
298298

299299
'main_loop: loop {
300-
iters += 1;
301-
if iters % 1_000_000_000 == 0 {
302-
println!("({:5} M) exec {}", iters / 1_000_000, ip);
300+
iterations += 1;
301+
if iterations % 1_000_000_000 == 0 {
302+
println!("({:5} M) exec {:?}", iterations / 1_000_000, ip);
303303
}
304304

305-
match ip.term {
306-
Term::Unit => {}
307-
Term::Iden => self.copy(ip.source_ty.bit_width),
308-
Term::InjL(t) => {
309-
self.write_bit(false);
310-
if let TypeInner::Sum(ref a, _) = ip.target_ty.ty {
311-
let aw = a.bit_width;
312-
self.skip(ip.target_ty.bit_width - aw - 1);
313-
call_stack.push(CallStack::Goto(ip.index - t));
305+
match &ip.inner {
306+
NodeInner::Unit => {}
307+
NodeInner::Iden => {
308+
let size_a = ip.ty.source.bit_width;
309+
self.copy(size_a);
310+
}
311+
NodeInner::InjL(left) => {
312+
let padl_b_c = if let TypeInner::Sum(b, _) = &ip.ty.target.ty {
313+
ip.ty.target.bit_width - b.bit_width - 1
314314
} else {
315-
panic!("type error")
316-
}
315+
unreachable!()
316+
};
317+
318+
self.write_bit(false);
319+
self.skip(padl_b_c);
320+
call_stack.push(CallStack::Goto(left));
317321
}
318-
Term::InjR(t) => {
319-
self.write_bit(true);
320-
if let TypeInner::Sum(_, ref b) = ip.target_ty.ty {
321-
let bw = b.bit_width;
322-
self.skip(ip.target_ty.bit_width - bw - 1);
323-
call_stack.push(CallStack::Goto(ip.index - t));
322+
NodeInner::InjR(left) => {
323+
let padr_b_c = if let TypeInner::Sum(_, c) = &ip.ty.target.ty {
324+
ip.ty.target.bit_width - c.bit_width - 1
324325
} else {
325-
panic!("type error")
326-
}
326+
unreachable!()
327+
};
328+
329+
self.write_bit(true);
330+
self.skip(padr_b_c);
331+
call_stack.push(CallStack::Goto(left));
327332
}
328-
Term::Pair(s, t) => {
329-
call_stack.push(CallStack::Goto(ip.index - t));
330-
call_stack.push(CallStack::Goto(ip.index - s));
333+
NodeInner::Pair(left, right) => {
334+
call_stack.push(CallStack::Goto(right));
335+
call_stack.push(CallStack::Goto(left));
331336
}
332-
Term::Comp(s, t) => {
333-
let size = program.nodes[ip.index - s].target_ty().bit_width;
334-
self.new_frame(size);
337+
NodeInner::Comp(left, right) => {
338+
let size_b = left.ty.target.bit_width;
335339

340+
self.new_frame(size_b);
336341
call_stack.push(CallStack::DropFrame);
337-
call_stack.push(CallStack::Goto(ip.index - t));
342+
call_stack.push(CallStack::Goto(right));
338343
call_stack.push(CallStack::MoveFrame);
339-
call_stack.push(CallStack::Goto(ip.index - s));
344+
call_stack.push(CallStack::Goto(left));
340345
}
341-
Term::Disconnect(s, t) => {
342-
// Write `t`'s CMR followed by `s` input to a new read frame
343-
let size = program.nodes[ip.index - s].source_ty().bit_width;
344-
assert!(size >= 256);
345-
self.new_frame(size);
346-
self.write_bytes(program.nodes[ip.index - t].cmr().as_ref());
347-
self.copy(size - 256);
346+
NodeInner::Disconnect(left, right) => {
347+
let size_prod_256_a = left.ty.source.bit_width;
348+
let size_a = size_prod_256_a - 256;
349+
let size_prod_b_c = left.ty.target.bit_width;
350+
let size_b = size_prod_b_c - right.ty.source.bit_width;
351+
352+
self.new_frame(size_prod_256_a);
353+
self.write_bytes(right.cmr.as_ref());
354+
self.copy(size_a);
348355
self.move_frame();
356+
self.new_frame(size_prod_b_c);
349357

350-
let s_target_size = program.nodes[ip.index - s].target_ty().bit_width;
351-
self.new_frame(s_target_size);
352-
// Then recurse. Remembering that call stack pushes are executed
353-
// in reverse order:
354-
355-
// 3. Delete the two frames we created, which have both moved to the read stack
358+
// Remember that call stack pushes are executed in reverse order
356359
call_stack.push(CallStack::DropFrame);
357360
call_stack.push(CallStack::DropFrame);
358-
let b_size = s_target_size - program.nodes[ip.index - t].source_ty().bit_width;
359-
// Back not required since we are dropping the frame anyways
360-
// call_stack.push(CallStack::Back(b_size));
361-
// 2. Copy the first half of `s`s output directly then execute `t` on the second half
362-
call_stack.push(CallStack::Goto(ip.index - t));
363-
call_stack.push(CallStack::CopyFwd(b_size));
364-
// 1. Execute `s` then move the write frame to the read frame for `t`
361+
call_stack.push(CallStack::Goto(right));
362+
call_stack.push(CallStack::CopyFwd(size_b));
365363
call_stack.push(CallStack::MoveFrame);
366-
call_stack.push(CallStack::Goto(ip.index - s));
364+
call_stack.push(CallStack::Goto(left));
367365
}
368-
Term::Take(t) => call_stack.push(CallStack::Goto(ip.index - t)),
369-
Term::Drop(t) => {
370-
if let TypeInner::Product(ref a, _) = ip.source_ty.ty {
371-
let aw = a.bit_width;
372-
self.fwd(aw);
373-
call_stack.push(CallStack::Back(aw));
374-
call_stack.push(CallStack::Goto(ip.index - t));
366+
NodeInner::Take(left) => call_stack.push(CallStack::Goto(left)),
367+
NodeInner::Drop(left) => {
368+
let size_a = if let TypeInner::Product(a, _) = &ip.ty.source.ty {
369+
a.bit_width
375370
} else {
376-
panic!("type error")
377-
}
371+
unreachable!()
372+
};
373+
374+
self.fwd(size_a);
375+
call_stack.push(CallStack::Back(size_a));
376+
call_stack.push(CallStack::Goto(left));
378377
}
379-
Term::Case(s, t) | Term::AssertL(s, t) | Term::AssertR(s, t) => {
380-
let sw = self.read[self.read.len() - 1].peek_bit(&self.data);
381-
let aw;
382-
let bw;
383-
if let TypeInner::Product(ref a, _) = ip.source_ty.ty {
384-
if let TypeInner::Sum(ref a, ref b) = a.ty {
385-
aw = a.bit_width;
386-
bw = b.bit_width;
378+
NodeInner::Case(left, right)
379+
| NodeInner::AssertL(left, right)
380+
| NodeInner::AssertR(left, right) => {
381+
let choice_bit = self.read[self.read.len() - 1].peek_bit(&self.data);
382+
383+
let (size_a, size_b) = if let TypeInner::Product(sum_a_b, _c) = &ip.ty.source.ty
384+
{
385+
if let TypeInner::Sum(a, b) = &sum_a_b.ty {
386+
(a.bit_width, b.bit_width)
387387
} else {
388-
panic!("type error");
388+
unreachable!()
389389
}
390390
} else {
391-
panic!("type error");
392-
}
393-
394-
if sw {
395-
self.fwd(1 + cmp::max(aw, bw) - bw);
396-
call_stack.push(CallStack::Back(1 + cmp::max(aw, bw) - bw));
397-
call_stack.push(CallStack::Goto(ip.index - t));
391+
unreachable!()
392+
};
393+
394+
if choice_bit {
395+
let padr_a_b = cmp::max(size_a, size_b) - size_b;
396+
self.fwd(1 + padr_a_b);
397+
call_stack.push(CallStack::Back(1 + padr_a_b));
398+
call_stack.push(CallStack::Goto(right));
398399
} else {
399-
self.fwd(1 + cmp::max(aw, bw) - aw);
400-
call_stack.push(CallStack::Back(1 + cmp::max(aw, bw) - aw));
401-
call_stack.push(CallStack::Goto(ip.index - s));
400+
let padl_a_b = cmp::max(size_a, size_b) - size_a;
401+
self.fwd(1 + padl_a_b);
402+
call_stack.push(CallStack::Back(1 + padl_a_b));
403+
call_stack.push(CallStack::Goto(left));
402404
}
403405
}
404-
Term::Witness(ref value) => self.write_value(value),
405-
Term::Hidden(ref h) => panic!("Hit hidden node {} at iter {}: {}", ip, iters, h),
406-
Term::Jet(j) => App::exec_jet(j, self, env)
406+
NodeInner::Witness(value) => self.write_value(value),
407+
NodeInner::Hidden(h) => {
408+
// TODO: Convert into crate::Error
409+
panic!(
410+
"Hit hidden node {:?} at iteration {}: {}",
411+
ip, iterations, h
412+
)
413+
}
414+
NodeInner::Jet(j) => App::exec_jet(j, self, env)
407415
.map_err(|x| ExecutionError::AppError(Box::new(x)))?,
408-
Term::Fail(..) => return Err(ExecutionError::ReachedFailNode),
416+
NodeInner::Fail(..) => return Err(ExecutionError::ReachedFailNode),
409417
}
410418

411419
ip = loop {
412420
match call_stack.pop() {
413-
Some(CallStack::Goto(next)) => break &program.nodes[next].typed,
421+
Some(CallStack::Goto(next)) => break next,
414422
Some(CallStack::MoveFrame) => self.move_frame(),
415423
Some(CallStack::DropFrame) => self.drop_frame(),
416424
Some(CallStack::CopyFwd(n)) => {
@@ -423,20 +431,18 @@ impl BitMachine {
423431
};
424432
}
425433

426-
let res = if output_width > 0 {
434+
if output_width > 0 {
427435
let out_frame = self.write.last_mut().unwrap();
428436
out_frame.reset_cursor();
429-
decode::decode_value(
430-
program.root().target_ty(),
431-
&mut out_frame.to_frame_data(&self.data),
432-
)
433-
.expect("decoding output value")
437+
let value =
438+
decode::decode_value(&program.ty.target, &mut out_frame.to_frame_data(&self.data))
439+
.expect("Decode value of output frame");
440+
441+
Ok(value)
434442
} else {
435-
Value::Unit
436-
};
437-
Ok(res)
443+
Ok(Value::Unit)
444+
}
438445
}
439-
*/
440446
}
441447

442448
/// Errors related to simplicity Execution

0 commit comments

Comments
 (0)