From 77fffd81cb99dcbb4a979b31c45b94d246a87418 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Hensgen?= <24550538+sebhmg@users.noreply.github.com> Date: Sat, 20 Jun 2026 11:41:17 -0400 Subject: [PATCH 1/2] Revert "GEOPY-2910: Reduce chunking of sensitivities for TEM inversions" --- .../time_domain/simulation.py | 69 ++++++++----------- 1 file changed, 27 insertions(+), 42 deletions(-) diff --git a/simpeg/dask/electromagnetics/time_domain/simulation.py b/simpeg/dask/electromagnetics/time_domain/simulation.py index 3574a5de0d..5986f57ffa 100644 --- a/simpeg/dask/electromagnetics/time_domain/simulation.py +++ b/simpeg/dask/electromagnetics/time_domain/simulation.py @@ -99,7 +99,6 @@ def compute_J(self, m, f=None): self.survey.source_list, compute_row_size, thread_count=self.n_threads(client=client, worker=worker), - optimize=False, ) fields_array = f[:, ftype, :] @@ -157,6 +156,7 @@ def compute_J(self, m, f=None): AdiagTinv, ATinv_df_duT_v[ind], time_mask, + client, ) if client: @@ -166,7 +166,7 @@ def compute_J(self, m, f=None): for block_ind in range(len(blocks)): - if len(blocks[block_ind]) == 0: + if len(block) == 0: continue if client: @@ -337,52 +337,48 @@ def get_field_deriv_block( AdiagTinv, ATinv_df_duT_v, time_mask, + client, ): """ Stack the blocks of field derivatives for a given timestep and call the direct solver. """ - if len(block) == 0: - return None + if len(ATinv_df_duT_v) == 0: + ATinv_df_duT_v = [[] for _ in block] Asubdiag = None if tInd < self.nT - 1: Asubdiag = self.getAsubdiag(tInd + 1) - time_blocks = [] - colm_indices = [] - colm_count = 0 - for (_, (rx_ind, _, shape)), field_deriv in zip(block, field_derivs): + updated_ATinv_df_duT_v = [] + + for (_, (rx_ind, _, shape)), field_deriv, ATinv_chunk in zip( + block, field_derivs, ATinv_df_duT_v + ): # Cut out early data time_check = np.kron(time_mask, np.ones(shape, dtype=bool))[rx_ind] local_ind = np.arange(rx_ind.shape[0])[time_check] - if len(ATinv_df_duT_v) == 0: + if len(ATinv_chunk) == 0: # last timestep (first to be solved) time_block = field_deriv.toarray()[:, local_ind] + shape = ( + field_deriv.shape[0], + len(rx_ind), + ) + ATinv_chunk = np.zeros(shape, dtype=np.float32) else: time_block = np.asarray( - field_deriv[:, local_ind] - - Asubdiag.T - * ATinv_df_duT_v[:, colm_count : colm_count + rx_ind.shape[0]][ - :, local_ind - ] + field_deriv[:, local_ind] - Asubdiag.T * ATinv_chunk[:, local_ind] ) - time_blocks.append(time_block) - colm_indices.append(local_ind + colm_count) - colm_count += rx_ind.shape[0] - - if len(ATinv_df_duT_v) == 0: - ATinv_df_duT_v = np.zeros((field_deriv.shape[0], colm_count), dtype=np.float32) + if time_block.ndim == 2 and time_block.shape[1] > 0: + solve = (AdiagTinv * time_block).reshape(time_block.shape) + ATinv_chunk[:, local_ind] = solve - if len(time_blocks) > 0: - solve = AdiagTinv * np.hstack(time_blocks).reshape( - (ATinv_df_duT_v.shape[0], -1) - ) - ATinv_df_duT_v[:, np.hstack(colm_indices)] = solve + updated_ATinv_df_duT_v.append(ATinv_chunk) - return ATinv_df_duT_v + return updated_ATinv_df_duT_v def block_deriv( @@ -466,14 +462,11 @@ def compute_rows( Compute the rows of the sensitivity matrix for a given source and receiver. """ rows = [] - colm_count = 0 - for address, ind_array in blocks[block_ind]: + for ind, (address, ind_array) in enumerate(blocks[block_ind]): # for (address, ind_array), field_derivs in zip(chunks, ATinv_df_duT_v): src = simulation.survey.source_list[address[0]] time_check = np.kron(time_mask, np.ones(ind_array[2], dtype=bool))[ind_array[0]] - - n_rec = len(ind_array[0]) - local_ind = np.arange(n_rec)[time_check] + local_ind = np.arange(len(ind_array[0]))[time_check] if len(local_ind) < 1: row_block = np.zeros( @@ -485,24 +478,18 @@ def compute_rows( dAsubdiagT_dm_v = simulation.getAsubdiagDeriv( tInd, fields[:, address[0], tInd], - field_derivs[block_ind][:, colm_count : colm_count + n_rec][:, local_ind], + field_derivs[block_ind][ind][:, local_ind], adjoint=True, ) dRHST_dm_v = simulation.getRHSDeriv( - tInd + 1, - src, - field_derivs[block_ind][:, colm_count : colm_count + n_rec][:, local_ind], - adjoint=True, + tInd + 1, src, field_derivs[block_ind][ind][:, local_ind], adjoint=True ) # on nodes of time mesh un_src = fields[:, address[0], tInd + 1] # cell centered on time mesh dAT_dm_v = simulation.getAdiagDeriv( - tInd, - un_src, - field_derivs[block_ind][:, colm_count : colm_count + n_rec][:, local_ind], - adjoint=True, + tInd, un_src, field_derivs[block_ind][ind][:, local_ind], adjoint=True ) row_block = np.zeros( (len(ind_array[1]), simulation.model.size), dtype=np.float32 @@ -519,8 +506,6 @@ def compute_rows( else: Jmatrix[ind_array[1], :] += row_block - colm_count += n_rec - def evaluate_dpred_block(indices, sources, mesh, time_mesh, fields): """ From 97e39db70e6b66fceaaeacb068e9388f14e72c92 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Hensgen?= <24550538+sebhmg@users.noreply.github.com> Date: Sat, 20 Jun 2026 11:44:30 -0400 Subject: [PATCH 2/2] Revert "GEOPY-2910: Reduce chunking of sensitivities for TEM inversions" --- .../time_domain/simulation.py | 69 ++++++++----------- 1 file changed, 27 insertions(+), 42 deletions(-) diff --git a/simpeg/dask/electromagnetics/time_domain/simulation.py b/simpeg/dask/electromagnetics/time_domain/simulation.py index 2f38e87544..5986f57ffa 100644 --- a/simpeg/dask/electromagnetics/time_domain/simulation.py +++ b/simpeg/dask/electromagnetics/time_domain/simulation.py @@ -99,7 +99,6 @@ def compute_J(self, m, f=None): self.survey.source_list, compute_row_size, thread_count=self.n_threads(client=client, worker=worker), - optimize=False, ) fields_array = f[:, ftype, :] @@ -157,6 +156,7 @@ def compute_J(self, m, f=None): AdiagTinv, ATinv_df_duT_v[ind], time_mask, + client, ) if client: @@ -166,7 +166,7 @@ def compute_J(self, m, f=None): for block_ind in range(len(blocks)): - if len(blocks[block_ind]) == 0: + if len(block) == 0: continue if client: @@ -337,52 +337,48 @@ def get_field_deriv_block( AdiagTinv, ATinv_df_duT_v, time_mask, + client, ): """ Stack the blocks of field derivatives for a given timestep and call the direct solver. """ - if len(block) == 0: - return None + if len(ATinv_df_duT_v) == 0: + ATinv_df_duT_v = [[] for _ in block] Asubdiag = None if tInd < self.nT - 1: Asubdiag = self.getAsubdiag(tInd + 1) - time_blocks = [] - colm_indices = [] - colm_count = 0 - for (_, (rx_ind, _, shape)), field_deriv in zip(block, field_derivs): + updated_ATinv_df_duT_v = [] + + for (_, (rx_ind, _, shape)), field_deriv, ATinv_chunk in zip( + block, field_derivs, ATinv_df_duT_v + ): # Cut out early data time_check = np.kron(time_mask, np.ones(shape, dtype=bool))[rx_ind] local_ind = np.arange(rx_ind.shape[0])[time_check] - if len(ATinv_df_duT_v) == 0: + if len(ATinv_chunk) == 0: # last timestep (first to be solved) time_block = field_deriv.toarray()[:, local_ind] + shape = ( + field_deriv.shape[0], + len(rx_ind), + ) + ATinv_chunk = np.zeros(shape, dtype=np.float32) else: time_block = np.asarray( - field_deriv[:, local_ind] - - Asubdiag.T - * ATinv_df_duT_v[:, colm_count : colm_count + rx_ind.shape[0]][ - :, local_ind - ] + field_deriv[:, local_ind] - Asubdiag.T * ATinv_chunk[:, local_ind] ) - time_blocks.append(time_block) - colm_indices.append(local_ind + colm_count) - colm_count += rx_ind.shape[0] - - if len(ATinv_df_duT_v) == 0: - ATinv_df_duT_v = np.zeros((field_deriv.shape[0], colm_count)) + if time_block.ndim == 2 and time_block.shape[1] > 0: + solve = (AdiagTinv * time_block).reshape(time_block.shape) + ATinv_chunk[:, local_ind] = solve - if len(time_blocks) > 0: - solve = AdiagTinv * np.hstack(time_blocks).reshape( - (ATinv_df_duT_v.shape[0], -1) - ) - ATinv_df_duT_v[:, np.hstack(colm_indices)] = solve + updated_ATinv_df_duT_v.append(ATinv_chunk) - return ATinv_df_duT_v + return updated_ATinv_df_duT_v def block_deriv( @@ -466,14 +462,11 @@ def compute_rows( Compute the rows of the sensitivity matrix for a given source and receiver. """ rows = [] - colm_count = 0 - for address, ind_array in blocks[block_ind]: + for ind, (address, ind_array) in enumerate(blocks[block_ind]): # for (address, ind_array), field_derivs in zip(chunks, ATinv_df_duT_v): src = simulation.survey.source_list[address[0]] time_check = np.kron(time_mask, np.ones(ind_array[2], dtype=bool))[ind_array[0]] - - n_rec = len(ind_array[0]) - local_ind = np.arange(n_rec)[time_check] + local_ind = np.arange(len(ind_array[0]))[time_check] if len(local_ind) < 1: row_block = np.zeros( @@ -485,24 +478,18 @@ def compute_rows( dAsubdiagT_dm_v = simulation.getAsubdiagDeriv( tInd, fields[:, address[0], tInd], - field_derivs[block_ind][:, colm_count : colm_count + n_rec][:, local_ind], + field_derivs[block_ind][ind][:, local_ind], adjoint=True, ) dRHST_dm_v = simulation.getRHSDeriv( - tInd + 1, - src, - field_derivs[block_ind][:, colm_count : colm_count + n_rec][:, local_ind], - adjoint=True, + tInd + 1, src, field_derivs[block_ind][ind][:, local_ind], adjoint=True ) # on nodes of time mesh un_src = fields[:, address[0], tInd + 1] # cell centered on time mesh dAT_dm_v = simulation.getAdiagDeriv( - tInd, - un_src, - field_derivs[block_ind][:, colm_count : colm_count + n_rec][:, local_ind], - adjoint=True, + tInd, un_src, field_derivs[block_ind][ind][:, local_ind], adjoint=True ) row_block = np.zeros( (len(ind_array[1]), simulation.model.size), dtype=np.float32 @@ -519,8 +506,6 @@ def compute_rows( else: Jmatrix[ind_array[1], :] += row_block - colm_count += n_rec - def evaluate_dpred_block(indices, sources, mesh, time_mesh, fields): """