Commit 203622cb authored by Marcel Koch's avatar Marcel Koch
Browse files

always convert distributed matrix to seq csr

parent 429321c4
......@@ -274,6 +274,11 @@ class GinkgoCGSolver : public Solver< OperatorType >
using FunctionType = typename OperatorType::srcType;
using valueType = typename FunctionType::valueType;
using mtx = gko::distributed::Matrix< valueType, int32_t >;
using csr = gko::matrix::Csr< valueType, gko::distributed::global_index_type >;
using vec = gko::distributed::Vector< valueType >;
using dense = gko::matrix::Dense< valueType >;
GinkgoCGSolver() = default;
GinkgoCGSolver( const std::shared_ptr< PrimitiveStorage >& storage,
const uint_t& level,
......@@ -284,6 +289,7 @@ class GinkgoCGSolver : public Solver< OperatorType >
std::shared_ptr< gko::Executor > solver_exec = gko::ReferenceExecutor::create() )
: storage_( storage )
, level_( level )
, comm_( gko::mpi::communicator::create( storage->getSplitCommunicatorByPrimitiveDistribution() ) )
, constraints_type_( constraints_type )
, host_exec_( solver_exec->get_master() )
, solver_exec_( std::move( solver_exec ) )
......@@ -310,7 +316,7 @@ class GinkgoCGSolver : public Solver< OperatorType >
void solve( const OperatorType& A, const FunctionType& x, const FunctionType& b, const walberla::uint_t level ) override
{
const auto num_local_dofs = numberOfLocalDoFs< typename FunctionType::Tag >( *storage_, level );
const auto num_global_dofs = numberOfGlobalDoFs< typename FunctionType::Tag >( *storage_, level );
const auto num_global_dofs = numberOfGlobalDoFs< typename FunctionType::Tag >( *storage_, level, comm_->get() );
// maybe called in parallel, thus need to keep it for empty processes
num_.copyBoundaryConditionFromFunction( x );
......@@ -318,9 +324,8 @@ class GinkgoCGSolver : public Solver< OperatorType >
auto rank = walberla::mpi::MPIManager::instance()->rank();
auto comm = gko::share( gko::mpi::communicator::create( walberla::mpi::MPIManager::instance()->comm() ) );
auto [start, end] = local_range( num_local_dofs, comm );
auto part = gko::share( gko::distributed::Partition<>::build_from_local_range( host_exec_, start, end, comm ) );
auto [start, end] = local_range( num_local_dofs, comm_ );
auto part = gko::share( gko::distributed::Partition<>::build_from_local_range( host_exec_, start, end, comm_ ) );
if ( printInfo_ )
{
......@@ -330,11 +335,10 @@ class GinkgoCGSolver : public Solver< OperatorType >
if ( num_local_dofs )
{
using mtx = gko::distributed::Matrix< valueType, int32_t >;
using vec = gko::distributed::Vector< valueType >;
auto x_vec = vec::create( host_exec_, comm, gko::dim< 2 >{ num_global_dofs, 1 }, gko::dim< 2 >{ num_local_dofs, 1 } );
auto b_vec = vec::create( host_exec_, comm, gko::dim< 2 >{ num_global_dofs, 1 }, gko::dim< 2 >{ num_local_dofs, 1 } );
auto x_vec =
vec::create( host_exec_, comm_, part, gko::dim< 2 >{ num_global_dofs, 1 }, gko::dim< 2 >{ num_local_dofs, 1 } );
auto b_vec =
vec::create( host_exec_, comm_, part, gko::dim< 2 >{ num_global_dofs, 1 }, gko::dim< 2 >{ num_local_dofs, 1 } );
x.getStorage()->getTimingTree()->start( "Ginkgo CG Solver" );
......@@ -357,13 +361,13 @@ class GinkgoCGSolver : public Solver< OperatorType >
All );
// Todo: add check if assembly is neccessary
const bool doAssemble = !matrix_ || reassembleMatrix_;
const bool doAssemble = !host_matrix_ || reassembleMatrix_;
if ( doAssemble )
{
host_matrix_ = gko::share( mtx::create( host_exec_, comm_ ) );
x.getStorage()->getTimingTree()->start( "Ginkgo System Matrix Assembly" );
matrix_ = gko::share( mtx::create( host_exec_, comm ) );
auto matrix_proxy = std::make_shared< GinkgoSparseMatrixProxy< mtx > >(
matrix_.get(), gko::dim< 2 >{ num_global_dofs, num_global_dofs }, part );
host_matrix_.get(), gko::dim< 2 >{ num_global_dofs, num_global_dofs }, part );
hyteg::petsc::createMatrix< OperatorType >( A, num_, num_, matrix_proxy, level, All );
matrix_proxy->finalize();
x.getStorage()->getTimingTree()->stop( "Ginkgo System Matrix Assembly" );
......@@ -372,11 +376,11 @@ class GinkgoCGSolver : public Solver< OperatorType >
std::unique_ptr< DirichletHandlerBase > dir_handler;
if ( constraints_type_ == constraints::penalty )
{
dir_handler = std::make_unique< PeneltyDirichletHandler >( bcIndices, b_vec.get(), matrix_, doAssemble );
dir_handler = std::make_unique< PeneltyDirichletHandler >( bcIndices, b_vec.get(), host_matrix_, doAssemble );
}
else if ( constraints_type_ == constraints::zero_row )
{
dir_handler = std::make_unique< ZeroRowsDirichletHandler >( bcIndices, b_vec.get(), matrix_, doAssemble );
dir_handler = std::make_unique< ZeroRowsDirichletHandler >( bcIndices, b_vec.get(), host_matrix_, doAssemble );
}
else
{
......@@ -390,9 +394,8 @@ class GinkgoCGSolver : public Solver< OperatorType >
if ( !solver_ || doAssemble )
{
x.getStorage()->getTimingTree()->start( "Ginkgo CG Solver Set-Up" );
auto host_matrix = matrix_;
matrix_ = mtx::create( solver_exec_ );
host_matrix->convert_to( gko::lend( matrix_ ) );
matrix_ = csr::create( solver_exec_ );
host_matrix_->convert_to( gko::lend( matrix_ ) );
//auto par_ilu = gko::factorization::Ilu< valueType, int32_t >::build().on( solver_exec_ )->generate( matrix_ );
//auto ilu = gko::preconditioner::Ilu<>::build().on( solver_exec_ )->generate( gko::share( par_ilu ) );
......@@ -400,10 +403,18 @@ class GinkgoCGSolver : public Solver< OperatorType >
//solver_->set_preconditioner( gko::share( ilu ) );
x.getStorage()->getTimingTree()->stop( "Ginkgo CG Solver Set-Up" );
}
auto global_rhs = dense::create( solver_exec_ );
auto global_x0 = dense::create( solver_exec_ );
rhs->convert_to( gko::lend( global_rhs ) );
x0->convert_to( gko::lend( global_x0 ) );
x.getStorage()->getTimingTree()->start( "Ginkgo CG Solver Apply" );
solver_->apply( gko::lend( rhs ), gko::lend( x0 ) );
solver_->apply( gko::lend( global_rhs ), gko::lend( global_x0 ) );
x.getStorage()->getTimingTree()->stop( "Ginkgo CG Solver Apply" );
gko::mpi::broadcast( global_x0->get_values(), global_x0->get_num_stored_elements(), 0, comm_ );
x0->read_distributed( gko::lend( global_x0 ), part );
dir_handler->update_solution( x0 );
hyteg::petsc::createFunctionFromVector(
......@@ -432,6 +443,8 @@ class GinkgoCGSolver : public Solver< OperatorType >
std::shared_ptr< PrimitiveStorage > storage_;
uint_t level_;
std::shared_ptr< gko::mpi::communicator > comm_;
constraints constraints_type_;
typename OperatorType::srcType::template FunctionType< int > num_;
......@@ -441,7 +454,9 @@ class GinkgoCGSolver : public Solver< OperatorType >
std::unique_ptr< typename gko::solver::Cg< valueType >::Factory > solver_factory_;
std::unique_ptr< typename gko::solver::Cg< valueType > > solver_;
std::shared_ptr< gko::distributed::Matrix< valueType, int32_t > > matrix_;
std::shared_ptr< mtx > host_matrix_;
std::shared_ptr< csr > matrix_;
bool printInfo_ = false;
bool reassembleMatrix_ = false;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment