Skip to content
34 changes: 29 additions & 5 deletions job_creator/jobcreator/job_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,23 @@
from jobcreator.utils import load_kubernetes_config, logger


def _setup_smb_pv(pv_name: str, secret_name: str, secret_namespace: str, source: str, mount_options: list[str]) -> None:
def _setup_smb_pv(
pv_name: str,
secret_name: str,
secret_namespace: str,
source: str,
mount_options: list[str],
access_mode: str = "ReadOnlyMany",
) -> None:
"""
Sets up an smb PV using the loaded kubeconfig as a destination
:param pv_name: str, The name given to the smb-pv when it's made
:param secret_name: str, The name of the secret that contains the credentials for the smb share
:param secret_namespace: str, the namespace of the secret
:param source: str, The IP/url/uri that is used to mount the smb share
:param mount_options: list, The mount options for the smb share
:return: str, the name of the archive PV
:param access_mode: str, The access mode for the PV. Defaults to "ReadOnlyMany"
:return: str, the name of the PV
"""
metadata = client.V1ObjectMeta(name=pv_name, annotations={"pv.kubernetes.io/provisioned-by": "smb.csi.k8s.io"})
secret_ref = client.V1SecretReference(name=secret_name, namespace=secret_namespace)
Expand All @@ -30,13 +38,13 @@ def _setup_smb_pv(pv_name: str, secret_name: str, secret_namespace: str, source:
)
spec = client.V1PersistentVolumeSpec(
capacity={"storage": "1000Gi"},
access_modes=["ReadOnlyMany"],
access_modes=[access_mode],
persistent_volume_reclaim_policy="Retain",
mount_options=mount_options,
csi=csi,
)
archive_pv = client.V1PersistentVolume(api_version="v1", kind="PersistentVolume", metadata=metadata, spec=spec)
client.CoreV1Api().create_persistent_volume(archive_pv)
pv = client.V1PersistentVolume(api_version="v1", kind="PersistentVolume", metadata=metadata, spec=spec)
client.CoreV1Api().create_persistent_volume(pv)


def _setup_pvc(pvc_name: str, pv_name: str, namespace: str, access_mode: str = "ReadOnlyMany") -> None:
Expand Down Expand Up @@ -163,6 +171,15 @@ def _setup_ceph_pv(
return pv_name


def _setup_ngem_pv_and_pvcs(job_name: str, namespace: str, pv_names: list[str], pvc_names: list[str]) -> None:
ngem_pv_name = f"{job_name}-ngem-pv-smb"
ngem_pvc_name = f"{job_name}-ngem-pvc"
_setup_smb_pv(ngem_pv_name, "archive-creds", namespace, "//isis.cclrc.ac.uk/Science", [], "ReadWriteMany")
_setup_pvc(ngem_pvc_name, ngem_pv_name, namespace, "ReadWriteMany")
pv_names.append(ngem_pv_name)
pvc_names.append(ngem_pvc_name)


def _setup_imat_pv_and_pvcs(job_name: str, namespace: str, pv_names: list[str], pvc_names: list[str]) -> None:
imat_pv_name = f"{job_name}-ndximat-pv-smb"
imat_pvc_name = f"{job_name}-ndximat-pvc"
Expand Down Expand Up @@ -385,6 +402,13 @@ def spawn_job( # noqa: PLR0913
client.V1VolumeMount(name="extras-mount", mount_path="/extras"),
]
# Setup special PVs and add them to the volume mounts
if "ngem" in special_pvs:
_setup_ngem_pv_and_pvcs(job_name, job_namespace, pv_names, pvc_names)
ngem_pvc_source = client.V1PersistentVolumeClaimVolumeSource(
claim_name=f"{job_name}-ngem-pvc", read_only=False
)
volumes.append(client.V1Volume(name="ngem-mount", persistent_volume_claim=ngem_pvc_source))
volumes_mounts.append(client.V1VolumeMount(name="ngem-mount", mount_path="/ngem"))
if "imat" in special_pvs:
_setup_imat_pv_and_pvcs(job_name, job_namespace, pv_names, pvc_names)
imat_pvc_source = client.V1PersistentVolumeClaimVolumeSource(
Expand Down
63 changes: 47 additions & 16 deletions job_creator/jobcreator/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,17 @@
CONSUMER_PASSWORD = os.environ.get("QUEUE_PASSWORD", "")
REDUCE_USER_ID = os.environ.get("REDUCE_USER_ID", "")
JOB_NAMESPACE = os.environ.get("JOB_NAMESPACE", "fia")
JOB_CREATOR = JobCreator(dev_mode=DEV_MODE, watcher_sha=WATCHER_SHA)
JOB_CREATOR: JobCreator | None = None


def get_job_creator() -> JobCreator:
Comment thread
Pasarus marked this conversation as resolved.
global JOB_CREATOR # noqa: PLW0603
if JOB_CREATOR is None:
if WATCHER_SHA is None:
raise OSError("WATCHER_SHA not set in the environment, please add it.")
JOB_CREATOR = JobCreator(dev_mode=DEV_MODE, watcher_sha=WATCHER_SHA)
return JOB_CREATOR


CEPH_CREDS_SECRET_NAME = os.environ.get("CEPH_CREDS_SECRET_NAME", "ceph-creds")
CEPH_CREDS_SECRET_NAMESPACE = os.environ.get("CEPH_CREDS_SECRET_NAMESPACE", "fia")
Expand All @@ -59,7 +69,7 @@
MAX_TIME_TO_COMPLETE = int(os.environ.get("MAX_TIME_TO_COMPLETE", str(60 * 60 * 6)))


def _generate_special_pvs(instrument: str) -> list[str]:
def _generate_special_pvs(instrument: str, additional_values: dict[str, Any]) -> list[str]:
"""
A generic function for, based on passed args, returning what the special persistent volumes should be.
"""
Expand All @@ -68,19 +78,31 @@ def _generate_special_pvs(instrument: str) -> list[str]:
match instrument.lower():
case "imat":
logger.info("Special PV for %s added.", instrument)
special_pvs.append("imat")
if "ngem" in additional_values and additional_values["ngem"] == "true":
special_pvs.append("ngem")
else:
special_pvs.append("imat")
case "ines":
logger.info("Special PV for %s added.", instrument)
if "ngem" in additional_values and additional_values["ngem"] == "true":
special_pvs.append("ngem")
else:
special_pvs.append("ines")
case _:
logger.info("No special PV needed for %s", instrument)

return special_pvs


def _select_runner_image(instrument: str) -> str:
def _select_runner_image(instrument: str, additional_values: dict[str, Any]) -> str:
"""
A generic function for, based on passed args, returning what the runner that should be used.
"""
match instrument.lower():
case "imat":
if "ngem" in additional_values and additional_values["ngem"] == "true":
# For ngem we want to return the default mantid runner. INES always wants mantid default runner.
return DEFAULT_RUNNER
if IMAGING_RUNNER_SHA is not None:
logger.info("Imaging runner image selected for %s ", instrument)
return IMAGING_RUNNER
Expand All @@ -91,7 +113,9 @@ def _select_runner_image(instrument: str) -> str:
return DEFAULT_RUNNER


def _select_taints_and_affinity(instrument: str) -> tuple[list[dict[str, Any]], dict[str, Any] | None]:
def _select_taints_and_affinity(
instrument: str, additional_values: dict[str, Any]
) -> tuple[list[dict[str, Any]], dict[str, Any] | None]:
"""
A generic function for, based on passed args, returning what the runner that should be used.
"""
Expand All @@ -100,9 +124,10 @@ def _select_taints_and_affinity(instrument: str) -> tuple[list[dict[str, Any]],

match instrument.lower():
case "imat":
logger.info("Applying taint to the job on instrument %s", instrument)
taints.append({"key": "nvidia.com/gpu", "effect": "NoSchedule", "operator": "Exists"})
affinity = {"key": "node-type", "operator": "In", "values": ["gpu-worker"]}
if "ngem" not in additional_values or additional_values["ngem"] != "true":
logger.info("Applying taint to the job on instrument %s", instrument)
taints.append({"key": "nvidia.com/gpu", "effect": "NoSchedule", "operator": "Exists"})
affinity = {"key": "node-type", "operator": "In", "values": ["gpu-worker"]}
case _:
logger.info("No taints applied to %s runners", instrument)

Expand Down Expand Up @@ -146,7 +171,7 @@ def process_simple_message(message: dict[str, Any]) -> None:
{"user_number": str(user_number)} if user_number else {"experiment_number": str(experiment_number)}
)
ceph_mount_path = create_ceph_mount_path_simple(**ceph_mount_path_kwargs)
JOB_CREATOR.spawn_job(
get_job_creator().spawn_job(
job_name=job_name,
script=script,
job_namespace=JOB_NAMESPACE,
Expand Down Expand Up @@ -185,12 +210,16 @@ def process_rerun_message(message: dict[str, Any]) -> None:
rb_number=str(message["rb_number"]),
)

special_pvs = _generate_special_pvs(instrument=message["instrument"])
taints, affinity = _select_taints_and_affinity(instrument=message["instrument"])
special_pvs = _generate_special_pvs(
instrument=message["instrument"], additional_values=message.get("additional_values", {})
)
taints, affinity = _select_taints_and_affinity(
instrument=message["instrument"], additional_values=message.get("additional_values", {})
)

# Add UUID which will avoid collisions for reruns
job_name = f"run-{str(message['filename']).lower()}-{uuid.uuid4().hex!s}"
JOB_CREATOR.spawn_job(
get_job_creator().spawn_job(
job_name=job_name,
script=script,
job_namespace=JOB_NAMESPACE,
Expand Down Expand Up @@ -226,7 +255,7 @@ def process_autoreduction_message(message: dict[str, Any]) -> None:
instrument_name = message["instrument"]
runner_image = message.get("runner_image")
if runner_image is None:
runner_image = _select_runner_image(instrument_name)
runner_image = _select_runner_image(instrument_name, message["additional_values"])
runner_image = find_sha256_of_image(runner_image)
autoreduction_request = {
"filename": filename,
Expand All @@ -242,8 +271,10 @@ def process_autoreduction_message(message: dict[str, Any]) -> None:
"runner_image": runner_image,
}

special_pvs = _generate_special_pvs(instrument=instrument_name)
taints, affinity = _select_taints_and_affinity(instrument=message["instrument"])
special_pvs = _generate_special_pvs(instrument=instrument_name, additional_values=message["additional_values"])
taints, affinity = _select_taints_and_affinity(
instrument=message["instrument"], additional_values=message["additional_values"]
)

# Add UUID which will avoid collisions for reruns
job_name = f"run-{filename.lower()}-{uuid.uuid4().hex!s}"
Expand All @@ -253,7 +284,7 @@ def process_autoreduction_message(message: dict[str, Any]) -> None:
autoreduction_request=autoreduction_request,
)
ceph_mount_path = create_ceph_mount_path_autoreduction(instrument_name, rb_number)
JOB_CREATOR.spawn_job(
get_job_creator().spawn_job(
job_name=job_name,
script=script,
job_namespace=JOB_NAMESPACE,
Expand Down
136 changes: 136 additions & 0 deletions job_creator/test/test_job_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@

from jobcreator.job_creator import (
JobCreator,
_generate_affinities,
_generate_tolerations_from_taints,
_setup_ceph_pv,
_setup_extras_pv,
_setup_extras_pvc,
_setup_imat_pv_and_pvcs,
_setup_ngem_pv_and_pvcs,
_setup_pvc,
_setup_smb_pv,
)
Expand Down Expand Up @@ -155,6 +159,83 @@ def test_setup_extras_pv(client):
client.V1SecretReference.assert_called_once_with(name="manila-creds", namespace=secret_namespace)


EXPECTED_TOLERATIONS_COUNT = 2


@mock.patch("jobcreator.job_creator.client")
def test_generate_tolerations_from_taints(client):
taints = [
{"key": "key1", "value": "value1", "operator": "Equal", "effect": "NoSchedule"},
{"key": "key2", "operator": "Exists", "effect": "NoExecute"},
]
tolerations = _generate_tolerations_from_taints(taints)

assert len(tolerations) == EXPECTED_TOLERATIONS_COUNT
client.V1Toleration.assert_has_calls(
[
call(key="key1", value="value1", operator="Equal", effect="NoSchedule"),
call(key="key2", value=None, operator="Exists", effect="NoExecute"),
]
)


@mock.patch("jobcreator.job_creator.client")
def test_generate_affinities_none(client):
affinity = _generate_affinities(None)
assert affinity == client.V1Affinity.return_value
client.V1Affinity.assert_called_once_with(pod_anti_affinity=client.V1PodAntiAffinity.return_value)


@mock.patch("jobcreator.job_creator.logger")
@mock.patch("jobcreator.job_creator.client")
def test_generate_affinities_missing_key(client, logger):
node_affinity_dict = {"key": "some-key", "operator": "In"} # missing "values"
_generate_affinities(node_affinity_dict)
logger.error.assert_called_once()
client.V1Affinity.assert_called_once_with(pod_anti_affinity=client.V1PodAntiAffinity.return_value)


@mock.patch("jobcreator.job_creator.client")
def test_generate_affinities_valid(client):
node_affinity_dict = {"key": "some-key", "operator": "In", "values": ["val1"]}
_generate_affinities(node_affinity_dict)
client.V1Affinity.assert_called_once_with(
pod_anti_affinity=client.V1PodAntiAffinity.return_value, node_affinity=client.V1NodeAffinity.return_value
)


@mock.patch("jobcreator.job_creator._setup_pvc")
@mock.patch("jobcreator.job_creator._setup_smb_pv")
def test_setup_ngem_pv_and_pvcs(setup_smb_pv, setup_pvc):
pv_names = []
pvc_names = []
_setup_ngem_pv_and_pvcs("job1", "ns1", pv_names, pvc_names)
setup_smb_pv.assert_called_once_with(
"job1-ngem-pv-smb", "archive-creds", "ns1", "//isis.cclrc.ac.uk/Science", [], "ReadWriteMany"
)
setup_pvc.assert_called_once_with("job1-ngem-pvc", "job1-ngem-pv-smb", "ns1", "ReadWriteMany")
assert pv_names == ["job1-ngem-pv-smb"]
assert pvc_names == ["job1-ngem-pvc"]


@mock.patch("jobcreator.job_creator._setup_pvc")
@mock.patch("jobcreator.job_creator._setup_smb_pv")
def test_setup_imat_pv_and_pvcs(setup_smb_pv, setup_pvc):
pv_names = []
pvc_names = []
_setup_imat_pv_and_pvcs("job1", "ns1", pv_names, pvc_names)
setup_smb_pv.assert_called_once_with(
"job1-ndximat-pv-smb",
"imat-creds",
"ns1",
"//NDXIMAT.isis.cclrc.ac.uk/data$/",
[],
)
setup_pvc.assert_called_once_with("job1-ndximat-pvc", "job1-ndximat-pv-smb", "ns1")
assert pv_names == ["job1-ndximat-pv-smb"]
assert pvc_names == ["job1-ndximat-pvc"]


@mock.patch("jobcreator.job_creator.client")
def test_setup_ceph_pv(client):
pv_name = mock.MagicMock()
Expand Down Expand Up @@ -219,6 +300,61 @@ def test_jobcreator_init(mock_load_kubernetes_config):
mock_load_kubernetes_config.assert_called_once()


@mock.patch("jobcreator.job_creator._setup_ngem_pv_and_pvcs")
@mock.patch("jobcreator.job_creator._setup_imat_pv_and_pvcs")
@mock.patch("jobcreator.job_creator._setup_extras_pv")
@mock.patch("jobcreator.job_creator._setup_extras_pvc")
@mock.patch("jobcreator.job_creator._setup_smb_pv")
@mock.patch("jobcreator.job_creator._setup_pvc")
@mock.patch("jobcreator.job_creator._setup_ceph_pv")
@mock.patch("jobcreator.job_creator.load_kubernetes_config")
@mock.patch("jobcreator.job_creator.client")
def test_jobcreator_spawn_job_ngem(
client,
_, # noqa: PT019
setup_ceph_pv,
setup_pvc,
setup_smb_pv,
setup_extras_pvc,
setup_extras_pv,
setup_imat_pv,
setup_ngem_pv,
):
job_name = "test-job"
script = "test-script"
job_namespace = "test-ns"
watcher_sha = "test-sha"
job_creator = JobCreator(watcher_sha, False)

job_creator.spawn_job(
job_name=job_name,
script=script,
job_namespace=job_namespace,
ceph_creds_k8s_secret_name="some-secret-name", # noqa: S106
ceph_creds_k8s_namespace="ns",
cluster_id="id",
fs_name="fs",
ceph_mount_path="/path",
job_id=1,
max_time_to_complete_job=100,
fia_api_host="host",
fia_api_api_key="key",
runner_image="image",
manila_share_id="mid",
manila_share_access_id="maid",
special_pvs=["ngem"],
taints=[],
affinity=None,
)

setup_ngem_pv.assert_called_once()
# Check that ngem volume and volume mount were added
# We check if V1Volume was called with name="ngem-mount"
assert any(c.kwargs.get("name") == "ngem-mount" for c in client.V1Volume.call_args_list)
# Check if V1VolumeMount was called with name="ngem-mount"
assert any(c.kwargs.get("name") == "ngem-mount" for c in client.V1VolumeMount.call_args_list)


@mock.patch("jobcreator.job_creator._setup_extras_pv")
@mock.patch("jobcreator.job_creator._setup_extras_pvc")
@mock.patch("jobcreator.job_creator._setup_smb_pv")
Expand Down
Loading