Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 27 additions & 42 deletions simpeg/dask/electromagnetics/time_domain/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ def compute_J(self, m, f=None):
self.survey.source_list,
compute_row_size,
thread_count=self.n_threads(client=client, worker=worker),
optimize=False,
)
fields_array = f[:, ftype, :]

Expand Down Expand Up @@ -157,6 +156,7 @@ def compute_J(self, m, f=None):
AdiagTinv,
ATinv_df_duT_v[ind],
time_mask,
client,
)

if client:
Expand All @@ -166,7 +166,7 @@ def compute_J(self, m, f=None):

for block_ind in range(len(blocks)):

if len(blocks[block_ind]) == 0:
if len(block) == 0:
continue

if client:
Comment on lines 166 to 172
Expand Down Expand Up @@ -337,52 +337,48 @@ def get_field_deriv_block(
AdiagTinv,
ATinv_df_duT_v,
time_mask,
client,
):
"""
Stack the blocks of field derivatives for a given timestep and call the direct solver.
"""
if len(block) == 0:
return None
if len(ATinv_df_duT_v) == 0:
ATinv_df_duT_v = [[] for _ in block]

Asubdiag = None
if tInd < self.nT - 1:
Asubdiag = self.getAsubdiag(tInd + 1)

time_blocks = []
colm_indices = []
colm_count = 0
for (_, (rx_ind, _, shape)), field_deriv in zip(block, field_derivs):
updated_ATinv_df_duT_v = []

for (_, (rx_ind, _, shape)), field_deriv, ATinv_chunk in zip(
block, field_derivs, ATinv_df_duT_v
):

# Cut out early data
time_check = np.kron(time_mask, np.ones(shape, dtype=bool))[rx_ind]
local_ind = np.arange(rx_ind.shape[0])[time_check]

if len(ATinv_df_duT_v) == 0:
if len(ATinv_chunk) == 0:
# last timestep (first to be solved)
time_block = field_deriv.toarray()[:, local_ind]
shape = (
field_deriv.shape[0],
len(rx_ind),
)
ATinv_chunk = np.zeros(shape, dtype=np.float32)
else:
time_block = np.asarray(
field_deriv[:, local_ind]
- Asubdiag.T
* ATinv_df_duT_v[:, colm_count : colm_count + rx_ind.shape[0]][
:, local_ind
]
field_deriv[:, local_ind] - Asubdiag.T * ATinv_chunk[:, local_ind]
)

time_blocks.append(time_block)
colm_indices.append(local_ind + colm_count)
colm_count += rx_ind.shape[0]

if len(ATinv_df_duT_v) == 0:
ATinv_df_duT_v = np.zeros((field_deriv.shape[0], colm_count), dtype=np.float32)
if time_block.ndim == 2 and time_block.shape[1] > 0:
solve = (AdiagTinv * time_block).reshape(time_block.shape)
ATinv_chunk[:, local_ind] = solve

if len(time_blocks) > 0:
solve = AdiagTinv * np.hstack(time_blocks).reshape(
(ATinv_df_duT_v.shape[0], -1)
)
ATinv_df_duT_v[:, np.hstack(colm_indices)] = solve
updated_ATinv_df_duT_v.append(ATinv_chunk)

return ATinv_df_duT_v
return updated_ATinv_df_duT_v


def block_deriv(
Expand Down Expand Up @@ -466,14 +462,11 @@ def compute_rows(
Compute the rows of the sensitivity matrix for a given source and receiver.
"""
rows = []
colm_count = 0
for address, ind_array in blocks[block_ind]:
for ind, (address, ind_array) in enumerate(blocks[block_ind]):
# for (address, ind_array), field_derivs in zip(chunks, ATinv_df_duT_v):
src = simulation.survey.source_list[address[0]]
time_check = np.kron(time_mask, np.ones(ind_array[2], dtype=bool))[ind_array[0]]

n_rec = len(ind_array[0])
local_ind = np.arange(n_rec)[time_check]
local_ind = np.arange(len(ind_array[0]))[time_check]

if len(local_ind) < 1:
row_block = np.zeros(
Expand All @@ -485,24 +478,18 @@ def compute_rows(
dAsubdiagT_dm_v = simulation.getAsubdiagDeriv(
tInd,
fields[:, address[0], tInd],
field_derivs[block_ind][:, colm_count : colm_count + n_rec][:, local_ind],
field_derivs[block_ind][ind][:, local_ind],
adjoint=True,
)

dRHST_dm_v = simulation.getRHSDeriv(
tInd + 1,
src,
field_derivs[block_ind][:, colm_count : colm_count + n_rec][:, local_ind],
adjoint=True,
tInd + 1, src, field_derivs[block_ind][ind][:, local_ind], adjoint=True
) # on nodes of time mesh

un_src = fields[:, address[0], tInd + 1]
# cell centered on time mesh
dAT_dm_v = simulation.getAdiagDeriv(
tInd,
un_src,
field_derivs[block_ind][:, colm_count : colm_count + n_rec][:, local_ind],
adjoint=True,
tInd, un_src, field_derivs[block_ind][ind][:, local_ind], adjoint=True
)
row_block = np.zeros(
(len(ind_array[1]), simulation.model.size), dtype=np.float32
Expand All @@ -519,8 +506,6 @@ def compute_rows(
else:
Jmatrix[ind_array[1], :] += row_block

colm_count += n_rec


def evaluate_dpred_block(indices, sources, mesh, time_mesh, fields):
"""
Expand Down
Loading