From e85f56483dcf6814e41958d9f0bf30e99c4bc8ca Mon Sep 17 00:00:00 2001 From: Tibo-lg Date: Wed, 22 Jan 2025 20:51:18 +0900 Subject: [PATCH] Prevent infinite loop on invalid arguments --- dlc-trie/src/digit_decomposition.rs | 12 +++++++++++ dlc-trie/src/multi_oracle_trie.rs | 21 ++++++++++++++++++++ dlc-trie/src/multi_oracle_trie_with_diff.rs | 22 +++++++++++++++++++++ sample/src/hex_utils.rs | 5 +---- 4 files changed, 56 insertions(+), 4 deletions(-) diff --git a/dlc-trie/src/digit_decomposition.rs b/dlc-trie/src/digit_decomposition.rs index 136d0b89..602ded6e 100644 --- a/dlc-trie/src/digit_decomposition.rs +++ b/dlc-trie/src/digit_decomposition.rs @@ -97,12 +97,18 @@ fn remove_tail(v: &mut Vec, to_remove: usize) { } /// Returns the set of decomposed prefixes that cover the range [start, end]. +/// +/// # Panics +/// +/// Panics if `start` is greater than `end`. pub fn group_by_ignoring_digits( start: usize, end: usize, base: usize, nb_digits: usize, ) -> Vec> { + assert!(start <= end); + let mut ds = decompose_value(start, base, nb_digits); let mut de = decompose_value(end, base, nb_digits); @@ -629,6 +635,12 @@ mod tests { } } + #[test] + #[should_panic] + fn group_by_ignoring_digits_start_greater_than_end_panics() { + super::group_by_ignoring_digits(11, 10, 2, 4); + } + #[test] fn get_max_ranges_test() { for test_case in get_max_range_test_cases() { diff --git a/dlc-trie/src/multi_oracle_trie.rs b/dlc-trie/src/multi_oracle_trie.rs index c66e8513..f632c569 100644 --- a/dlc-trie/src/multi_oracle_trie.rs +++ b/dlc-trie/src/multi_oracle_trie.rs @@ -200,6 +200,9 @@ impl<'a> DlcTrie<'a, MultiOracleTrieIter<'a>> for MultiOracleTrie { let mut trie_infos = Vec::new(); let oracle_numeric_infos = &self.oracle_numeric_infos; for (cet_index, outcome) in outcomes.iter().enumerate() { + if outcome.count == 0 { + return Err(Error::InvalidArgument); + } let groups = group_by_ignoring_digits( outcome.start, outcome.start + outcome.count - 1, @@ -405,4 +408,22 @@ mod tests { ]) .expect("Could not retrieve path with extra len."); } + + #[test] + fn test_invalid_range_payout() { + let range_payouts = vec![RangePayout { + start: 0, + count: 0, + payout: Payout { + offer: Amount::ZERO, + accept: Amount::from_sat(200000000), + }, + }]; + + let oracle_numeric_infos = get_variable_oracle_numeric_infos(&[13, 12], 2); + let mut multi_oracle_trie = MultiOracleTrie::new(&oracle_numeric_infos, 2).unwrap(); + multi_oracle_trie + .generate(0, &range_payouts) + .expect_err("Should fail when given a range payout with a count of 0"); + } } diff --git a/dlc-trie/src/multi_oracle_trie_with_diff.rs b/dlc-trie/src/multi_oracle_trie_with_diff.rs index da12be13..16999840 100644 --- a/dlc-trie/src/multi_oracle_trie_with_diff.rs +++ b/dlc-trie/src/multi_oracle_trie_with_diff.rs @@ -59,6 +59,9 @@ impl<'a> DlcTrie<'a, MultiOracleTrieWithDiffIter<'a>> for MultiOracleTrieWithDif let mut trie_infos = Vec::new(); for (cet_index, outcome) in outcomes.iter().enumerate() { + if outcome.count == 0 { + return Err(Error::InvalidArgument); + } let groups = group_by_ignoring_digits( outcome.start, outcome.start + outcome.count - 1, @@ -273,4 +276,23 @@ mod tests { &iter_res.paths ); } + + #[test] + fn test_invalid_range_payout() { + let range_payouts = vec![RangePayout { + start: 0, + count: 0, + payout: Payout { + offer: Amount::ZERO, + accept: Amount::from_sat(200000000), + }, + }]; + + let oracle_numeric_infos = get_variable_oracle_numeric_infos(&[13, 12], 2); + let mut multi_oracle_trie = + MultiOracleTrieWithDiff::new(&oracle_numeric_infos, 2, 1, 2).unwrap(); + multi_oracle_trie + .generate(0, &range_payouts) + .expect_err("Should fail when given a range payout with a count of 0"); + } } diff --git a/sample/src/hex_utils.rs b/sample/src/hex_utils.rs index 023555a4..e21ba794 100644 --- a/sample/src/hex_utils.rs +++ b/sample/src/hex_utils.rs @@ -42,10 +42,7 @@ pub fn hex_str(value: &[u8]) -> String { } pub fn to_compressed_pubkey(hex: &str) -> Option { - let data = match to_vec(&hex[0..33 * 2]) { - Some(bytes) => bytes, - None => return None, - }; + let data = to_vec(&hex[0..33 * 2])?; match PublicKey::from_slice(&data) { Ok(pk) => Some(pk), Err(_) => None,