diff --git a/tools/perf/ucc_perftest.cc b/tools/perf/ucc_perftest.cc index dd873ba05a..33e75ddbea 100644 --- a/tools/perf/ucc_perftest.cc +++ b/tools/perf/ucc_perftest.cc @@ -17,7 +17,7 @@ int main(int argc, char *argv[]) ucc_pt_cuda_init(); ucc_pt_rocm_init(); try { - comm = new ucc_pt_comm(pt_config.comm); + comm = new ucc_pt_comm(pt_config.comm,pt_config.bench); } catch(std::exception &e) { std::cerr << e.what() << std::endl; std::exit(1); diff --git a/tools/perf/ucc_pt_benchmark.cc b/tools/perf/ucc_pt_benchmark.cc index e5c1fa1e68..007de422c4 100644 --- a/tools/perf/ucc_pt_benchmark.cc +++ b/tools/perf/ucc_pt_benchmark.cc @@ -31,7 +31,7 @@ ucc_pt_benchmark::ucc_pt_benchmark(ucc_pt_benchmark_config cfg, break; case UCC_PT_OP_TYPE_ALLTOALL: coll = new ucc_pt_coll_alltoall(cfg.dt, cfg.mt, cfg.inplace, - cfg.persistent, comm); + cfg.persistent, cfg.onesided, comm); break; case UCC_PT_OP_TYPE_ALLTOALLV: coll = new ucc_pt_coll_alltoallv(cfg.dt, cfg.mt, cfg.inplace, diff --git a/tools/perf/ucc_pt_coll.h b/tools/perf/ucc_pt_coll.h index 0b92039fab..1570932fdc 100644 --- a/tools/perf/ucc_pt_coll.h +++ b/tools/perf/ucc_pt_coll.h @@ -13,6 +13,9 @@ extern "C" { #include #include } +#define UCC_IS_ONESIDED(_args) \ + (((_args).mask & UCC_COLL_ARGS_FIELD_FLAGS) && \ + ((_args).flags & UCC_COLL_ARGS_FLAG_MEM_MAPPED_BUFFERS)) ucc_status_t ucc_pt_alloc(ucc_mc_buffer_header_t **h_ptr, size_t len, ucc_memory_type_t mem_type); @@ -87,7 +90,7 @@ class ucc_pt_coll_allreduce: public ucc_pt_coll { class ucc_pt_coll_alltoall: public ucc_pt_coll { public: ucc_pt_coll_alltoall(ucc_datatype_t dt, ucc_memory_type mt, - bool is_inplace, bool is_persistent, + bool is_inplace, bool is_persistent, bool is_onesided, ucc_pt_comm *communicator); ucc_status_t init_args(size_t count, ucc_pt_test_args_t &args) override; void free_args(ucc_pt_test_args_t &args) override; diff --git a/tools/perf/ucc_pt_coll_alltoall.cc b/tools/perf/ucc_pt_coll_alltoall.cc index 77a2608f7f..b32852935f 100644 --- a/tools/perf/ucc_pt_coll_alltoall.cc +++ b/tools/perf/ucc_pt_coll_alltoall.cc @@ -12,7 +12,7 @@ ucc_pt_coll_alltoall::ucc_pt_coll_alltoall(ucc_datatype_t dt, ucc_memory_type mt, bool is_inplace, - bool is_persistent, + bool is_persistent,bool is_onesided, ucc_pt_comm *communicator) : ucc_pt_coll(communicator) { has_inplace_ = true; @@ -38,6 +38,10 @@ ucc_pt_coll_alltoall::ucc_pt_coll_alltoall(ucc_datatype_t dt, coll_args.mask |= UCC_COLL_ARGS_FIELD_FLAGS; coll_args.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT; } + if(is_onesided){ + coll_args.mask = UCC_COLL_ARGS_FIELD_FLAGS | UCC_COLL_ARGS_FIELD_GLOBAL_WORK_BUFFER; + coll_args.flags |= UCC_COLL_ARGS_FLAG_MEM_MAPPED_BUFFERS; + } } ucc_status_t ucc_pt_coll_alltoall::init_args(size_t single_rank_count, @@ -60,6 +64,11 @@ ucc_status_t ucc_pt_coll_alltoall::init_args(size_t single_rank_count, free_dst, st); args.src.info.buffer = src_header->addr; } + if(UCC_IS_ONESIDED(args)){ + args.src.info.buffer = comm->get_global_buffer(0); + args.dst.info.buffer = comm->get_global_buffer(1); + args.global_work_buffer = comm->get_global_buffer(2); + } return UCC_OK; free_dst: ucc_pt_free(dst_header); diff --git a/tools/perf/ucc_pt_comm.cc b/tools/perf/ucc_pt_comm.cc index 1f6a91a651..5522fba274 100644 --- a/tools/perf/ucc_pt_comm.cc +++ b/tools/perf/ucc_pt_comm.cc @@ -9,10 +9,16 @@ extern "C" { #include "utils/ucc_coll_utils.h" #include "components/mc/ucc_mc.h" } +#define UCC_MALLOC_CHECK(_obj) \ + if (!(_obj)) { \ + std::cerr << "*** UCC MALLOC FAIL \n"; \ + MPI_Abort(MPI_COMM_WORLD.-1); \ + } -ucc_pt_comm::ucc_pt_comm(ucc_pt_comm_config config) +ucc_pt_comm::ucc_pt_comm(ucc_pt_comm_config config,ucc_pt_benchmark_config ben_config) { cfg = config; + bcfg = ben_config; bootstrap = new ucc_pt_bootstrap_mpi(); } @@ -124,6 +130,7 @@ ucc_status_t ucc_pt_comm::init() ucc_status_t st; std::string cfg_mod; + ucc_mem_map_t segments[UCC_TEST_N_MEM_SEGMENTS]; ee = nullptr; executor = nullptr; stream = nullptr; @@ -157,6 +164,29 @@ ucc_status_t ucc_pt_comm::init() UCC_CONTEXT_PARAM_FIELD_OOB; ctx_params.type = UCC_CONTEXT_SHARED; ctx_params.oob = bootstrap->get_context_oob(); + if (bcfg.onesided) + { + for (auto i = 0; i < UCC_TEST_N_MEM_SEGMENTS; i++) + { + onesided_buffers[i] = ucc_calloc(UCC_TEST_MEM_SEGMENT_SIZE, + bootstrap->get_size(),"onesided buffers"); + UCC_MALLOC_CHECK(onesided_buffers[i]); + segments[i].address = onesided_buffers[i]; + segments[i].len = UCC_TEST_MEM_SEGMENT_SIZE * (bootstrap->get_size()); + } + ctx_params.mask |= UCC_CONTEXT_PARAM_FIELD_MEM_PARAMS; + ctx_params.mem_params.segments = segments; + ctx_params.mem_params.n_segments = UCC_TEST_N_MEM_SEGMENTS; + } + if(!bcfg.onesided) + { + for (auto i = 0; i < UCC_TEST_N_MEM_SEGMENTS; i++) + { + onesided_buffers[i] = NULL; + } + + } + UCCCHECK_GOTO(ucc_context_create(lib, &ctx_params, ctx_config, &context), free_ctx_config, st); team_params.mask = UCC_TEAM_PARAM_FIELD_EP | @@ -165,6 +195,10 @@ ucc_status_t ucc_pt_comm::init() team_params.oob = bootstrap->get_team_oob(); team_params.ep = bootstrap->get_rank(); team_params.ep_range = UCC_COLLECTIVE_EP_RANGE_CONTIG; + if(bcfg.onesided){ + team_params.mask |= UCC_TEAM_PARAM_FIELD_FLAGS; + team_params.flags = UCC_TEAM_FLAG_COLL_WORK_BUFFER; + } UCCCHECK_GOTO(ucc_team_create_post(&context, 1, &team_params, &team), free_ctx, st); do { @@ -219,6 +253,15 @@ ucc_status_t ucc_pt_comm::finalize() if (status != UCC_OK) { std::cerr << "ucc team destroy error: " << ucc_status_string(status); } + if (onesided_buffers[0]) + { + for (auto i = 0; i < UCC_TEST_N_MEM_SEGMENTS; i++) + { + ucc_free(onesided_buffers[i]); + } + + } + ucc_context_destroy(context); ucc_finalize(lib); return UCC_OK; diff --git a/tools/perf/ucc_pt_comm.h b/tools/perf/ucc_pt_comm.h index 827246ee4e..bb34217460 100644 --- a/tools/perf/ucc_pt_comm.h +++ b/tools/perf/ucc_pt_comm.h @@ -14,8 +14,11 @@ extern "C" { #include "components/ec/ucc_ec.h" } +#define UCC_TEST_N_MEM_SEGMENTS 3 +#define UCC_TEST_MEM_SEGMENT_SIZE (1 << 21) class ucc_pt_comm { + ucc_pt_benchmark_config bcfg; ucc_pt_comm_config cfg; ucc_lib_h lib; ucc_context_h context; @@ -24,11 +27,20 @@ class ucc_pt_comm { ucc_ee_h ee; ucc_ee_executor_t *executor; ucc_pt_bootstrap *bootstrap; + void *onesided_buffers[3]; void set_gpu_device(); public: - ucc_pt_comm(ucc_pt_comm_config config); + ucc_pt_comm(ucc_pt_comm_config config,ucc_pt_benchmark_config ben_config); int get_rank(); int get_size(); + void* get_global_buffer(int index){ + if (index < 0 || index >=3) + { + throw std::out_of_range("Index out range"); + } + return onesided_buffers[index]; + + }; ucc_ee_executor_t* get_executor(); ucc_ee_h get_ee(); ucc_team_h get_team(); diff --git a/tools/perf/ucc_pt_config.cc b/tools/perf/ucc_pt_config.cc index 08469483e3..b1bd552267 100644 --- a/tools/perf/ucc_pt_config.cc +++ b/tools/perf/ucc_pt_config.cc @@ -31,6 +31,7 @@ ucc_pt_config::ucc_pt_config() { bench.root_shift = 0; bench.mult_factor = 2; comm.mt = bench.mt; + bench.onesided = false; } const std::map ucc_pt_reduction_op_map = { @@ -91,7 +92,7 @@ ucc_status_t ucc_pt_config::process_args(int argc, char *argv[]) int c; ucc_status_t st; - while ((c = getopt(argc, argv, "c:b:e:d:m:n:w:o:N:r:S:iphFT")) != -1) { + while ((c = getopt(argc, argv, "c:b:e:d:m:n:w:o:N:r:S:iphFTJ")) != -1) { switch (c) { case 'c': if (ucc_pt_op_map.count(optarg) == 0) { @@ -172,6 +173,9 @@ ucc_status_t ucc_pt_config::process_args(int argc, char *argv[]) case 'F': bench.full_print = true; break; + case 'J': + bench.onesided = true; + break; case 'h': default: print_help(); @@ -201,5 +205,6 @@ void ucc_pt_config::print_help() std::cout << " -F: enable full print"<: root shift for rooted collectives"<