Commit d8c5a273 authored by Marcel Koch's avatar Marcel Koch
Browse files

use scatter instead of broadcast

parent 95bbef67
...@@ -269,6 +269,70 @@ std::pair< int, int > local_range( const uint_t local_size, std::shared_ptr< gko ...@@ -269,6 +269,70 @@ std::pair< int, int > local_range( const uint_t local_size, std::shared_ptr< gko
return std::make_pair( start, end ); return std::make_pair( start, end );
} }
template < typename LocalIndexType >
std::vector< gko::Array< gko::distributed::global_index_type > >
compute_gather_idxs( std::shared_ptr< gko::distributed::Partition< LocalIndexType > > partition )
{
auto ranges = partition->get_range_bounds();
auto range_to_pid = partition->get_const_part_ids();
std::vector< std::vector< gko::distributed::global_index_type > > idxs( partition->get_num_parts() );
for ( int i = 0; i < partition->get_num_ranges(); ++i )
{
auto pid = range_to_pid[i];
auto range_start = ranges[i];
auto range_end = ranges[i + 1];
for ( int j = range_start; j < range_end; ++j )
{
idxs[pid].push_back( j );
}
}
std::vector< gko::Array< gko::distributed::global_index_type > > arr_idxs( idxs.size() );
for ( size_t i = 0; i < idxs.size(); ++i )
{
arr_idxs[i] = gko::Array< gko::distributed::global_index_type >{ gko::ReferenceExecutor::create(), idxs[i].size() };
std::copy_n( idxs[i].data(), idxs[i].size(), arr_idxs[i].get_data() );
}
return arr_idxs;
}
template < typename ValueType, typename LocalIndexType >
void scatter_global_vector( const gko::matrix::Dense< ValueType >* input,
gko::distributed::Vector< ValueType, LocalIndexType >* output,
const std::vector< gko::Array< gko::distributed::global_index_type > >& gather_idxs,
std::shared_ptr< gko::mpi::communicator > comm )
{
std::vector< int > counts( gather_idxs.size() );
std::vector< int > disp( counts.size() + 1 );
std::transform(
gather_idxs.cbegin(), gather_idxs.cend(), counts.begin(), []( const auto& idxs ) { return idxs.get_num_elems(); } );
std::partial_sum( counts.cbegin(), counts.cend(), disp.begin() + 1 );
auto exec = input->get_executor();
gko::Array< ValueType > sorted_buffer{ exec, input->get_num_stored_elements() };
if ( comm->rank() == 0 )
{
for ( size_t i = 0; i < gather_idxs.size(); ++i )
{
const auto& idxs = gather_idxs[i];
auto gathered_vec = input->row_gather( &idxs );
exec->copy( counts[i], gathered_vec->get_const_values(), sorted_buffer.get_data() + disp[i] );
}
}
gko::mpi::scatter( sorted_buffer.get_data(),
counts.data(),
disp.data(),
output->get_local()->get_values(),
output->get_local()->get_num_stored_elements(),
0,
comm );
}
template < class OperatorType > template < class OperatorType >
class GinkgoCGSolver : public Solver< OperatorType > class GinkgoCGSolver : public Solver< OperatorType >
{ {
...@@ -299,7 +363,8 @@ class GinkgoCGSolver : public Solver< OperatorType > ...@@ -299,7 +363,8 @@ class GinkgoCGSolver : public Solver< OperatorType >
{ {
auto rel_mode = constraints_type == constraints::penalty ? gko::stop::mode::initial_resnorm : gko::stop::mode::rhs_norm; auto rel_mode = constraints_type == constraints::penalty ? gko::stop::mode::initial_resnorm : gko::stop::mode::rhs_norm;
auto log_cout = gko::share( gko::log::Stream<valueType >::create( host_exec_, gko::log::Logger::criterion_check_completed_mask, std::cout, true ) ); auto log_cout = gko::share( gko::log::Stream< valueType >::create(
host_exec_, gko::log::Logger::criterion_check_completed_mask, std::cout, true ) );
auto log = gko::share( gko::log::Convergence< valueType >::create( solver_exec_ ) ); auto log = gko::share( gko::log::Convergence< valueType >::create( solver_exec_ ) );
auto criteria = gko::stop::Combined::build() auto criteria = gko::stop::Combined::build()
.with_criteria( gko::stop::ResidualNorm< valueType >::build() .with_criteria( gko::stop::ResidualNorm< valueType >::build()
...@@ -404,13 +469,16 @@ class GinkgoCGSolver : public Solver< OperatorType > ...@@ -404,13 +469,16 @@ class GinkgoCGSolver : public Solver< OperatorType >
solver_ = solver_factory_->generate( matrix_ ); solver_ = solver_factory_->generate( matrix_ );
if ( matrix_->get_size()[0] > 0 && matrix_->get_size()[1] > 0 ) if ( matrix_->get_size()[0] > 0 && matrix_->get_size()[1] > 0 )
{ {
auto ilu = gko::preconditioner::Ilu<>::build() // auto ilu = gko::preconditioner::Ilu<>::build()
.with_factorization_factory( // .with_factorization_factory(
gko::share( gko::factorization::Ilu< valueType, int32_t >::build().on( solver_exec_ ) ) ) // gko::share( gko::factorization::Ilu< valueType, int32_t >::build().on( solver_exec_ ) ) )
// .on( solver_exec_ )
// ->generate( matrix_ );
auto jac = gko::preconditioner::Jacobi< valueType, int32_t >::build()
.with_max_block_size( 1 )
.on( solver_exec_ ) .on( solver_exec_ )
->generate( matrix_ ); ->generate( matrix_ );
auto jac = gko::preconditioner::Jacobi< valueType, int32_t >::build().on( solver_exec_ )->generate( matrix_ ); solver_->set_preconditioner( gko::share( jac ) );
solver_->set_preconditioner( gko::share( ilu ) );
} }
x.getStorage()->getTimingTree()->stop( "Ginkgo CG Solver Set-Up" ); x.getStorage()->getTimingTree()->stop( "Ginkgo CG Solver Set-Up" );
} }
...@@ -424,14 +492,10 @@ class GinkgoCGSolver : public Solver< OperatorType > ...@@ -424,14 +492,10 @@ class GinkgoCGSolver : public Solver< OperatorType >
{ {
solver_->apply( gko::lend( global_rhs ), gko::lend( global_x0 ) ); solver_->apply( gko::lend( global_rhs ), gko::lend( global_x0 ) );
} }
else
{
global_x0 = dense::create( host_exec_, x0->get_size() );
}
x.getStorage()->getTimingTree()->stop( "Ginkgo CG Solver Apply" ); x.getStorage()->getTimingTree()->stop( "Ginkgo CG Solver Apply" );
gko::mpi::broadcast( global_x0->get_values(), global_x0->get_num_stored_elements(), 0, comm_ ); auto gather_idxs = compute_gather_idxs( part );
x0->read_distributed( gko::lend( global_x0 ), part ); scatter_global_vector( global_x0.get(), x0, gather_idxs, comm_ );
dir_handler->update_solution( x0 ); dir_handler->update_solution( x0 );
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment