@@ -65,6 +65,7 @@ def __init__(self, support, data={}, c=None, up_to_date=False):
6565 self .apply_scaling_matrix = True
6666 self .add_ridge_regulatisation = True
6767 self .ridge_factor = 1e-8
68+
6869 def set_nelements (self , nelements : int ) -> int :
6970 return self .support .set_nelements (nelements )
7071
@@ -247,11 +248,11 @@ def add_constraints_to_least_squares(self, A, B, idc, w=1.0, name="undefined"):
247248
248249 rows = np .tile (rows , (A .shape [- 1 ], 1 )).T
249250 self .constraints [name ] = {
250- ' matrix' : sparse .coo_matrix (
251+ " matrix" : sparse .coo_matrix (
251252 (A .flatten (), (rows .flatten (), idc .flatten ())), shape = (n_rows , self .dof )
252253 ).tocsc (),
253- 'b' : B .flatten (),
254- 'w' : w ,
254+ "b" : B .flatten (),
255+ "w" : w ,
255256 }
256257
257258 @abstractmethod
@@ -305,7 +306,7 @@ def add_inequality_constraints_to_matrix(
305306 rows = np .tile (rows , (A .shape [- 1 ], 1 )).T
306307
307308 self .ineq_constraints [name ] = {
308- ' matrix' : sparse .coo_matrix (
309+ " matrix" : sparse .coo_matrix (
309310 (A .flatten (), (rows .flatten (), idc .flatten ())), shape = (rows .shape [0 ], self .dof )
310311 ).tocsc (),
311312 "bounds" : bounds ,
@@ -320,7 +321,7 @@ def add_value_inequality_constraints(self, w: float = 1.0):
320321 rows = np .tile (rows , (a .shape [- 1 ], 1 )).T
321322 a = a [inside ]
322323 cols = self .support .elements [element [inside ]]
323- self .add_inequality_constraints_to_matrix (a , points [:, 3 :5 ], cols , ' inequality_value' )
324+ self .add_inequality_constraints_to_matrix (a , points [:, 3 :5 ], cols , " inequality_value" )
324325
325326 def add_inequality_pairs_constraints (
326327 self ,
@@ -354,11 +355,11 @@ def add_inequality_pairs_constraints(
354355 lower_interpolation = self .support .get_element_for_location (lower_points )
355356 if (~ upper_interpolation [3 ]).sum () > 0 :
356357 logger .warning (
357- f' Upper points not in mesh { upper_points [~ upper_interpolation [3 ]]} '
358+ f" Upper points not in mesh { upper_points [~ upper_interpolation [3 ]]} "
358359 )
359360 if (~ lower_interpolation [3 ]).sum () > 0 :
360361 logger .warning (
361- f' Lower points not in mesh { lower_points [~ lower_interpolation [3 ]]} '
362+ f" Lower points not in mesh { lower_points [~ lower_interpolation [3 ]]} "
362363 )
363364 ij = np .array (
364365 [
@@ -392,7 +393,7 @@ def add_inequality_pairs_constraints(
392393 bounds [:, 1 ] = upper_bound
393394
394395 self .add_inequality_constraints_to_matrix (
395- a , bounds , cols , f' inequality_pairs_{ pair [0 ]} _{ pair [1 ]} '
396+ a , bounds , cols , f" inequality_pairs_{ pair [0 ]} _{ pair [1 ]} "
396397 )
397398
398399 def add_inequality_feature (
@@ -506,13 +507,14 @@ def build_matrix(self):
506507 for c in self .constraints .values ():
507508 if len (c ["w" ]) == 0 :
508509 continue
509- mats .append (c [' matrix' ].multiply (c ['w' ][:, None ]))
510- bs .append (c ['b' ] * c ['w' ])
510+ mats .append (c [" matrix" ].multiply (c ["w" ][:, None ]))
511+ bs .append (c ["b" ] * c ["w" ])
511512 A = sparse .vstack (mats )
512513 logger .info (f"Interpolation matrix is { A .shape [0 ]} x { A .shape [1 ]} " )
513514
514515 B = np .hstack (bs )
515516 return A , B
517+
516518 def compute_column_scaling_matrix (self , A : sparse .csr_matrix ) -> sparse .dia_matrix :
517519 """Compute column scaling matrix S for matrix A so that A @ S has columns with unit norm.
518520
@@ -576,8 +578,8 @@ def build_inequality_matrix(self):
576578 mats = []
577579 bounds = []
578580 for c in self .ineq_constraints .values ():
579- mats .append (c [' matrix' ])
580- bounds .append (c [' bounds' ])
581+ mats .append (c [" matrix" ])
582+ bounds .append (c [" bounds" ])
581583 if len (mats ) == 0 :
582584 return sparse .csr_matrix ((0 , self .dof ), dtype = float ), np .zeros ((0 , 3 ))
583585 Q = sparse .vstack (mats )
@@ -623,40 +625,40 @@ def solve_system(
623625
624626 Q , bounds = self .build_inequality_matrix ()
625627 if callable (solver ):
626- logger .warning (' Using custom solver' )
628+ logger .warning (" Using custom solver" )
627629 self .c = solver (A .tocsr (), b )
628630 self .up_to_date = True
629631 elif isinstance (solver , str ) or solver is None :
630- if solver not in ['cg' , ' lsmr' , ' admm' ]:
632+ if solver not in ["cg" , " lsmr" , " admm" ]:
631633 logger .warning (
632- f' Unknown solver { solver } using cg. \n Available solvers are cg and lsmr or a custom solver as a callable function'
634+ f" Unknown solver { solver } using cg. \n Available solvers are cg and lsmr or a custom solver as a callable function"
633635 )
634- solver = 'cg'
635- if solver == 'cg' :
636+ solver = "cg"
637+ if solver == "cg" :
636638 logger .info ("Solving using cg" )
637- if ' atol' not in solver_kwargs or ' rtol' not in solver_kwargs :
639+ if " atol" not in solver_kwargs or " rtol" not in solver_kwargs :
638640 if tol is not None :
639- solver_kwargs [' atol' ] = tol
641+ solver_kwargs [" atol" ] = tol
640642
641643 logger .info (f"Solver kwargs: { solver_kwargs } " )
642644
643645 res = sparse .linalg .cg (A .T @ A , A .T @ b , ** solver_kwargs )
644646 if res [1 ] > 0 :
645647 logger .warning (
646- f' CG reached iteration limit ({ res [1 ]} )and did not converge, check input data. Setting solution to last iteration'
648+ f" CG reached iteration limit ({ res [1 ]} )and did not converge, check input data. Setting solution to last iteration"
647649 )
648650 self .c = res [0 ]
649651 self .up_to_date = True
650652
651- elif solver == ' lsmr' :
653+ elif solver == " lsmr" :
652654 logger .info ("Solving using lsmr" )
653655 # if 'atol' not in solver_kwargs:
654656 # if tol is not None:
655657 # solver_kwargs['atol'] = tol
656- if ' btol' not in solver_kwargs :
658+ if " btol" not in solver_kwargs :
657659 if tol is not None :
658- solver_kwargs [' btol' ] = tol
659- solver_kwargs [' atol' ] = 0.
660+ solver_kwargs [" btol" ] = tol
661+ solver_kwargs [" atol" ] = 0.0
660662 logger .info (f"Setting lsmr btol to { tol } " )
661663 logger .info (f"Solver kwargs: { solver_kwargs } " )
662664 res = sparse .linalg .lsmr (A , b , ** solver_kwargs )
@@ -674,31 +676,31 @@ def solve_system(
674676 self .c = res [0 ]
675677 self .up_to_date = True
676678
677- elif solver == ' admm' :
679+ elif solver == " admm" :
678680 logger .info ("Solving using admm" )
679681
680- if 'x0' in solver_kwargs :
681- x0 = solver_kwargs ['x0' ](self .support )
682+ if "x0" in solver_kwargs :
683+ x0 = solver_kwargs ["x0" ](self .support )
682684 else :
683685 x0 = np .zeros (A .shape [1 ])
684- solver_kwargs .pop ('x0' , None )
686+ solver_kwargs .pop ("x0" , None )
685687 if Q is None :
686688 logger .warning ("No inequality constraints, using lsmr" )
687- return self .solve_system (' lsmr' , solver_kwargs )
689+ return self .solve_system (" lsmr" , solver_kwargs = solver_kwargs )
688690
689691 try :
690692 from loopsolver import admm_solve
691693
692694 try :
693- linsys_solver = solver_kwargs .pop (' linsys_solver' , ' lsmr' )
695+ linsys_solver = solver_kwargs .pop (" linsys_solver" , " lsmr" )
694696 res = admm_solve (
695697 A ,
696698 b ,
697699 Q ,
698700 bounds ,
699701 x0 = x0 ,
700- admm_weight = solver_kwargs .pop (' admm_weight' , 0.01 ),
701- nmajor = solver_kwargs .pop (' nmajor' , 200 ),
702+ admm_weight = solver_kwargs .pop (" admm_weight" , 0.01 ),
703+ nmajor = solver_kwargs .pop (" nmajor" , 200 ),
702704 linsys_solver_kwargs = solver_kwargs ,
703705 linsys_solver = linsys_solver ,
704706 )
@@ -756,12 +758,7 @@ def evaluate_value(self, locations: np.ndarray) -> np.ndarray:
756758 """
757759 self .update ()
758760 evaluation_points = np .array (locations )
759- evaluated = np .zeros (evaluation_points .shape [0 ])
760- mask = np .any (evaluation_points == np .nan , axis = 1 )
761-
762- if evaluation_points [~ mask , :].shape [0 ] > 0 :
763- evaluated [~ mask ] = self .support .evaluate_value (evaluation_points [~ mask ], self .c )
764- return evaluated
761+ return self .support .evaluate_value (evaluation_points , self .c )
765762
766763 def evaluate_gradient (self , locations : np .ndarray ) -> np .ndarray :
767764 """
@@ -792,4 +789,4 @@ def to_dict(self):
792789 def vtk (self ):
793790 if self .up_to_date is False :
794791 self .update ()
795- return self .support .vtk ({'c' : self .c })
792+ return self .support .vtk ({"c" : self .c })
0 commit comments