Skip to content

Commit fd3bc3a

Browse files
authored
UCP/PROTO: Add option to force ZCOPY (#11289)
1 parent f7aaaf2 commit fd3bc3a

6 files changed

Lines changed: 191 additions & 28 deletions

File tree

src/ucp/core/ucp_context.c

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,12 @@ static ucs_config_field_t ucp_context_config_table[] = {
471471
"lane without waiting for remote completion.",
472472
ucs_offsetof(ucp_context_config_t, rndv_put_force_flush), UCS_CONFIG_TYPE_BOOL},
473473

474+
{"PROTO_EMULATION_ENABLE", "y",
475+
"When set to 'no', emulation protocols for put and get are disabled. If no native\n"
476+
"zero-copy RMA protocol exist for the memory type pair, RMA requests will be\n"
477+
"cancelled.",
478+
ucs_offsetof(ucp_context_config_t, proto_emulation_enable), UCS_CONFIG_TYPE_BOOL},
479+
474480
{"SA_DATA_VERSION", "v2",
475481
"Defines the minimal header version the client will use for establishing\n"
476482
"client/server connection",

src/ucp/core/ucp_context.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,9 @@ typedef struct ucp_context_config {
167167
uint64_t reg_whole_alloc_bitmap;
168168
/** Always use flush operation in rendezvous put */
169169
int rndv_put_force_flush;
170+
/** Allow RMA emulation protocols. When disabled, provide an explicit error
171+
* if no suitable proto is found */
172+
int proto_emulation_enable;
170173
/** Maximum size of mem type direct rndv*/
171174
size_t rndv_memtype_direct_size;
172175
/** UCP sockaddr private data format version */

src/ucp/proto/proto_reconfig.c

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright (c) NVIDIA CORPORATION & AFFILIATES, 2021. ALL RIGHTS RESERVED.
2+
* Copyright (c) NVIDIA CORPORATION & AFFILIATES, 2021-2026. ALL RIGHTS RESERVED.
33
*
44
* See file LICENSE for terms.
55
*/
@@ -15,6 +15,7 @@
1515

1616
#include <ucp/am/ucp_am.inl>
1717
#include <ucp/core/ucp_worker.inl>
18+
#include <ucs/memory/memory_type.h>
1819
#include <ucs/sys/math.h>
1920

2021

@@ -43,6 +44,37 @@ static void ucp_proto_reconfig_abort(ucp_request_t *req, ucs_status_t status)
4344
ucp_request_complete_send(req, status);
4445
}
4546

47+
static int
48+
ucp_proto_reconfig_report_no_rma_emulation_no_proto(ucp_request_t *req,
49+
ucp_ep_h ep)
50+
{
51+
ucp_operation_id_t op_id;
52+
ucs_memory_type_t local_mem_type, remote_mem_type;
53+
54+
if (ep->worker->context->config.ext.proto_emulation_enable) {
55+
return 0;
56+
}
57+
58+
op_id = ucp_proto_select_op_id(&req->send.proto_config->select_param);
59+
if (((op_id != UCP_OP_ID_PUT) && (op_id != UCP_OP_ID_GET))) {
60+
return 0;
61+
}
62+
63+
local_mem_type = req->send.proto_config->select_param.mem_type;
64+
remote_mem_type = req->send.rma.rkey->mem_type;
65+
66+
ucs_error("No zero-copy protocol found for %s %s %s %s, %zu bytes. "
67+
"Please check for proper GPU and/or HCA support, or set "
68+
"UCX_PROTO_EMULATION_ENABLE=y to proceed by allowing slower "
69+
"software emulation.",
70+
(op_id == UCP_OP_ID_PUT) ? "put from" : "get into",
71+
ucs_memory_type_names[local_mem_type],
72+
(op_id == UCP_OP_ID_PUT) ? "to" : "from",
73+
ucs_memory_type_names[remote_mem_type],
74+
req->send.state.dt_iter.length);
75+
return 1;
76+
}
77+
4678
static ucs_status_t ucp_proto_reconfig_progress(uct_pending_req_t *self)
4779
{
4880
ucp_request_t *req = ucs_container_of(self, ucp_request_t, send.uct);
@@ -52,6 +84,11 @@ static ucs_status_t ucp_proto_reconfig_progress(uct_pending_req_t *self)
5284

5385
/* This protocol should not be selected for valid and connected endpoint */
5486
if (ep->flags & UCP_EP_FLAG_REMOTE_CONNECTED) {
87+
if (ucp_proto_reconfig_report_no_rma_emulation_no_proto(req, ep)) {
88+
ucp_proto_request_abort(req, UCS_ERR_CANCELED);
89+
return UCS_OK;
90+
}
91+
5592
ucp_ep_config_name(ep->worker, req->send.proto_config->ep_cfg_index,
5693
&strb);
5794
ucs_string_buffer_appendf(&strb, " | ");

src/ucp/rma/get_am.c

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright (c) NVIDIA CORPORATION & AFFILIATES, 2020. ALL RIGHTS RESERVED.
2+
* Copyright (c) NVIDIA CORPORATION & AFFILIATES, 2020-2026. ALL RIGHTS RESERVED.
33
*
44
* See file LICENSE for terms.
55
*/
@@ -111,6 +111,10 @@ ucp_proto_get_am_bcopy_probe(const ucp_proto_init_params_t *init_params)
111111
return;
112112
}
113113

114+
if (!init_params->worker->context->config.ext.proto_emulation_enable) {
115+
return;
116+
}
117+
114118
ucp_proto_single_probe(&params);
115119
}
116120

src/ucp/rma/put_am.c

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright (c) NVIDIA CORPORATION & AFFILIATES, 2020. ALL RIGHTS RESERVED.
2+
* Copyright (c) NVIDIA CORPORATION & AFFILIATES, 2020-2026. ALL RIGHTS RESERVED.
33
*
44
* See file LICENSE for terms.
55
*/
@@ -119,6 +119,10 @@ ucp_proto_put_am_bcopy_probe(const ucp_proto_init_params_t *init_params)
119119
return;
120120
}
121121

122+
if (!init_params->worker->context->config.ext.proto_emulation_enable) {
123+
return;
124+
}
125+
122126
ucp_proto_multi_probe(&params);
123127
}
124128

test/gtest/ucp/test_ucp_rma.cc

100644100755
Lines changed: 134 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright (c) NVIDIA CORPORATION & AFFILIATES, 2001-2015. ALL RIGHTS RESERVED.
2+
* Copyright (c) NVIDIA CORPORATION & AFFILIATES, 2001-2026. ALL RIGHTS RESERVED.
33
* Copyright (c) UT-Battelle, LLC. 2015. ALL RIGHTS RESERVED.
44
*
55
* See file LICENSE for terms.
@@ -53,6 +53,21 @@ class test_ucp_rma : public test_ucp_memheap {
5353
}
5454
}
5555

56+
ucs_status_ptr_t do_put(size_t size, void *expected_data, ucp_mem_h memh,
57+
void *target_ptr, ucp_rkey_h rkey, void *arg)
58+
{
59+
ucs_memory_type_t *mem_types = reinterpret_cast<ucs_memory_type_t*>(
60+
arg);
61+
mem_buffer::pattern_fill(expected_data, size, ucs::rand(),
62+
mem_types[0]);
63+
64+
ucp_request_param_t param;
65+
request_param_init(&param, memh);
66+
67+
return ucp_put_nbx(sender().ep(), expected_data, size,
68+
(uintptr_t)target_ptr, rkey, &param);
69+
}
70+
5671
void put_b(size_t size, void *expected_data, ucp_mem_h memh,
5772
void *target_ptr, ucp_rkey_h rkey, void *arg)
5873
{
@@ -76,6 +91,16 @@ class test_ucp_rma : public test_ucp_memheap {
7691
rkey, arg);
7792
}
7893

94+
ucs_status_ptr_t do_get(size_t size, void *expected_data, ucp_mem_h memh,
95+
void *target_ptr, ucp_rkey_h rkey)
96+
{
97+
ucp_request_param_t param;
98+
request_param_init(&param, memh);
99+
100+
return ucp_get_nbx(sender().ep(), expected_data, size,
101+
(uintptr_t)target_ptr, rkey, &param);
102+
}
103+
79104
void get_b(size_t size, void *expected_data, ucp_mem_h memh,
80105
void *target_ptr, ucp_rkey_h rkey, void *arg)
81106
{
@@ -118,7 +143,6 @@ class test_ucp_rma : public test_ucp_memheap {
118143
ucs::supported_mem_type_pairs();
119144

120145
for (size_t i = 0; i < pairs.size(); ++i) {
121-
122146
/* Memory type put/get is fully supported only with new protocols */
123147
if (!is_proto_enabled() && (!UCP_MEM_IS_HOST(pairs[i][0]) ||
124148
!UCP_MEM_IS_HOST(pairs[i][1]))) {
@@ -209,19 +233,6 @@ class test_ucp_rma : public test_ucp_memheap {
209233
param->memh = memh;
210234
}
211235

212-
ucs_status_ptr_t do_put(size_t size, void *expected_data, ucp_mem_h memh,
213-
void *target_ptr, ucp_rkey_h rkey, void *arg)
214-
{
215-
ucs_memory_type_t *mem_types = reinterpret_cast<ucs_memory_type_t*>(arg);
216-
mem_buffer::pattern_fill(expected_data, size, ucs::rand(), mem_types[0]);
217-
218-
ucp_request_param_t param;
219-
request_param_init(&param, memh);
220-
221-
return ucp_put_nbx(sender().ep(), expected_data, size,
222-
(uintptr_t)target_ptr, rkey, &param);
223-
}
224-
225236
ucs_status_ptr_t do_put_iov(size_t size, void *expected_data,
226237
ucp_request_param_t *param, void *target_ptr,
227238
ucp_rkey_h rkey, ucp_dt_iov_t *iov,
@@ -241,16 +252,6 @@ class test_ucp_rma : public test_ucp_memheap {
241252
rkey, param);
242253
}
243254

244-
ucs_status_ptr_t do_get(size_t size, void *expected_data, ucp_mem_h memh,
245-
void *target_ptr, ucp_rkey_h rkey)
246-
{
247-
ucp_request_param_t param;
248-
request_param_init(&param, memh);
249-
250-
return ucp_get_nbx(sender().ep(), expected_data, size,
251-
(uintptr_t)target_ptr, rkey, &param);
252-
}
253-
254255
ucs_status_ptr_t do_get_iov(size_t size, void *expected_data,
255256
ucp_request_param_t *param, void *target_ptr,
256257
ucp_rkey_h rkey, ucp_dt_iov_t *iov,
@@ -345,6 +346,114 @@ UCS_TEST_P(test_ucp_rma, proto_disabled_unsupported, "PROTO_ENABLE=n")
345346
UCP_INSTANTIATE_TEST_CASE_GPU_AWARE(test_ucp_rma)
346347

347348

349+
class test_ucp_proto_emulation_enable : public test_ucp_rma {
350+
public:
351+
static constexpr size_t SMALL_SIZE = 8;
352+
static constexpr size_t BIG_SIZE = 512 * UCS_KBYTE;
353+
354+
static void get_test_variants(std::vector<ucp_test_variant> &variants)
355+
{
356+
add_variant_with_value(variants, UCP_FEATURE_RMA, 0, "");
357+
}
358+
359+
test_ucp_proto_emulation_enable()
360+
{
361+
modify_config("PROTO_EMULATION_ENABLE", "n");
362+
modify_config("IB_TX_INLINE_RESP", "0", SETENV_IF_NOT_EXIST);
363+
}
364+
365+
protected:
366+
enum rma_op_t {
367+
RMA_OP_PUT,
368+
RMA_OP_GET
369+
};
370+
371+
void run_expect_canceled(rma_op_t op, size_t size)
372+
{
373+
mapped_buffer rbuf(size * 2, receiver());
374+
ucs::handle<ucp_rkey_h> rkey = rbuf.rkey(sender());
375+
mem_buffer lbuf(size, UCS_MEMORY_TYPE_HOST);
376+
ucs_memory_type_t mem_types[] = {UCS_MEMORY_TYPE_HOST,
377+
UCS_MEMORY_TYPE_HOST};
378+
379+
scoped_log_handler slh(wrap_errors_logger);
380+
381+
ucs_status_ptr_t req;
382+
if (op == RMA_OP_PUT) {
383+
req = do_put(size, lbuf.ptr(), NULL, rbuf.ptr(), rkey.get(),
384+
mem_types);
385+
} else {
386+
req = do_get(size, lbuf.ptr(), NULL, rbuf.ptr(), rkey.get());
387+
}
388+
ucs_status_t status = request_wait(req);
389+
EXPECT_EQ(UCS_ERR_CANCELED, status)
390+
<< (op == RMA_OP_PUT ? "put" : "get") << " should be canceled";
391+
392+
/* Verify PROTO_EMULATION_ENABLE error message was logged */
393+
bool found_rma_msg = false;
394+
for (const auto &err : m_errors) {
395+
if (err.find("set UCX_PROTO_EMULATION_ENABLE=y to proceed") !=
396+
std::string::npos) {
397+
found_rma_msg = true;
398+
break;
399+
}
400+
}
401+
EXPECT_TRUE(found_rma_msg) << "Expected error message with "
402+
"UCX_PROTO_EMULATION_ENABLE=y advice";
403+
}
404+
405+
void test_forced_message_sizes(send_func_t send_func)
406+
{
407+
for (const auto &pair : ucs::supported_mem_type_pairs()) {
408+
if (check_reg_mem_types(sender(), pair[0]) &&
409+
check_reg_mem_types(sender(), pair[1])) {
410+
test_message_sizes(send_func, SMALL_SIZE, BIG_SIZE, pair[0],
411+
pair[1], 0);
412+
}
413+
}
414+
}
415+
};
416+
417+
UCS_TEST_P(test_ucp_proto_emulation_enable, no_zcopy_proto_fails_put_small,
418+
"PROTOS=put/am/*,get/am/*,reconfig")
419+
{
420+
run_expect_canceled(RMA_OP_PUT, SMALL_SIZE);
421+
}
422+
423+
UCS_TEST_P(test_ucp_proto_emulation_enable, no_zcopy_proto_fails_put_big,
424+
"PROTOS=put/am/*,get/am/*,reconfig")
425+
{
426+
run_expect_canceled(RMA_OP_PUT, BIG_SIZE);
427+
}
428+
429+
UCS_TEST_P(test_ucp_proto_emulation_enable, no_zcopy_proto_fails_get_small,
430+
"PROTOS=put/am/*,get/am/*,reconfig")
431+
{
432+
run_expect_canceled(RMA_OP_GET, SMALL_SIZE);
433+
}
434+
435+
UCS_TEST_P(test_ucp_proto_emulation_enable, no_zcopy_proto_fails_get_big,
436+
"PROTOS=put/am/*,get/am/*,reconfig")
437+
{
438+
run_expect_canceled(RMA_OP_GET, BIG_SIZE);
439+
}
440+
441+
UCS_TEST_P(test_ucp_proto_emulation_enable, get_zcopy_forced_success,
442+
"PROTOS=get/bcopy,get/zcopy,reconfig")
443+
{
444+
test_forced_message_sizes(static_cast<send_func_t>(&test_ucp_rma::get_b));
445+
}
446+
447+
UCS_TEST_P(test_ucp_proto_emulation_enable, put_zcopy_forced_success,
448+
"PROTOS=put/offload/*,reconfig")
449+
{
450+
test_forced_message_sizes(static_cast<send_func_t>(&test_ucp_rma::put_b));
451+
}
452+
453+
UCP_INSTANTIATE_TEST_CASE_TLS_GPU_AWARE(test_ucp_proto_emulation_enable, ib,
454+
"ib")
455+
456+
348457
class test_ucp_rma_reg : public test_ucp_rma {
349458
public:
350459
static void get_test_variants(std::vector<ucp_test_variant>& variants) {

0 commit comments

Comments
 (0)