1#ifndef CHI_MPI_MAP_ALL2ALL_H
2#define CHI_MPI_MAP_ALL2ALL_H
24template<
typename K,
class T> std::map<K, std::vector<T>>
25 MapAllToAll(
const std::map<K, std::vector<T>>& pid_data_pairs,
26 const MPI_Datatype data_mpi_type,
27 const MPI_Comm communicator=
Chi::mpi.comm)
29 static_assert(std::is_integral<K>::value,
"Integral datatype required.");
33 std::vector<int> sendcounts(
Chi::mpi.process_count, 0);
34 std::vector<int> senddispls(
Chi::mpi.process_count, 0);
36 size_t accumulated_displ = 0;
37 for (
const auto& [pid, data] : pid_data_pairs)
39 sendcounts[pid] =
static_cast<int>(data.size());
40 senddispls[pid] =
static_cast<int>(accumulated_displ);
41 accumulated_displ += data.size();
47 std::vector<int> recvcounts(
Chi::mpi.process_count, 0);
49 MPI_Alltoall(sendcounts.data(),
60 std::vector<int> recvdispls(
Chi::mpi.process_count, 0);
61 std::set<K> sender_pids_set;
62 size_t total_recv_count;
67 recvdispls[pid] = displacement;
68 displacement += recvcounts[pid];
70 if (recvcounts[pid] > 0)
71 sender_pids_set.insert(
static_cast<K
>(pid));
73 total_recv_count = displacement;
79 std::vector<T> sendbuf;
80 for (
const auto& pid_data_pair : pid_data_pairs)
81 sendbuf.insert(sendbuf.end(),
82 pid_data_pair.second.begin(),
83 pid_data_pair.second.end());
86 std::vector<T> recvbuf(total_recv_count);
89 MPI_Alltoallv(sendbuf.data(),
99 std::map<K, std::vector<T>> output_data;
101 for (K pid : sender_pids_set)
103 const int data_count = recvcounts.at(pid);
104 const int data_displ = recvdispls.at(pid);
106 auto& data = output_data[pid];
107 data.resize(data_count);
109 for (
int i=0; i<data_count; ++i)
110 data.at(i) = recvbuf.at(data_displ + i);
static chi::MPI_Info & mpi
const int & process_count
Total number of processes.
std::map< K, std::vector< T > > MapAllToAll(const std::map< K, std::vector< T > > &pid_data_pairs, const MPI_Datatype data_mpi_type, const MPI_Comm communicator=Chi::mpi.comm)