Commit d61a0109 authored by Marcel Koch's avatar Marcel Koch
Browse files

refactor ginkgo block solver

parent 77041849
......@@ -103,22 +103,22 @@ class GinkgoBlockSolver : public Solver< OperatorType >
const typename OperatorType::dstType& b,
const walberla::uint_t level ) override
{
x.getStorage()->getTimingTree()->start( "Ginkgo CG Solver" );
const auto& timingTree = x.getStorage()->getTimingTree();
timingTree->start( "Ginkgo CG Solver" );
const auto num_local_dofs = numberOfLocalDoFs< typename FunctionType::Tag >( *storage_, level );
const auto num_global_dofs = numberOfGlobalDoFs< typename FunctionType::Tag >( *storage_, level, comm_->get() );
auto [start, end] = local_range( num_local_dofs, comm_ );
part_ = gko::share( gko::distributed::Partition<>::build_from_local_range( host_exec_, start, end, comm_ ) );
if ( !part_ )
{
auto [start, end] = local_range( num_local_dofs, comm_ );
part_ = gko::share( gko::distributed::Partition<>::build_from_local_range( host_exec_, start, end, comm_ ) );
}
// maybe called in parallel, thus need to keep it for empty processes
num_.copyBoundaryConditionFromFunction( x );
num_.enumerate( level );
std::vector< int32_t > bcIndices;
hyteg::petsc::applyDirichletBC( num_, bcIndices, level );
std::sort( std::begin( bcIndices ), std::end( bcIndices ) );
auto x_gko =
vec::create( host_exec_, comm_, part_, gko::dim< 2 >{ num_global_dofs, 1 }, gko::dim< 2 >{ num_local_dofs, 1 } );
auto b_gko =
......@@ -130,19 +130,94 @@ class GinkgoBlockSolver : public Solver< OperatorType >
hyteg::petsc::createVectorFromFunction(
b, num_, std::make_shared< GinkgoVectorProxy >( b_gko.get(), gko::dim< 2 >{ num_global_dofs, 1 }, part_ ), level, All );
x.getStorage()->getTimingTree()->start( "Ginkgo CG Solver Set-up distributed matrix" );
if(!host_monolithic_matrix_ || reassembleMatrix_)
{
this->setup_ginkgo(A, x, level, b_gko.get());
}
auto rhs = dir_handler_->get_rhs( b_gko.get(), x_gko.get() );
auto x0 = dir_handler_->get_initial_guess( x_gko.get() );
timingTree->start( "Ginkgo CG Solver Gather vectors" );
auto global_rhs = dense::create( host_exec_ );
auto global_x0 = dense::create( host_exec_ );
rhs->convert_to( gko::lend( global_rhs ) );
x0->convert_to( gko::lend( global_x0 ) );
timingTree->stop( "Ginkgo CG Solver Gather vectors" );
auto log = gko::share( gko::log::Convergence< valueType >::create( solver_exec_ ) );
if ( monolithic_matrix_->get_size() )
{
auto factory = gko::clone( solver_->get_stop_criterion_factory() );
factory->add_logger( log );
solver_->set_stop_criterion_factory( gko::share( factory ) );
timingTree->start( "Ginkgo CG Solver Apply" );
if ( comm_->size() > 1 )
{
auto permuted_global_rhs = gko::as< dense >( global_rhs->row_permute( &perm_ ) );
auto permuted_global_x0 = gko::as< dense >( global_x0->row_permute( &perm_ ) );
solver_->apply( gko::lend( permuted_global_rhs ), gko::lend( permuted_global_x0 ) );
permuted_global_x0->inverse_row_permute( &perm_, global_x0.get() );
}
else
{
solver_->apply( gko::lend( global_rhs ), gko::lend( global_x0 ) );
}
timingTree->stop( "Ginkgo CG Solver Apply" );
}
WALBERLA_MPI_BARRIER(); // only necessary to get correct timings
timingTree->start( "Ginkgo CG Solver Scatter vectors" );
gather_idxs_ = compute_gather_idxs( part_ );
scatter_global_vector( global_x0.get(), x0, gather_idxs_, comm_ );
timingTree->stop( "Ginkgo CG Solver Scatter vectors" );
dir_handler_->update_solution( x0 );
hyteg::petsc::createFunctionFromVector(
x, num_, std::make_shared< GinkgoVectorProxy >( x0, gko::dim< 2 >{ num_global_dofs, 1 }, part_ ), level, All );
if ( printInfo_ && comm_->rank() == 0 )
{
WALBERLA_LOG_INFO_ON_ROOT(
"[Ginkgo CG]" << ( !log->has_converged() ? " NOT " : " " ) << "converged after " << log->get_num_iterations()
<< " iterations, residual norm: "
<< solver_exec_->copy_val_to_host(
gko::as< gko::matrix::Dense< valueType > >( log->get_residual_norm() )->get_const_values() ) );
}
timingTree->stop( "Ginkgo CG Solver" );
}
void setPrintInfo( bool printInfo ) { printInfo_ = printInfo; }
void setReassembleMatrix( bool reassembleMatrix ) { reassembleMatrix_ = reassembleMatrix; }
private:
void setup_ginkgo( const OperatorType& A,
const typename OperatorType::srcType& x,
const walberla::uint_t level,
const vec* b_gko){
const auto& timingTree = x.getStorage()->getTimingTree();
const auto num_global_dofs = b_gko->get_size()[0];
std::vector< int32_t > bcIndices;
hyteg::petsc::applyDirichletBC( num_, bcIndices, level );
std::sort( std::begin( bcIndices ), std::end( bcIndices ) );
timingTree->start( "Ginkgo CG Solver Set-up distributed matrix" );
host_monolithic_matrix_ = gko::share( mtx::create( host_exec_, comm_ ) );
auto dir_handler = std::make_unique< ZeroRowsDirichletHandler >( bcIndices, b_gko.get(), host_monolithic_matrix_, true );
dir_handler_ = std::make_unique< ZeroRowsDirichletHandler >( bcIndices, b_gko, host_monolithic_matrix_, true );
{
auto proxy = std::make_shared< GinkgoSparseMatrixProxy< mtx > >(
host_monolithic_matrix_.get(), gko::dim< 2 >{ num_global_dofs, num_global_dofs }, part_ );
hyteg::petsc::createMatrix( A, num_, num_, proxy, level, All );
proxy->finalize();
dir_handler->update_matrix();
dir_handler_->update_matrix();
}
x.getStorage()->getTimingTree()->stop( "Ginkgo CG Solver Set-up distributed matrix" );
timingTree->stop( "Ginkgo CG Solver Set-up distributed matrix" );
x.getStorage()->getTimingTree()->start( "Ginkgo CG Solver Set-up distributed preconditioner" );
timingTree->start( "Ginkgo CG Solver Set-up distributed preconditioner" );
auto monolithic_preconditioner_ = gko::share( mtx::create( host_exec_, comm_ ) );
{
auto proxy = std::make_shared< GinkgoSparseMatrixProxy< mtx > >(
......@@ -150,24 +225,21 @@ class GinkgoBlockSolver : public Solver< OperatorType >
hyteg::petsc::createMatrix( blockPreconditioner_, num_, num_, proxy, level, All );
proxy->finalize();
auto dir_handler_p =
std::make_unique< ZeroRowsDirichletHandler >( bcIndices, b_gko.get(), monolithic_preconditioner_, true );
std::make_unique< ZeroRowsDirichletHandler >( bcIndices, b_gko, monolithic_preconditioner_, true );
dir_handler_p->update_matrix();
}
x.getStorage()->getTimingTree()->stop( "Ginkgo CG Solver Set-up distributed preconditioner" );
auto rhs = dir_handler->get_rhs( b_gko.get(), x_gko.get() );
auto x0 = dir_handler->get_initial_guess( x_gko.get() );
timingTree->stop( "Ginkgo CG Solver Set-up distributed preconditioner" );
std::vector< int32_t > vIndices;
std::vector< int32_t > pIndices;
x.getStorage()->getTimingTree()->start( "Ginkgo CG Solver Gather index sets" );
timingTree->start( "Ginkgo CG Solver Gather index sets" );
gatherIndices( vIndices, pIndices, *storage_, level, num_ );
std::sort( vIndices.begin(), vIndices.end() );
std::sort( pIndices.begin(), pIndices.end() );
x.getStorage()->getTimingTree()->stop( "Ginkgo CG Solver Gather index sets" );
timingTree->stop( "Ginkgo CG Solver Gather index sets" );
x.getStorage()->getTimingTree()->start( "Ginkgo CG Solver Set-up permutation" );
timingTree->start( "Ginkgo CG Solver Set-up permutation" );
std::vector< int32_t > perm_vec;
std::vector< int32_t > recv_sizes( comm_->size() );
......@@ -188,8 +260,8 @@ class GinkgoBlockSolver : public Solver< OperatorType >
gko::mpi::gather(
pIndices.data(), local_size, perm_vec.data() + global_v_size, recv_sizes.data(), recv_offsets.data(), 0, comm_ );
gko::Array< gko::int32 > perm{ solver_exec_, perm_vec.begin(), perm_vec.end() };
x.getStorage()->getTimingTree()->stop( "Ginkgo CG Solver Set-up permutation" );
perm_ = gko::Array< gko::int32 >{ solver_exec_, perm_vec.begin(), perm_vec.end() };
timingTree->stop( "Ginkgo CG Solver Set-up permutation" );
auto [p_v_min, p_v_max] = std::minmax_element( std::begin( vIndices ), std::end( vIndices ) );
auto [p_p_min, p_p_max] = std::minmax_element( std::begin( pIndices ), std::end( pIndices ) );
......@@ -200,35 +272,35 @@ class GinkgoBlockSolver : public Solver< OperatorType >
WALBERLA_ABORT( "Indices are NOT blocked: v" << v_span << ", p" << p_span )
}
x.getStorage()->getTimingTree()->start( "Ginkgo CG Solver Gather matrix" );
timingTree->start( "Ginkgo CG Solver Gather matrix" );
monolithic_matrix_ = gko::share( csr::create( solver_exec_ ) );
gko::as< mtx >( host_monolithic_matrix_ )->convert_to( monolithic_matrix_.get() );
x.getStorage()->getTimingTree()->stop( "Ginkgo CG Solver Gather matrix" );
timingTree->stop( "Ginkgo CG Solver Gather matrix" );
if ( comm_->size() > 1 && monolithic_matrix_->get_size() )
{
x.getStorage()->getTimingTree()->start( "Ginkgo CG Solver Permute matrix" );
monolithic_matrix_ = gko::as< csr >( monolithic_matrix_->permute( &perm ) );
x.getStorage()->getTimingTree()->stop( "Ginkgo CG Solver Permute matrix" );
timingTree->start( "Ginkgo CG Solver Permute matrix" );
monolithic_matrix_ = gko::as< csr >( monolithic_matrix_->permute( &perm_ ) );
timingTree->stop( "Ginkgo CG Solver Permute matrix" );
}
{
x.getStorage()->getTimingTree()->start( "Ginkgo CG Solver Gather preconditioner" );
timingTree->start( "Ginkgo CG Solver Gather preconditioner" );
auto device_monolithic_pre = gko::share( csr::create( solver_exec_ ) );
monolithic_preconditioner_->convert_to( device_monolithic_pre.get() );
x.getStorage()->getTimingTree()->stop( "Ginkgo CG Solver Gather preconditioner" );
timingTree->stop( "Ginkgo CG Solver Gather preconditioner" );
if ( device_monolithic_pre->get_size() && device_monolithic_pre->get_size() )
{
gko::span block_v_span{ 0, global_v_size };
gko::span block_p_span{ global_v_size, global_v_size + global_p_size };
x.getStorage()->getTimingTree()->start( "Ginkgo CG Solver Permute preconditioner" );
device_monolithic_pre = gko::as< csr >( device_monolithic_pre->permute( &perm ) );
x.getStorage()->getTimingTree()->stop( "Ginkgo CG Solver Permute preconditioner" );
timingTree->start( "Ginkgo CG Solver Permute preconditioner" );
device_monolithic_pre = gko::as< csr >( device_monolithic_pre->permute( &perm_ ) );
timingTree->stop( "Ginkgo CG Solver Permute preconditioner" );
std::shared_ptr< gko::LinOp > b_pre_v;
std::shared_ptr< gko::LinOp > b_pre_p;
x.getStorage()->getTimingTree()->start( "Ginkgo CG Solver Build block preconditioner" );
timingTree->start( "Ginkgo CG Solver Build block preconditioner" );
auto jac_gen =
gko::share( gko::preconditioner::Jacobi< valueType >::build().with_max_block_size( 1u ).on( solver_exec_ ) );
if ( velocityPreconditionerType_ == 0 )
......@@ -277,7 +349,7 @@ class GinkgoBlockSolver : public Solver< OperatorType >
.on( solver_exec_ )
->generate( gko::share( device_monolithic_pre->create_submatrix( block_v_span, block_v_span ) ) ) );
if(printInfo_)
if ( printInfo_ )
{
auto amg = gko::as< gko::solver::Multigrid >( b_pre_v );
WALBERLA_LOG_INFO_ON_ROOT( "[Ginkgo AMGX] Number of levels: " << amg->get_mg_level_list().size() + 1 )
......@@ -320,12 +392,11 @@ class GinkgoBlockSolver : public Solver< OperatorType >
gko::dim< 2 >{ global_v_size + global_p_size, global_v_size + global_p_size },
std::vector< std::vector< std::shared_ptr< gko::LinOp > > >{ { b_pre_v, b_pre_vp }, { b_pre_pv, b_pre_p } },
std::vector< gko::span >{ block_v_span, block_p_span } ) );
x.getStorage()->getTimingTree()->stop( "Ginkgo CG Solver Build block preconditioner" );
timingTree->stop( "Ginkgo CG Solver Build block preconditioner" );
}
WALBERLA_MPI_BARRIER();
}
auto log = gko::share( gko::log::Convergence< valueType >::create( solver_exec_ ) );
if ( monolithic_matrix_->get_size() )
{
solver_ = gko::solver::Gmres< valueType >::build()
......@@ -341,61 +412,9 @@ class GinkgoBlockSolver : public Solver< OperatorType >
.with_generated_preconditioner( block_preconditioner_ )
.on( solver_exec_ )
->generate( monolithic_matrix_ );
auto factory = gko::clone( solver_->get_stop_criterion_factory() );
factory->add_logger( log );
solver_->set_stop_criterion_factory( gko::share( factory ) );
}
x.getStorage()->getTimingTree()->start( "Ginkgo CG Solver Gather vectors" );
auto global_rhs = dense::create( host_exec_ );
auto global_x0 = dense::create( host_exec_ );
rhs->convert_to( gko::lend( global_rhs ) );
x0->convert_to( gko::lend( global_x0 ) );
x.getStorage()->getTimingTree()->stop( "Ginkgo CG Solver Gather vectors" );
if ( monolithic_matrix_->get_size() )
{
x.getStorage()->getTimingTree()->start( "Ginkgo CG Solver Apply" );
if ( comm_->size() > 1 )
{
auto permuted_global_rhs = gko::as< dense >( global_rhs->row_permute( &perm ) );
auto permuted_global_x0 = gko::as< dense >( global_x0->row_permute( &perm ) );
solver_->apply( gko::lend( permuted_global_rhs ), gko::lend( permuted_global_x0 ) );
permuted_global_x0->inverse_row_permute( &perm, global_x0.get() );
}
else
{
solver_->apply( gko::lend( global_rhs ), gko::lend( global_x0 ) );
}
x.getStorage()->getTimingTree()->stop( "Ginkgo CG Solver Apply" );
}
WALBERLA_MPI_BARRIER();
x.getStorage()->getTimingTree()->start( "Ginkgo CG Solver Scatter vectors" );
gather_idxs_ = compute_gather_idxs( part_ );
scatter_global_vector( global_x0.get(), x0, gather_idxs_, comm_ );
x.getStorage()->getTimingTree()->stop( "Ginkgo CG Solver Scatter vectors" );
dir_handler->update_solution( x0 );
hyteg::petsc::createFunctionFromVector(
x, num_, std::make_shared< GinkgoVectorProxy >( x0, gko::dim< 2 >{ num_global_dofs, 1 }, part_ ), level, All );
if ( printInfo_ && comm_->rank() == 0 )
{
WALBERLA_LOG_INFO_ON_ROOT(
"[Ginkgo CG]" << ( !log->has_converged() ? " NOT " : " " ) << "converged after " << log->get_num_iterations()
<< " iterations, residual norm: "
<< solver_exec_->copy_val_to_host(
gko::as< gko::matrix::Dense< valueType > >( log->get_residual_norm() )->get_const_values() ) );
}
x.getStorage()->getTimingTree()->stop( "Ginkgo CG Solver" );
}
void setPrintInfo( bool printInfo ) { printInfo_ = printInfo; }
private:
std::shared_ptr< PrimitiveStorage > storage_;
std::shared_ptr< const gko::Executor > host_exec_;
......@@ -409,18 +428,22 @@ class GinkgoBlockSolver : public Solver< OperatorType >
BlockPreconditioner_T blockPreconditioner_;
typename OperatorType::srcType::template FunctionType< int > num_;
std::shared_ptr< gko::LinOp > block_preconditioner_;
std::shared_ptr< gko::matrix::BlockMatrix > host_block_preconditioner_;
std::shared_ptr< gko::LinOp > block_preconditioner_;
std::shared_ptr< csr > monolithic_matrix_;
std::shared_ptr< mtx > host_monolithic_matrix_;
gko::Array< indexType > perm_;
std::unique_ptr< ZeroRowsDirichletHandler > dir_handler_;
std::vector< gko::Array< gko::distributed::global_index_type > > gather_idxs_;
uint_t velocityPreconditionerType_;
uint_t pressurePreconditionerType_;
bool printInfo_ = false;
bool reassembleMatrix_ = false;
bool printInfo_ = false;
};
} // namespace hyteg
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment