diff --git a/src/ansys/dpf/core/plotter.py b/src/ansys/dpf/core/plotter.py index dc3f0b865be..ad319f1e81c 100644 --- a/src/ansys/dpf/core/plotter.py +++ b/src/ansys/dpf/core/plotter.py @@ -30,10 +30,12 @@ from __future__ import annotations +from enum import Enum, unique +import importlib.util from pathlib import Path import sys import tempfile -from typing import TYPE_CHECKING, List, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import warnings import numpy as np @@ -49,15 +51,65 @@ from ansys.dpf.core import Operator, Result from ansys.dpf.core.field import Field from ansys.dpf.core.fields_container import FieldsContainer + from ansys.dpf.core.helpers import Streamlines, StreamlinesSource from ansys.dpf.core.meshed_region import MeshedRegion +@unique +class PlotterBackend(Enum): + """Plotter backend types for DPF visualization. + + Specifies which plotting backend to use for mesh and field visualization. + """ + + AUTO = "auto" + """Automatically select backend: prefer visualization_interface if available, fallback to pyvista.""" + + PYVISTA = "pyvista" + """Use the legacy PyVista-based plotter directly.""" + + VISUALIZATION_INTERFACE = "visualization_interface" + """Use ansys-tools-visualization-interface Plotter with PyVista backend.""" + + class _InternalPlotterFactory: - """Factory for _InternalPlotter based on the backend.""" + """Factory for _InternalPlotter based on the backend. + + By default, uses the new _VisualizationInterfacePlotter if ansys-tools-visualization-interface + is available, otherwise falls back to the legacy _PyVistaPlotter. Use the plotter_type parameter + to explicitly select a specific plotter implementation. + """ @staticmethod - def get_plotter_class(): - return _PyVistaPlotter + def get_plotter_class(plotter_type: PlotterBackend = PlotterBackend.AUTO): + """Get the plotter class based on the specified type. + + Parameters + ---------- + plotter_type : PlotterBackend, optional + The type of plotter to use. + + Returns + ------- + type + The plotter class to instantiate. + """ + if plotter_type == PlotterBackend.PYVISTA: + return _PyVistaPlotter + elif plotter_type == PlotterBackend.VISUALIZATION_INTERFACE: + return _VisualizationInterfacePlotter + elif plotter_type == PlotterBackend.AUTO: + # Check if visualization interface is available + if importlib.util.find_spec("ansys.tools.visualization_interface") is not None: + return _VisualizationInterfacePlotter + else: + # Fall back to legacy plotter if visualization-interface is not installed + return _PyVistaPlotter + else: + raise ValueError( + f"Invalid plotter_type '{plotter_type}'. " + f"Must be one of: {[e.name for e in PlotterBackend]}" + ) class _PyVistaPlotter: @@ -485,6 +537,647 @@ def _set_scalar_bar_title(kwargs): return kwargs +class _VisualizationInterfacePlotter: + """Plotter based on ansys-tools-visualization-interface. + + This class provides the same interface as _PyVistaPlotter but uses + the ansys-tools-visualization-interface Plotter class for rendering, + enabling future backend flexibility. + """ + + def __init__(self, **kwargs: Any) -> None: + """Initialize the visualization interface plotter. + + Parameters + ---------- + **kwargs : dict + Keyword arguments passed to the PyVistaBackend constructor. + """ + import pyvista as pv + + from ansys.tools.visualization_interface import Plotter + from ansys.tools.visualization_interface.backends.pyvista import PyVistaBackend + + # Filter kwargs for pv.Plotter.__init__ (final destination) + kwargs_in = _sort_supported_kwargs(bound_method=pv.Plotter.__init__, **kwargs) + + # Create the backend with filtered kwargs + self._backend = PyVistaBackend(**kwargs_in) + self._plotter = Plotter(backend=self._backend) + + def add_scale_factor_legend(self, scale_factor: float, **kwargs: Any) -> None: + """Add a scale factor legend text to the plotter. + + Parameters + ---------- + scale_factor : float + The scale factor value to display. + **kwargs : dict + Additional keyword arguments (unused, for compatibility). + """ + self._plotter.add_text( + text=f"Scale factor: {scale_factor}", + position="upper_right", + font_size=12, + ) + + def add_points( + self, points: Any, field: Optional[Field], point_size: float = 10.0, **kwargs: Any + ) -> None: + """Add points to the plotter. + + Parameters + ---------- + points : array-like + Point coordinates as Nx3 array. + field : Field, optional + Field containing scalar data for coloring. + point_size : float, default: 10.0 + Size of the point markers in pixels or display units + **kwargs : dict + Additional keyword arguments for add_points. + """ + import pyvista as pv + + # Filter kwargs for pv.Plotter.add_mesh (final destination for both paths) + kwargs_in = _sort_supported_kwargs(bound_method=pv.Plotter.add_mesh, **kwargs) + + point_cloud = pv.PolyData(points) + if field: + point_cloud[f"{field.name}"] = field.data + # Use plot for colored points + self._plotter.plot(point_cloud, **kwargs_in) + else: + self._plotter.add_points(points, size=point_size, **kwargs_in) + + def add_line( + self, points: Any, field: Optional[Field] = None, width: float = 1.0, **kwargs: Any + ) -> None: + """Add a line to the plotter. + + Parameters + ---------- + points : array-like + Point coordinates defining the line. + field : Field, optional + Field containing scalar data for coloring. + width : float, optional + Width of the line. Default is 1.0. + **kwargs : dict + Additional keyword arguments. + """ + import pyvista as pv + + # Filter kwargs for pv.Plotter.add_mesh (final destination for both paths) + kwargs_in = _sort_supported_kwargs(bound_method=pv.Plotter.add_mesh, **kwargs) + + line_field = pv.PolyData(np.array(points)) + if field: + line_field[f"{field.name}"] = field.data + self._plotter.plot(line_field, **kwargs_in) + else: + self._plotter.add_lines(points, width=width, **kwargs_in) + + def add_plane(self, plane: Any, field: Optional[Field] = None, **kwargs: Any) -> None: + """Add a plane to the plotter. + + Parameters + ---------- + plane : object + Plane object with center, normal_dir, width, height, n_cells_x, n_cells_y. + field : Field, optional + Field containing scalar data for coloring. + **kwargs : dict + Additional keyword arguments. + """ + import pyvista as pv + + # Filter kwargs for pv.Plotter.add_mesh (final destination) + kwargs_in = _sort_supported_kwargs(bound_method=pv.Plotter.add_mesh, **kwargs) + + plane_plot = pv.Plane( + center=plane.center, + direction=plane.normal_dir, + i_size=plane.width, + j_size=plane.height, + i_resolution=plane.n_cells_x, + j_resolution=plane.n_cells_y, + ) + if field: + plane_plot[f"{field.name}"] = field.data + self._plotter.plot(plane_plot, **kwargs_in) + + def add_mesh( + self, + meshed_region: MeshedRegion, + deform_by: Optional[Field] = None, + scale_factor: float = 1.0, + as_linear: bool = True, + **kwargs: Any, + ) -> None: + """Add a DPF mesh to the plotter. + + Parameters + ---------- + meshed_region : MeshedRegion + The DPF mesh to plot. + deform_by : Field, optional + Field to use for mesh deformation. + scale_factor : float, optional + Scale factor for deformation. Default is 1.0. + as_linear : bool, optional + Whether to treat elements as linear. Default is True. + **kwargs : dict + Additional keyword arguments. + """ + kwargs = self._set_scalar_bar_title(kwargs) + + # Set defaults for PyDPF + kwargs.setdefault("show_edges", True) + kwargs.setdefault("nan_color", "grey") + + # If deformed geometry, print the scale_factor + if deform_by: + self.add_scale_factor_legend(scale_factor, **kwargs) + + # Get the grid + if not deform_by: + if as_linear != meshed_region.as_linear: + grid = meshed_region._as_vtk( + meshed_region.nodes.coordinates_field, as_linear=as_linear + ) + meshed_region._full_grid = grid + meshed_region.as_linear = as_linear + else: + grid = meshed_region.grid + else: + grid = meshed_region._as_vtk( + meshed_region.deform_by(deform_by, scale_factor), as_linear=as_linear + ) + + # show axes + show_axes = kwargs.pop("show_axes", None) + if show_axes: + self._backend.base_plotter.add_axes() + + # Filter kwargs for pv.Plotter.add_mesh (final destination) + import pyvista as pv + + kwargs_in = _sort_supported_kwargs(bound_method=pv.Plotter.add_mesh, **kwargs) + + grid.set_active_scalars(None) + self._plotter.plot(grid, **kwargs_in) + + def add_point_labels( + self, + nodes: Union[Nodes, List[Node], List[int]], + meshed_region: MeshedRegion, + labels: Union[List[str], None] = None, + **kwargs, + ) -> List: + """Add labels at node locations. + + Parameters + ---------- + nodes : Nodes, List[Node], or List[int] + Nodes to label (as object, list of Node objects, or list of node IDs). + meshed_region : MeshedRegion + The mesh containing the nodes. + labels : List[str], optional + Custom labels. If None, uses scalar values or node IDs. + **kwargs : dict + Additional keyword arguments for add_point_labels. + + Returns + ------- + List + List of label actors. + """ + from packaging.version import parse + import pyvista as pv + + # Filter kwargs for pv.Plotter.add_point_labels (final destination) + kwargs_in = _sort_supported_kwargs(bound_method=pv.Plotter.add_point_labels, **kwargs) + + label_actors = [] + if isinstance(nodes, Nodes): + nodes = nodes.scoping.ids + elif isinstance(nodes, list): + if isinstance(nodes[0], Node): + nodes = [node.id for node in nodes] + node_indexes = [meshed_region.nodes.mapping_id_to_index.get(node_id) for node_id in nodes] + grid_points = [meshed_region.grid.points[node_index] for node_index in node_indexes] + + def get_label_at_grid_point(index): + try: + label = labels[index] + except: + label = None + return label + + # The scalar data used will be the one of the last field added. + active_scalars = None + if parse(pv.__version__) >= parse("0.42.0"): + # Get actors of active renderer + actors = list(self._backend.base_plotter.actors.values()) + for actor in actors: + mapper = actor.mapper if hasattr(actor, "mapper") else None + if mapper: + dataset = mapper.dataset + if type(dataset) is pv.core.pointset.UnstructuredGrid: + active_scalars = dataset.active_scalars + break + elif parse(pv.__version__) >= parse("0.35.2"): + for data_set in self._backend.base_plotter._datasets: + if type(data_set) is pv.core.pointset.UnstructuredGrid: + active_scalars = data_set.active_scalars + else: + active_scalars = meshed_region.grid.active_scalars + if active_scalars is None: + self.add_mesh(meshed_region=meshed_region) + + # For all grid_points given + for index, grid_point in enumerate(grid_points): + # Check for existing label at that point + label_at_grid_point = get_label_at_grid_point(index) + if label_at_grid_point: + # If there is already a label, create the associated actor + label_actors.append( + self._plotter.add_labels([grid_point], [labels[index]], **kwargs_in) + ) + else: + if active_scalars is not None: + # get the value of the current scalar field if present + scalar_at_index = active_scalars[node_indexes[index]] + value = f"{scalar_at_index:.2f}" + else: + # if no scalar field is present, print the node id + value = nodes[index] + label_actors.append( + self._plotter.add_labels([grid_point], [str(value)], **kwargs_in) + ) + return label_actors + + def add_scoping( + self, + scoping: core.Scoping, + mesh: core.MeshedRegion, + show_mesh: bool = False, + **kwargs: Any, + ) -> None: + """Add a scoping visualization to the plotter. + + Parameters + ---------- + scoping : Scoping + The scoping to visualize. + mesh : MeshedRegion + The mesh containing the entities. + show_mesh : bool, optional + Whether to show the base mesh with low opacity. Default is False. + **kwargs : dict + Additional keyword arguments. + """ + import pyvista as pv + + # Filter kwargs for pv.Plotter.add_mesh (final destination) + kwargs_in = _sort_supported_kwargs(bound_method=pv.Plotter.add_mesh, **kwargs) + + # Add the mesh to the scene with low opacity + if show_mesh: + self._plotter.plot(mesh.grid, opacity=0.3) + + scoping_mesh = None + + # If the scoping is nodal, use the add_points_label method + if scoping.location == locations.nodal: + node_indexes = np.where(np.isin(mesh.nodes.scoping.ids, scoping.ids))[0] + scoping_mesh = mesh.grid.extract_points(ind=node_indexes, include_cells=False) + # If the scoping is elemental, extract their edges and use active scalars to color them + if scoping.location == locations.elemental: + element_indexes = np.where(np.isin(mesh.elements.scoping.ids, scoping.ids))[0] + scoping_mesh = mesh.grid.extract_cells(ind=element_indexes) + + # If the scoping is faces, raise not implemented + if scoping.location == locations.faces: + raise NotImplementedError("Cannot plot a face scoping.") + + self._plotter.plot(scoping_mesh, **kwargs_in) + + def add_field( + self, + field: Field, + meshed_region: Optional[MeshedRegion] = None, + show_max: bool = False, + show_min: bool = False, + label_text_size: int = 30, + label_point_size: int = 20, + deform_by: Optional[Field] = None, + scale_factor: float = 1.0, + scale_factor_legend: Optional[float] = None, + as_linear: bool = True, + shell_layer: eshell_layers = eshell_layers.top, + **kwargs: Any, + ) -> None: + """Add a field visualization to the plotter. + + Parameters + ---------- + field : Field + The field to plot. + meshed_region : MeshedRegion, optional + The mesh to plot on. If None, uses the field's mesh. + show_max : bool, optional + Whether to show maximum value label. Default is False. + show_min : bool, optional + Whether to show minimum value label. Default is False. + label_text_size : int, optional + Font size for min/max labels. Default is 30. + label_point_size : int, optional + Point size for min/max labels. Default is 20. + deform_by : Field, optional + Field to use for mesh deformation. + scale_factor : float, optional + Scale factor for deformation. Default is 1.0. + scale_factor_legend : float, optional + Custom scale factor for legend. If None, uses scale_factor. + as_linear : bool, optional + Whether to treat elements as linear. Default is True. + shell_layer : shell_layers, optional + Shell layer to use. Default is top. + **kwargs : dict + Additional keyword arguments. + """ + # Get the field name + name = field.name.split("_")[0] + unit = field.unit + kwargs.setdefault("stitle", f"{name} ({unit})") + + kwargs = self._set_scalar_bar_title(kwargs) + + kwargs.setdefault("show_edges", True) + kwargs.setdefault("nan_color", "grey") + + # show axes + show_axes = kwargs.pop("show_axes", None) + if show_axes: + self._backend.base_plotter.add_axes() + + # get the meshed region location + if meshed_region is None: + meshed_region = field.meshed_region + + location = field.location + if location == locations.nodal: + mesh_location = meshed_region.nodes + elif location == locations.elemental: + mesh_location = meshed_region.elements + if show_max or show_min: + warnings.warn("`show_max` and `show_min` is only supported for Nodal results.") + show_max = False + show_min = False + elif location == locations.faces: + mesh_location = meshed_region.faces + if len(mesh_location) == 0: + raise ValueError("No faces found to plot on") + if show_max or show_min: + warnings.warn("`show_max` and `show_min` is only supported for Nodal results.") + show_max = False + show_min = False + elif location == locations.overall: + mesh_location = meshed_region.elements + elif location == locations.elemental_nodal: + mesh_location = meshed_region.elements + # If ElementalNodal, first extend results to mid-nodes + field = dpf.core.operators.averaging.extend_to_mid_nodes(field=field).eval() + else: + raise ValueError( + "Only elemental, elemental nodal, nodal, faces, or overall location are supported for plotting." + ) + + # Treat multilayered shells + if not isinstance(shell_layer, eshell_layers): + raise TypeError("shell_layer attribute must be a core.shell_layers instance.") + if field.shell_layers in [ + eshell_layers.topbottom, + eshell_layers.topbottommid, + ]: + change_shell_layer_op = core.operators.utility.change_shell_layers( + fields_container=field, + e_shell_layer=shell_layer, + ) + field = change_shell_layer_op.get_output(0, core.types.field) + + location_data_len = meshed_region.location_data_len(location) + component_count = field.component_count + if component_count > 1: + overall_data = np.full((location_data_len, component_count), np.nan) + else: + overall_data = np.full(location_data_len, np.nan) + if location != locations.overall: + ind, mask = mesh_location.map_scoping(field.scoping) + + # Rework ind and mask to take into account n_nodes per element if ElementalNodal + if location == locations.elemental_nodal: + n_nodes_list = meshed_region.get_elemental_nodal_size_list().astype(np.int32) + first_index = np.insert(np.cumsum(n_nodes_list)[:-1], 0, 0).astype(np.int32) + mask_2 = np.asarray( + [mask_i for i, mask_i in enumerate(mask) for _ in range(n_nodes_list[ind[i]])] + ) + ind_2 = np.asarray( + [first_index[ind_i] + j for ind_i in ind for j in range(n_nodes_list[ind_i])] + ) + mask = mask_2 + ind = ind_2 + overall_data[ind] = field.data[mask] + else: + overall_data[:] = field.data[0] + + # Have to remove any active scalar field from the pre-existing grid object, + # otherwise we get two scalar bars when calling several plot_contour on the same mesh + # but not for the same field. The PyVista UnstructuredGrid keeps memory of it. + if location == locations.elemental_nodal: + as_linear = False + if deform_by: + grid = meshed_region._as_vtk( + meshed_region.deform_by(deform_by, scale_factor), as_linear=as_linear + ) + else: + if as_linear != meshed_region.as_linear: + grid = meshed_region._as_vtk( + meshed_region.nodes.coordinates_field, as_linear=as_linear + ) + meshed_region.as_linear = as_linear + else: + grid = meshed_region.grid + if location == locations.elemental_nodal: + grid = grid.shrink(1.0) + grid.set_active_scalars(None) + + # Filter kwargs for pv.Plotter.add_mesh (final destination) + import pyvista as pv + + kwargs_in = _sort_supported_kwargs(bound_method=pv.Plotter.add_mesh, **kwargs) + self._plotter.plot(grid, scalars=overall_data, **kwargs_in) + + # If deformed geometry, print the scale_factor + if deform_by and scale_factor_legend is not False: + if scale_factor_legend is None: + scale_factor_legend = scale_factor + self.add_scale_factor_legend(scale_factor_legend, **kwargs) + + if show_max or show_min: + # Get Min-Max for the field + min_max = core.operators.min_max.min_max() + min_max.inputs.connect(field) + + # Add Min and Max Labels + labels = [] + node_ids = [] + if show_max: + max_field = min_max.outputs.field_max() + # Get Node ID at max. + node_id_at_max = max_field.scoping.id(0) + labels.append( + f"Max: {((max_field.data ** 2).sum() ** 0.5):.2e}\nNodeID: {node_id_at_max}" + ) + # Get Node index at max value. + node_ids.append(node_id_at_max) + + if show_min: + min_field = min_max.outputs.field_min() + # Get Node ID at min. + node_id_at_min = min_field.scoping.id(0) + labels.append( + f"Min: {((min_field.data ** 2).sum() ** 0.5):.2e}\nNodeID: {node_id_at_min}" + ) + # Get Node index at min. value. + node_ids.append(node_id_at_min) + + # Plot labels: + for index, node_id in enumerate(node_ids): + self.add_point_labels( + [node_id], + meshed_region, + [labels[index]], + font_size=label_text_size, + point_size=label_point_size, + ) + + def add_streamlines( + self, + streamlines: Streamlines, + source: Optional[StreamlinesSource] = None, + radius: float = 1.0, + **kwargs: Any, + ) -> None: + """Add streamlines to the plotter. + + Parameters + ---------- + streamlines : Streamlines + The streamlines object to plot. + source : StreamlinesSource, optional + The source object for streamlines. + radius : float, optional + Tube radius for streamlines. Default is 1.0. + **kwargs : dict + Additional keyword arguments. + """ + import pyvista as pv + + permissive = kwargs.pop("permissive", True) + + # Filter kwargs for pv.Plotter.add_mesh (final destination) + kwargs_in = _sort_supported_kwargs(bound_method=pv.Plotter.add_mesh, **kwargs) + + # set streamline on plotter + sargs = dict(vertical=False) + streamlines_vtk = streamlines._as_pyvista_data_set() + if not (permissive and streamlines_vtk.n_points == 0): + self._plotter.plot( + streamlines_vtk.tube(radius=radius), scalar_bar_args=sargs, **kwargs_in + ) + if source is not None: + src = source._as_pyvista_data_set() + self._plotter.plot(src, **kwargs_in) + + def show_figure(self, **kwargs: Any) -> Tuple[Any, Any]: + """Show the figure. + + Parameters + ---------- + **kwargs : dict + Keyword arguments including: + - text: Text to display at lower edge + - background: Background color + - show_axes: Whether to show axes + - parallel_projection: Whether to use parallel projection + - cpos: Camera position + - zoom: Zoom factor + + Returns + ------- + tuple + (result of show(), plotter object) + """ + text = kwargs.pop("text", None) + if text is not None: + self._plotter.add_text(text, position="lower_edge") + + background = kwargs.pop("background", None) + if background is not None: + self._backend.base_plotter.set_background(background) + + # show result + show_axes = kwargs.pop("show_axes", None) + if show_axes: + self._backend.base_plotter.add_axes() + + if kwargs.pop("parallel_projection", False): + self._backend.base_plotter.parallel_projection = True + + # Set cpos + cpos = kwargs.pop("cpos", None) + if cpos is not None: + self._backend.base_plotter.camera_position = cpos + + zoom = kwargs.pop("zoom", None) + if zoom is not None: + self._backend.base_plotter.camera.zoom(zoom) + + # Filter remaining kwargs for pv.Plotter.show (final destination) + import pyvista as pv + + kwargs_in = _sort_supported_kwargs(bound_method=pv.Plotter.show, **kwargs) + + # Show + result = self._plotter.show(**kwargs_in) + return result, self._backend.base_plotter + + @staticmethod + def _set_scalar_bar_title(kwargs: Dict[str, Any]) -> Dict[str, Any]: + """Set the scalar bar title from kwargs. + + Parameters + ---------- + kwargs : dict + Keyword arguments containing 'stitle' or 'scalar_bar_args'. + + Returns + ------- + dict + Updated kwargs with scalar_bar_args set. + """ + stitle = kwargs.pop("stitle", None) + # use scalar_bar_args + scalar_bar_args = kwargs.pop("scalar_bar_args", None) + if not scalar_bar_args: + scalar_bar_args = {"title": stitle} + kwargs.setdefault("scalar_bar_args", scalar_bar_args) + return kwargs + + class DpfPlotter: """DpfPlotter class. Can be used in order to plot results over a mesh. @@ -497,7 +1190,7 @@ class DpfPlotter: available at :class:`pyvista.Plotter`. """ - def __init__(self, **kwargs): + def __init__(self, plotter_type: PlotterBackend | str = PlotterBackend.AUTO, **kwargs): """Create a DpfPlotter object. The current DpfPlotter is a PyVista based object. @@ -510,6 +1203,9 @@ def __init__(self, **kwargs): Parameters ---------- + plotter_type : PlotterBackend or str, optional + The type of plotter to use. Can be a ``PlotterBackend`` enum value or + its string equivalent. **kwargs : optional Additional keyword arguments for the plotter. More information are available at :class:`pyvista.Plotter`. @@ -518,9 +1214,19 @@ def __init__(self, **kwargs): -------- >>> from ansys.dpf.core.plotter import DpfPlotter >>> pl = DpfPlotter(notebook=False) - """ - _InternalPlotterClass = _InternalPlotterFactory.get_plotter_class() + if isinstance(plotter_type, str): + try: + plotter_type = PlotterBackend(plotter_type) + except ValueError as e: + raise ValueError(f"Unsupported value for plotter_type: {plotter_type!r}") from e + + if not isinstance(plotter_type, PlotterBackend): + raise ValueError( + f"plotter_type must be an instance of PlotterBackend or str, got {type(plotter_type)}" + ) + + _InternalPlotterClass = _InternalPlotterFactory.get_plotter_class(plotter_type) self._internal_plotter = _InternalPlotterClass(**kwargs) self._labels = [] @@ -880,7 +1586,7 @@ class Plotter: """ def __init__(self, mesh, **kwargs): - _InternalPlotterClass = _InternalPlotterFactory.get_plotter_class() + _InternalPlotterClass = _InternalPlotterFactory.get_plotter_class(PlotterBackend.PYVISTA) self._internal_plotter = _InternalPlotterClass(mesh=mesh, **kwargs) self._mesh = mesh