Chi-Tech
chi_math_sparse_matrix.h
Go to the documentation of this file.
1#ifndef _chi_math_sparse_matrix_h
2#define _chi_math_sparse_matrix_h
3
4#include "../chi_math.h"
5
6//###################################################################
7/**Sparse matrix utility. This is a basic CSR type sparse matrix
8 * which allows efficient matrix storage and multiplication. It is
9 * not intended for solving linear systems (use PETSc for that instead).
10 * It was originally developed for the transfer matrices of transport
11 * cross-sections.*/
13{
14private:
15 size_t row_size_; ///< Maximum number of rows for this matrix
16 size_t col_size_; ///< Maximum number of columns for this matrix
17
18public:
19 /**rowI_indices[i] is a vector indices j for the
20 * non-zero columns.*/
21 std::vector<std::vector<size_t>> rowI_indices_;
22 /**rowI_values[i] corresponds to column indices and
23 * contains the non-zero value.*/
24 std::vector<std::vector<double>> rowI_values_;
25
26public:
27 SparseMatrix(size_t num_rows, size_t num_cols);
28 SparseMatrix(const SparseMatrix& in_matrix);
29
30 size_t NumRows() const {return row_size_;}
31 size_t NumCols() const {return col_size_;}
32
33 void Insert(size_t i, size_t j, double value);
34 void InsertAdd(size_t i, size_t j, double value);
35 double ValueIJ(size_t i, size_t j) const;
36 void SetDiagonal(const std::vector<double>& diag);
37
38 void Compress();
39
40 std::string PrintStr() const;
41
42private:
43 void CheckInitialized() const;
44
45public:
46 virtual ~SparseMatrix() = default;
47
48public:
50 {
51 const size_t& row_index;
52 const size_t& column_index;
53 double& value;
54
55 EntryReference(const size_t& row_id,
56 const size_t& column_id,
57 double& in_value) :
58 row_index(row_id),
59 column_index(column_id),
60 value(in_value) {}
61 };
62
64 {
65 public:
66 const size_t& row_index;
67 const size_t& column_index;
68 const double& value;
69
70 ConstEntryReference(const size_t& row_id,
71 const size_t& column_id,
72 const double& in_value) :
73 row_index(row_id),
74 column_index(column_id),
75 value(in_value) {}
76 };
77
79 {
80 private:
81 const std::vector<size_t>& ref_col_ids_;
82 std::vector<double>& ref_col_vals_;
83 const size_t ref_row_;
84 public:
85 RowIteratorContext(SparseMatrix& matrix, size_t ref_row) :
86 ref_col_ids_(matrix.rowI_indices_[ref_row]),
87 ref_col_vals_(matrix.rowI_values_[ref_row]),
88 ref_row_(ref_row){}
89
91 {
92 private:
93 typedef RowIterator It;
94 private:
96 size_t ref_entry_;
97 public:
98 RowIterator(RowIteratorContext& context, size_t ref_entry) :
99 context_{context}, ref_entry_{ref_entry} {}
100
101 It operator++() {It i = *this; ref_entry_++; return i;}
102 It operator++(int) {ref_entry_++; return *this;}
103
108
109 bool operator==(const It& rhs) const {return ref_entry_ == rhs.ref_entry_;}
110 bool operator!=(const It& rhs) const {return ref_entry_ != rhs.ref_entry_;}
111 };
112
113 RowIterator begin() {return {*this, 0};}
114 RowIterator end() {return {*this, ref_col_vals_.size()};}
115 };
116
117 RowIteratorContext Row(size_t row_id); //See .cc file
118
120 {
121 private:
122 const std::vector<size_t>& ref_col_ids_;
123 const std::vector<double>& ref_col_vals_;
124 const size_t ref_row_;
125 public:
126 ConstRowIteratorContext(const SparseMatrix& matrix, size_t ref_row) :
127 ref_col_ids_(matrix.rowI_indices_[ref_row]),
128 ref_col_vals_(matrix.rowI_values_[ref_row]),
129 ref_row_(ref_row){}
130
132 {
133 private:
135 private:
138 public:
139 ConstRowIterator(const ConstRowIteratorContext& context, size_t ref_entry) :
140 context_(context), ref_entry_{ref_entry} {}
141
142 It operator++() {It i = *this; ref_entry_++; return i;}
143 It operator++(int) {ref_entry_++; return *this;}
144
149
150 bool operator==(const It& rhs) const {return ref_entry_ == rhs.ref_entry_;}
151 bool operator!=(const It& rhs) const {return ref_entry_ != rhs.ref_entry_;}
152 };
153
154 ConstRowIterator begin() const {return {*this, 0};}
155 ConstRowIterator end() const {return {*this, ref_col_vals_.size()};}
156 };
157
158 ConstRowIteratorContext Row(size_t row_id) const;
159
160 /**Iterator to loop over all matrix entries.*/
162 {
163 private:
165 private:
167 size_t ref_row_;
168 size_t ref_col_;
169 public:
170
171 explicit EntriesIterator(SparseMatrix& context, size_t row) :
172 sp_matrix{context}, ref_row_{row}, ref_col_(0)
173 {}
174
175 void Advance()
176 {
177 ref_col_++;
179 {
180 ref_row_++;
181 ref_col_ = 0;
182 while ((ref_row_ < sp_matrix.row_size_) and
184 ref_row_++;
185 }
186 }
187
188 EIt operator++() {EIt i = *this; Advance(); return i;}
189 EIt operator++(int) {Advance(); return *this;}
190
192 {
193 return {ref_row_,
196 }
197 bool operator==(const EIt& rhs) const
198 { return (ref_row_ == rhs.ref_row_) and
199 (ref_col_ == rhs.ref_col_); }
200 bool operator!=(const EIt& rhs) const
201 { return (ref_row_ != rhs.ref_row_) or
202 (ref_col_ != rhs.ref_col_); }
203 };
204
205 EntriesIterator begin();
206 EntriesIterator end();
207};
208
209
210#endif
ConstRowIterator(const ConstRowIteratorContext &context, size_t ref_entry)
ConstRowIteratorContext(const SparseMatrix &matrix, size_t ref_row)
EntriesIterator(SparseMatrix &context, size_t row)
RowIterator(RowIteratorContext &context, size_t ref_entry)
RowIteratorContext(SparseMatrix &matrix, size_t ref_row)
size_t col_size_
Maximum number of columns for this matrix.
void Insert(size_t i, size_t j, double value)
SparseMatrix(size_t num_rows, size_t num_cols)
RowIteratorContext Row(size_t row_id)
std::vector< std::vector< double > > rowI_values_
void InsertAdd(size_t i, size_t j, double value)
double ValueIJ(size_t i, size_t j) const
virtual ~SparseMatrix()=default
std::vector< std::vector< size_t > > rowI_indices_
size_t row_size_
Maximum number of rows for this matrix.
void SetDiagonal(const std::vector< double > &diag)
const double & value
const size_t & column_index
const size_t & row_index
ConstEntryReference(const size_t &row_id, const size_t &column_id, const double &in_value)
const size_t & row_index
double & value
EntryReference(const size_t &row_id, const size_t &column_id, double &in_value)
const size_t & column_index