Skip to content

Commit 5e29fa1

Browse files
author
Lachlan Grose
committed
fix: numpy speed improvements
1 parent 54fba19 commit 5e29fa1

1 file changed

Lines changed: 23 additions & 12 deletions

File tree

LoopStructural/interpolators/supports/_3d_base_structured.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -285,10 +285,7 @@ def position_to_cell_global_index(self, pos):
285285

286286
def inside(self, pos):
287287
# check whether point is inside box
288-
inside = np.ones(pos.shape[0]).astype(bool)
289-
for i in range(3):
290-
inside *= pos[:, i] > self.origin[None, i]
291-
inside *= pos[:, i] < self.maximum[None, i]
288+
inside = np.all((pos > self.origin) & (pos < self.maximum), axis=1)
292289
return inside
293290

294291
def check_position(self, pos: np.ndarray) -> np.ndarray:
@@ -306,11 +303,21 @@ def check_position(self, pos: np.ndarray) -> np.ndarray:
306303
[type]
307304
[description]
308305
"""
309-
pos = np.array(pos)
306+
if not isinstance(pos, np.ndarray):
307+
try:
308+
pos = np.array(pos, dtype=float)
309+
except Exception as e:
310+
logger.error(
311+
f"Position array should be a numpy array or list of points, not {type(pos)}"
312+
)
313+
raise ValueError(
314+
f"Position array should be a numpy array or list of points, not {type(pos)}"
315+
) from e
316+
310317
if len(pos.shape) == 1:
311318
pos = np.array([pos])
312319
if len(pos.shape) != 2:
313-
print("Position array needs to be a list of points or a point")
320+
logger.error("Position array needs to be a list of points or a point")
314321
raise ValueError("Position array needs to be a list of points or a point")
315322
return pos
316323

@@ -379,20 +386,24 @@ def position_to_cell_corners(self, pos):
379386
----------
380387
pos : np.array
381388
(N, 3) array of xyz coordinates representing the positions of N points.
382-
389+
383390
Returns
384391
-------
385392
globalidx : np.array
386-
(N, 8) array of global indices corresponding to the 8 corner nodes of the cell
387-
each point lies in. If a point lies outside the support, its corresponding entry
393+
(N, 8) array of global indices corresponding to the 8 corner nodes of the cell
394+
each point lies in. If a point lies outside the support, its corresponding entry
388395
will be set to -1.
389396
inside : np.array
390397
(N,) boolean array indicating whether each point is inside the support domain.
391398
"""
392399
cell_indexes, inside = self.position_to_cell_index(pos)
393-
corner_indexes = self.cell_corner_indexes(cell_indexes)
394-
globalidx = self.global_node_indices(corner_indexes)
395-
# if global index is not inside the support set to -1
400+
nx, ny = self.nsteps[0], self.nsteps[1]
401+
offsets = np.array(
402+
[0, 1, nx, nx + 1, nx * ny, nx * ny + 1, nx * ny + nx, nx * ny + nx + 1],
403+
dtype=np.intp,
404+
)
405+
g = cell_indexes[:, 0] + nx * cell_indexes[:, 1] + nx * ny * cell_indexes[:, 2]
406+
globalidx = g[:, None] + offsets[None, :] # (N, 8)
396407
globalidx[~inside] = -1
397408
return globalidx, inside
398409

0 commit comments

Comments
 (0)