Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 25 additions & 9 deletions src/underworld3/discretisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2784,39 +2784,55 @@ def checkpoint_xdmf(

attributes = ""
for var in meshVars:
var_filename = filename + f"mesh.{var.clean_name}.{index:05}.h5"
var_filename = filename + f".mesh.{var.clean_name}.{index:05}.h5"

def get_cell_field_shape(h5_filename):

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

on second inspection I don't think this function should be called get..._shape. Consider get...._size.

Also can you not define this function in the for var in meshVars loop. This is not a good design. Move the function out of the loop. That will make it much less coupled to the loop.

with h5py.File(h5_filename, 'r') as f:

Copilot AI Jun 26, 2025

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The helper function 'get_cell_field_shape' assumes that the expected dataset exists in the HDF5 file. Consider adding error handling to provide a clearer message if the dataset is missing.

Copilot uses AI. Check for mistakes.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If not cell fields are found what does this return? Is it well handled?

shape = f[f'cell_fields/{var.clean_name}_{var.clean_name}'].shape[0]
return shape

if var.num_components == 1:
variable_type = "Scalar"
else:
variable_type = "Vector"
# We should add a tensor type here ...

# Determine if data is stored on nodes (vertex_fields) or cells (cell_fields)
if not getattr(var, "continuous") or getattr(var, "degree")==0:
center = "Cell"
numItems = get_cell_field_shape(var_filename)
field_group = "cell_fields"
else:
center = "Node"
numItems = numVertices
field_group = "vertex_fields"

var_attribute = f"""
<Attribute
Name="{var.clean_name}"
Type="{variable_type}"
Center="Node">
Center="{center}">
<DataItem ItemType="HyperSlab"
Dimensions="1 {numVertices} {var.num_components}"
Type="HyperSlab">
Dimensions="1 {numItems} {var.num_components}"
Type="HyperSlab">
<DataItem
Dimensions="3 3"
Format="XML">
0 0 0
1 1 1
1 {numVertices} {var.num_components}
1 {numItems} {var.num_components}
</DataItem>
<DataItem
DataType="Float" Precision="8"
Dimensions="1 {numVertices} {var.num_components}"
Dimensions="1 {numItems} {var.num_components}"
Format="HDF">
&{var.clean_name+"_Data"};:/vertex_fields/{var.clean_name+"_P"+str(var.degree)}
&{var.clean_name+"_Data"};:/{field_group}/{var.clean_name+"_"+var.clean_name}
</DataItem>
</DataItem>
</Attribute>
"""
"""
attributes += var_attribute


for var in swarmVars:
var_filename = filename + f".proxy.{var.clean_name}.{index:05}.h5"
if var.num_components == 1:
Expand Down
188 changes: 188 additions & 0 deletions tests/test_0005_check_xdmf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
import pytest
import sympy
import underworld3 as uw

import re
import h5py
import glob

import os
import numpy as np

# +
structured_quad_box = uw.meshing.StructuredQuadBox(elementRes=(3,) * 2)

unstructured_quad_box_irregular = uw.meshing.UnstructuredSimplexBox(
cellSize=0.2, regular=False, qdegree=2, refinement=1
)
unstructured_quad_box_regular = uw.meshing.UnstructuredSimplexBox(
cellSize=0.2, regular=True, qdegree=2, refinement=2
)

unstructured_quad_box_irregular_3D = uw.meshing.UnstructuredSimplexBox(
minCoords=(0.0, 0.0, 0.0),
maxCoords=(1.0, 1.0, 1.0),
cellSize=0.25,
regular=False,
qdegree=2,
)


# -

def check_xdmf_vertex_fields_exist_in_h5(xdmf_filename, tmp_path=""):
errors = []
with open(xdmf_filename, 'r') as f:
content = f.read()
doctype_match = re.search(r'<!DOCTYPE\s+Xdmf.*?\[(.*?)\]>', content, re.DOTALL)
if not doctype_match:
raise AssertionError("No DOCTYPE entity block found in XDMF file.")
entity_block = doctype_match.group(1)
entities = dict(re.findall(r'<!ENTITY\s+(\w+)\s+"([^"]+\.h5)"\s*>', entity_block))
refs = re.findall(r'&(\w+);:(/vertex_fields/[A-Za-z0-9_]+)', content)
print("Checking vertex field dataset references in XDMF:")
for entity_name, dataset_path in refs:
h5_file = entities.get(entity_name)
if not h5_file:
err = f"[ENTITY NOT FOUND] {entity_name}: {dataset_path}"
print(err)
errors.append(err)
continue
h5_full_path = os.path.join(tmp_path, h5_file)
try:
with h5py.File(h5_full_path, 'r') as f:
h5_path = dataset_path.lstrip('/')
if h5_path in f:
print(f"[OK] {h5_file}: {dataset_path} found")
else:
err = f"[MISSING] {h5_file}: {dataset_path} not found"
print(err)
errors.append(err)
except OSError as e:
err = f"[ERROR] Cannot open {h5_file}: {e}"
print(err)
errors.append(err)
if errors:
raise AssertionError("Missing or inaccessible vertex fields:\n" + "\n".join(errors))


def remove_test_mesh_files(directory='.'):
"""Delete all files starting with test.mesh. in the specified directory."""
pattern = os.path.join(directory, 'test.mesh.*')
for file_path in glob.glob(pattern):
try:
os.remove(file_path)
print(f"Removed: {file_path}")
except Exception as e:
print(f"Could not remove {file_path}: {e}")


@pytest.mark.parametrize(
"mesh",
[
structured_quad_box,
unstructured_quad_box_irregular,
unstructured_quad_box_regular,
unstructured_quad_box_irregular_3D,
],
)
def test_stokes_boxmesh(mesh, tmp_path):
print(f"Mesh - Coordinates: {mesh.CoordinateSystem.type}")
mesh.dm.view()

if mesh.dim == 2:
x, y = mesh.X
else:
x, y, z = mesh.X

u = uw.discretisation.MeshVariable(
'u', mesh, mesh.dim, vtype=uw.VarType.VECTOR, degree=2
)
p = uw.discretisation.MeshVariable(
'p', mesh, 1, vtype=uw.VarType.SCALAR, degree=1
)
u2 = uw.discretisation.MeshVariable(
'u2', mesh, mesh.dim, vtype=uw.VarType.VECTOR, degree=2
)
p2 = uw.discretisation.MeshVariable(
'p2', mesh, 1, vtype=uw.VarType.SCALAR, degree=1
)

stokes = uw.systems.Stokes(mesh, velocityField=u, pressureField=p)
stokes.constitutive_model = uw.constitutive_models.ViscousFlowModel
stokes.constitutive_model.Parameters.shear_viscosity_0 = 1

stokes.petsc_options["snes_type"] = "newtonls"
stokes.petsc_options["ksp_type"] = "fgmres"

stokes.petsc_options["snes_type"] = "newtonls"
stokes.petsc_options["ksp_type"] = "fgmres"
Comment thread
julesghub marked this conversation as resolved.
Outdated
stokes.petsc_options["ksp_monitor"] = None
stokes.petsc_options["snes_monitor"] = None
stokes.tolerance = 1.0e-3

# stokes.petsc_options.setValue("fieldsplit_velocity_pc_type", "mg")
stokes.petsc_options.setValue("fieldsplit_velocity_pc_mg_type", "kaskade")
stokes.petsc_options.setValue("fieldsplit_velocity_pc_mg_cycle_type", "w")

stokes.petsc_options["fieldsplit_velocity_mg_coarse_pc_type"] = "svd"
stokes.petsc_options[f"fieldsplit_velocity_ksp_type"] = "fcg"
stokes.petsc_options[f"fieldsplit_velocity_mg_levels_ksp_type"] = "chebyshev"
stokes.petsc_options[f"fieldsplit_velocity_mg_levels_ksp_max_it"] = 7
stokes.petsc_options[f"fieldsplit_velocity_mg_levels_ksp_converged_maxits"] = None

stokes.petsc_options.setValue("fieldsplit_pressure_pc_type", "gamg")
stokes.petsc_options.setValue("fieldsplit_pressure_pc_mg_type", "multiplicative")
stokes.petsc_options.setValue("fieldsplit_pressure_pc_mg_cycle_type", "v")

if mesh.dim == 2:
stokes.bodyforce = 1.0e6 * sympy.Matrix([0, x])

stokes.add_dirichlet_bc((0.0, 0.0), "Bottom")
stokes.add_dirichlet_bc((0.0, None), "Top")

stokes.add_dirichlet_bc((0.0, None), "Left")
stokes.add_condition(conds=(0.0, None),
label="Right",
f_id=0,
c_type='dirichlet' )
else:
stokes.bodyforce = 1.0e6 * sympy.Matrix([0, x, 0])

stokes.add_dirichlet_bc((0.0, 0.0, 0.0), "Bottom")
stokes.add_dirichlet_bc((0.0, 0.0, 0.0), "Top")

stokes.add_dirichlet_bc((0.0, None, sympy.oo), "Left")
stokes.add_dirichlet_bc((0.0, sympy.oo, None), "Right")

stokes.add_dirichlet_bc((sympy.oo, 0.0, sympy.oo), "Front")
stokes.add_dirichlet_bc((sympy.oo, 0.0, sympy.oo), "Back")

stokes.solve()

print(f"Mesh dimensions {mesh.dim}", flush=True)
# stokes.dm.ds.view()

assert stokes.snes.getConvergedReason() > 0

mesh.write_timestep("test", meshUpdates=False, meshVars=[u, p], outputPath=tmp_path, index=0)

# Call XDMF/HDF5 checker (assume xdmf file is named "test.mesh.xdmf" and written in tmp_path)
xdmf_filename = os.path.join(tmp_path, "test.mesh.00000.xdmf")
check_xdmf_vertex_fields_exist_in_h5(xdmf_filename, tmp_path=str(tmp_path))

u2.read_timestep("test", "u", 0, outputPath=tmp_path)
p2.read_timestep("test", "p", 0, outputPath=tmp_path)

with mesh.access():
assert np.allclose(u.data, u2.data)
assert np.allclose(p.data, p2.data)

remove_test_mesh_files(directory=tmp_path)

del mesh
del stokes
print('----------------------------------------------------------------------------------------')
return