Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
hyteg
hyteg
Commits
03be1737
Commit
03be1737
authored
Nov 29, 2021
by
Marcel Koch
Browse files
separate GinkgoBlockSolver.hpp into multiple files
parent
820a959a
Changes
3
Hide whitespace changes
Inline
Side-by-side
src/hyteg/ginkgo/GinkgoCGSolver.hpp
View file @
03be1737
...
...
@@ -14,6 +14,8 @@
#include "hyteg/ginkgo/GinkgoSparseMatrixProxy.hpp"
#include "hyteg/ginkgo/GinkgoUtilities.hpp"
#include "hyteg/ginkgo/GinkgoVectorProxy.hpp"
#include "hyteg/ginkgo/GinkgoCommunication.hpp"
#include "hyteg/ginkgo/GinkgoDirichletHandlig.hpp"
#include "hyteg/p1functionspace/P1Petsc.hpp"
#include "hyteg/p2functionspace/P2Petsc.hpp"
#include "hyteg/p2functionspace/P2ProjectNormalOperator.hpp"
...
...
@@ -38,305 +40,6 @@
namespace
hyteg
{
class
DirichletHandlerBase
{
public:
using
vec
=
gko
::
distributed
::
Vector
<
real_t
,
int32_t
>
;
using
mtx
=
gko
::
distributed
::
Matrix
<
real_t
,
int32_t
>
;
DirichletHandlerBase
(
std
::
vector
<
int32_t
>
bcIndices
,
const
vec
*
dir_vals
,
std
::
shared_ptr
<
mtx
>
matrix
,
bool
doUpdate
=
true
)
:
bcIndices_
(
std
::
move
(
bcIndices
)
)
,
dir_vals_
(
dir_vals
)
,
matrix_
(
std
::
move
(
matrix
)
)
,
doUpdate_
(
doUpdate
)
{}
virtual
void
update_matrix
()
=
0
;
virtual
void
update_solution
(
vec
*
)
=
0
;
virtual
vec
*
get_initial_guess
(
const
vec
*
)
=
0
;
virtual
vec
*
get_rhs
(
const
vec
*
,
const
vec
*
)
=
0
;
protected:
std
::
vector
<
int32_t
>
bcIndices_
;
const
vec
*
dir_vals_
;
std
::
shared_ptr
<
mtx
>
matrix_
;
bool
doUpdate_
;
};
template
<
typename
T
>
std
::
optional
<
T
>
to_local_idx
(
T
global_idx
,
const
gko
::
distributed
::
Partition
<
T
>*
part
,
T
rank
)
{
auto
local_start
=
part
->
get_const_range_bounds
()[
rank
];
auto
local_end
=
part
->
get_const_range_bounds
()[
rank
+
1
];
if
(
local_start
<=
global_idx
&&
global_idx
<
local_end
)
{
return
global_idx
-
local_start
;
}
else
{
return
{};
}
}
class
PeneltyDirichletHandler
:
public
DirichletHandlerBase
{
public:
using
DirichletHandlerBase
::
DirichletHandlerBase
;
void
update_matrix
()
override
{
if
(
this
->
doUpdate_
)
{
auto
local_mat
=
this
->
matrix_
->
get_local_diag
();
auto
row_ptrs
=
local_mat
->
get_const_row_ptrs
();
auto
cols
=
local_mat
->
get_const_col_idxs
();
auto
vals
=
local_mat
->
get_values
();
auto
rank
=
this
->
matrix_
->
get_communicator
()
->
rank
();
auto
part
=
this
->
matrix_
->
get_partition
();
for
(
auto
idx
:
this
->
bcIndices_
)
{
if
(
auto
lidx
=
to_local_idx
(
idx
,
part
,
rank
);
lidx
)
{
for
(
int
i
=
row_ptrs
[
*
lidx
];
i
<
row_ptrs
[
*
lidx
+
1
];
++
i
)
{
if
(
*
lidx
==
cols
[
i
]
)
vals
[
i
]
+=
tgv
;
}
}
}
}
}
void
update_solution
(
vec
*
cur_solution
)
override
{}
vec
*
get_initial_guess
(
const
vec
*
cur_initial_guess
)
override
{
initial_guess_
=
gko
::
clone
(
cur_initial_guess
);
auto
rank
=
this
->
matrix_
->
get_communicator
()
->
rank
();
auto
part
=
this
->
matrix_
->
get_partition
();
for
(
auto
idx
:
this
->
bcIndices_
)
{
if
(
auto
lidx
=
to_local_idx
(
idx
,
part
,
rank
);
lidx
)
{
initial_guess_
->
get_local
()
->
at
(
*
lidx
)
=
this
->
dir_vals_
->
get_local
()
->
at
(
*
lidx
);
}
}
return
initial_guess_
.
get
();
}
vec
*
get_rhs
(
const
vec
*
cur_rhs
,
const
vec
*
)
override
{
rhs_
=
gko
::
clone
(
cur_rhs
);
auto
rank
=
this
->
matrix_
->
get_communicator
()
->
rank
();
auto
part
=
this
->
matrix_
->
get_partition
();
// use something like a scattered view + view->scale
for
(
auto
idx
:
this
->
bcIndices_
)
{
if
(
auto
lidx
=
to_local_idx
(
idx
,
part
,
rank
);
lidx
)
{
rhs_
->
get_local
()
->
at
(
*
lidx
)
*=
tgv
;
}
}
return
rhs_
.
get
();
}
private:
std
::
unique_ptr
<
vec
>
initial_guess_
;
std
::
unique_ptr
<
vec
>
rhs_
;
double
tgv
=
1e30
;
};
class
ZeroRowsDirichletHandler
:
public
DirichletHandlerBase
{
using
valueType
=
typename
vec
::
value_type
;
public:
using
DirichletHandlerBase
::
DirichletHandlerBase
;
void
update_matrix
()
override
{
if
(
this
->
doUpdate_
)
{
auto
rank
=
this
->
matrix_
->
get_communicator
()
->
rank
();
auto
part
=
this
->
matrix_
->
get_partition
();
auto
local_mat
=
this
->
matrix_
->
get_local_diag
();
auto
row_ptrs
=
local_mat
->
get_const_row_ptrs
();
auto
cols
=
local_mat
->
get_const_col_idxs
();
auto
vals
=
local_mat
->
get_values
();
auto
local_offdiag_mat
=
this
->
matrix_
->
get_local_offdiag
();
auto
od_row_ptrs
=
local_offdiag_mat
->
get_const_row_ptrs
();
auto
od_vals
=
local_offdiag_mat
->
get_values
();
for
(
auto
idx
:
this
->
bcIndices_
)
{
if
(
auto
lidx
=
to_local_idx
(
idx
,
part
,
rank
);
lidx
)
{
for
(
int
i
=
row_ptrs
[
*
lidx
];
i
<
row_ptrs
[
*
lidx
+
1
];
++
i
)
{
if
(
*
lidx
!=
cols
[
i
]
)
vals
[
i
]
=
gko
::
zero
<
valueType
>
();
else
vals
[
i
]
=
gko
::
one
<
valueType
>
();
}
for
(
int
i
=
od_row_ptrs
[
*
lidx
];
i
<
od_row_ptrs
[
*
lidx
+
1
];
++
i
)
{
od_vals
[
i
]
=
gko
::
zero
<
valueType
>
();
}
}
}
}
}
void
update_solution
(
vec
*
cur_solution
)
override
{
auto
rank
=
this
->
matrix_
->
get_communicator
()
->
rank
();
auto
part
=
this
->
matrix_
->
get_partition
();
auto
one
=
gko
::
initialize
<
gko
::
matrix
::
Dense
<
valueType
>
>
(
{
1
},
cur_solution
->
get_executor
()
);
cur_solution
->
add_scaled
(
one
.
get
(),
orig_initial_guess_
);
for
(
auto
idx
:
this
->
bcIndices_
)
{
if
(
auto
lidx
=
to_local_idx
(
idx
,
part
,
rank
);
lidx
)
{
cur_solution
->
get_local
()
->
at
(
*
lidx
)
=
this
->
dir_vals_
->
get_local
()
->
at
(
*
lidx
);
}
}
}
vec
*
get_rhs
(
const
vec
*
cur_rhs
,
const
vec
*
orig_initial_guess
)
override
{
auto
rank
=
this
->
matrix_
->
get_communicator
()
->
rank
();
auto
part
=
this
->
matrix_
->
get_partition
();
auto
one
=
gko
::
initialize
<
gko
::
matrix
::
Dense
<
valueType
>
>
(
{
1
},
cur_rhs
->
get_executor
()
);
auto
neg_one
=
gko
::
initialize
<
gko
::
matrix
::
Dense
<
valueType
>
>
(
{
-
1
},
cur_rhs
->
get_executor
()
);
rhs_
=
gko
::
clone
(
cur_rhs
);
this
->
matrix_
->
apply
(
neg_one
.
get
(),
orig_initial_guess
,
one
.
get
(),
rhs_
.
get
()
);
for
(
auto
idx
:
this
->
bcIndices_
)
{
if
(
auto
lidx
=
to_local_idx
(
idx
,
part
,
rank
);
lidx
)
{
rhs_
->
get_local
()
->
at
(
*
lidx
)
=
gko
::
zero
<
valueType
>
();
}
}
return
rhs_
.
get
();
}
vec
*
get_initial_guess
(
const
vec
*
cur_initial_guess
)
override
{
auto
rank
=
this
->
matrix_
->
get_communicator
()
->
rank
();
auto
part
=
this
->
matrix_
->
get_partition
();
orig_initial_guess_
=
cur_initial_guess
;
z_
=
gko
::
clone
(
cur_initial_guess
);
for
(
auto
idx
:
this
->
bcIndices_
)
{
if
(
auto
lidx
=
to_local_idx
(
idx
,
part
,
rank
);
lidx
)
{
z_
->
get_local
()
->
at
(
*
lidx
)
=
gko
::
zero
<
valueType
>
();
}
}
return
z_
.
get
();
}
private:
const
vec
*
orig_initial_guess_
=
nullptr
;
std
::
unique_ptr
<
vec
>
rhs_
=
nullptr
;
std
::
unique_ptr
<
vec
>
z_
=
nullptr
;
};
enum
class
constraints
{
penalty
,
zero_row
};
std
::
pair
<
int
,
int
>
local_range
(
const
uint_t
local_size
,
std
::
shared_ptr
<
gko
::
mpi
::
communicator
>
comm
)
{
uint_t
start
=
0
;
uint_t
end
=
0
;
MPI_Exscan
(
&
local_size
,
&
start
,
1
,
MPI_UINT64_T
,
MPI_SUM
,
comm
->
get
()
);
MPI_Scan
(
&
local_size
,
&
end
,
1
,
MPI_UINT64_T
,
MPI_SUM
,
comm
->
get
()
);
return
std
::
make_pair
(
start
,
end
);
}
template
<
typename
LocalIndexType
>
std
::
vector
<
gko
::
Array
<
gko
::
distributed
::
global_index_type
>
>
compute_gather_idxs
(
std
::
shared_ptr
<
gko
::
distributed
::
Partition
<
LocalIndexType
>
>
partition
)
{
auto
ranges
=
partition
->
get_range_bounds
();
auto
range_to_pid
=
partition
->
get_const_part_ids
();
std
::
vector
<
std
::
vector
<
gko
::
distributed
::
global_index_type
>
>
idxs
(
partition
->
get_num_parts
()
);
for
(
int
i
=
0
;
i
<
partition
->
get_num_ranges
();
++
i
)
{
auto
pid
=
range_to_pid
[
i
];
auto
range_start
=
ranges
[
i
];
auto
range_end
=
ranges
[
i
+
1
];
for
(
int
j
=
range_start
;
j
<
range_end
;
++
j
)
{
idxs
[
pid
].
push_back
(
j
);
}
}
std
::
vector
<
gko
::
Array
<
gko
::
distributed
::
global_index_type
>
>
arr_idxs
(
idxs
.
size
()
);
for
(
size_t
i
=
0
;
i
<
idxs
.
size
();
++
i
)
{
arr_idxs
[
i
]
=
gko
::
Array
<
gko
::
distributed
::
global_index_type
>
{
gko
::
ReferenceExecutor
::
create
(),
idxs
[
i
].
size
()
};
std
::
copy_n
(
idxs
[
i
].
data
(),
idxs
[
i
].
size
(),
arr_idxs
[
i
].
get_data
()
);
}
return
arr_idxs
;
}
template
<
typename
ValueType
,
typename
LocalIndexType
>
void
scatter_global_vector
(
const
gko
::
matrix
::
Dense
<
ValueType
>*
input
,
gko
::
distributed
::
Vector
<
ValueType
,
LocalIndexType
>*
output
,
const
std
::
vector
<
gko
::
Array
<
gko
::
distributed
::
global_index_type
>
>&
gather_idxs
,
std
::
shared_ptr
<
gko
::
mpi
::
communicator
>
comm
)
{
std
::
vector
<
int
>
counts
(
gather_idxs
.
size
()
);
std
::
vector
<
int
>
disp
(
counts
.
size
()
+
1
);
std
::
transform
(
gather_idxs
.
cbegin
(),
gather_idxs
.
cend
(),
counts
.
begin
(),
[](
const
auto
&
idxs
)
{
return
idxs
.
get_num_elems
();
}
);
std
::
partial_sum
(
counts
.
cbegin
(),
counts
.
cend
(),
disp
.
begin
()
+
1
);
auto
exec
=
input
->
get_executor
();
gko
::
Array
<
ValueType
>
sorted_buffer
{
exec
,
input
->
get_num_stored_elements
()
};
if
(
comm
->
rank
()
==
0
)
{
for
(
size_t
i
=
0
;
i
<
gather_idxs
.
size
();
++
i
)
{
const
auto
&
idxs
=
gather_idxs
[
i
];
auto
gathered_vec
=
input
->
row_gather
(
&
idxs
);
exec
->
copy
(
counts
[
i
],
gathered_vec
->
get_const_values
(),
sorted_buffer
.
get_data
()
+
disp
[
i
]
);
}
}
gko
::
mpi
::
scatter
(
sorted_buffer
.
get_data
(),
counts
.
data
(),
disp
.
data
(),
output
->
get_local
()
->
get_values
(),
output
->
get_local
()
->
get_num_stored_elements
(),
0
,
comm
);
}
template
<
class
OperatorType
>
class
GinkgoCGSolver
:
public
Solver
<
OperatorType
>
{
...
...
src/hyteg/ginkgo/GinkgoCommunication.hpp
0 → 100644
View file @
03be1737
#pragma once
#include <memory>
#include <utility>
#include <vector>
#include "ginkgo/core/base/array.hpp"
#include "ginkgo/core/base/mpi.hpp"
#include "ginkgo/core/distributed/partition.hpp"
namespace
hyteg
{
std
::
pair
<
int
,
int
>
local_range
(
const
uint_t
local_size
,
std
::
shared_ptr
<
gko
::
mpi
::
communicator
>
comm
)
{
uint_t
start
=
0
;
uint_t
end
=
0
;
MPI_Exscan
(
&
local_size
,
&
start
,
1
,
MPI_UINT64_T
,
MPI_SUM
,
comm
->
get
()
);
MPI_Scan
(
&
local_size
,
&
end
,
1
,
MPI_UINT64_T
,
MPI_SUM
,
comm
->
get
()
);
return
std
::
make_pair
(
start
,
end
);
}
template
<
typename
LocalIndexType
>
std
::
vector
<
gko
::
Array
<
gko
::
distributed
::
global_index_type
>
>
compute_gather_idxs
(
std
::
shared_ptr
<
gko
::
distributed
::
Partition
<
LocalIndexType
>
>
partition
)
{
auto
ranges
=
partition
->
get_range_bounds
();
auto
range_to_pid
=
partition
->
get_const_part_ids
();
std
::
vector
<
std
::
vector
<
gko
::
distributed
::
global_index_type
>
>
idxs
(
partition
->
get_num_parts
()
);
for
(
int
i
=
0
;
i
<
partition
->
get_num_ranges
();
++
i
)
{
auto
pid
=
range_to_pid
[
i
];
auto
range_start
=
ranges
[
i
];
auto
range_end
=
ranges
[
i
+
1
];
for
(
int
j
=
range_start
;
j
<
range_end
;
++
j
)
{
idxs
[
pid
].
push_back
(
j
);
}
}
std
::
vector
<
gko
::
Array
<
gko
::
distributed
::
global_index_type
>
>
arr_idxs
(
idxs
.
size
()
);
for
(
size_t
i
=
0
;
i
<
idxs
.
size
();
++
i
)
{
arr_idxs
[
i
]
=
gko
::
Array
<
gko
::
distributed
::
global_index_type
>
{
gko
::
ReferenceExecutor
::
create
(),
idxs
[
i
].
size
()
};
std
::
copy_n
(
idxs
[
i
].
data
(),
idxs
[
i
].
size
(),
arr_idxs
[
i
].
get_data
()
);
}
return
arr_idxs
;
}
template
<
typename
ValueType
,
typename
LocalIndexType
>
void
scatter_global_vector
(
const
gko
::
matrix
::
Dense
<
ValueType
>*
input
,
gko
::
distributed
::
Vector
<
ValueType
,
LocalIndexType
>*
output
,
const
std
::
vector
<
gko
::
Array
<
gko
::
distributed
::
global_index_type
>
>&
gather_idxs
,
std
::
shared_ptr
<
gko
::
mpi
::
communicator
>
comm
)
{
std
::
vector
<
int
>
counts
(
gather_idxs
.
size
()
);
std
::
vector
<
int
>
disp
(
counts
.
size
()
+
1
);
std
::
transform
(
gather_idxs
.
cbegin
(),
gather_idxs
.
cend
(),
counts
.
begin
(),
[](
const
auto
&
idxs
)
{
return
idxs
.
get_num_elems
();
}
);
std
::
partial_sum
(
counts
.
cbegin
(),
counts
.
cend
(),
disp
.
begin
()
+
1
);
auto
exec
=
input
->
get_executor
();
gko
::
Array
<
ValueType
>
sorted_buffer
{
exec
,
input
->
get_num_stored_elements
()
};
if
(
comm
->
rank
()
==
0
)
{
for
(
size_t
i
=
0
;
i
<
gather_idxs
.
size
();
++
i
)
{
const
auto
&
idxs
=
gather_idxs
[
i
];
auto
gathered_vec
=
input
->
row_gather
(
&
idxs
);
exec
->
copy
(
counts
[
i
],
gathered_vec
->
get_const_values
(),
sorted_buffer
.
get_data
()
+
disp
[
i
]
);
}
}
gko
::
mpi
::
scatter
(
sorted_buffer
.
get_data
(),
counts
.
data
(),
disp
.
data
(),
output
->
get_local
()
->
get_values
(),
output
->
get_local
()
->
get_num_stored_elements
(),
0
,
comm
);
}
}
// namespace hyteg
\ No newline at end of file
src/hyteg/ginkgo/GinkgoDirichletHandlig.hpp
0 → 100644
View file @
03be1737
#pragma once
#include <vector>
#include <optional>
#include <utility>
#include "ginkgo/core/matrix/csr.hpp"
#include "ginkgo/core/matrix/dense.hpp"
#include "ginkgo/core/distributed/partition.hpp"
namespace
hyteg
{
class
DirichletHandlerBase
{
public:
using
vec
=
gko
::
distributed
::
Vector
<
real_t
,
int32_t
>
;
using
mtx
=
gko
::
distributed
::
Matrix
<
real_t
,
int32_t
>
;
DirichletHandlerBase
(
std
::
vector
<
int32_t
>
bcIndices
,
const
vec
*
dir_vals
,
std
::
shared_ptr
<
mtx
>
matrix
,
bool
doUpdate
=
true
)
:
bcIndices_
(
std
::
move
(
bcIndices
)
)
,
dir_vals_
(
dir_vals
)
,
matrix_
(
std
::
move
(
matrix
)
)
,
doUpdate_
(
doUpdate
)
{}
virtual
void
update_matrix
()
=
0
;
virtual
void
update_solution
(
vec
*
)
=
0
;
virtual
vec
*
get_initial_guess
(
const
vec
*
)
=
0
;
virtual
vec
*
get_rhs
(
const
vec
*
,
const
vec
*
)
=
0
;
protected:
std
::
vector
<
int32_t
>
bcIndices_
;
const
vec
*
dir_vals_
;
std
::
shared_ptr
<
mtx
>
matrix_
;
bool
doUpdate_
;
};
template
<
typename
T
>
std
::
optional
<
T
>
to_local_idx
(
T
global_idx
,
const
gko
::
distributed
::
Partition
<
T
>*
part
,
T
rank
)
{
auto
local_start
=
part
->
get_const_range_bounds
()[
rank
];
auto
local_end
=
part
->
get_const_range_bounds
()[
rank
+
1
];
if
(
local_start
<=
global_idx
&&
global_idx
<
local_end
)
{
return
global_idx
-
local_start
;
}
else
{
return
{};
}
}
class
PeneltyDirichletHandler
:
public
DirichletHandlerBase
{
public:
using
DirichletHandlerBase
::
DirichletHandlerBase
;
void
update_matrix
()
override
{
if
(
this
->
doUpdate_
)
{
auto
local_mat
=
this
->
matrix_
->
get_local_diag
();
auto
row_ptrs
=
local_mat
->
get_const_row_ptrs
();
auto
cols
=
local_mat
->
get_const_col_idxs
();
auto
vals
=
local_mat
->
get_values
();
auto
rank
=
this
->
matrix_
->
get_communicator
()
->
rank
();
auto
part
=
this
->
matrix_
->
get_partition
();
for
(
auto
idx
:
this
->
bcIndices_
)
{
if
(
auto
lidx
=
to_local_idx
(
idx
,
part
,
rank
);
lidx
)
{
for
(
int
i
=
row_ptrs
[
*
lidx
];
i
<
row_ptrs
[
*
lidx
+
1
];
++
i
)
{
if
(
*
lidx
==
cols
[
i
]
)
vals
[
i
]
+=
tgv
;
}
}
}
}
}
void
update_solution
(
vec
*
cur_solution
)
override
{}
vec
*
get_initial_guess
(
const
vec
*
cur_initial_guess
)
override
{
initial_guess_
=
gko
::
clone
(
cur_initial_guess
);
auto
rank
=
this
->
matrix_
->
get_communicator
()
->
rank
();
auto
part
=
this
->
matrix_
->
get_partition
();
for
(
auto
idx
:
this
->
bcIndices_
)
{
if
(
auto
lidx
=
to_local_idx
(
idx
,
part
,
rank
);
lidx
)
{
initial_guess_
->
get_local
()
->
at
(
*
lidx
)
=
this
->
dir_vals_
->
get_local
()
->
at
(
*
lidx
);
}
}
return
initial_guess_
.
get
();
}
vec
*
get_rhs
(
const
vec
*
cur_rhs
,
const
vec
*
)
override
{
rhs_
=
gko
::
clone
(
cur_rhs
);
auto
rank
=
this
->
matrix_
->
get_communicator
()
->
rank
();
auto
part
=
this
->
matrix_
->
get_partition
();
// use something like a scattered view + view->scale
for
(
auto
idx
:
this
->
bcIndices_
)
{
if
(
auto
lidx
=
to_local_idx
(
idx
,
part
,
rank
);
lidx
)
{
rhs_
->
get_local
()
->
at
(
*
lidx
)
*=
tgv
;
}
}
return
rhs_
.
get
();
}
private:
std
::
unique_ptr
<
vec
>
initial_guess_
;
std
::
unique_ptr
<
vec
>
rhs_
;
double
tgv
=
1e30
;
};
class
ZeroRowsDirichletHandler
:
public
DirichletHandlerBase
{
using
valueType
=
typename
vec
::
value_type
;
public:
using
DirichletHandlerBase
::
DirichletHandlerBase
;
void
update_matrix
()
override
{
if
(
this
->
doUpdate_
)
{
auto
rank
=
this
->
matrix_
->
get_communicator
()
->
rank
();
auto
part
=
this
->
matrix_
->
get_partition
();
auto
local_mat
=
this
->
matrix_
->
get_local_diag
();
auto
row_ptrs
=
local_mat
->
get_const_row_ptrs
();
auto
cols
=
local_mat
->
get_const_col_idxs
();
auto
vals
=
local_mat
->
get_values
();
auto
local_offdiag_mat
=
this
->
matrix_
->
get_local_offdiag
();
auto
od_row_ptrs
=
local_offdiag_mat
->
get_const_row_ptrs
();
auto
od_vals
=
local_offdiag_mat
->
get_values
();
for
(
auto
idx
:
this
->
bcIndices_
)
{
if
(
auto
lidx
=
to_local_idx
(
idx
,
part
,
rank
);
lidx
)
{
for
(
int
i
=
row_ptrs
[
*
lidx
];
i
<
row_ptrs
[
*
lidx
+
1
];
++
i
)
{
if
(
*
lidx
!=
cols
[
i
]
)
vals
[
i
]
=
gko
::
zero
<
valueType
>
();
else
vals
[
i
]
=
gko
::
one
<
valueType
>
();
}
for
(
int
i
=
od_row_ptrs
[
*
lidx
];
i
<
od_row_ptrs
[
*
lidx
+
1
];
++
i
)
{
od_vals
[
i
]
=
gko
::
zero
<
valueType
>
();
}
}
}
}
}
void
update_solution
(
vec
*
cur_solution
)
override
{
auto
rank
=
this
->
matrix_
->
get_communicator
()
->
rank
();
auto
part
=
this
->
matrix_
->
get_partition
();
auto
one
=
gko
::
initialize
<
gko
::
matrix
::
Dense
<
valueType
>
>
(
{
1
},
cur_solution
->
get_executor
()
);
cur_solution
->
add_scaled
(
one
.
get
(),
orig_initial_guess_
);
for
(
auto
idx
:
this
->
bcIndices_
)
{
if
(
auto
lidx
=
to_local_idx
(
idx
,
part
,
rank
);
lidx
)
{
cur_solution
->
get_local
()
->
at
(
*
lidx
)
=
this
->
dir_vals_
->
get_local
()
->
at
(
*
lidx
);
}
}
}
vec
*
get_rhs
(
const
vec
*
cur_rhs
,
const
vec
*
orig_initial_guess
)
override
{
auto
rank
=
this
->
matrix_
->
get_communicator
()
->
rank
();