Skip to content

Commit 66c833d

Browse files
committed
In heartbeat handle, proactively check for missed min_time.
1 parent ac5ff22 commit 66c833d

File tree

6 files changed

+70
-23
lines changed

6 files changed

+70
-23
lines changed

crates/tako/src/internal/messages/worker.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ pub enum WorkerTaskUpdate {
141141
RunningPrefilled(TaskRunningMsg),
142142
RejectRequest {
143143
task_id: TaskId,
144-
rv_id: ResourceVariantId,
144+
rv_id: Option<ResourceVariantId>,
145145
},
146146
EnableRequest {
147147
resource_rq_id: ResourceRqId,

crates/tako/src/internal/server/reactor.rs

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ fn task_reject(
357357
comm: &mut impl Comm,
358358
worker_id: WorkerId,
359359
task_id: TaskId,
360-
resource_rq_variant: ResourceVariantId,
360+
resource_rq_variant: Option<ResourceVariantId>,
361361
) -> bool {
362362
let CoreSplitMut {
363363
task_map,
@@ -371,23 +371,27 @@ fn task_reject(
371371
log::debug!("Unknown task rejected id={task_id}");
372372
return false;
373373
};
374-
log::debug!("Task id={task_id} (variant={resource_rq_variant}) rejected on worker={worker_id}");
374+
log::debug!(
375+
"Task id={task_id} (variant={resource_rq_variant:?}) rejected on worker={worker_id}"
376+
);
375377
let worker = worker_map.get_worker_mut(worker_id);
376378
let resource_rq_id = task.resource_rq_id;
377-
worker.block_request(resource_rq_id, resource_rq_variant);
379+
if let Some(rv_id) = resource_rq_variant {
380+
worker.block_request(resource_rq_id, rv_id);
381+
}
378382
match &task.state {
379383
TaskRuntimeState::Assigned {
380384
worker_id: w_id,
381385
rv_id,
382386
} => {
383387
if worker_id != *w_id {
384388
log::debug!("Rejection from invalid worker");
389+
} else if resource_rq_variant != Some(*rv_id) {
390+
log::debug!("Rejection invalid variant");
391+
} else {
392+
let rq = request_map.get(resource_rq_id).get(*rv_id);
393+
worker.remove_sn_task(task_id, rq);
385394
}
386-
if resource_rq_variant != *rv_id {
387-
log::debug!("Rejection from invalid worker");
388-
}
389-
let rq = request_map.get(resource_rq_id).get(resource_rq_variant);
390-
worker.remove_sn_task(task_id, rq);
391395
}
392396
TaskRuntimeState::Prefilled { worker_id: w_id } => {
393397
if worker_id != *w_id {

crates/tako/src/internal/tests/test_reactor.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -672,7 +672,7 @@ fn test_task_reject1() {
672672
w,
673673
smallvec![WorkerTaskUpdate::RejectRequest {
674674
task_id: t,
675-
rv_id: ResourceVariantId::new(0)
675+
rv_id: Some(ResourceVariantId::new(0))
676676
}],
677677
);
678678
comm.check_need_scheduling();
@@ -717,7 +717,7 @@ fn test_task_reject2() {
717717
w,
718718
smallvec![WorkerTaskUpdate::RejectRequest {
719719
task_id: t,
720-
rv_id: ResourceVariantId::new(0)
720+
rv_id: Some(ResourceVariantId::new(0))
721721
}],
722722
);
723723
comm.check_need_scheduling();
@@ -746,7 +746,7 @@ fn test_task_reject3() {
746746
w,
747747
smallvec![WorkerTaskUpdate::RejectRequest {
748748
task_id: t1,
749-
rv_id: ResourceVariantId::new(0)
749+
rv_id: Some(ResourceVariantId::new(0))
750750
}],
751751
);
752752
comm.check_need_scheduling();
@@ -760,7 +760,7 @@ fn test_task_reject3() {
760760
w,
761761
smallvec![WorkerTaskUpdate::RejectRequest {
762762
task_id: t2,
763-
rv_id: ResourceVariantId::new(0)
763+
rv_id: Some(ResourceVariantId::new(0))
764764
}],
765765
);
766766
comm.check_need_scheduling();
@@ -894,7 +894,7 @@ fn test_prefill_rejected() {
894894
let (w1, _t1, t2) = setup_prefill(&mut rt);
895895
let up = WorkerTaskUpdate::RejectRequest {
896896
task_id: t2,
897-
rv_id: 0.into(),
897+
rv_id: Some(0.into()),
898898
};
899899
let mut comm = TestComm::new();
900900
on_task_update(rt.core(), &mut comm, w1, smallvec![up]);
@@ -1080,7 +1080,7 @@ fn test_steal_rejected() {
10801080
let (w1, w2, t) = setup_retracting(&mut rt);
10811081
let up = WorkerTaskUpdate::RejectRequest {
10821082
task_id: t,
1083-
rv_id: 0.into(),
1083+
rv_id: Some(0.into()),
10841084
};
10851085
let mut comm = TestComm::new();
10861086
on_task_update(rt.core(), &mut comm, w1, smallvec![up]);

crates/tako/src/internal/worker/reactor.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ fn try_alloc_and_start_task(
6262
state.blocked_requests.insert((task.resource_rq_id, rv_id));
6363
task_updates.push(WorkerTaskUpdate::RejectRequest {
6464
task_id: task.id,
65-
rv_id,
65+
rv_id: Some(rv_id),
6666
});
6767
return;
6868
};
@@ -111,7 +111,7 @@ fn try_start_task(
111111
// Hard reject, we never unblock this rejection so we do not need to update blocked requests
112112
task_updates.push(WorkerTaskUpdate::RejectRequest {
113113
task_id: task.id,
114-
rv_id,
114+
rv_id: Some(rv_id),
115115
});
116116
return Some(allocation);
117117
}

crates/tako/src/internal/worker/rpc.rs

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ use crate::internal::common::WrappedRcRefCell;
1919
use crate::internal::common::resources::Allocation;
2020
use crate::internal::common::resources::map::ResourceIdMap;
2121
use crate::internal::messages::worker::{
22-
FromWorkerMessage, RetractResponseMsg, TaskResourceAllocation, ToWorkerMessage, WorkerOverview,
23-
WorkerRegistrationResponse, WorkerStopReason,
22+
FromWorkerMessage, RetractResponseMsg, TaskResourceAllocation, TaskUpdates, ToWorkerMessage,
23+
WorkerOverview, WorkerRegistrationResponse, WorkerStopReason, WorkerTaskUpdate,
2424
};
2525
use crate::internal::server::rpc::ConnectionDescriptor;
2626
use crate::internal::transfer::auth::{
@@ -305,10 +305,39 @@ async fn heartbeat_process(heartbeat_interval: Duration, state_ref: WrappedRcRef
305305
let mut interval = tokio::time::interval(heartbeat_interval);
306306
loop {
307307
interval.tick().await;
308-
state_ref
309-
.get_mut()
310-
.comm()
311-
.send_message_to_server(FromWorkerMessage::Heartbeat);
308+
{
309+
let mut state = state_ref.get_mut();
310+
state
311+
.comm()
312+
.send_message_to_server(FromWorkerMessage::Heartbeat);
313+
if !state.prefilled_tasks.is_empty()
314+
&& let Some(remaining_time) = state.remaining_time()
315+
{
316+
let mut to_remove = Vec::new();
317+
let mut updates = TaskUpdates::new();
318+
for (rq_id, tasks) in &state.prefilled_tasks {
319+
let rqv = state.resource_rq_map.get(*rq_id);
320+
if remaining_time < rqv.min_time() {
321+
to_remove.push(*rq_id);
322+
for task in tasks {
323+
// Hard reject, we never unblock this rejection so we do not need to update blocked requests
324+
updates.push(WorkerTaskUpdate::RejectRequest {
325+
task_id: task.id,
326+
rv_id: None,
327+
});
328+
}
329+
}
330+
}
331+
if !updates.is_empty() {
332+
state
333+
.comm()
334+
.send_message_to_server(FromWorkerMessage::TaskUpdate(updates));
335+
for rq_id in to_remove {
336+
state.prefilled_tasks.remove(&rq_id);
337+
}
338+
}
339+
}
340+
}
312341
log::debug!("Heartbeat sent");
313342
}
314343
}

tests/test_job.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1427,3 +1427,17 @@ def test_resource_weight(hq_env: HqEnv):
14271427
hq_env.start_worker(cpus=2, args=["--resource", "gpus=sum(2)"])
14281428
wait_for_job_state(hq_env, 1, "RUNNING")
14291429
wait_for_job_state(hq_env, 2, "WAITING")
1430+
1431+
1432+
def test_fast_retracting(hq_env: HqEnv):
1433+
hq_env.start_server()
1434+
hq_env.command(["submit", "--array=1-350", "--time-request=78s", "--", "sleep", "200"])
1435+
hq_env.start_worker(cpus=4, args=["--time-limit=80s", "--heartbeat=500ms"])
1436+
wait_for_job_state(hq_env, 1, "RUNNING")
1437+
out = hq_env.command(["--output-mode=json", "job", "info", "1"], as_json=True)
1438+
assert sorted(t["id"] for t in out[0]["tasks"] if t["state"] == "running") == list(range(1, 5))
1439+
time.sleep(6)
1440+
hq_env.start_worker(cpus=4)
1441+
time.sleep(3)
1442+
out = hq_env.command(["--output-mode=json", "job", "info", "1"], as_json=True)
1443+
assert sorted(t["id"] for t in out[0]["tasks"] if t["state"] == "running") == list(range(1, 9))

0 commit comments

Comments
 (0)