@@ -362,6 +362,7 @@ static void ucp_memh_dereg(ucp_context_h context, ucp_mem_h memh,
362362
363363 memh -> uct [md_index ] = NULL ;
364364 }
365+ memh -> md_map &= ~md_map ;
365366
366367 ucs_assert (comp .count == 1 );
367368}
@@ -473,7 +474,8 @@ ucp_memh_register_internal(ucp_context_h context, ucp_mem_h memh,
473474 uct_flags |= UCT_MD_MEM_FLAG_NONBLOCK ;
474475 }
475476
476- reg_params .flags = uct_flags ;
477+ /* When adding registrations, existing access flags must be supported */
478+ reg_params .flags = uct_flags | memh -> uct_flags ;
477479 reg_params .dmabuf_fd = UCT_DMABUF_FD_INVALID ;
478480 reg_params .dmabuf_offset = 0 ;
479481
@@ -576,27 +578,38 @@ static size_t ucp_memh_size(ucp_context_h context)
576578 return sizeof (ucp_mem_t ) + (sizeof (uct_mem_h ) * context -> num_mds );
577579}
578580
579- static void ucp_memh_set (ucp_mem_h memh , ucp_context_h context , void * address ,
580- size_t length , ucs_memory_type_t mem_type ,
581- uint8_t memh_flags , uct_alloc_method_t method )
581+ static void ucp_memh_set_uct_flags (ucp_mem_h memh , unsigned uct_flags )
582+ {
583+ /* When changing memh->uct_flags, must not have any existing registrations,
584+ since those may not support the new flags */
585+ ucs_assertv (memh -> md_map == 0 ,
586+ "memh=%p memh->md_map=0x%" PRIx64
587+ " memh->uct_flags=0x%x uct_flags=0x%x" ,
588+ memh , memh -> md_map , memh -> uct_flags , uct_flags );
589+ memh -> uct_flags = uct_flags & UCP_MM_UCT_ACCESS_MASK ;
590+ }
591+
592+ static void ucp_memh_init (ucp_mem_h memh , ucp_context_h context ,
593+ uint8_t memh_flags , unsigned uct_flags ,
594+ uct_alloc_method_t method , ucs_memory_type_t mem_type )
582595{
583596 ucp_memory_info_t info ;
584597
585- ucp_memory_detect (context , address , length , & info );
586- memh -> super .super .start = (uintptr_t )address ;
587- memh -> super .super .end = (uintptr_t )address + length ;
588- memh -> flags = memh_flags ;
598+ ucp_memory_detect (context , ucp_memh_address (memh ), ucp_memh_length (memh ),
599+ & info );
600+ ucp_memh_set_uct_flags (memh , uct_flags );
589601 memh -> context = context ;
602+ memh -> flags = memh_flags ;
603+ memh -> alloc_md_index = UCP_NULL_RESOURCE ;
604+ memh -> alloc_method = method ;
590605 memh -> mem_type = mem_type ;
591606 memh -> sys_dev = info .sys_dev ;
592- memh -> alloc_method = method ;
593- memh -> alloc_md_index = UCP_NULL_RESOURCE ;
594607}
595608
596609static ucs_status_t
597610ucp_memh_create (ucp_context_h context , void * address , size_t length ,
598611 ucs_memory_type_t mem_type , uct_alloc_method_t method ,
599- uint8_t memh_flags , ucp_mem_h * memh_p )
612+ uint8_t memh_flags , unsigned uct_flags , ucp_mem_h * memh_p )
600613{
601614 ucp_mem_h memh ;
602615
@@ -605,7 +618,9 @@ ucp_memh_create(ucp_context_h context, void *address, size_t length,
605618 return UCS_ERR_NO_MEMORY ;
606619 }
607620
608- ucp_memh_set (memh , context , address , length , mem_type , memh_flags , method );
621+ memh -> super .super .start = (uintptr_t )address ;
622+ memh -> super .super .end = (uintptr_t )address + length ;
623+ ucp_memh_init (memh , context , memh_flags , uct_flags , method , mem_type );
609624
610625 * memh_p = memh ;
611626 return UCS_OK ;
@@ -658,13 +673,14 @@ static ucp_md_index_t ucp_mem_get_md_index(ucp_context_h context,
658673
659674static ucs_status_t ucp_memh_create_from_mem (ucp_context_h context ,
660675 const uct_allocated_memory_t * mem ,
676+ unsigned uct_flags ,
661677 ucp_mem_h * memh_p )
662678{
663679 ucs_status_t status ;
664680 ucp_mem_h memh ;
665681
666682 status = ucp_memh_create (context , mem -> address , mem -> length , mem -> mem_type ,
667- mem -> method , 0 , & memh );
683+ mem -> method , 0 , uct_flags , & memh );
668684 if (status != UCS_OK ) {
669685 return status ;
670686 }
@@ -787,21 +803,30 @@ ucs_status_t ucp_memh_get_slow(ucp_context_h context, void *address,
787803 UCP_THREAD_CS_ENTER (& context -> mt_lock );
788804 if (context -> rcache == NULL ) {
789805 status = ucp_memh_create (context , reg_address , reg_length , mem_type ,
790- UCT_ALLOC_METHOD_LAST , 0 , & memh );
806+ UCT_ALLOC_METHOD_LAST , 0 , uct_flags , & memh );
807+ if (status != UCS_OK ) {
808+ goto out ;
809+ }
791810 } else {
792811 status = ucp_memh_rcache_get (context -> rcache , reg_address , reg_length ,
793812 reg_align , mem_type , reg_md_map , uct_flags ,
794813 alloc_name , & memh );
814+ if (status != UCS_OK ) {
815+ goto out ;
816+ }
817+
818+ if (!ucs_test_all_flags (memh -> uct_flags ,
819+ uct_flags & UCP_MM_UCT_ACCESS_MASK )) {
820+ reg_md_map |= memh -> md_map ; /* Re-register previous MDs */
821+ ucp_memh_dereg (context , memh , memh -> md_map );
822+ ucp_memh_set_uct_flags (memh , uct_flags );
823+ }
795824
796825 ucs_assert (memh -> mem_type == mem_type );
797826 ucs_assert (ucs_padding ((intptr_t )ucp_memh_address (memh ), reg_align ) == 0 );
798827 ucs_assert (ucs_padding (ucp_memh_length (memh ), reg_align ) == 0 );
799828 }
800829
801- if (status != UCS_OK ) {
802- goto out ;
803- }
804-
805830 ucs_trace (
806831 "memh_get_slow: %s address %p/%p length %zu/%zu %s md_map %" PRIx64
807832 " flags 0x%x" ,
@@ -847,7 +872,7 @@ ucp_memh_alloc(ucp_context_h context, void *address, size_t length,
847872 goto out ;
848873 }
849874
850- status = ucp_memh_create_from_mem (context , & mem , & memh );
875+ status = ucp_memh_create_from_mem (context , & mem , uct_flags , & memh );
851876 if (status != UCS_OK ) {
852877 goto err_dealloc ;
853878 }
@@ -974,7 +999,7 @@ ucs_status_t ucp_mem_map(ucp_context_h context, const ucp_mem_map_params_t *para
974999 alloc_name , & memh );
9751000 } else {
9761001 status = ucp_memh_create (context , address , length , mem_type ,
977- UCT_ALLOC_METHOD_LAST , 0 , & memh );
1002+ UCT_ALLOC_METHOD_LAST , 0 , uct_flags , & memh );
9781003 if (status != UCS_OK ) {
9791004 goto out ;
9801005 }
@@ -1412,15 +1437,9 @@ ucp_mem_rcache_mem_reg_cb(void *ctx, ucs_rcache_t *rcache, void *arg,
14121437 ucp_context_h context = (ucp_context_h )ctx ;
14131438 ucp_mem_rcache_reg_ctx_t * reg_ctx = arg ;
14141439 ucp_mem_h memh = ucs_derived_of (rregion , ucp_mem_t );
1415- ucp_memory_info_t info ;
14161440
1417- ucp_memory_detect (context , (void * )memh -> super .super .start ,
1418- memh -> super .super .end - memh -> super .super .start , & info );
1419- memh -> context = context ;
1420- memh -> alloc_md_index = UCP_NULL_RESOURCE ;
1421- memh -> alloc_method = UCT_ALLOC_METHOD_LAST ;
1422- memh -> mem_type = reg_ctx -> mem_type ;
1423- memh -> sys_dev = info .sys_dev ;
1441+ ucp_memh_init (memh , context , 0 , reg_ctx -> uct_flags , UCT_ALLOC_METHOD_LAST ,
1442+ reg_ctx -> mem_type );
14241443
14251444 if (rcache_mem_reg_flags & UCS_RCACHE_MEM_REG_HIDE_ERRORS ) {
14261445 /* Hide errors during registration but fail if any memory domain failed
@@ -1730,7 +1749,7 @@ ucp_memh_import(ucp_context_h context, const void *export_mkey_buffer,
17301749
17311750 status = ucp_memh_create (context , unpacked_memh .address ,
17321751 unpacked_memh .length , unpacked_memh .mem_type ,
1733- UCT_ALLOC_METHOD_LAST , 0 , & memh );
1752+ UCT_ALLOC_METHOD_LAST , 0 , 0 , & memh );
17341753 if (status != UCS_OK ) {
17351754 goto out ;
17361755 }
0 commit comments