FunctionWrapper.hpp 9.71 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
/*
 * Copyright (c) 2021 Marcus Mohr.
 *
 * This file is part of HyTeG
 * (see https://i10git.cs.fau.de/hyteg/hyteg).
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program. If not, see <http://www.gnu.org/licenses/>.
 */

#pragma once

#include "core/DataTypes.h"

#include "hyteg/functions/FunctionTraits.hpp"
#include "hyteg/functions/GenericFunction.hpp"
27
28
29
30
31
32
#include "hyteg/sparseassembly/VectorProxy.hpp"

// A whole lot of includes, so that createVectorFromFunction below has
// a valid prototype for all possible cases
#include "hyteg/dgfunctionspace/DGPetsc.hpp"
#include "hyteg/edgedofspace/EdgeDoFPetsc.hpp"
33
#include "hyteg/facedofspace/FaceDoFPetsc.hpp"
34
35
36
37
38
#include "hyteg/p1functionspace/P1Petsc.hpp"
#include "hyteg/p2functionspace/P2Petsc.hpp"

// only needed for using PetscInt in to/fromVector() below!
#include "hyteg/petsc/PETScWrapper.hpp"
39
40
41
42
43
44
45
46
47
48

namespace hyteg {

using walberla::uint_t;

/// Base class for all function classes in HyTeG
template < typename func_t >
class FunctionWrapper final : public GenericFunction< typename FunctionTrait< func_t >::ValueType >
{
 public:
49
   typedef typename FunctionTrait< func_t >::ValueType value_t;
50
   typedef typename FunctionTrait< func_t >::Tag       Tag;
51
52
53
54
55

   // is this really helpful, as it is not templated?
   // typedef func_t FunctionType;
   // how about this instead:
   using WrappedFuncType = func_t;
56
   template < typename VType >
57
   using WrappedFuncKind = typename WrappedFuncType::template FunctionType< VType >;
58
59
60
61
62
63
64
65
66
67
68
69
70

   /// No need for this one, if we do not implement a setter method for wrappedFunc_;
   FunctionWrapper() = delete;

   /// Constructor that constructs the function which the class wraps itself around
   FunctionWrapper( const std::string&                         name,
                    const std::shared_ptr< PrimitiveStorage >& storage,
                    size_t                                     minLevel,
                    size_t                                     maxLevel )
   {
      wrappedFunc_ = std::make_unique< func_t >( name, storage, minLevel, maxLevel );
   };

71
72
   ~FunctionWrapper()
   {
73
      // WALBERLA_LOG_INFO_ON_ROOT( "Destructing '" << this->getFunctionName() << "'" );
74
75
   }

76
77
78
79
80
81
82
83
84
85
86
   /// provide access to wrapped function
   /// @{
   func_t& unwrap() { return *wrappedFunc_; }

   const func_t& unwrap() const { return *wrappedFunc_; }
   /// @}

   uint_t getDimension() const { return wrappedFunc_->getDimension(); };

   const std::string& getFunctionName() const { return wrappedFunc_->getFunctionName(); };

87
   functionTraits::FunctionKind getFunctionKind() const { return FunctionTrait< func_t >::kind; };
88

89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
   std::shared_ptr< PrimitiveStorage > getStorage() const { return wrappedFunc_->getStorage(); }

   void multElementwise( const std::vector< std::reference_wrapper< const GenericFunction< value_t > > >& functions,
                         uint_t                                                                           level,
                         DoFType                                                                          flag = All ) const
   {
      std::vector< std::reference_wrapper< const func_t > > realFuncs;
      for ( const GenericFunction< value_t >& func : functions )
      {
         realFuncs.push_back( func.template unwrap< func_t >() );
      }
      wrappedFunc_->multElementwise( realFuncs, level, flag );
   };

   void interpolate( value_t constant, uint_t level, DoFType flag = All ) const
   {
      wrappedFunc_->interpolate( constant, level, flag );
   };

108
109
110
111
112
113
114
115
116
117
118
119
   void interpolate( const std::function< value_t( const hyteg::Point3D& ) >& expr, uint_t level, DoFType flag = All ) const
   {
      wrappedFunc_->interpolate( expr, level, flag );
   };

   void interpolate( const std::vector< std::function< value_t( const hyteg::Point3D& ) > >& expressions,
                     uint_t                                                                  level,
                     DoFType                                                                 flag = All ) const
   {
      wrappedFunc_->interpolate( expressions, level, flag );
   };

120
121
   value_t dotGlobal( const GenericFunction< value_t >& secondOp, const uint_t level, const DoFType flag = All ) const
   {
Marcus Mohr's avatar
Marcus Mohr committed
122
      return wrappedFunc_->dotGlobal( secondOp.template unwrap< func_t >(), level, flag );
123
124
125
126
   };

   value_t dotLocal( const GenericFunction< value_t >& secondOp, uint_t level, DoFType flag = All ) const
   {
127
      return wrappedFunc_->dotLocal( secondOp.template unwrap< func_t >(), level, flag );
128
129
130
131
   };

   void enableTiming( const std::shared_ptr< walberla::WcTimingTree >& timingTree ) { wrappedFunc_->enableTiming( timingTree ); };

132
133
   void setBoundaryCondition( BoundaryCondition bc ) { wrappedFunc_->setBoundaryCondition( bc ); };

134
135
   BoundaryCondition getBoundaryCondition() const { return wrappedFunc_->getBoundaryCondition(); };

136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
   void add( const value_t scalar, uint_t level, DoFType flag = All ) const { wrappedFunc_->add( scalar, level, flag ); };

   void add( const std::vector< value_t >                                                     scalars,
             const std::vector< std::reference_wrapper< const GenericFunction< value_t > > >& functions,
             uint_t                                                                           level,
             DoFType                                                                          flag = All ) const
   {
      std::vector< std::reference_wrapper< const func_t > > realFuncs;
      for ( const GenericFunction< value_t >& func : functions )
      {
         realFuncs.push_back( func.template unwrap< func_t >() );
      }
      wrappedFunc_->add( scalars, realFuncs, level, flag );
   };

   void assign( const std::vector< value_t >                                                     scalars,
                const std::vector< std::reference_wrapper< const GenericFunction< value_t > > >& functions,
                uint_t                                                                           level,
                DoFType                                                                          flag = All ) const
   {
      std::vector< std::reference_wrapper< const func_t > > realFuncs;
      for ( const GenericFunction< value_t >& func : functions )
      {
         realFuncs.push_back( func.template unwrap< func_t >() );
      }
      wrappedFunc_->assign( scalars, realFuncs, level, flag );
   };

   void swap( const GenericFunction< value_t >& other, const uint_t& level, const DoFType& flag = All ) const
   {
      wrappedFunc_->swap( other.template unwrap< func_t >(), level, flag );
   };

   void copyFrom( const GenericFunction< value_t >&              other,
                  const uint_t&                                  level,
                  const std::map< PrimitiveID::IDType, uint_t >& localPrimitiveIDsToRank,
                  const std::map< PrimitiveID::IDType, uint_t >& otherPrimitiveIDsToRank ) const
   {
      wrappedFunc_->copyFrom( other.template unwrap< func_t >(), level, localPrimitiveIDsToRank, otherPrimitiveIDsToRank );
   };

177
178
   void enumerate( uint_t level ) const { wrappedFunc_->enumerate( level ); };

179
180
   void enumerate( uint_t level, value_t& offset ) const { wrappedFunc_->enumerate( level, offset ); };

181
182
   uint_t getNumberOfLocalDoFs( uint_t level ) const
   {
183
184
      auto storage = wrappedFunc_->getStorage();
      return numberOfLocalDoFs< typename FunctionTrait< WrappedFuncType >::Tag >( *storage, level );
185
186
   }

187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
#ifdef HYTEG_BUILD_WITH_PETSC
   /// conversion to/from linear algebra representation
   /// @{
   void toVector( const GenericFunction< PetscInt >&    numerator,
                  const std::shared_ptr< VectorProxy >& vec,
                  uint_t                                level,
                  DoFType                               flag ) const
   {
      if constexpr ( std::is_same< value_t, PetscReal >::value )
      {
         using numer_t = typename func_t::template FunctionType< PetscInt >;
         petsc::createVectorFromFunction( *wrappedFunc_, numerator.template unwrap< numer_t >(), vec, level, flag );
      }
      else
      {
         WALBERLA_ABORT( "FunctionWrapper::toVector() only works for ValueType being identical to PetscReal" );
      }
   };

   void fromVector( const GenericFunction< PetscInt >&    numerator,
                    const std::shared_ptr< VectorProxy >& vec,
                    uint_t                                level,
                    DoFType                               flag ) const
   {
      if constexpr ( std::is_same< value_t, PetscReal >::value )
      {
         using numer_t = typename func_t::template FunctionType< PetscInt >;
         petsc::createFunctionFromVector( *wrappedFunc_, numerator.template unwrap< numer_t >(), vec, level, flag );
      }
      else
      {
         WALBERLA_ABORT( "FunctionWrapper::fromVector() only works for ValueType being identical to PetscReal" );
      }
   };
      /// @}
#endif

224
225
 private:
   std::unique_ptr< func_t > wrappedFunc_;
226
227
};

228
template < template < typename > class WrapperFunc, typename func_t >
229
230
231
232
233
const func_t& unwrap( const WrapperFunc< func_t >& wrapped )
{
   return wrapped.unwrap();
};

234
template < template < typename > class WrapperFunc, typename func_t >
235
236
237
238
239
func_t& unwrap( WrapperFunc< func_t >& wrapped )
{
   return wrapped.unwrap();
};

240
} // namespace hyteg