Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Eliminate more recursion, and Liftable impl for Terminal #725

Merged
merged 7 commits into from
Aug 26, 2024
55 changes: 55 additions & 0 deletions src/iter/tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,14 @@ pub trait TreeLike: Clone + Sized {
fn post_order_iter(self) -> PostOrderIter<Self> {
PostOrderIter { index: 0, stack: vec![IterStackItem::unprocessed(self, None)] }
}

/// Obtains an iterator of all the nodes rooted at the DAG, in right-to-left post order.
///
/// This ordering is useful for "translation" algorithms which iterate over a
/// structure, pushing translated nodes and popping children.
fn rtl_post_order_iter(self) -> RtlPostOrderIter<Self> {
RtlPostOrderIter { inner: Rtl(self).post_order_iter() }
}
}

/// Element stored internally on the stack of a [`PostOrderIter`].
Expand Down Expand Up @@ -202,6 +210,53 @@ impl<T: TreeLike> Iterator for PostOrderIter<T> {
}
}

/// Adaptor structure to allow iterating in right-to-left order.
#[derive(Clone, Debug)]
struct Rtl<T>(pub T);

impl<T: TreeLike> TreeLike for Rtl<T> {
type NaryChildren = T::NaryChildren;

fn nary_len(tc: &Self::NaryChildren) -> usize { T::nary_len(tc) }
fn nary_index(tc: Self::NaryChildren, idx: usize) -> Self {
let rtl_idx = T::nary_len(&tc) - idx - 1;
Rtl(T::nary_index(tc, rtl_idx))
}

fn as_node(&self) -> Tree<Self, Self::NaryChildren> {
match self.0.as_node() {
Tree::Nullary => Tree::Nullary,
Tree::Unary(a) => Tree::Unary(Rtl(a)),
Tree::Binary(a, b) => Tree::Binary(Rtl(b), Rtl(a)),
Tree::Ternary(a, b, c) => Tree::Ternary(Rtl(c), Rtl(b), Rtl(a)),
Tree::Nary(data) => Tree::Nary(data),
}
}
}

/// Iterates over a DAG in _right-to-left post order_.
///
/// That means nodes are yielded in the order (right child, left child, parent).
#[derive(Clone, Debug)]
pub struct RtlPostOrderIter<T> {
inner: PostOrderIter<Rtl<T>>,
}

impl<T: TreeLike> Iterator for RtlPostOrderIter<T> {
type Item = PostOrderIterItem<T>;

fn next(&mut self) -> Option<Self::Item> {
self.inner.next().map(|mut item| {
item.child_indices.reverse();
PostOrderIterItem {
child_indices: item.child_indices,
index: item.index,
node: item.node.0,
}
})
}
}

/// Iterates over a [`TreeLike`] in _pre order_.
///
/// Unlike the post-order iterator, this one does not keep track of indices
Expand Down
220 changes: 169 additions & 51 deletions src/miniscript/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,7 @@ mod private {
/// and they can call `Miniscript::clone`.
fn clone(&self) -> Self {
let mut stack = vec![];
for item in self.post_order_iter() {
let child_n = |n| Arc::clone(&stack[item.child_indices[n]]);

for item in self.rtl_post_order_iter() {
let new_term = match item.node.node {
Terminal::PkK(ref p) => Terminal::PkK(p.clone()),
Terminal::PkH(ref p) => Terminal::PkH(p.clone()),
Expand All @@ -101,23 +99,31 @@ mod private {
Terminal::Hash160(ref x) => Terminal::Hash160(x.clone()),
Terminal::True => Terminal::True,
Terminal::False => Terminal::False,
Terminal::Alt(..) => Terminal::Alt(child_n(0)),
Terminal::Swap(..) => Terminal::Swap(child_n(0)),
Terminal::Check(..) => Terminal::Check(child_n(0)),
Terminal::DupIf(..) => Terminal::DupIf(child_n(0)),
Terminal::Verify(..) => Terminal::Verify(child_n(0)),
Terminal::NonZero(..) => Terminal::NonZero(child_n(0)),
Terminal::ZeroNotEqual(..) => Terminal::ZeroNotEqual(child_n(0)),
Terminal::AndV(..) => Terminal::AndV(child_n(0), child_n(1)),
Terminal::AndB(..) => Terminal::AndB(child_n(0), child_n(1)),
Terminal::AndOr(..) => Terminal::AndOr(child_n(0), child_n(1), child_n(2)),
Terminal::OrB(..) => Terminal::OrB(child_n(0), child_n(1)),
Terminal::OrD(..) => Terminal::OrD(child_n(0), child_n(1)),
Terminal::OrC(..) => Terminal::OrC(child_n(0), child_n(1)),
Terminal::OrI(..) => Terminal::OrI(child_n(0), child_n(1)),
Terminal::Thresh(ref thresh) => Terminal::Thresh(
thresh.map_from_post_order_iter(&item.child_indices, &stack),
Terminal::Alt(..) => Terminal::Alt(stack.pop().unwrap()),
Terminal::Swap(..) => Terminal::Swap(stack.pop().unwrap()),
Terminal::Check(..) => Terminal::Check(stack.pop().unwrap()),
Terminal::DupIf(..) => Terminal::DupIf(stack.pop().unwrap()),
Terminal::Verify(..) => Terminal::Verify(stack.pop().unwrap()),
Terminal::NonZero(..) => Terminal::NonZero(stack.pop().unwrap()),
Terminal::ZeroNotEqual(..) => Terminal::ZeroNotEqual(stack.pop().unwrap()),
Terminal::AndV(..) => {
Terminal::AndV(stack.pop().unwrap(), stack.pop().unwrap())
}
Terminal::AndB(..) => {
Terminal::AndB(stack.pop().unwrap(), stack.pop().unwrap())
}
Terminal::AndOr(..) => Terminal::AndOr(
stack.pop().unwrap(),
stack.pop().unwrap(),
stack.pop().unwrap(),
),
Terminal::OrB(..) => Terminal::OrB(stack.pop().unwrap(), stack.pop().unwrap()),
Terminal::OrD(..) => Terminal::OrD(stack.pop().unwrap(), stack.pop().unwrap()),
Terminal::OrC(..) => Terminal::OrC(stack.pop().unwrap(), stack.pop().unwrap()),
Terminal::OrI(..) => Terminal::OrI(stack.pop().unwrap(), stack.pop().unwrap()),
Terminal::Thresh(ref thresh) => {
Terminal::Thresh(thresh.map_ref(|_| stack.pop().unwrap()))
}
Terminal::Multi(ref thresh) => Terminal::Multi(thresh.clone()),
Terminal::MultiA(ref thresh) => Terminal::MultiA(thresh.clone()),
};
Expand All @@ -130,6 +136,7 @@ mod private {
}));
}

assert_eq!(stack.len(), 1);
Arc::try_unwrap(stack.pop().unwrap()).unwrap()
}
}
Expand Down Expand Up @@ -536,9 +543,7 @@ impl<Pk: MiniscriptKey, Ctx: ScriptContext> Miniscript<Pk, Ctx> {
T: Translator<Pk, Q, FuncError>,
{
let mut translated = vec![];
for data in Arc::new(self.clone()).post_order_iter() {
let child_n = |n| Arc::clone(&translated[data.child_indices[n]]);

for data in self.rtl_post_order_iter() {
let new_term = match data.node.node {
Terminal::PkK(ref p) => Terminal::PkK(t.pk(p)?),
Terminal::PkH(ref p) => Terminal::PkH(t.pk(p)?),
Expand All @@ -551,23 +556,39 @@ impl<Pk: MiniscriptKey, Ctx: ScriptContext> Miniscript<Pk, Ctx> {
Terminal::Hash160(ref x) => Terminal::Hash160(t.hash160(x)?),
Terminal::True => Terminal::True,
Terminal::False => Terminal::False,
Terminal::Alt(..) => Terminal::Alt(child_n(0)),
Terminal::Swap(..) => Terminal::Swap(child_n(0)),
Terminal::Check(..) => Terminal::Check(child_n(0)),
Terminal::DupIf(..) => Terminal::DupIf(child_n(0)),
Terminal::Verify(..) => Terminal::Verify(child_n(0)),
Terminal::NonZero(..) => Terminal::NonZero(child_n(0)),
Terminal::ZeroNotEqual(..) => Terminal::ZeroNotEqual(child_n(0)),
Terminal::AndV(..) => Terminal::AndV(child_n(0), child_n(1)),
Terminal::AndB(..) => Terminal::AndB(child_n(0), child_n(1)),
Terminal::AndOr(..) => Terminal::AndOr(child_n(0), child_n(1), child_n(2)),
Terminal::OrB(..) => Terminal::OrB(child_n(0), child_n(1)),
Terminal::OrD(..) => Terminal::OrD(child_n(0), child_n(1)),
Terminal::OrC(..) => Terminal::OrC(child_n(0), child_n(1)),
Terminal::OrI(..) => Terminal::OrI(child_n(0), child_n(1)),
Terminal::Thresh(ref thresh) => Terminal::Thresh(
thresh.map_from_post_order_iter(&data.child_indices, &translated),
Terminal::Alt(..) => Terminal::Alt(translated.pop().unwrap()),
Terminal::Swap(..) => Terminal::Swap(translated.pop().unwrap()),
Terminal::Check(..) => Terminal::Check(translated.pop().unwrap()),
Terminal::DupIf(..) => Terminal::DupIf(translated.pop().unwrap()),
Terminal::Verify(..) => Terminal::Verify(translated.pop().unwrap()),
Terminal::NonZero(..) => Terminal::NonZero(translated.pop().unwrap()),
Terminal::ZeroNotEqual(..) => Terminal::ZeroNotEqual(translated.pop().unwrap()),
Terminal::AndV(..) => {
Terminal::AndV(translated.pop().unwrap(), translated.pop().unwrap())
}
Terminal::AndB(..) => {
Terminal::AndB(translated.pop().unwrap(), translated.pop().unwrap())
}
Terminal::AndOr(..) => Terminal::AndOr(
translated.pop().unwrap(),
translated.pop().unwrap(),
translated.pop().unwrap(),
),
Terminal::OrB(..) => {
Terminal::OrB(translated.pop().unwrap(), translated.pop().unwrap())
}
Terminal::OrD(..) => {
Terminal::OrD(translated.pop().unwrap(), translated.pop().unwrap())
}
Terminal::OrC(..) => {
Terminal::OrC(translated.pop().unwrap(), translated.pop().unwrap())
}
Terminal::OrI(..) => {
Terminal::OrI(translated.pop().unwrap(), translated.pop().unwrap())
}
Terminal::Thresh(ref thresh) => {
Terminal::Thresh(thresh.map_ref(|_| translated.pop().unwrap()))
}
Terminal::Multi(ref thresh) => Terminal::Multi(thresh.translate_ref(|k| t.pk(k))?),
Terminal::MultiA(ref thresh) => {
Terminal::MultiA(thresh.translate_ref(|k| t.pk(k))?)
Expand All @@ -582,22 +603,58 @@ impl<Pk: MiniscriptKey, Ctx: ScriptContext> Miniscript<Pk, Ctx> {

/// Substitutes raw public keys hashes with the public keys as provided by map.
pub fn substitute_raw_pkh(&self, pk_map: &BTreeMap<hash160::Hash, Pk>) -> Miniscript<Pk, Ctx> {
let mut translated = vec![];
for data in Arc::new(self.clone()).post_order_iter() {
let new_term = if let Terminal::RawPkH(ref p) = data.node.node {
match pk_map.get(p) {
Some(pk) => Terminal::PkH(pk.clone()),
None => Terminal::RawPkH(*p),
let mut stack = vec![];
for item in self.rtl_post_order_iter() {
let new_term = match item.node.node {
Terminal::PkK(ref p) => Terminal::PkK(p.clone()),
Terminal::PkH(ref p) => Terminal::PkH(p.clone()),
// This algorithm is identical to Clone::clone except for this line.
Terminal::RawPkH(ref hash) => match pk_map.get(hash) {
Some(p) => Terminal::PkH(p.clone()),
None => Terminal::RawPkH(*hash),
},
Terminal::After(ref n) => Terminal::After(*n),
Terminal::Older(ref n) => Terminal::Older(*n),
Terminal::Sha256(ref x) => Terminal::Sha256(x.clone()),
Terminal::Hash256(ref x) => Terminal::Hash256(x.clone()),
Terminal::Ripemd160(ref x) => Terminal::Ripemd160(x.clone()),
Terminal::Hash160(ref x) => Terminal::Hash160(x.clone()),
Terminal::True => Terminal::True,
Terminal::False => Terminal::False,
Terminal::Alt(..) => Terminal::Alt(stack.pop().unwrap()),
Terminal::Swap(..) => Terminal::Swap(stack.pop().unwrap()),
Terminal::Check(..) => Terminal::Check(stack.pop().unwrap()),
Terminal::DupIf(..) => Terminal::DupIf(stack.pop().unwrap()),
Terminal::Verify(..) => Terminal::Verify(stack.pop().unwrap()),
Terminal::NonZero(..) => Terminal::NonZero(stack.pop().unwrap()),
Terminal::ZeroNotEqual(..) => Terminal::ZeroNotEqual(stack.pop().unwrap()),
Terminal::AndV(..) => Terminal::AndV(stack.pop().unwrap(), stack.pop().unwrap()),
Terminal::AndB(..) => Terminal::AndB(stack.pop().unwrap(), stack.pop().unwrap()),
Terminal::AndOr(..) => Terminal::AndOr(
stack.pop().unwrap(),
stack.pop().unwrap(),
stack.pop().unwrap(),
),
Terminal::OrB(..) => Terminal::OrB(stack.pop().unwrap(), stack.pop().unwrap()),
Terminal::OrD(..) => Terminal::OrD(stack.pop().unwrap(), stack.pop().unwrap()),
Terminal::OrC(..) => Terminal::OrC(stack.pop().unwrap(), stack.pop().unwrap()),
Terminal::OrI(..) => Terminal::OrI(stack.pop().unwrap(), stack.pop().unwrap()),
Terminal::Thresh(ref thresh) => {
Terminal::Thresh(thresh.map_ref(|_| stack.pop().unwrap()))
}
} else {
data.node.node.clone()
Terminal::Multi(ref thresh) => Terminal::Multi(thresh.clone()),
Terminal::MultiA(ref thresh) => Terminal::MultiA(thresh.clone()),
};

let new_ms = Miniscript::from_ast(new_term).expect("typeck");
translated.push(Arc::new(new_ms));
stack.push(Arc::new(Miniscript::from_components_unchecked(
new_term,
item.node.ty,
item.node.ext,
)));
}

Arc::try_unwrap(translated.pop().unwrap()).unwrap()
assert_eq!(stack.len(), 1);
Arc::try_unwrap(stack.pop().unwrap()).unwrap()
}
}

Expand Down Expand Up @@ -822,6 +879,7 @@ mod tests {
}
let roundtrip = Miniscript::from_str(&display).expect("parse string serialization");
assert_eq!(roundtrip, script);
assert_eq!(roundtrip.clone(), script);
}

fn string_display_debug_test<Ctx: ScriptContext>(
Expand Down Expand Up @@ -1373,8 +1431,12 @@ mod tests {
#[test]
fn expr_features() {
// test that parsing raw hash160 does not work with
let hash160_str = "e9f171df53e04b270fa6271b42f66b0f4a99c5a2";
let ms_str = &format!("c:expr_raw_pkh({})", hash160_str);
let pk = bitcoin::PublicKey::from_str(
"02c2fd50ceae468857bb7eb32ae9cd4083e6c7e42fbbec179d81134b3e3830586c",
)
.unwrap();
let hash160 = pk.pubkey_hash().to_raw_hash();
let ms_str = &format!("c:expr_raw_pkh({})", hash160);
type SegwitMs = Miniscript<bitcoin::PublicKey, Segwitv0>;

// Test that parsing raw hash160 from string does not work without extra features
Expand All @@ -1387,6 +1449,12 @@ mod tests {
SegwitMs::parse(&script).unwrap_err();
SegwitMs::parse_insane(&script).unwrap_err();
SegwitMs::parse_with_ext(&script, &ExtParams::allow_all()).unwrap();

// Try replacing the raw_pkh with a pkh
let mut map = BTreeMap::new();
map.insert(hash160, pk);
let ms_no_raw = ms.substitute_raw_pkh(&map);
assert_eq!(ms_no_raw.to_string(), format!("pkh({})", pk),);
}

#[test]
Expand All @@ -1408,6 +1476,56 @@ mod tests {
ms.translate_pk(&mut t).unwrap_err();
}

#[test]
fn duplicate_keys() {
// You cannot parse a Miniscript that has duplicate keys
let err = Miniscript::<String, Segwitv0>::from_str("and_v(v:pk(A),pk(A))").unwrap_err();
assert_eq!(err, Error::AnalysisError(crate::AnalysisError::RepeatedPubkeys));

// ...though you can parse one with from_str_insane
let ok_insane =
Miniscript::<String, Segwitv0>::from_str_insane("and_v(v:pk(A),pk(A))").unwrap();
// ...but this cannot be sanity checked.
assert_eq!(ok_insane.sanity_check().unwrap_err(), crate::AnalysisError::RepeatedPubkeys);
// ...it can be lifted, though it's unclear whether this is a deliberate
// choice or just an accident. It seems weird given that duplicate public
// keys are forbidden in several other places.
ok_insane.lift().unwrap();
}

#[test]
fn mixed_timelocks() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am surprised that this did not exist before. Anyways more tests are always welcome.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possibly one existed somewhere. I grepped for the timelock-specific errors and found nothing. But there are a ton of tests that just assert .unwrap_err() and don't check more specifically.

And I'm pretty confident that there is no test that covers all three separate checks (in Miniscript, Semantic and Concrete).

// You cannot parse a Miniscript that mixes timelocks.
let err = Miniscript::<String, Segwitv0>::from_str(
"and_v(v:and_v(v:older(4194304),pk(A)),and_v(v:older(1),pk(B)))",
)
.unwrap_err();
assert_eq!(err, Error::AnalysisError(crate::AnalysisError::HeightTimelockCombination));

// Though you can in an or() rather than and()
let ok_or = Miniscript::<String, Segwitv0>::from_str(
"or_i(and_v(v:older(4194304),pk(A)),and_v(v:older(1),pk(B)))",
)
.unwrap();
ok_or.sanity_check().unwrap();
ok_or.lift().unwrap();

// ...and you can parse one with from_str_insane
let ok_insane = Miniscript::<String, Segwitv0>::from_str_insane(
"and_v(v:and_v(v:older(4194304),pk(A)),and_v(v:older(1),pk(B)))",
)
.unwrap();
// ...but this cannot be sanity checked or lifted
assert_eq!(
ok_insane.sanity_check().unwrap_err(),
crate::AnalysisError::HeightTimelockCombination
);
assert_eq!(
ok_insane.lift().unwrap_err(),
Error::LiftError(crate::policy::LiftError::HeightTimelockCombination)
);
}

#[test]
fn template_timelocks() {
use crate::{AbsLockTime, RelLockTime};
Expand Down
6 changes: 3 additions & 3 deletions src/miniscript/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,19 +87,19 @@ impl fmt::Display for Error {
),
ErrorKind::SwapNonOne => write!(
f,
"fragment «{}» attempts to use `SWAP` to prefix something
"fragment «{}» attempts to use `SWAP` to prefix something \
which does not take exactly one input",
self.fragment_string,
),
ErrorKind::NonZeroZero => write!(
f,
"fragment «{}» attempts to use use the `j:` wrapper around a
"fragment «{}» attempts to use use the `j:` wrapper around a \
fragment which might be satisfied by an input of size zero",
self.fragment_string,
),
ErrorKind::LeftNotUnit => write!(
f,
"fragment «{}» requires its left child be a unit (outputs
"fragment «{}» requires its left child be a unit (outputs \
exactly 1 given a satisfying input)",
self.fragment_string,
),
Expand Down
Loading
Loading