Skip to content

Commit ee4c150

Browse files
committed
Fixed invalid semantics for finished tasks
1 parent 83d4774 commit ee4c150

8 files changed

Lines changed: 127 additions & 149 deletions

File tree

crates/tako/src/internal/scheduler/metrics.rs

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
use crate::TaskId;
21
use crate::internal::common::{Map, Set};
32
use crate::internal::server::task::Task;
43
use crate::internal::server::taskmap::TaskMap;
4+
use crate::TaskId;
55

66
pub fn compute_b_level_metric(tasks: &mut TaskMap) {
77
crawl(tasks, |t| t.get_consumers());
@@ -30,26 +30,25 @@ fn crawl<F1: Fn(&Task) -> &Set<TaskId>>(tasks: &mut TaskMap, predecessor_fn: F1)
3030
task.set_scheduler_priority(level + 1);
3131

3232
for t in task.task_deps.iter() {
33-
let v: &mut u32 = neighbours
34-
.get_mut(t)
35-
.expect("Couldn't find task neighbour in level computation");
36-
if *v <= 1 {
37-
assert_eq!(*v, 1);
38-
stack.push(*t);
39-
} else {
40-
*v -= 1;
33+
if let Some(v) = neighbours.get_mut(t) {
34+
if *v <= 1 {
35+
assert_eq!(*v, 1);
36+
stack.push(*t);
37+
} else {
38+
*v -= 1;
39+
}
4140
}
4241
}
4342
}
4443
}
4544

4645
#[cfg(test)]
4746
mod tests {
48-
use crate::TaskId;
4947
use crate::internal::common::index::ItemId;
5048
use crate::internal::scheduler::metrics::compute_b_level_metric;
5149
use crate::internal::server::core::Core;
5250
use crate::internal::tests::utils::workflows::submit_example_2;
51+
use crate::TaskId;
5352

5453
#[test]
5554
fn b_level_simple_graph() {

crates/tako/src/internal/scheduler/state.rs

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ impl SchedulerState {
276276
}
277277
TaskRuntimeState::Running { .. }
278278
| TaskRuntimeState::RunningMultiNode(_)
279-
| TaskRuntimeState::Finished(_) => {
279+
| TaskRuntimeState::Finished => {
280280
panic!("Invalid state {:?}", task.state);
281281
}
282282
};
@@ -379,14 +379,13 @@ impl SchedulerState {
379379
//log::debug!("Task {} initially assigned to {}", task.id, worker_id);
380380
};
381381
if let Some(worker_id) = worker_id {
382-
debug_assert!(
383-
core.get_worker_map()
384-
.get_worker(worker_id)
385-
.is_capable_to_run_rqv(
386-
&core.get_task(task_id).configuration.resources,
387-
self.now
388-
)
389-
);
382+
debug_assert!(core
383+
.get_worker_map()
384+
.get_worker(worker_id)
385+
.is_capable_to_run_rqv(
386+
&core.get_task(task_id).configuration.resources,
387+
self.now
388+
));
390389
self.assign(core, task_id, worker_id);
391390
} else {
392391
core.add_sleeping_sn_task(task_id);

crates/tako/src/internal/server/client.rs

Lines changed: 26 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,9 @@ pub(crate) async fn process_client_message(
6565
),
6666
FromGatewayMessage::ServerInfo => {
6767
let core = core_ref.get();
68-
assert!(
69-
client_sender
70-
.send(ToGatewayMessage::ServerInfo(core.get_server_info()))
71-
.is_ok()
72-
);
68+
assert!(client_sender
69+
.send(ToGatewayMessage::ServerInfo(core.get_server_info()))
70+
.is_ok());
7371
None
7472
}
7573
FromGatewayMessage::GetTaskInfo(request) => {
@@ -91,18 +89,16 @@ pub(crate) async fn process_client_message(
9189
TaskRuntimeState::Stealing(_, _) => TaskState::Waiting,
9290
TaskRuntimeState::Running { .. } => TaskState::Waiting,
9391
TaskRuntimeState::RunningMultiNode(_) => TaskState::Waiting,
94-
TaskRuntimeState::Finished(_) => TaskState::Finished,
92+
TaskRuntimeState::Finished => TaskState::Finished,
9593
},
9694
}
9795
})
9896
.collect();
99-
assert!(
100-
client_sender
101-
.send(ToGatewayMessage::TaskInfo(TasksInfoResponse {
102-
tasks: task_infos
103-
}))
104-
.is_ok()
105-
);
97+
assert!(client_sender
98+
.send(ToGatewayMessage::TaskInfo(TasksInfoResponse {
99+
tasks: task_infos
100+
}))
101+
.is_ok());
106102
None
107103
}
108104
FromGatewayMessage::CancelTasks(msg) => {
@@ -111,14 +107,12 @@ pub(crate) async fn process_client_message(
111107
let mut comm = comm_ref.get_mut();
112108
let (cancelled_tasks, already_finished) =
113109
on_cancel_tasks(&mut core, &mut *comm, &msg.tasks);
114-
assert!(
115-
client_sender
116-
.send(ToGatewayMessage::CancelTasksResponse(CancelTasksResponse {
117-
cancelled_tasks,
118-
already_finished
119-
}))
120-
.is_ok()
121-
);
110+
assert!(client_sender
111+
.send(ToGatewayMessage::CancelTasksResponse(CancelTasksResponse {
112+
cancelled_tasks,
113+
already_finished
114+
}))
115+
.is_ok());
122116
None
123117
}
124118
FromGatewayMessage::StopWorker(msg) => {
@@ -151,11 +145,9 @@ pub(crate) async fn process_client_message(
151145
let core = core_ref.get();
152146
compute_new_worker_query(&core, &msg.worker_queries)
153147
};
154-
assert!(
155-
client_sender
156-
.send(ToGatewayMessage::NewWorkerAllocationQueryResponse(response))
157-
.is_ok()
158-
);
148+
assert!(client_sender
149+
.send(ToGatewayMessage::NewWorkerAllocationQueryResponse(response))
150+
.is_ok());
159151
None
160152
}
161153
FromGatewayMessage::TryReleaseMemory => {
@@ -169,11 +161,9 @@ pub(crate) async fn process_client_message(
169161
.get_worker_map()
170162
.get(&worker_id)
171163
.map(|w| w.worker_info(core.task_map()));
172-
assert!(
173-
client_sender
174-
.send(ToGatewayMessage::WorkerInfo(response))
175-
.is_ok()
176-
);
164+
assert!(client_sender
165+
.send(ToGatewayMessage::WorkerInfo(response))
166+
.is_ok());
177167
None
178168
}
179169
}
@@ -227,12 +217,10 @@ fn handle_new_tasks(
227217
}
228218
on_new_tasks(core, comm, tasks);
229219

230-
assert!(
231-
client_sender
232-
.send(ToGatewayMessage::NewTasksResponse(NewTasksResponse {
233-
n_waiting_for_workers: 0 // TODO
234-
}))
235-
.is_ok()
236-
);
220+
assert!(client_sender
221+
.send(ToGatewayMessage::NewTasksResponse(NewTasksResponse {
222+
n_waiting_for_workers: 0 // TODO
223+
}))
224+
.is_ok());
237225
None
238226
}

crates/tako/src/internal/server/core.rs

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,13 @@ impl Core {
313313
.tasks
314314
.remove(task_id)
315315
.expect("Trying to remove non-existent task");
316-
assert!(!task.has_consumers());
316+
if matches!(&task.state, TaskRuntimeState::Waiting(w) if w.unfinished_deps > 0) {
317+
for input_id in task.task_deps {
318+
if let Some(input) = self.find_task_mut(input_id) {
319+
assert!(input.remove_consumer(task_id));
320+
}
321+
}
322+
}
317323
task.state
318324
}
319325

@@ -364,7 +370,7 @@ impl Core {
364370
pub fn sanity_check(&self) {
365371
let fw_check = |task: &Task| {
366372
for task_dep in &task.task_deps {
367-
assert!(self.tasks.get_task(*task_dep).is_finished());
373+
assert!(self.tasks.find_task(*task_dep).is_none());
368374
}
369375
for &task_id in task.get_consumers() {
370376
assert!(self.tasks.get_task(task_id).is_waiting());
@@ -399,7 +405,11 @@ impl Core {
399405
TaskRuntimeState::Waiting(winfo) => {
400406
let mut count = 0;
401407
for task_dep in &task.task_deps {
402-
if !self.tasks.get_task(*task_dep).is_finished() {
408+
if !self
409+
.tasks
410+
.find_task(*task_dep)
411+
.is_none_or(|t| t.is_finished())
412+
{
403413
count += 1;
404414
}
405415
}
@@ -424,9 +434,9 @@ impl Core {
424434
worker_check_sn(self, task.id, target.unwrap_or(WorkerId::new(0)));
425435
}
426436

427-
TaskRuntimeState::Finished(_) => {
437+
TaskRuntimeState::Finished => {
428438
for task_dep in &task.task_deps {
429-
assert!(self.tasks.get_task(*task_dep).is_finished());
439+
assert!(self.tasks.find_task(*task_dep).is_none());
430440
}
431441
}
432442
TaskRuntimeState::RunningMultiNode(ws) => {

crates/tako/src/internal/server/reactor.rs

Lines changed: 23 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use crate::internal::messages::worker::{
77
};
88
use crate::internal::server::comm::Comm;
99
use crate::internal::server::core::Core;
10-
use crate::internal::server::task::{FinishInfo, WaitingInfo};
10+
use crate::internal::server::task::WaitingInfo;
1111
use crate::internal::server::task::{Task, TaskRuntimeState};
1212
use crate::internal::server::worker::Worker;
1313
use crate::internal::server::workermap::WorkerMap;
@@ -84,7 +84,7 @@ pub(crate) fn on_remove_worker(
8484
continue;
8585
}
8686
}
87-
TaskRuntimeState::Finished(_finfo) => {
87+
TaskRuntimeState::Finished => {
8888
continue;
8989
}
9090
TaskRuntimeState::RunningMultiNode(ws) => {
@@ -164,13 +164,17 @@ pub(crate) fn on_new_tasks(core: &mut Core, comm: &mut impl Comm, new_tasks: Vec
164164
assert!(!new_tasks.is_empty());
165165
for mut task in new_tasks.into_iter() {
166166
let mut count = 0;
167-
for t in task.task_deps.iter() {
168-
let task_dep = core.get_task_mut(*t);
169-
task_dep.add_consumer(task.id);
170-
if !task_dep.is_finished() {
171-
count += 1
167+
task.task_deps.retain(|t| {
168+
if let Some(task_dep) = core.find_task_mut(*t) {
169+
task_dep.add_consumer(task.id);
170+
if !task_dep.is_finished() {
171+
count += 1
172+
}
173+
true
174+
} else {
175+
false
172176
}
173-
}
177+
});
174178
assert!(matches!(
175179
task.state,
176180
TaskRuntimeState::Waiting(WaitingInfo { unfinished_deps: 0 })
@@ -225,7 +229,7 @@ pub(crate) fn on_task_running(
225229
}
226230
TaskRuntimeState::Running { .. }
227231
| TaskRuntimeState::Waiting(_)
228-
| TaskRuntimeState::Finished(_) => {
232+
| TaskRuntimeState::Finished => {
229233
unreachable!()
230234
}
231235
};
@@ -276,7 +280,7 @@ pub(crate) fn on_task_finished(
276280
assert_eq!(*w_id, worker_id);
277281
/* Do nothing */
278282
}
279-
TaskRuntimeState::Waiting(_) | TaskRuntimeState::Finished(_) => {
283+
TaskRuntimeState::Waiting(_) | TaskRuntimeState::Finished => {
280284
unreachable!();
281285
}
282286
}
@@ -287,7 +291,7 @@ pub(crate) fn on_task_finished(
287291
placement.insert(worker_id);
288292
}
289293

290-
task.state = TaskRuntimeState::Finished(FinishInfo {});
294+
task.state = TaskRuntimeState::Finished;
291295
comm.ask_for_scheduling();
292296
comm.send_client_task_finished(task.id);
293297
} else {
@@ -315,8 +319,8 @@ pub(crate) fn on_task_finished(
315319
core.add_ready_to_assign(id);
316320
comm.ask_for_scheduling();
317321
}
318-
unregister_as_consumer(core, comm, msg.id);
319-
remove_task_if_possible(core, comm, msg.id);
322+
let state = core.remove_task(msg.id);
323+
assert!(matches!(state, TaskRuntimeState::Finished));
320324
}
321325

322326
pub(crate) fn on_steal_response(
@@ -423,18 +427,13 @@ fn fail_task_helper(
423427
}
424428
};
425429

426-
// TODO: take taskmap in `unregister_as_consumer`
427-
unregister_as_consumer(core, comm, task_id);
428-
429430
for &consumer in &consumers {
430-
{
431-
let task = core.get_task(consumer);
432-
log::debug!("Task={} canceled because of failed dependency", task.id);
433-
assert!(task.is_waiting());
434-
}
435-
unregister_as_consumer(core, comm, consumer);
431+
log::debug!("Task={} canceled because of failed dependency", consumer);
432+
assert!(matches!(
433+
core.remove_task(consumer),
434+
TaskRuntimeState::Waiting(_)
435+
));
436436
}
437-
438437
let state = core.remove_task(task_id);
439438
if worker_id.is_some() {
440439
assert!(matches!(
@@ -448,14 +447,6 @@ fn fail_task_helper(
448447
assert!(matches!(state, TaskRuntimeState::Waiting(_)));
449448
}
450449
drop(state);
451-
452-
for &consumer in &consumers {
453-
// We can drop the resulting state as checks was done earlier
454-
assert!(matches!(
455-
core.remove_task(consumer),
456-
TaskRuntimeState::Waiting(_)
457-
));
458-
}
459450
comm.send_client_task_error(task_id, consumers, error_info);
460451
}
461452

@@ -522,17 +513,13 @@ pub(crate) fn on_cancel_tasks(
522513
}
523514
running_ids.entry(from_id).or_default().push(task_id);
524515
}
525-
TaskRuntimeState::Finished(_) => {
516+
TaskRuntimeState::Finished => {
526517
already_finished.push(task_id);
527518
}
528519
};
529520
}
530521
}
531522

532-
for &task_id in &to_unregister {
533-
unregister_as_consumer(core, comm, task_id);
534-
}
535-
536523
core.remove_tasks_batched(&to_unregister);
537524

538525
for (w_id, ids) in running_ids {
@@ -542,23 +529,3 @@ pub(crate) fn on_cancel_tasks(
542529
comm.ask_for_scheduling();
543530
(to_unregister.into_iter().collect(), already_finished)
544531
}
545-
546-
fn unregister_as_consumer(core: &mut Core, comm: &mut impl Comm, task_id: TaskId) {
547-
let inputs: Vec<TaskId> = core.get_task(task_id).task_deps.iter().copied().collect();
548-
for input_id in inputs {
549-
let input = core.get_task_mut(input_id);
550-
assert!(input.remove_consumer(task_id));
551-
remove_task_if_possible(core, comm, input_id);
552-
}
553-
}
554-
555-
fn remove_task_if_possible(core: &mut Core, _comm: &mut impl Comm, task_id: TaskId) {
556-
if !core.get_task(task_id).is_removable() {
557-
return;
558-
}
559-
match core.remove_task(task_id) {
560-
TaskRuntimeState::Finished(_finfo) => { /* Ok */ }
561-
_ => unreachable!(),
562-
};
563-
log::debug!("Task id={task_id} is no longer needed");
564-
}

0 commit comments

Comments
 (0)