Chi-Tech
wgs_linear_solver.cc
Go to the documentation of this file.
1#include "wgs_linear_solver.h"
2
5
8
9#include "chi_runtime.h"
10#include "chi_log.h"
11
12#include "utils/chi_timer.h"
13
14#include <petscksp.h>
15#include <memory>
16#include <iomanip>
17
18#define sc_double static_cast<double>
19#define sc_int64_t static_cast<int64_t>
20
21#define GetGSContextPtr(x) \
22 std::dynamic_pointer_cast<WGSContext<Mat, Vec, KSP>>(x)
23
24namespace lbs
25{
26
27template <>
29{
30 auto gs_context_ptr = GetGSContextPtr(context_ptr_);
31
32 gs_context_ptr->PreSetupCallback();
33}
34
35template <>
37{
38 KSPSetConvergenceTest(solver_, &GSConvergenceTest, nullptr, nullptr);
39}
40
41template <>
43{
44 auto gs_context_ptr = GetGSContextPtr(context_ptr_);
45 const auto sizes = gs_context_ptr->SystemSize();
46
47 num_local_dofs_ = sizes.first;
48 num_globl_dofs_ = sizes.second;
49}
50
51template <>
53{
54 if (IsSystemSet()) return;
55
57 sc_int64_t(num_globl_dofs_));
58
59 VecSet(x_, 0.0);
60 VecDuplicate(x_, &b_);
61
62 //============================================= Create the matrix-shell
63 MatCreateShell(PETSC_COMM_WORLD,
64 sc_int64_t(num_local_dofs_),
65 sc_int64_t(num_local_dofs_),
66 sc_int64_t(num_globl_dofs_),
67 sc_int64_t(num_globl_dofs_),
68 &(*context_ptr_),
69 &A_);
70
71 //============================================= Set the action-operator
72 MatShellSetOperation(
73 A_, MATOP_MULT, (void (*)())chi_math::LinearSolverMatrixAction<Mat, Vec>);
74
75 //============================================= Set solver operators
76 KSPSetOperators(solver_, A_, A_);
77 KSPSetUp(solver_);
78}
79
80template <>
82{
83 if (IsSystemSet()) return;
84 auto gs_context_ptr = GetGSContextPtr(context_ptr_);
85
86 gs_context_ptr->SetPreconditioner(solver_);
87}
88
89template <>
91{
92 auto gs_context_ptr = GetGSContextPtr(context_ptr_);
93
94 gs_context_ptr->PostSetupCallback();
95}
96
97template <>
99{
100 auto gs_context_ptr = GetGSContextPtr(context_ptr_);
101
102 gs_context_ptr->PreSolveCallback();
103}
104
105/**Sets the initial guess for a gs solver. If the initial guess's norm
106 * is large enough the initial guess will be used, otherwise it is assumed
107 * zero.*/
108template <>
110{
111 auto gs_context_ptr = GetGSContextPtr(context_ptr_);
112
113 auto& groupset = gs_context_ptr->groupset_;
114 auto& lbs_solver = gs_context_ptr->lbs_solver_;
115
116 lbs_solver.SetGSPETScVecFromPrimarySTLvector(
117 groupset, x_, PhiSTLOption::PHI_OLD);
118
119 double init_guess_norm = 0.0;
120 VecNorm(x_, NORM_2, &init_guess_norm);
121
122 if (init_guess_norm > 1.0e-10)
123 {
124 KSPSetInitialGuessNonzero(solver_, PETSC_TRUE);
125 if (gs_context_ptr->log_info_)
126 Chi::log.Log() << "Using phi_old as initial guess.";
127 }
128}
129
130template <>
132{
133 auto gs_context_ptr = GetGSContextPtr(context_ptr_);
134
135 auto& groupset = gs_context_ptr->groupset_;
136 auto& lbs_solver = gs_context_ptr->lbs_solver_;
137
138 if (gs_context_ptr->log_info_)
139 Chi::log.Log() << Chi::program_timer.GetTimeString() << " Computing b";
140
141 // SetSource for RHS
142 saved_q_moments_local_ = lbs_solver.QMomentsLocal();
143
144 const bool single_richardson = iterative_method_ == "richardson" and
145 tolerance_options_.maximum_iterations == 1;
146
147 if (not single_richardson)
148 {
149 const int scope =
150 gs_context_ptr->rhs_src_scope_ | ZERO_INCOMING_DELAYED_PSI;
151 gs_context_ptr->set_source_function_(
152 groupset, lbs_solver.QMomentsLocal(), lbs_solver.PhiOldLocal(), scope);
153
154 //=================================================== Apply transport
155 //operator
156 gs_context_ptr->ApplyInverseTransportOperator(scope);
157
158 //=================================================== Assemble PETSc vector
159 lbs_solver.SetGSPETScVecFromPrimarySTLvector(
160 groupset, b_, PhiSTLOption::PHI_NEW);
161
162 //============================================= Compute RHS norm
163 VecNorm(b_, NORM_2, &context_ptr_->rhs_norm);
164
165 //============================================= Compute precondition RHS
166 //norm
167 PC pc;
168 KSPGetPC(solver_, &pc);
169 Vec temp_vec;
170 VecDuplicate(b_, &temp_vec);
171 PCApply(pc, b_, temp_vec);
172 VecNorm(temp_vec, NORM_2, &context_ptr_->rhs_preconditioned_norm);
173 VecDestroy(&temp_vec);
174 }
175 // If we have a single richardson iteration then the user probably wants
176 // only a single sweep. Therefore, we are going to combine the scattering
177 // source (normally included in the lhs_src_scope) into the sweep for the
178 // RHS, and just suppress the kspsolve part.
179 else
180 {
181 const int scope = gs_context_ptr->rhs_src_scope_ |
182 gs_context_ptr->lhs_src_scope_;
183 gs_context_ptr->set_source_function_(
184 groupset, lbs_solver.QMomentsLocal(), lbs_solver.PhiOldLocal(), scope);
185
186 //=================================================== Apply transport
187 //operator
188 gs_context_ptr->ApplyInverseTransportOperator(scope);
189
190 //=================================================== Assemble PETSc vector
191 lbs_solver.SetGSPETScVecFromPrimarySTLvector(
192 groupset, x_, PhiSTLOption::PHI_NEW);
193
194 //============================================= Compute RHS norm
195 VecNorm(x_, NORM_2, &context_ptr_->rhs_norm);
196
197 //============================================= Compute precondition RHS
198 //norm
199 PC pc;
200 KSPGetPC(solver_, &pc);
201 Vec temp_vec;
202 VecDuplicate(x_, &temp_vec);
203 PCApply(pc, x_, temp_vec);
204 VecNorm(temp_vec, NORM_2, &context_ptr_->rhs_preconditioned_norm);
205 VecDestroy(&temp_vec);
206
207 SetKSPSolveSuppressionFlag(true);
208 }
209}
210
211/**For this callback we simply restore the q_moments_local vector.*/
212template <>
214{
215 //============================================= Get convergence reason
216 if (not GetKSPSolveSuppressionFlag())
217 {
218 KSPConvergedReason reason;
219 KSPGetConvergedReason(solver_, &reason);
220 if (reason != KSP_CONVERGED_RTOL and reason != KSP_DIVERGED_ITS)
221 Chi::log.Log0Warning() << "Krylov solver failed. "
222 << "Reason: "
224 reason);
225 }
226
227 //============================================= Copy x to local solution
228 auto gs_context_ptr = GetGSContextPtr(context_ptr_);
229
230 auto& groupset = gs_context_ptr->groupset_;
231 auto& lbs_solver = gs_context_ptr->lbs_solver_;
232
233 lbs_solver.SetPrimarySTLvectorFromGSPETScVec(
234 groupset, x_, PhiSTLOption::PHI_NEW);
235 lbs_solver.SetPrimarySTLvectorFromGSPETScVec(
236 groupset, x_, PhiSTLOption::PHI_OLD);
237
238 //============================================= Restore saved q_moms
239 lbs_solver.QMomentsLocal() = saved_q_moments_local_;
240
241 //============================================= Context specific callback
242 gs_context_ptr->PostSolveCallback();
243}
244
245template <>
247{
248 MatDestroy(&A_);
249}
250} // namespace lbs
static chi::Timer program_timer
Definition: chi_runtime.h:79
static chi::ChiLog & log
Definition: chi_runtime.h:81
LogStream Log0Warning()
Definition: chi_log.h:231
LogStream Log(LOG_LVL level=LOG_0)
Definition: chi_log.cc:35
std::string GetTimeString() const
Definition: chi_timer.cc:38
void PreSolveCallback() override
void PostSolveCallback() override
void PreSetupCallback() override
virtual ~WGSLinearSolver() override
void SetConvergenceTest() override
virtual void SetSystemSize() override
void PostSetupCallback() override
void SetRHS() override
void SetInitialGuess() override
virtual void SetSystem() override
void SetPreconditioner() override
Vec CreateVector(int64_t local_size, int64_t global_size)
std::string GetPETScConvergedReasonstring(KSPConvergedReason reason)
@ ZERO_INCOMING_DELAYED_PSI
Definition: lbs_structs.h:96
PetscErrorCode GSConvergenceTest(KSP ksp, PetscInt n, PetscReal rnorm, KSPConvergedReason *convergedReason, void *)
struct _p_Vec * Vec
#define GetGSContextPtr(x)
#define sc_int64_t