Skip to content

Commit 8b1e13b

Browse files
committed
correclty handle input requirements
1 parent 7defa5c commit 8b1e13b

3 files changed

Lines changed: 85 additions & 43 deletions

File tree

crates/cwl-execution/src/lib.rs

Lines changed: 74 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ pub fn execute_cwlfile(cwlfile: impl AsRef<Path>, raw_inputs: &[String], outdir:
8282
}
8383
}
8484

85-
let output_values = execute(cwlfile, input_values, outdir, None)?;
85+
let output_values = execute(cwlfile, &input_values, outdir, None)?;
8686
let json = serde_json::to_string_pretty(&output_values)?;
8787
println!("{json}");
8888

@@ -91,7 +91,7 @@ pub fn execute_cwlfile(cwlfile: impl AsRef<Path>, raw_inputs: &[String], outdir:
9191

9292
pub fn execute(
9393
cwlfile: impl AsRef<Path>,
94-
input_values: InputObject,
94+
input_values: &InputObject,
9595
outdir: Option<impl AsRef<Path>>,
9696
cwl_doc: Option<&CWLDocument>,
9797
) -> Result<HashMap<String, DefaultValue>, Box<dyn Error>> {
@@ -130,6 +130,11 @@ pub struct InputObject {
130130
#[serde(rename = "cwl:hints")]
131131
#[serde(default)]
132132
pub hints: Vec<Requirement>,
133+
134+
#[serde(skip)]
135+
cwl_requirements: Vec<Requirement>,
136+
#[serde(skip)]
137+
cwl_hints: Vec<Requirement>,
133138
}
134139

135140
impl InputObject {
@@ -141,21 +146,56 @@ impl InputObject {
141146
}
142147

143148
pub fn add_requirement(&mut self, requirement: &Requirement) {
144-
if self
145-
.requirements
146-
.iter()
147-
.any(|r| std::mem::discriminant(r) == std::mem::discriminant(requirement))
149+
if let Some(r) = self
150+
.cwl_requirements
151+
.iter_mut()
152+
.find(|r| std::mem::discriminant(*r) == std::mem::discriminant(requirement))
148153
{
149-
return;
154+
*r = requirement.clone();
155+
} else {
156+
self.cwl_requirements.push(requirement.clone());
150157
}
151-
self.requirements.push(requirement.clone());
152158
}
153159

154160
pub fn add_hint(&mut self, hint: &Requirement) {
155-
if self.hints.iter().any(|r| std::mem::discriminant(r) == std::mem::discriminant(hint)) {
156-
return;
161+
if let Some(r) = self
162+
.cwl_hints
163+
.iter_mut()
164+
.find(|r| std::mem::discriminant(*r) == std::mem::discriminant(hint))
165+
{
166+
*r = hint.clone();
167+
} else {
168+
self.cwl_hints.push(hint.clone());
169+
}
170+
}
171+
172+
pub fn handle_requirements(&self, requirements: &[Requirement], hints: &[Requirement]) -> Self {
173+
let mut new_obj = self.clone();
174+
for hint in hints {
175+
new_obj.add_hint(hint);
176+
}
177+
178+
for req in requirements {
179+
new_obj.add_requirement(req);
157180
}
158-
self.hints.push(hint.clone());
181+
new_obj
182+
}
183+
184+
pub fn lock(&mut self) {
185+
fn merge(dst: &mut Vec<Requirement>, src: &[Requirement]) {
186+
for req in src {
187+
if let Some(r) = dst.iter_mut().find(|r| std::mem::discriminant(*r) == std::mem::discriminant(req)) {
188+
*r = req.clone();
189+
} else {
190+
dst.push(req.clone());
191+
}
192+
}
193+
}
194+
merge(&mut self.cwl_requirements, &self.requirements);
195+
self.requirements = self.cwl_requirements.clone();
196+
197+
merge(&mut self.cwl_hints, &self.hints);
198+
self.hints = self.cwl_hints.clone();
159199
}
160200
}
161201

@@ -264,3 +304,26 @@ pub fn set_container_engine(value: ContainerEngine) {
264304
pub fn container_engine() -> ContainerEngine {
265305
CONTAINER_ENGINE.with(|engine| *engine.borrow())
266306
}
307+
308+
#[cfg(test)]
309+
mod tests {
310+
use super::*;
311+
use cwl::{requirements::EnvVarRequirement, types::EnviromentDefs};
312+
313+
#[test]
314+
fn test_add_requirement() {
315+
let mut input = InputObject::default();
316+
let base_req = Requirement::EnvVarRequirement(EnvVarRequirement {
317+
env_def: EnviromentDefs::Map(HashMap::from([("MY_ENV".to_string(), "BASE".to_string())])),
318+
});
319+
input.add_requirement(&base_req);
320+
assert_eq!(input.requirements.len(), 1);
321+
322+
let requirement = Requirement::EnvVarRequirement(EnvVarRequirement {
323+
env_def: EnviromentDefs::Map(HashMap::from([("MY_ENV".to_string(), "OVERWRITE".to_string())])),
324+
});
325+
input.add_requirement(&requirement);
326+
assert_eq!(input.requirements.len(), 1);
327+
assert_eq!(input.requirements[0], requirement);
328+
}
329+
}

crates/cwl-execution/src/runner.rs

Lines changed: 9 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ use wait_timeout::ChildExt;
4141

4242
pub fn run_workflow(
4343
workflow: &mut Workflow,
44-
input_values: InputObject,
44+
input_values: &InputObject,
4545
cwl_path: Option<&PathBuf>,
4646
out_dir: Option<String>,
4747
) -> Result<HashMap<String, DefaultValue>, Box<dyn Error>> {
@@ -59,14 +59,7 @@ pub fn run_workflow(
5959
};
6060

6161
let workflow_folder = cwl_path.unwrap().parent().unwrap_or(Path::new("."));
62-
63-
let mut input_values = input_values;
64-
for req in &workflow.requirements {
65-
input_values.add_requirement(req);
66-
}
67-
for hint in &workflow.hints {
68-
input_values.add_hint(hint);
69-
}
62+
let input_values = input_values.handle_requirements(&workflow.requirements, &workflow.hints);
7063

7164
//prevent tool from outputting
7265
set_print_output(false);
@@ -124,21 +117,12 @@ pub fn run_workflow(
124117
}
125118
}
126119
}
127-
let mut input_values = InputObject {
128-
inputs: step_inputs,
129-
requirements: input_values.requirements.clone(),
130-
hints: input_values.hints.clone(),
131-
};
132-
for req in &step.requirements {
133-
input_values.add_requirement(req);
134-
}
135-
for hint in &step.hints {
136-
input_values.add_hint(hint);
137-
}
120+
let input_values = input_values.handle_requirements(&step.requirements, &step.hints);
121+
138122
let step_outputs = if let Some(path) = path {
139-
execute(&path, input_values, Some(tmp_path.clone()), None)?
123+
execute(&path, &input_values, Some(tmp_path.clone()), None)?
140124
} else if let StringOrDocument::Document(doc) = &step.run {
141-
execute(workflow_folder, input_values, Some(tmp_path.clone()), Some(doc))?
125+
execute(workflow_folder, &input_values, Some(tmp_path.clone()), Some(doc))?
142126
} else {
143127
unreachable!()
144128
};
@@ -227,7 +211,7 @@ pub fn run_workflow(
227211

228212
pub fn run_tool(
229213
tool: &mut CWLDocument,
230-
input_values: InputObject,
214+
input_values: &InputObject,
231215
cwl_path: Option<&PathBuf>,
232216
out_dir: Option<String>,
233217
) -> Result<HashMap<String, DefaultValue>, Box<dyn Error>> {
@@ -254,13 +238,8 @@ pub fn run_tool(
254238
//create runtime tmpdir
255239
let tmp_dir = tempdir()?;
256240

257-
let mut input_values = input_values;
258-
for req in &tool.requirements {
259-
input_values.add_requirement(req);
260-
}
261-
for hint in &tool.hints {
262-
input_values.add_hint(hint);
263-
}
241+
let mut input_values = input_values.handle_requirements(&tool.requirements, &tool.hints);
242+
input_values.lock();
264243

265244
//build runtime object
266245
let mut runtime = RuntimeEnvironment::initialize(tool, &input_values, dir.path(), tool_path, tmp_dir.path())?;

tests/runner_test.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ baseCommand:
188188
";
189189

190190
let mut tool: CWLDocument = serde_yaml::from_str(cwl).expect("Tool parsing failed");
191-
let result = run_tool(&mut tool, Default::default(), None, None);
191+
let result = run_tool(&mut tool, &Default::default(), None, None);
192192
assert!(result.is_ok());
193193
//delete results.txt
194194
let _ = fs::remove_file("results.txt");
@@ -203,6 +203,6 @@ baseCommand:
203203
pub fn test_run_commandlinetool_array_glob() {
204204
let dir = tempdir().unwrap();
205205
let mut tool = CWLDocument::CommandLineTool(load_tool("tests/test_data/array_test.cwl").expect("Tool parsing failed"));
206-
let result = run_tool(&mut tool, Default::default(), None, Some(dir.path().to_string_lossy().into_owned()));
206+
let result = run_tool(&mut tool, &Default::default(), None, Some(dir.path().to_string_lossy().into_owned()));
207207
assert!(result.is_ok(), "{result:?}");
208208
}

0 commit comments

Comments
 (0)