@@ -9,6 +9,7 @@ use itertools::Itertools;
99use rayon:: prelude:: * ;
1010use thiserror:: Error ;
1111
12+ use crate :: error:: U32_TO_USIZE_ERR ;
1213use crate :: prelude:: * ;
1314
1415/// Indexes internal nodes of a [`MerkleTree`].
@@ -136,13 +137,9 @@ impl MerkleTree {
136137 /// - If the number of leafs is zero.
137138 /// - If the number of leafs is not a power of two.
138139 pub fn sequential_new ( leafs : & [ Digest ] ) -> Result < Self > {
139- let mut nodes = Self :: initialize_merkle_tree_nodes ( leafs) ?;
140-
141- for i in ( ROOT_INDEX ..leafs. len ( ) ) . rev ( ) {
142- nodes[ i] = Tip5 :: hash_pair ( nodes[ i * 2 ] , nodes[ i * 2 + 1 ] ) ;
143- }
144-
145- Ok ( MerkleTree { nodes } )
140+ let nodes = Self :: initialize_merkle_tree_nodes ( leafs) ?;
141+ let num_remaining_nodes = leafs. len ( ) ;
142+ Self :: sequentially_fill_tree ( nodes, num_remaining_nodes)
146143 }
147144
148145 /// Build a MerkleTree with the given leafs.
@@ -159,27 +156,114 @@ impl MerkleTree {
159156 let mut nodes = Self :: initialize_merkle_tree_nodes ( leafs) ?;
160157
161158 // parallel
162- let mut num_nodes_on_this_level = leafs. len ( ) ;
163- while num_nodes_on_this_level >= crate :: config:: merkle_tree_parallelization_cutoff ( ) {
164- num_nodes_on_this_level /= 2 ;
165- let node_indices_on_this_level = num_nodes_on_this_level..2 * num_nodes_on_this_level;
166- let nodes_on_this_level = node_indices_on_this_level
167- . clone ( )
168- . into_par_iter ( )
169- . map ( |i| Tip5 :: hash_pair ( nodes[ i * 2 ] , nodes[ i * 2 + 1 ] ) )
170- . collect :: < Vec < _ > > ( ) ;
171- nodes[ node_indices_on_this_level] . copy_from_slice ( & nodes_on_this_level) ;
159+ let mut num_remaining_nodes = leafs. len ( ) ;
160+ let mut num_threads = Self :: num_threads ( ) ;
161+ while num_remaining_nodes >= crate :: config:: merkle_tree_parallelization_cutoff ( ) {
162+ // If the number of threads is so large that the chunk size is 1
163+ // (or even 0), each individual thread performs no work anymore.
164+ // In such a case, the most reasonable course of action is to reduce
165+ // the effective number of worker threads, which increases the chunk
166+ // size.
167+ //
168+ // Since parallelization_cutoff >= 2, it follows that
169+ // num_nodes_missing / 2 >= 1. Hence, the loop terminates at latest
170+ // once num_threads equals 1.
171+ while num_threads > num_remaining_nodes / 2 {
172+ num_threads /= 2 ;
173+ }
174+
175+ // re-slice to only include
176+ // 1. the nodes that need to be computed and
177+ // 2. exactly those nodes required to compute them.
178+ let nodes = & mut nodes[ ..2 * num_remaining_nodes] ;
179+ let subtrees = Self :: subtrees_mut ( nodes, num_threads) ;
180+ subtrees. into_par_iter ( ) . for_each ( |mut tree_layers| {
181+ debug_assert ! ( tree_layers. len( ) > 1 , "internal error: infinite iteration" ) ;
182+ let mut previous_layer = tree_layers. pop ( ) . unwrap ( ) ;
183+ for next_layer in tree_layers. into_iter ( ) . rev ( ) {
184+ for ( node, ( & left, & right) ) in
185+ next_layer. iter_mut ( ) . zip ( previous_layer. iter ( ) . tuples ( ) )
186+ {
187+ * node = Tip5 :: hash_pair ( left, right) ;
188+ }
189+ previous_layer = next_layer;
190+ }
191+ } ) ;
192+
193+ // Update the number of remaining nodes by subtracting the number of
194+ // freshly computed nodes. Equivalently, record that the tree has
195+ // grown by subtree_height many layers.
196+ let current_tree_height = num_remaining_nodes. ilog2 ( ) ;
197+ let subtree_height = current_tree_height - num_threads. ilog2 ( ) ;
198+ num_remaining_nodes >>= subtree_height;
172199 }
173200
174- // sequential
175- let num_remaining_nodes = num_nodes_on_this_level;
201+ Self :: sequentially_fill_tree ( nodes, num_remaining_nodes)
202+ }
203+
204+ /// Internal helper function to de-duplicate code between
205+ /// [`Self::sequential_new`] and [`Self::par_new`].
206+ fn sequentially_fill_tree ( mut nodes : Vec < Digest > , num_remaining_nodes : usize ) -> Result < Self > {
176207 for i in ( ROOT_INDEX ..num_remaining_nodes) . rev ( ) {
177208 nodes[ i] = Tip5 :: hash_pair ( nodes[ i * 2 ] , nodes[ i * 2 + 1 ] ) ;
178209 }
179210
180211 Ok ( MerkleTree { nodes } )
181212 }
182213
214+ /// Divides the given, contiguous slice into `num_trees` subtrees.
215+ ///
216+ /// The passed-in slice must represent a complete Merkle tree. Each subtree
217+ /// is returned as a number of mutable slices, where each such slice
218+ /// represents one layer in the subtree. The first slice contains only one
219+ /// element, the subtree's root, the next slice contains two elements, the
220+ /// root's children, and so on.
221+ ///
222+ /// In general, the top-most nodes of the complete tree will not be covered
223+ /// by the returned subtrees. For example, when requesting 2 subtrees, the
224+ /// complete tree's root will not be covered; when requesting 4 subtrees,
225+ /// the root and its direct children will not be covered; and so on.
226+ ///
227+ /// The number of subtrees must be a power of two, and must not exceed the
228+ /// number of leafs in the complete Merkle tree represented by the given
229+ /// slice.
230+ ///
231+ /// Because this is an internal helper function (and only for that reason),
232+ /// it's the caller's responsibility to ensure that the arguments are
233+ /// integral. To recap:
234+ /// - the number of `nodes` must be a power of 2
235+ /// - the number of `num_trees` must be a power of 2
236+ /// - the number of `nodes` must be at least `2 * num_trees`
237+ fn subtrees_mut < T > ( nodes : & mut [ T ] , num_trees : usize ) -> Vec < Vec < & mut [ T ] > > {
238+ let num_leafs = nodes. len ( ) / 2 ;
239+ let total_tree_height = num_leafs. ilog2 ( ) ;
240+ let sub_tree_height =
241+ usize:: try_from ( total_tree_height - num_trees. ilog2 ( ) ) . expect ( U32_TO_USIZE_ERR ) ;
242+
243+ // a tree's “height” is the number of layers excluding the root,
244+ // but we want to include the root
245+ let num_layers = sub_tree_height + 1 ;
246+ let mut subtrees = ( 0 ..num_trees)
247+ . map ( |_| Vec :: with_capacity ( num_layers) )
248+ . collect_vec ( ) ;
249+
250+ // the number of nodes to skip includes the dummy node at index 0
251+ let nodes_to_skip = num_trees;
252+ let ( _, mut nodes) = nodes. split_at_mut ( nodes_to_skip) ;
253+
254+ for layer_idx in 0 ..num_layers {
255+ let nodes_at_this_layer = 1 << layer_idx;
256+ for tree in & mut subtrees {
257+ let ( layer, rest) = nodes. split_at_mut ( nodes_at_this_layer) ;
258+ tree. push ( layer) ;
259+ nodes = rest;
260+ }
261+ }
262+ debug_assert ! ( nodes. is_empty( ) ) ;
263+
264+ subtrees
265+ }
266+
183267 /// Compute the Merkle root from the given leafs without recording any
184268 /// internal nodes.
185269 ///
@@ -240,17 +324,8 @@ impl MerkleTree {
240324 return Err ( MerkleTreeError :: IncorrectNumberOfLeafs ) ;
241325 }
242326
243- // To guarantee that all chunks correspond to trees of the same height,
244- // the number of threads must divide the number of leafs cleanly.
245- let num_threads = rayon:: current_num_threads ( ) ;
246- let num_threads = if num_threads. is_power_of_two ( ) {
247- num_threads
248- } else {
249- num_threads. next_power_of_two ( ) / 2 // previous power of 2
250- } ;
251- let mut num_threads = num_threads. max ( 1 ) ; // avoid division by 0
252-
253327 // parallel
328+ let mut num_threads = Self :: num_threads ( ) ;
254329 let mut leafs = Cow :: Borrowed ( leafs) ;
255330 while leafs. len ( ) >= crate :: config:: merkle_tree_parallelization_cutoff ( ) {
256331 // If the number of threads is so large that the chunk size is 1
@@ -278,6 +353,30 @@ impl MerkleTree {
278353 Self :: sequential_frugal_root ( & leafs)
279354 }
280355
356+ /// Internal helper function to determine the number of threads to use for
357+ /// parallel Merkle tree construction.
358+ ///
359+ /// Can be used to figure out the number of chunks to split the work into,
360+ /// but take care that each chunk contains at least 2 nodes, else no
361+ /// meaningful work will be done and your iteration might run forever.
362+ ///
363+ /// Guaranteed to be a power of two.
364+ ///
365+ /// Respects the [`RAYON_NUM_THREADS`][rayon] environment variable, if set.
366+ fn num_threads ( ) -> usize {
367+ // To guarantee that all chunks correspond to trees of the same height,
368+ // the number of threads must divide the number of leafs cleanly.
369+ let num_threads = rayon:: current_num_threads ( ) ;
370+ let num_threads = if num_threads. is_power_of_two ( ) {
371+ num_threads
372+ } else {
373+ num_threads. next_power_of_two ( ) / 2 // previous power of 2
374+ } ;
375+
376+ // avoid division by 0
377+ num_threads. max ( 1 )
378+ }
379+
281380 /// Helps to kick off Merkle tree construction. Sets up the Merkle tree's
282381 /// internal nodes if (and only if) it is possible to construct a Merkle
283382 /// tree with the given leafs.
@@ -1360,4 +1459,62 @@ pub mod merkle_tree_test {
13601459
13611460 assert_eq ! ( expected_paths, auth_paths) ;
13621461 }
1462+
1463+ #[ test]
1464+ fn merkle_subtrees_are_sliced_correctly ( ) {
1465+ const TREE_HEIGHT : usize = 5 ;
1466+ const NUM_LEAFS : u32 = 1 << TREE_HEIGHT ;
1467+ const NUM_NODES : u32 = 2 * NUM_LEAFS ;
1468+ debug_assert_eq ! ( 64 , NUM_NODES ) ;
1469+
1470+ let mut nodes = ( 0 ..NUM_NODES ) . collect_vec ( ) ;
1471+
1472+ let all_nodes = MerkleTree :: subtrees_mut ( & mut nodes, 1 ) ;
1473+ assert_eq ! ( 1 , all_nodes. len( ) ) ;
1474+ let subtree = & all_nodes[ 0 ] ;
1475+ assert_eq ! ( [ 1 ] , subtree[ 0 ] ) ;
1476+ assert_eq ! ( [ 2 , 3 ] , subtree[ 1 ] ) ;
1477+ assert_eq ! ( ( 4 ..8 ) . collect_vec( ) . as_slice( ) , subtree[ 2 ] ) ;
1478+ assert_eq ! ( ( 8 ..16 ) . collect_vec( ) . as_slice( ) , subtree[ 3 ] ) ;
1479+ assert_eq ! ( ( 16 ..32 ) . collect_vec( ) . as_slice( ) , subtree[ 4 ] ) ;
1480+ assert_eq ! ( ( 32 ..64 ) . collect_vec( ) . as_slice( ) , subtree[ 5 ] ) ;
1481+
1482+ let two_trees = MerkleTree :: subtrees_mut ( & mut nodes, 2 ) ;
1483+ assert_eq ! ( 2 , two_trees. len( ) ) ;
1484+ let left_tree = & two_trees[ 0 ] ;
1485+ assert_eq ! ( [ 2 ] , left_tree[ 0 ] ) ;
1486+ assert_eq ! ( [ 4 , 5 ] , left_tree[ 1 ] ) ;
1487+ assert_eq ! ( ( 8 ..12 ) . collect_vec( ) . as_slice( ) , left_tree[ 2 ] ) ;
1488+ assert_eq ! ( ( 16 ..24 ) . collect_vec( ) . as_slice( ) , left_tree[ 3 ] ) ;
1489+ assert_eq ! ( ( 32 ..48 ) . collect_vec( ) . as_slice( ) , left_tree[ 4 ] ) ;
1490+ let right_tree = & two_trees[ 1 ] ;
1491+ assert_eq ! ( [ 3 ] , right_tree[ 0 ] ) ;
1492+ assert_eq ! ( [ 6 , 7 ] , right_tree[ 1 ] ) ;
1493+ assert_eq ! ( ( 12 ..16 ) . collect_vec( ) . as_slice( ) , right_tree[ 2 ] ) ;
1494+ assert_eq ! ( ( 24 ..32 ) . collect_vec( ) . as_slice( ) , right_tree[ 3 ] ) ;
1495+ assert_eq ! ( ( 48 ..64 ) . collect_vec( ) . as_slice( ) , right_tree[ 4 ] ) ;
1496+
1497+ let four_trees = MerkleTree :: subtrees_mut ( & mut nodes, 4 ) ;
1498+ assert_eq ! ( 4 , four_trees. len( ) ) ;
1499+ let left_left_tree = & four_trees[ 0 ] ;
1500+ assert_eq ! ( [ 4 ] , left_left_tree[ 0 ] ) ;
1501+ assert_eq ! ( [ 8 , 9 ] , left_left_tree[ 1 ] ) ;
1502+ assert_eq ! ( ( 16 ..20 ) . collect_vec( ) . as_slice( ) , left_left_tree[ 2 ] ) ;
1503+ assert_eq ! ( ( 32 ..40 ) . collect_vec( ) . as_slice( ) , left_left_tree[ 3 ] ) ;
1504+ let left_right_tree = & four_trees[ 1 ] ;
1505+ assert_eq ! ( [ 5 ] , left_right_tree[ 0 ] ) ;
1506+ assert_eq ! ( [ 10 , 11 ] , left_right_tree[ 1 ] ) ;
1507+ assert_eq ! ( ( 20 ..24 ) . collect_vec( ) . as_slice( ) , left_right_tree[ 2 ] ) ;
1508+ assert_eq ! ( ( 40 ..48 ) . collect_vec( ) . as_slice( ) , left_right_tree[ 3 ] ) ;
1509+ let right_left_tree = & four_trees[ 2 ] ;
1510+ assert_eq ! ( [ 6 ] , right_left_tree[ 0 ] ) ;
1511+ assert_eq ! ( [ 12 , 13 ] , right_left_tree[ 1 ] ) ;
1512+ assert_eq ! ( ( 24 ..28 ) . collect_vec( ) . as_slice( ) , right_left_tree[ 2 ] ) ;
1513+ assert_eq ! ( ( 48 ..56 ) . collect_vec( ) . as_slice( ) , right_left_tree[ 3 ] ) ;
1514+ let right_right_tree = & four_trees[ 3 ] ;
1515+ assert_eq ! ( [ 7 ] , right_right_tree[ 0 ] ) ;
1516+ assert_eq ! ( [ 14 , 15 ] , right_right_tree[ 1 ] ) ;
1517+ assert_eq ! ( ( 28 ..32 ) . collect_vec( ) . as_slice( ) , right_right_tree[ 2 ] ) ;
1518+ assert_eq ! ( ( 56 ..64 ) . collect_vec( ) . as_slice( ) , right_right_tree[ 3 ] ) ;
1519+ }
13631520}
0 commit comments