Skip to content

Commit 02720be

Browse files
committed
perf(MerkleTre): Parallelize over subtrees
Previously, building a Merkle tree in parallel meant that each layer was built in parallel, with all worker threads synchronizing after each layer. This is nice to implement but, as it turns out, has more overhead than expected. A more performant strategy is to take a number of subtrees of the big Merkle tree and compute all subtrees in parallel. f
1 parent e9fae6a commit 02720be

1 file changed

Lines changed: 186 additions & 29 deletions

File tree

twenty-first/src/util_types/merkle_tree.rs

Lines changed: 186 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use itertools::Itertools;
99
use rayon::prelude::*;
1010
use thiserror::Error;
1111

12+
use crate::error::U32_TO_USIZE_ERR;
1213
use crate::prelude::*;
1314

1415
/// Indexes internal nodes of a [`MerkleTree`].
@@ -136,13 +137,9 @@ impl MerkleTree {
136137
/// - If the number of leafs is zero.
137138
/// - If the number of leafs is not a power of two.
138139
pub fn sequential_new(leafs: &[Digest]) -> Result<Self> {
139-
let mut nodes = Self::initialize_merkle_tree_nodes(leafs)?;
140-
141-
for i in (ROOT_INDEX..leafs.len()).rev() {
142-
nodes[i] = Tip5::hash_pair(nodes[i * 2], nodes[i * 2 + 1]);
143-
}
144-
145-
Ok(MerkleTree { nodes })
140+
let nodes = Self::initialize_merkle_tree_nodes(leafs)?;
141+
let num_remaining_nodes = leafs.len();
142+
Self::sequentially_fill_tree(nodes, num_remaining_nodes)
146143
}
147144

148145
/// Build a MerkleTree with the given leafs.
@@ -159,27 +156,114 @@ impl MerkleTree {
159156
let mut nodes = Self::initialize_merkle_tree_nodes(leafs)?;
160157

161158
// parallel
162-
let mut num_nodes_on_this_level = leafs.len();
163-
while num_nodes_on_this_level >= crate::config::merkle_tree_parallelization_cutoff() {
164-
num_nodes_on_this_level /= 2;
165-
let node_indices_on_this_level = num_nodes_on_this_level..2 * num_nodes_on_this_level;
166-
let nodes_on_this_level = node_indices_on_this_level
167-
.clone()
168-
.into_par_iter()
169-
.map(|i| Tip5::hash_pair(nodes[i * 2], nodes[i * 2 + 1]))
170-
.collect::<Vec<_>>();
171-
nodes[node_indices_on_this_level].copy_from_slice(&nodes_on_this_level);
159+
let mut num_remaining_nodes = leafs.len();
160+
let mut num_threads = Self::num_threads();
161+
while num_remaining_nodes >= crate::config::merkle_tree_parallelization_cutoff() {
162+
// If the number of threads is so large that the chunk size is 1
163+
// (or even 0), each individual thread performs no work anymore.
164+
// In such a case, the most reasonable course of action is to reduce
165+
// the effective number of worker threads, which increases the chunk
166+
// size.
167+
//
168+
// Since parallelization_cutoff >= 2, it follows that
169+
// num_nodes_missing / 2 >= 1. Hence, the loop terminates at latest
170+
// once num_threads equals 1.
171+
while num_threads > num_remaining_nodes / 2 {
172+
num_threads /= 2;
173+
}
174+
175+
// re-slice to only include
176+
// 1. the nodes that need to be computed and
177+
// 2. exactly those nodes required to compute them.
178+
let nodes = &mut nodes[..2 * num_remaining_nodes];
179+
let subtrees = Self::subtrees_mut(nodes, num_threads);
180+
subtrees.into_par_iter().for_each(|mut tree_layers| {
181+
debug_assert!(tree_layers.len() > 1, "internal error: infinite iteration");
182+
let mut previous_layer = tree_layers.pop().unwrap();
183+
for next_layer in tree_layers.into_iter().rev() {
184+
for (node, (&left, &right)) in
185+
next_layer.iter_mut().zip(previous_layer.iter().tuples())
186+
{
187+
*node = Tip5::hash_pair(left, right);
188+
}
189+
previous_layer = next_layer;
190+
}
191+
});
192+
193+
// Update the number of remaining nodes by subtracting the number of
194+
// freshly computed nodes. Equivalently, record that the tree has
195+
// grown by subtree_height many layers.
196+
let current_tree_height = num_remaining_nodes.ilog2();
197+
let subtree_height = current_tree_height - num_threads.ilog2();
198+
num_remaining_nodes >>= subtree_height;
172199
}
173200

174-
// sequential
175-
let num_remaining_nodes = num_nodes_on_this_level;
201+
Self::sequentially_fill_tree(nodes, num_remaining_nodes)
202+
}
203+
204+
/// Internal helper function to de-duplicate code between
205+
/// [`Self::sequential_new`] and [`Self::par_new`].
206+
fn sequentially_fill_tree(mut nodes: Vec<Digest>, num_remaining_nodes: usize) -> Result<Self> {
176207
for i in (ROOT_INDEX..num_remaining_nodes).rev() {
177208
nodes[i] = Tip5::hash_pair(nodes[i * 2], nodes[i * 2 + 1]);
178209
}
179210

180211
Ok(MerkleTree { nodes })
181212
}
182213

214+
/// Divides the given, contiguous slice into `num_trees` subtrees.
215+
///
216+
/// The passed-in slice must represent a complete Merkle tree. Each subtree
217+
/// is returned as a number of mutable slices, where each such slice
218+
/// represents one layer in the subtree. The first slice contains only one
219+
/// element, the subtree's root, the next slice contains two elements, the
220+
/// root's children, and so on.
221+
///
222+
/// In general, the top-most nodes of the complete tree will not be covered
223+
/// by the returned subtrees. For example, when requesting 2 subtrees, the
224+
/// complete tree's root will not be covered; when requesting 4 subtrees,
225+
/// the root and its direct children will not be covered; and so on.
226+
///
227+
/// The number of subtrees must be a power of two, and must not exceed the
228+
/// number of leafs in the complete Merkle tree represented by the given
229+
/// slice.
230+
///
231+
/// Because this is an internal helper function (and only for that reason),
232+
/// it's the caller's responsibility to ensure that the arguments are
233+
/// integral. To recap:
234+
/// - the number of `nodes` must be a power of 2
235+
/// - the number of `num_trees` must be a power of 2
236+
/// - the number of `nodes` must be at least `2 * num_trees`
237+
fn subtrees_mut<T>(nodes: &mut [T], num_trees: usize) -> Vec<Vec<&mut [T]>> {
238+
let num_leafs = nodes.len() / 2;
239+
let total_tree_height = num_leafs.ilog2();
240+
let sub_tree_height =
241+
usize::try_from(total_tree_height - num_trees.ilog2()).expect(U32_TO_USIZE_ERR);
242+
243+
// a tree's “height” is the number of layers excluding the root,
244+
// but we want to include the root
245+
let num_layers = sub_tree_height + 1;
246+
let mut subtrees = (0..num_trees)
247+
.map(|_| Vec::with_capacity(num_layers))
248+
.collect_vec();
249+
250+
// the number of nodes to skip includes the dummy node at index 0
251+
let nodes_to_skip = num_trees;
252+
let (_, mut nodes) = nodes.split_at_mut(nodes_to_skip);
253+
254+
for layer_idx in 0..num_layers {
255+
let nodes_at_this_layer = 1 << layer_idx;
256+
for tree in &mut subtrees {
257+
let (layer, rest) = nodes.split_at_mut(nodes_at_this_layer);
258+
tree.push(layer);
259+
nodes = rest;
260+
}
261+
}
262+
debug_assert!(nodes.is_empty());
263+
264+
subtrees
265+
}
266+
183267
/// Compute the Merkle root from the given leafs without recording any
184268
/// internal nodes.
185269
///
@@ -240,17 +324,8 @@ impl MerkleTree {
240324
return Err(MerkleTreeError::IncorrectNumberOfLeafs);
241325
}
242326

243-
// To guarantee that all chunks correspond to trees of the same height,
244-
// the number of threads must divide the number of leafs cleanly.
245-
let num_threads = rayon::current_num_threads();
246-
let num_threads = if num_threads.is_power_of_two() {
247-
num_threads
248-
} else {
249-
num_threads.next_power_of_two() / 2 // previous power of 2
250-
};
251-
let mut num_threads = num_threads.max(1); // avoid division by 0
252-
253327
// parallel
328+
let mut num_threads = Self::num_threads();
254329
let mut leafs = Cow::Borrowed(leafs);
255330
while leafs.len() >= crate::config::merkle_tree_parallelization_cutoff() {
256331
// If the number of threads is so large that the chunk size is 1
@@ -278,6 +353,30 @@ impl MerkleTree {
278353
Self::sequential_frugal_root(&leafs)
279354
}
280355

356+
/// Internal helper function to determine the number of threads to use for
357+
/// parallel Merkle tree construction.
358+
///
359+
/// Can be used to figure out the number of chunks to split the work into,
360+
/// but take care that each chunk contains at least 2 nodes, else no
361+
/// meaningful work will be done and your iteration might run forever.
362+
///
363+
/// Guaranteed to be a power of two.
364+
///
365+
/// Respects the [`RAYON_NUM_THREADS`][rayon] environment variable, if set.
366+
fn num_threads() -> usize {
367+
// To guarantee that all chunks correspond to trees of the same height,
368+
// the number of threads must divide the number of leafs cleanly.
369+
let num_threads = rayon::current_num_threads();
370+
let num_threads = if num_threads.is_power_of_two() {
371+
num_threads
372+
} else {
373+
num_threads.next_power_of_two() / 2 // previous power of 2
374+
};
375+
376+
// avoid division by 0
377+
num_threads.max(1)
378+
}
379+
281380
/// Helps to kick off Merkle tree construction. Sets up the Merkle tree's
282381
/// internal nodes if (and only if) it is possible to construct a Merkle
283382
/// tree with the given leafs.
@@ -1360,4 +1459,62 @@ pub mod merkle_tree_test {
13601459

13611460
assert_eq!(expected_paths, auth_paths);
13621461
}
1462+
1463+
#[test]
1464+
fn merkle_subtrees_are_sliced_correctly() {
1465+
const TREE_HEIGHT: usize = 5;
1466+
const NUM_LEAFS: u32 = 1 << TREE_HEIGHT;
1467+
const NUM_NODES: u32 = 2 * NUM_LEAFS;
1468+
debug_assert_eq!(64, NUM_NODES);
1469+
1470+
let mut nodes = (0..NUM_NODES).collect_vec();
1471+
1472+
let all_nodes = MerkleTree::subtrees_mut(&mut nodes, 1);
1473+
assert_eq!(1, all_nodes.len());
1474+
let subtree = &all_nodes[0];
1475+
assert_eq!([1], subtree[0]);
1476+
assert_eq!([2, 3], subtree[1]);
1477+
assert_eq!((4..8).collect_vec().as_slice(), subtree[2]);
1478+
assert_eq!((8..16).collect_vec().as_slice(), subtree[3]);
1479+
assert_eq!((16..32).collect_vec().as_slice(), subtree[4]);
1480+
assert_eq!((32..64).collect_vec().as_slice(), subtree[5]);
1481+
1482+
let two_trees = MerkleTree::subtrees_mut(&mut nodes, 2);
1483+
assert_eq!(2, two_trees.len());
1484+
let left_tree = &two_trees[0];
1485+
assert_eq!([2], left_tree[0]);
1486+
assert_eq!([4, 5], left_tree[1]);
1487+
assert_eq!((8..12).collect_vec().as_slice(), left_tree[2]);
1488+
assert_eq!((16..24).collect_vec().as_slice(), left_tree[3]);
1489+
assert_eq!((32..48).collect_vec().as_slice(), left_tree[4]);
1490+
let right_tree = &two_trees[1];
1491+
assert_eq!([3], right_tree[0]);
1492+
assert_eq!([6, 7], right_tree[1]);
1493+
assert_eq!((12..16).collect_vec().as_slice(), right_tree[2]);
1494+
assert_eq!((24..32).collect_vec().as_slice(), right_tree[3]);
1495+
assert_eq!((48..64).collect_vec().as_slice(), right_tree[4]);
1496+
1497+
let four_trees = MerkleTree::subtrees_mut(&mut nodes, 4);
1498+
assert_eq!(4, four_trees.len());
1499+
let left_left_tree = &four_trees[0];
1500+
assert_eq!([4], left_left_tree[0]);
1501+
assert_eq!([8, 9], left_left_tree[1]);
1502+
assert_eq!((16..20).collect_vec().as_slice(), left_left_tree[2]);
1503+
assert_eq!((32..40).collect_vec().as_slice(), left_left_tree[3]);
1504+
let left_right_tree = &four_trees[1];
1505+
assert_eq!([5], left_right_tree[0]);
1506+
assert_eq!([10, 11], left_right_tree[1]);
1507+
assert_eq!((20..24).collect_vec().as_slice(), left_right_tree[2]);
1508+
assert_eq!((40..48).collect_vec().as_slice(), left_right_tree[3]);
1509+
let right_left_tree = &four_trees[2];
1510+
assert_eq!([6], right_left_tree[0]);
1511+
assert_eq!([12, 13], right_left_tree[1]);
1512+
assert_eq!((24..28).collect_vec().as_slice(), right_left_tree[2]);
1513+
assert_eq!((48..56).collect_vec().as_slice(), right_left_tree[3]);
1514+
let right_right_tree = &four_trees[3];
1515+
assert_eq!([7], right_right_tree[0]);
1516+
assert_eq!([14, 15], right_right_tree[1]);
1517+
assert_eq!((28..32).collect_vec().as_slice(), right_right_tree[2]);
1518+
assert_eq!((56..64).collect_vec().as_slice(), right_right_tree[3]);
1519+
}
13631520
}

0 commit comments

Comments
 (0)