Skip to content

Commit

Permalink
Fix record searching algorithm to select largest record first
Browse files Browse the repository at this point in the history
  • Loading branch information
iamalwaysuncomfortable committed Jun 21, 2023
1 parent 4eede44 commit f2e5e0e
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 36 deletions.
66 changes: 35 additions & 31 deletions rust/src/api/blocking.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,19 +166,27 @@ impl<N: Network> AleoAPIClient<N> {
block_heights: Range<u32>,
max_gates: Option<u64>,
specified_amounts: Option<&Vec<u64>>,
) -> Result<Vec<(Field<N>, Record<N, Ciphertext<N>>)>> {
) -> Result<Vec<(Field<N>, Record<N, Plaintext<N>>)>> {
let view_key = ViewKey::try_from(private_key)?;
let address_x_coordinate = view_key.to_address().to_x_coordinate();

let step_size = 49;
let required_amounts = if let Some(amounts) = specified_amounts {
ensure!(!amounts.is_empty(), "If specific amounts are specified, there must be one amount specified");
let mut required_amounts = amounts.clone();
required_amounts.sort_by(|a, b| b.cmp(a));
required_amounts
} else {
vec![]
};

ensure!(
block_heights.start < block_heights.end,
"The start block height must be less than the end block height"
);

// Initialize a vector for the records.
let mut records = Vec::new();
let mut records = vec![];

let mut total_gates = 0u64;
let mut end_height = block_heights.end;
Expand All @@ -196,23 +204,19 @@ impl<N: Network> AleoAPIClient<N> {
if start_height < block_heights.start {
start_height = block_heights.start
};

// Filter the records by the view key.
records.extend(records_iter.filter_map(|(commitment, record)| {
match record.is_owner_with_address_x_coordinate(&view_key, &address_x_coordinate) {
true => {
let sn = Record::<N, Ciphertext<N>>::serial_number(*private_key, commitment).ok()?;
if self.find_transition_id(sn).is_err() {
if max_gates.is_some() {
let _ = record
.decrypt(&view_key)
.map(|record| {
total_gates += record.microcredits().unwrap_or(0);
record
})
.ok();
let record = record.decrypt(&view_key);
if let Ok(record) = record {
total_gates += record.microcredits().unwrap_or(0);
Some((commitment, record))
} else {
None
}
Some((commitment, record))
} else {
None
}
Expand All @@ -222,32 +226,32 @@ impl<N: Network> AleoAPIClient<N> {
}));
// If a maximum number of gates is specified, stop searching when the total gates
// exceeds the specified limit
if max_gates.is_some() && total_gates > max_gates.unwrap() {
if max_gates.is_some() && total_gates >= max_gates.unwrap() {
break;
}
// If a list of specified amounts is specified, stop searching when records matching
// those amounts are found
if let Some(specified_amounts) = specified_amounts {
let found_records = specified_amounts
.iter()
.filter_map(|amount| {
let position = records.iter().position(|(_, record)| {
if let Ok(decrypted_record) = record.decrypt(&view_key) {
decrypted_record.microcredits().unwrap_or(0) > *amount
} else {
false
}
});
position.map(|index| records.remove(index))
})
.collect::<Vec<_>>();
records.extend(found_records);
if records.len() >= specified_amounts.len() {
return Ok(records);
if !required_amounts.is_empty() {
records.sort_by(|a, b| b.1.microcredits().unwrap_or(0).cmp(&a.1.microcredits().unwrap_or(0)));
let mut found_indices = std::collections::HashSet::<usize>::new();
required_amounts.iter().for_each(|amount| {
for (pos, record) in records.iter().enumerate() {
if !found_indices.contains(&pos) && record.1.microcredits().unwrap_or(0) >= *amount {
found_indices.insert(pos);
}
}
});
if found_indices.len() >= required_amounts.len() {
let found_records = records[0..required_amounts.len()].to_vec();
return Ok(found_records);
}
}
}

if !required_amounts.is_empty() {
bail!(
"Could not find enough records with the specified amounts, consider splitting records into smaller amounts"
);
}
Ok(records)
}

Expand Down
3 changes: 1 addition & 2 deletions rust/src/program/helpers/records.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,8 @@ impl<N: Network> RecordFinder<N> {
max_microcredits: Option<u64>,
private_key: &PrivateKey<N>,
) -> Result<Vec<Record<N, Plaintext<N>>>> {
let view_key = ViewKey::try_from(private_key)?;
let latest_height = self.api_client.latest_height()?;
let records = self.api_client.get_unspent_records(private_key, 0..latest_height, max_microcredits, amounts)?;
Ok(records.into_iter().filter_map(|(_, record)| record.decrypt(&view_key).ok()).collect())
Ok(records.into_iter().map(|(_, record)| record).collect())
}
}
3 changes: 1 addition & 2 deletions rust/src/program/transfer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,7 @@ mod tests {
let records = api_client.get_unspent_records(&recipient_private_key, 0..height, None, None).unwrap();
if !records.is_empty() {
let (_, record) = &records[0];
let record_plaintext = record.decrypt(&recipient_view_key).unwrap();
let amount = record_plaintext.microcredits().unwrap();
let amount = record.microcredits().unwrap();
if amount == 100 {
break;
}
Expand Down
2 changes: 1 addition & 1 deletion rust/src/test_utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,5 +201,5 @@ pub fn transfer_to_test_account(
let client = program_manager.api_client()?;
let latest_height = client.latest_height()?;
let records = client.get_unspent_records(&recipient_private_key, 0..latest_height, None, None)?;
Ok(records.iter().map(|(_cm, record)| record.decrypt(&recipient_view_key).unwrap()).collect())
Ok(records.into_iter().map(|(_cm, record)| record).collect())
}

0 comments on commit f2e5e0e

Please sign in to comment.