From 6bb97cb579d9141e5af3b728982a5b44738feefe Mon Sep 17 00:00:00 2001
From: Sebastian Eibl <sebastian.eibl@fau.de>
Date: Mon, 27 Jan 2020 10:20:52 +0100
Subject: [PATCH] moved DEM into functor

---
 .../GranularGas/MESA_PD_GranularGas.cpp       | 105 +++++++++++++-----
 1 file changed, 76 insertions(+), 29 deletions(-)

diff --git a/apps/benchmarks/GranularGas/MESA_PD_GranularGas.cpp b/apps/benchmarks/GranularGas/MESA_PD_GranularGas.cpp
index 7ff78bd1a..cb19bf5cd 100644
--- a/apps/benchmarks/GranularGas/MESA_PD_GranularGas.cpp
+++ b/apps/benchmarks/GranularGas/MESA_PD_GranularGas.cpp
@@ -71,6 +71,73 @@
 namespace walberla {
 namespace mesa_pd {
 
+class DEM
+{
+public:
+   DEM(const std::shared_ptr<domain::BlockForestDomain>& domain)
+   : domain_(domain)
+   {
+      dem_.setStiffness(0, 0, real_t(0));
+      dem_.setDampingN(0, 0, real_t(0));
+      dem_.setDampingT(0, 0, real_t(0));
+      dem_.setFriction(0, 0, real_t(0));
+   }
+
+   inline
+   void operator()(const size_t idx1, const size_t idx2, ParticleAccessorWithShape& ac)
+   {
+
+      ++contactsChecked_;
+      if (double_cast_(idx1, idx2, ac, acd_, ac))
+      {
+         ++contactsDetected_;
+         if (contact_filter_(acd_.getIdx1(), acd_.getIdx2(), ac, acd_.getContactPoint(),
+                             *domain_))
+         {
+            ++contactsTreated_;
+            dem_(acd_.getIdx1(), acd_.getIdx2(), ac, acd_.getContactPoint(),
+                 acd_.getContactNormal(), acd_.getPenetrationDepth());
+         }
+      }
+   }
+
+   inline
+   void resetCounters()
+   {
+      contactsChecked_ = 0;
+      contactsDetected_ = 0;
+      contactsTreated_ = 0;
+   }
+
+   inline
+   int64_t getContactsChecked() const
+   {
+      return contactsChecked_;
+   }
+
+   inline
+   int64_t getContactsDetected() const
+   {
+      return contactsDetected_;
+   }
+
+   inline
+   int64_t getContactsTreated() const
+   {
+      return contactsTreated_;
+   }
+
+private:
+   kernel::DoubleCast double_cast_;
+   mpi::ContactFilter contact_filter_;
+   std::shared_ptr<domain::BlockForestDomain> domain_;
+   kernel::SpringDashpot dem_ = kernel::SpringDashpot(1);
+   collision_detection::AnalyticContactDetection acd_;
+   int64_t contactsChecked_ = 0;
+   int64_t contactsDetected_ = 0;
+   int64_t contactsTreated_ = 0;
+};
+
 int main( int argc, char ** argv )
 {
    using namespace walberla::timing;
@@ -170,16 +237,9 @@ int main( int argc, char ** argv )
    WALBERLA_LOG_INFO_ON_ROOT("*** SIMULATION - START ***");
    // Init kernels
    kernel::ExplicitEulerWithShape        explicitEulerWithShape( params.dt );
+   DEM dem(domain);
    kernel::InsertParticleIntoLinkedCells ipilc;
-   kernel::SpringDashpot                 dem(1);
-   dem.setStiffness(0, 0, real_t(0));
-   dem.setDampingN (0, 0, real_t(0));
-   dem.setDampingT (0, 0, real_t(0));
-   dem.setFriction (0, 0, real_t(0));
-   collision_detection::AnalyticContactDetection              acd;
    kernel::AssocToBlock                  assoc(forest);
-   kernel::DoubleCast                    double_cast;
-   mpi::ContactFilter                    contact_filter;
    mpi::ReduceProperty                   RP;
    mpi::SyncNextNeighborsBlockForest     SNN;
 
@@ -203,10 +263,7 @@ int main( int argc, char ** argv )
       auto    RPBytesReceived  = RP.getBytesReceived();
       auto    RPSends          = RP.getNumberOfSends();
       auto    RPReceives       = RP.getNumberOfReceives();
-      int64_t contactsChecked  = 0;
-      int64_t contactsDetected = 0;
-      int64_t contactsTreated  = 0;
-      if (params.bBarrier) WALBERLA_MPI_BARRIER();
+      WALBERLA_MPI_BARRIER();
       timer.start();
       for (int64_t i=0; i < params.simulationSteps; ++i)
       {
@@ -226,27 +283,13 @@ int main( int argc, char ** argv )
          if (params.bBarrier) WALBERLA_MPI_BARRIER();
          tp["GenerateLinkedCells"].end();
 
+         dem.resetCounters();
          tp["DEM"].start();
-         contactsChecked  = 0;
-         contactsDetected = 0;
-         contactsTreated  = 0;
          lc.forEachParticlePairHalf(true,
                                     kernel::SelectAll(),
                                     accessor,
-                                    [&](const size_t idx1, const size_t idx2, auto& ac)
-         {
-            ++contactsChecked;
-            if (double_cast(idx1, idx2, ac, acd, ac ))
-            {
-               ++contactsDetected;
-               if (contact_filter(acd.getIdx1(), acd.getIdx2(), ac, acd.getContactPoint(), *domain))
-               {
-                  ++contactsTreated;
-                  dem(acd.getIdx1(), acd.getIdx2(), ac, acd.getContactPoint(), acd.getContactNormal(), acd.getPenetrationDepth());
-               }
-            }
-         },
-         accessor );
+                                    dem,
+                                    accessor);
          if (params.bBarrier) WALBERLA_MPI_BARRIER();
          tp["DEM"].end();
 
@@ -268,6 +311,10 @@ int main( int argc, char ** argv )
       }
       timer.end();
 
+      int64_t contactsChecked = dem.getContactsChecked();
+      int64_t contactsDetected = dem.getContactsDetected();
+      int64_t contactsTreated = dem.getContactsTreated();
+      
       SNNBytesSent     = SNN.getBytesSent();
       SNNBytesReceived = SNN.getBytesReceived();
       SNNSends         = SNN.getNumberOfSends();
-- 
GitLab