Skip to content

Commit 28f25a6

Browse files
authored
Merge pull request #284 from Loop3D/fix/speed-up
Fix/speed up
2 parents 4659b47 + 7b44228 commit 28f25a6

5 files changed

Lines changed: 113 additions & 80 deletions

File tree

LoopStructural/interpolators/_discrete_interpolator.py

Lines changed: 36 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def __init__(self, support, data={}, c=None, up_to_date=False):
6565
self.apply_scaling_matrix = True
6666
self.add_ridge_regulatisation = True
6767
self.ridge_factor = 1e-8
68+
6869
def set_nelements(self, nelements: int) -> int:
6970
return self.support.set_nelements(nelements)
7071

@@ -247,11 +248,11 @@ def add_constraints_to_least_squares(self, A, B, idc, w=1.0, name="undefined"):
247248

248249
rows = np.tile(rows, (A.shape[-1], 1)).T
249250
self.constraints[name] = {
250-
'matrix': sparse.coo_matrix(
251+
"matrix": sparse.coo_matrix(
251252
(A.flatten(), (rows.flatten(), idc.flatten())), shape=(n_rows, self.dof)
252253
).tocsc(),
253-
'b': B.flatten(),
254-
'w': w,
254+
"b": B.flatten(),
255+
"w": w,
255256
}
256257

257258
@abstractmethod
@@ -305,7 +306,7 @@ def add_inequality_constraints_to_matrix(
305306
rows = np.tile(rows, (A.shape[-1], 1)).T
306307

307308
self.ineq_constraints[name] = {
308-
'matrix': sparse.coo_matrix(
309+
"matrix": sparse.coo_matrix(
309310
(A.flatten(), (rows.flatten(), idc.flatten())), shape=(rows.shape[0], self.dof)
310311
).tocsc(),
311312
"bounds": bounds,
@@ -320,7 +321,7 @@ def add_value_inequality_constraints(self, w: float = 1.0):
320321
rows = np.tile(rows, (a.shape[-1], 1)).T
321322
a = a[inside]
322323
cols = self.support.elements[element[inside]]
323-
self.add_inequality_constraints_to_matrix(a, points[:, 3:5], cols, 'inequality_value')
324+
self.add_inequality_constraints_to_matrix(a, points[:, 3:5], cols, "inequality_value")
324325

325326
def add_inequality_pairs_constraints(
326327
self,
@@ -354,11 +355,11 @@ def add_inequality_pairs_constraints(
354355
lower_interpolation = self.support.get_element_for_location(lower_points)
355356
if (~upper_interpolation[3]).sum() > 0:
356357
logger.warning(
357-
f'Upper points not in mesh {upper_points[~upper_interpolation[3]]}'
358+
f"Upper points not in mesh {upper_points[~upper_interpolation[3]]}"
358359
)
359360
if (~lower_interpolation[3]).sum() > 0:
360361
logger.warning(
361-
f'Lower points not in mesh {lower_points[~lower_interpolation[3]]}'
362+
f"Lower points not in mesh {lower_points[~lower_interpolation[3]]}"
362363
)
363364
ij = np.array(
364365
[
@@ -392,7 +393,7 @@ def add_inequality_pairs_constraints(
392393
bounds[:, 1] = upper_bound
393394

394395
self.add_inequality_constraints_to_matrix(
395-
a, bounds, cols, f'inequality_pairs_{pair[0]}_{pair[1]}'
396+
a, bounds, cols, f"inequality_pairs_{pair[0]}_{pair[1]}"
396397
)
397398

398399
def add_inequality_feature(
@@ -506,13 +507,14 @@ def build_matrix(self):
506507
for c in self.constraints.values():
507508
if len(c["w"]) == 0:
508509
continue
509-
mats.append(c['matrix'].multiply(c['w'][:, None]))
510-
bs.append(c['b'] * c['w'])
510+
mats.append(c["matrix"].multiply(c["w"][:, None]))
511+
bs.append(c["b"] * c["w"])
511512
A = sparse.vstack(mats)
512513
logger.info(f"Interpolation matrix is {A.shape[0]} x {A.shape[1]}")
513514

514515
B = np.hstack(bs)
515516
return A, B
517+
516518
def compute_column_scaling_matrix(self, A: sparse.csr_matrix) -> sparse.dia_matrix:
517519
"""Compute column scaling matrix S for matrix A so that A @ S has columns with unit norm.
518520
@@ -576,8 +578,8 @@ def build_inequality_matrix(self):
576578
mats = []
577579
bounds = []
578580
for c in self.ineq_constraints.values():
579-
mats.append(c['matrix'])
580-
bounds.append(c['bounds'])
581+
mats.append(c["matrix"])
582+
bounds.append(c["bounds"])
581583
if len(mats) == 0:
582584
return sparse.csr_matrix((0, self.dof), dtype=float), np.zeros((0, 3))
583585
Q = sparse.vstack(mats)
@@ -623,40 +625,40 @@ def solve_system(
623625

624626
Q, bounds = self.build_inequality_matrix()
625627
if callable(solver):
626-
logger.warning('Using custom solver')
628+
logger.warning("Using custom solver")
627629
self.c = solver(A.tocsr(), b)
628630
self.up_to_date = True
629631
elif isinstance(solver, str) or solver is None:
630-
if solver not in ['cg', 'lsmr', 'admm']:
632+
if solver not in ["cg", "lsmr", "admm"]:
631633
logger.warning(
632-
f'Unknown solver {solver} using cg. \n Available solvers are cg and lsmr or a custom solver as a callable function'
634+
f"Unknown solver {solver} using cg. \n Available solvers are cg and lsmr or a custom solver as a callable function"
633635
)
634-
solver = 'cg'
635-
if solver == 'cg':
636+
solver = "cg"
637+
if solver == "cg":
636638
logger.info("Solving using cg")
637-
if 'atol' not in solver_kwargs or 'rtol' not in solver_kwargs:
639+
if "atol" not in solver_kwargs or "rtol" not in solver_kwargs:
638640
if tol is not None:
639-
solver_kwargs['atol'] = tol
641+
solver_kwargs["atol"] = tol
640642

641643
logger.info(f"Solver kwargs: {solver_kwargs}")
642644

643645
res = sparse.linalg.cg(A.T @ A, A.T @ b, **solver_kwargs)
644646
if res[1] > 0:
645647
logger.warning(
646-
f'CG reached iteration limit ({res[1]})and did not converge, check input data. Setting solution to last iteration'
648+
f"CG reached iteration limit ({res[1]})and did not converge, check input data. Setting solution to last iteration"
647649
)
648650
self.c = res[0]
649651
self.up_to_date = True
650652

651-
elif solver == 'lsmr':
653+
elif solver == "lsmr":
652654
logger.info("Solving using lsmr")
653655
# if 'atol' not in solver_kwargs:
654656
# if tol is not None:
655657
# solver_kwargs['atol'] = tol
656-
if 'btol' not in solver_kwargs:
658+
if "btol" not in solver_kwargs:
657659
if tol is not None:
658-
solver_kwargs['btol'] = tol
659-
solver_kwargs['atol'] = 0.
660+
solver_kwargs["btol"] = tol
661+
solver_kwargs["atol"] = 0.0
660662
logger.info(f"Setting lsmr btol to {tol}")
661663
logger.info(f"Solver kwargs: {solver_kwargs}")
662664
res = sparse.linalg.lsmr(A, b, **solver_kwargs)
@@ -674,31 +676,31 @@ def solve_system(
674676
self.c = res[0]
675677
self.up_to_date = True
676678

677-
elif solver == 'admm':
679+
elif solver == "admm":
678680
logger.info("Solving using admm")
679681

680-
if 'x0' in solver_kwargs:
681-
x0 = solver_kwargs['x0'](self.support)
682+
if "x0" in solver_kwargs:
683+
x0 = solver_kwargs["x0"](self.support)
682684
else:
683685
x0 = np.zeros(A.shape[1])
684-
solver_kwargs.pop('x0', None)
686+
solver_kwargs.pop("x0", None)
685687
if Q is None:
686688
logger.warning("No inequality constraints, using lsmr")
687-
return self.solve_system('lsmr', solver_kwargs)
689+
return self.solve_system("lsmr", solver_kwargs=solver_kwargs)
688690

689691
try:
690692
from loopsolver import admm_solve
691693

692694
try:
693-
linsys_solver = solver_kwargs.pop('linsys_solver', 'lsmr')
695+
linsys_solver = solver_kwargs.pop("linsys_solver", "lsmr")
694696
res = admm_solve(
695697
A,
696698
b,
697699
Q,
698700
bounds,
699701
x0=x0,
700-
admm_weight=solver_kwargs.pop('admm_weight', 0.01),
701-
nmajor=solver_kwargs.pop('nmajor', 200),
702+
admm_weight=solver_kwargs.pop("admm_weight", 0.01),
703+
nmajor=solver_kwargs.pop("nmajor", 200),
702704
linsys_solver_kwargs=solver_kwargs,
703705
linsys_solver=linsys_solver,
704706
)
@@ -756,12 +758,7 @@ def evaluate_value(self, locations: np.ndarray) -> np.ndarray:
756758
"""
757759
self.update()
758760
evaluation_points = np.array(locations)
759-
evaluated = np.zeros(evaluation_points.shape[0])
760-
mask = np.any(evaluation_points == np.nan, axis=1)
761-
762-
if evaluation_points[~mask, :].shape[0] > 0:
763-
evaluated[~mask] = self.support.evaluate_value(evaluation_points[~mask], self.c)
764-
return evaluated
761+
return self.support.evaluate_value(evaluation_points, self.c)
765762

766763
def evaluate_gradient(self, locations: np.ndarray) -> np.ndarray:
767764
"""
@@ -792,4 +789,4 @@ def to_dict(self):
792789
def vtk(self):
793790
if self.up_to_date is False:
794791
self.update()
795-
return self.support.vtk({'c': self.c})
792+
return self.support.vtk({"c": self.c})

LoopStructural/interpolators/supports/_3d_base_structured.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -285,10 +285,8 @@ def position_to_cell_global_index(self, pos):
285285

286286
def inside(self, pos):
287287
# check whether point is inside box
288-
inside = np.ones(pos.shape[0]).astype(bool)
289-
for i in range(3):
290-
inside *= pos[:, i] > self.origin[None, i]
291-
inside *= pos[:, i] < self.maximum[None, i]
288+
pos = self.check_position(pos)
289+
inside = np.all((pos > self.origin) & (pos < self.maximum), axis=1)
292290
return inside
293291

294292
def check_position(self, pos: np.ndarray) -> np.ndarray:
@@ -306,11 +304,21 @@ def check_position(self, pos: np.ndarray) -> np.ndarray:
306304
[type]
307305
[description]
308306
"""
309-
pos = np.array(pos)
307+
if not isinstance(pos, np.ndarray):
308+
try:
309+
pos = np.array(pos, dtype=float)
310+
except Exception as e:
311+
logger.error(
312+
f"Position array should be a numpy array or list of points, not {type(pos)}"
313+
)
314+
raise ValueError(
315+
f"Position array should be a numpy array or list of points, not {type(pos)}"
316+
) from e
317+
310318
if len(pos.shape) == 1:
311319
pos = np.array([pos])
312320
if len(pos.shape) != 2:
313-
print("Position array needs to be a list of points or a point")
321+
logger.error("Position array needs to be a list of points or a point")
314322
raise ValueError("Position array needs to be a list of points or a point")
315323
return pos
316324

@@ -379,20 +387,24 @@ def position_to_cell_corners(self, pos):
379387
----------
380388
pos : np.array
381389
(N, 3) array of xyz coordinates representing the positions of N points.
382-
390+
383391
Returns
384392
-------
385393
globalidx : np.array
386-
(N, 8) array of global indices corresponding to the 8 corner nodes of the cell
387-
each point lies in. If a point lies outside the support, its corresponding entry
394+
(N, 8) array of global indices corresponding to the 8 corner nodes of the cell
395+
each point lies in. If a point lies outside the support, its corresponding entry
388396
will be set to -1.
389397
inside : np.array
390398
(N,) boolean array indicating whether each point is inside the support domain.
391399
"""
392400
cell_indexes, inside = self.position_to_cell_index(pos)
393-
corner_indexes = self.cell_corner_indexes(cell_indexes)
394-
globalidx = self.global_node_indices(corner_indexes)
395-
# if global index is not inside the support set to -1
401+
nx, ny = self.nsteps[0], self.nsteps[1]
402+
offsets = np.array(
403+
[0, 1, nx, nx + 1, nx * ny, nx * ny + 1, nx * ny + nx, nx * ny + nx + 1],
404+
dtype=np.intp,
405+
)
406+
g = cell_indexes[:, 0] + nx * cell_indexes[:, 1] + nx * ny * cell_indexes[:, 2]
407+
globalidx = g[:, None] + offsets[None, :] # (N, 8)
396408
globalidx[~inside] = -1
397409
return globalidx, inside
398410

LoopStructural/modelling/features/_lambda_geological_feature.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,25 @@ def evaluate_value(self, pos: np.ndarray, ignore_regions=False) -> np.ndarray:
6666
v = np.zeros((pos.shape[0]))
6767
v[:] = np.nan
6868

69-
mask = self._calculate_mask(pos, ignore_regions=ignore_regions)
70-
pos = self._apply_faults(pos)
69+
# Precompute each fault's scalar value (gx = fault.__getitem__(0).evaluate_value)
70+
# once and reuse for both the region mask check and fault application.
71+
# FaultSegment.evaluate_value(pos) == self.__getitem__(0).evaluate_value(pos) == gx.
72+
# apply_to_points also needs gx to determine the displacement mask — passing it
73+
# avoids the duplicate evaluation on the np.copy of pos.
74+
precomputed_gx = {id(f): f.evaluate_value(pos) for f in self.faults}
75+
76+
# Region mask: pass precomputed gx so PositiveRegion/NegativeRegion skip re-evaluation
77+
mask = np.ones(pos.shape[0], dtype=bool)
78+
if not ignore_regions:
79+
for r in self.regions:
80+
pre = precomputed_gx.get(id(getattr(r, 'feature', None)))
81+
mask = np.logical_and(mask, r(pos, precomputed_val=pre))
82+
83+
# Apply faults: pass precomputed gx so apply_to_points skips one evaluate_value call
84+
if self.faults_enabled:
85+
for f in self.faults:
86+
pos = f.apply_to_points(pos, precomputed_gx=precomputed_gx.get(id(f)))
87+
7188
if self.function is not None:
7289
v[mask] = self.function(pos[mask,:])
7390
return v

LoopStructural/modelling/features/fault/_fault_segment.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -311,13 +311,14 @@ def evaluate_displacement(self, points):
311311
d[mask] = self.faultfunction(gx[mask] + self.fault_offset, gy[mask], gz[mask])
312312
return d * self.displacement
313313

314-
def apply_to_points(self, points, reverse=False):
314+
def apply_to_points(self, points, reverse=False, precomputed_gx=None):
315315
"""
316316
Unfault the array of points
317317
318318
Parameters
319319
----------
320320
points - numpy array Nx3
321+
precomputed_gx - optional pre-evaluated gx values (same points, avoids duplicate evaluation)
321322
322323
Returns
323324
-------
@@ -328,10 +329,12 @@ def apply_to_points(self, points, reverse=False):
328329
newp = np.copy(points).astype(float)
329330
# evaluate fault function for all points
330331
# then define mask for only points affected by fault
331-
gx = None
332-
gy = None
333-
gz = None
334-
if use_threads:
332+
# gx may be supplied by caller to avoid re-evaluation (precomputed from region check)
333+
if precomputed_gx is not None:
334+
gx = precomputed_gx
335+
gy = self.__getitem__(1).evaluate_value(newp)
336+
gz = self.__getitem__(2).evaluate_value(newp)
337+
elif use_threads:
335338
with ThreadPoolExecutor(max_workers=8) as executor:
336339
# all of these operations should be
337340
# independent so just run as different threads
@@ -361,27 +364,32 @@ def apply_to_points(self, points, reverse=False):
361364
d *= -1.0
362365
# calculate the fault frame for the evaluation points
363366
for _i in range(steps):
364-
gx = None
365-
gy = None
366-
gz = None
367367
g = None
368-
if use_threads:
368+
if _i == 0:
369+
# Reuse gx/gy/gz from the initial full-array evaluation above — newp[mask] hasn't
370+
# moved yet on the first iteration, so values are identical.
371+
gx_m = gx[mask]
372+
gy_m = gy[mask]
373+
gz_m = gz[mask]
374+
g = self.__getitem__(1).evaluate_gradient(newp[mask, :], ignore_regions=True)
375+
elif use_threads:
369376
with ThreadPoolExecutor(max_workers=8) as executor:
370377
# all of these operations should be
371378
# independent so just run as different threads
372379
gx_future = executor.submit(self.__getitem__(0).evaluate_value, newp[mask, :])
373380
g_future = executor.submit(self.__getitem__(1).evaluate_gradient, newp[mask, :])
374381
gy_future = executor.submit(self.__getitem__(1).evaluate_value, newp[mask, :])
375382
gz_future = executor.submit(self.__getitem__(2).evaluate_value, newp[mask, :])
376-
gx = gx_future.result()
383+
gx_m = gx_future.result()
377384
g = g_future.result()
378-
gy = gy_future.result()
379-
gz = gz_future.result()
385+
gy_m = gy_future.result()
386+
gz_m = gz_future.result()
380387
else:
381-
gx = self.__getitem__(0).evaluate_value(newp[mask, :])
382-
gy = self.__getitem__(1).evaluate_value(newp[mask, :])
383-
gz = self.__getitem__(2).evaluate_value(newp[mask, :])
388+
gx_m = self.__getitem__(0).evaluate_value(newp[mask, :])
389+
gy_m = self.__getitem__(1).evaluate_value(newp[mask, :])
390+
gz_m = self.__getitem__(2).evaluate_value(newp[mask, :])
384391
g = self.__getitem__(1).evaluate_gradient(newp[mask, :], ignore_regions=True)
392+
gx, gy, gz = gx_m, gy_m, gz_m # alias for block below
385393
# # get the fault frame val/grad for the points
386394
# determine displacement magnitude, for constant displacement
387395
# hanging wall should be > 0

0 commit comments

Comments
 (0)