Chi-Tech
chi_mpi_utils_map_all2all.h
Go to the documentation of this file.
1#ifndef CHI_MPI_MAP_ALL2ALL_H
2#define CHI_MPI_MAP_ALL2ALL_H
3
4#include <map>
5#include <vector>
6
7#include <type_traits>
8
9#include "chi_runtime.h"
10#include "chi_mpi.h"
11
12namespace chi_mpi_utils
13{
14
15/**Given a map with keys indicating the destination process-ids and the
16 * values for each key a list of values of type T (T must have an MPI_Datatype).
17 * Returns a map with the keys indicating the source process-ids and the
18 * values for each key a list of values of type T (sent by the respective
19 * process).
20 *
21 * The keys must be "castable" to `int`.
22 *
23 * Also expects the MPI_Datatype of T.*/
24template<typename K, class T> std::map<K, std::vector<T>>
25 MapAllToAll(const std::map<K, std::vector<T>>& pid_data_pairs,
26 const MPI_Datatype data_mpi_type,
27 const MPI_Comm communicator=Chi::mpi.comm)
28{
29 static_assert(std::is_integral<K>::value, "Integral datatype required.");
30
31 //============================================= Make sendcounts and
32 // senddispls
33 std::vector<int> sendcounts(Chi::mpi.process_count, 0);
34 std::vector<int> senddispls(Chi::mpi.process_count, 0);
35 {
36 size_t accumulated_displ = 0;
37 for (const auto& [pid, data] : pid_data_pairs)
38 {
39 sendcounts[pid] = static_cast<int>(data.size());
40 senddispls[pid] = static_cast<int>(accumulated_displ);
41 accumulated_displ += data.size();
42 }
43 }
44
45 //============================================= Communicate sendcounts to
46 // get recvcounts
47 std::vector<int> recvcounts(Chi::mpi.process_count, 0);
48
49 MPI_Alltoall(sendcounts.data(), //sendbuf
50 1, MPI_INT, //sendcount, sendtype
51 recvcounts.data(), //recvbuf
52 1, MPI_INT, //recvcount, recvtype
53 communicator); //communicator
54
55 //============================================= Populate recvdispls,
56 // sender_pids_set, and
57 // total_recv_count
58 // All three these quantities are constructed
59 // from recvcounts.
60 std::vector<int> recvdispls(Chi::mpi.process_count, 0);
61 std::set<K> sender_pids_set; //set of neighbor-partitions sending data
62 size_t total_recv_count;
63 {
64 int displacement=0;
65 for (int pid=0; pid < Chi::mpi.process_count; ++pid)
66 {
67 recvdispls[pid] = displacement;
68 displacement += recvcounts[pid];
69
70 if (recvcounts[pid] > 0)
71 sender_pids_set.insert(static_cast<K>(pid));
72 }//for pid
73 total_recv_count = displacement;
74 }
75
76 //============================================= Make sendbuf
77 // The data for each partition is now loaded
78 // into a single buffer
79 std::vector<T> sendbuf;
80 for (const auto& pid_data_pair : pid_data_pairs)
81 sendbuf.insert(sendbuf.end(),
82 pid_data_pair.second.begin(),
83 pid_data_pair.second.end());
84
85 //============================================= Make recvbuf
86 std::vector<T> recvbuf(total_recv_count);
87
88 //============================================= Communicate serial data
89 MPI_Alltoallv(sendbuf.data(), //sendbuf
90 sendcounts.data(), //sendcounts
91 senddispls.data(), //senddispls
92 data_mpi_type, //sendtype
93 recvbuf.data(), //recvbuf
94 recvcounts.data(), //recvcounts
95 recvdispls.data(), //recvdispls
96 data_mpi_type, //recvtype
97 communicator); //comm
98
99 std::map<K, std::vector<T>> output_data;
100 {
101 for (K pid : sender_pids_set)
102 {
103 const int data_count = recvcounts.at(pid);
104 const int data_displ = recvdispls.at(pid);
105
106 auto& data = output_data[pid];
107 data.resize(data_count);
108
109 for (int i=0; i<data_count; ++i)
110 data.at(i) = recvbuf.at(data_displ + i);
111 }
112 }
113
114 return output_data;
115}
116
117}//namespace chi_mpi_utils
118
119#endif//CHI_MPI_MAP_ALL2ALL_H
static chi::MPI_Info & mpi
Definition: chi_runtime.h:78
const int & process_count
Total number of processes.
Definition: mpi_info.h:27
std::map< K, std::vector< T > > MapAllToAll(const std::map< K, std::vector< T > > &pid_data_pairs, const MPI_Datatype data_mpi_type, const MPI_Comm communicator=Chi::mpi.comm)