From 15f203074942151983aa758e90bc60bb097b0b86 Mon Sep 17 00:00:00 2001
From: Michael Zikeli <michael.zikeli@fau.de>
Date: Tue, 19 Dec 2023 12:10:17 +0100
Subject: [PATCH] Extend missing implementations for float16 support

---
 CMakeLists.txt                    |  50 +++++----
 cmake/TestFloat16.cpp             |   7 ++
 src/core/CMakeLists.txt           |   9 ++
 src/core/DataTypes.cpp            |   5 +
 src/core/DataTypes.h              |  42 ++++++--
 src/core/mpi/MPIWrapper.h         |   3 +
 tests/core/CMakeLists.txt         |  14 +++
 tests/core/Float16SupportTest.cpp | 163 ++++++++++++++++++++++++++++++
 8 files changed, 262 insertions(+), 31 deletions(-)
 create mode 100644 cmake/TestFloat16.cpp
 create mode 100644 tests/core/Float16SupportTest.cpp

diff --git a/CMakeLists.txt b/CMakeLists.txt
index ad983d059..e3b808b8d 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -1255,28 +1255,34 @@ endif()
 ##  Half precision
 ##
 ############################################################################################################################
-if (WALBERLA_BUILD_WITH_HALF_PRECISION_SUPPORT)
-    if (WALBERLA_CXX_COMPILER_IS_GNU OR WALBERLA_CXX_COMPILER_IS_CLANG)
-        message(STATUS "Configuring with *experimental* half precision (float16) support. You better know what you are doing.")
-        if (WALBERLA_CXX_COMPILER_IS_GNU AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 12.0.0)
-            message(WARNING "[WALBERLA_BUILD_WITH_HALF_PRECISION_SUPPORT] "
-                    "Half precision support for gcc has only been tested with version >= 12. "
-                    "You are using a previous version - it may not work correctly.")
-        endif ()
-        if (WALBERLA_CXX_COMPILER_IS_CLANG AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 15.0.0)
-            message(WARNING "[WALBERLA_BUILD_WITH_HALF_PRECISION_SUPPORT] "
-                    "Half precision support for clang has only been tested with version >= 15. "
-                    "You are using a previous version - it may not work correctly.")
-        endif ()
-        if (NOT WALBERLA_OPTIMIZE_FOR_LOCALHOST)
-            message(WARNING "[WALBERLA_BUILD_WITH_HALF_PRECISION_SUPPORT] "
-                    "You are not optimizing for localhost. You may encounter linker errors, or WORSE: silent incorrect fp16 arithmetic! Consider also enabling WALBERLA_OPTIMIZE_FOR_LOCALHOST!")
-        endif ()
-    else ()
-        message(FATAL_ERROR "[WALBERLA_BUILD_WITH_HALF_PRECISION_SUPPORT] "
-                "Half precision support is currently only available for gcc and clang.")
-    endif ()
-endif ()
+if ( WALBERLA_BUILD_WITH_HALF_PRECISION_SUPPORT )
+   ### Compiler requirements:
+   ### Within this project, there are several checks to ensure that the template parameter 'ValueType'
+   ### is a floating point number. The check is_floating_point<ValueType> is done primarily in our MPI implementation.
+   ### The IEE 754 floating type format _Float16, evaluates to true only if your compiler supports the
+   ### open C++23 standard P1467R9 (Extended floating-point types and standard names).
+   ### Compare:
+   ###  https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2022/p1467r9.html
+   ###
+   ### Right now (18.12.2023) this is the case only for gcc13.
+   ### For more information see:
+   ###   https://gcc.gnu.org/projects/cxx-status.html#:~:text=Extended%20floating%2Dpoint%20types%20and%20standard%20names
+   ###   https://clang.llvm.org/cxx_status.html#:~:text=Extended%20floating%2Dpoint%20types%20and%20standard%20names
+
+   try_compile( WALBERLA_SUPPORT_HALF_PRECISION "${CMAKE_CURRENT_BINARY_DIR}" "${CMAKE_CURRENT_SOURCE_DIR}/cmake/TestFloat16.cpp"
+         CXX_STANDARD 23 OUTPUT_VARIABLE TRY_COMPILE_OUTPUT )
+   ## message( STATUS ${TRY_COMPILE_OUTPUT} )
+   if ( NOT WALBERLA_SUPPORT_HALF_PRECISION )
+      message( FATAL_ERROR "Compiler: ${CMAKE_CXX_COMPILER} Version: ${CMAKE_CXX_COMPILER_VERSION} does not support half precision" )
+   endif ()
+
+   # Local host optimization
+   if ( NOT WALBERLA_OPTIMIZE_FOR_LOCALHOST )
+      message( WARNING "[WALBERLA_BUILD_WITH_HALF_PRECISION_SUPPORT] "
+            "You are not optimizing for localhost. You may encounter linker errors, or WORSE: silent incorrect fp16 arithmetic! Consider also enabling WALBERLA_OPTIMIZE_FOR_LOCALHOST!" )
+   endif () # Local host check
+
+endif () # Check if WALBERLA_BUILD_WITH_HALF_PRECISION_SUPPORT is set
 
 ############################################################################################################################
 # Documentation Generation
diff --git a/cmake/TestFloat16.cpp b/cmake/TestFloat16.cpp
new file mode 100644
index 000000000..bae373fbb
--- /dev/null
+++ b/cmake/TestFloat16.cpp
@@ -0,0 +1,7 @@
+#include <iostream>
+
+
+int main()
+{
+   static_assert(std::is_floating_point_v<_Float16>);
+}
\ No newline at end of file
diff --git a/src/core/CMakeLists.txt b/src/core/CMakeLists.txt
index 099e0b573..71bb209d7 100644
--- a/src/core/CMakeLists.txt
+++ b/src/core/CMakeLists.txt
@@ -29,6 +29,15 @@ add_library( core )
 if( MPI_FOUND )
    target_link_libraries( core PUBLIC MPI::MPI_CXX )
 endif()
+
+if ( WALBERLA_BUILD_WITH_HALF_PRECISION_SUPPORT )
+    # Actual support for float16 is available only since C++23
+    #   before is_arithmetic and is_floating_point evaluated to false,
+    #   also many STL functions are compatible with float16 only since C++23.
+    # Which features are actually supported depend on the compiler
+    target_compile_features(core PUBLIC cxx_std_23)
+endif ()
+
 target_link_libraries( core PUBLIC ${SERVICE_LIBS} )
 target_sources( core
       PRIVATE
diff --git a/src/core/DataTypes.cpp b/src/core/DataTypes.cpp
index ead9f6fb6..0b9dcad1f 100644
--- a/src/core/DataTypes.cpp
+++ b/src/core/DataTypes.cpp
@@ -26,6 +26,11 @@ namespace walberla {
 
 namespace real_comparison
 {
+   #ifdef WALBERLA_BUILD_WITH_HALF_PRECISION_SUPPORT
+//   const    bfloat16 Epsilon<    bfloat16 >::value = static_cast<    bfloat16 >(1e-2); // machine eps is 2^-7
+   const     float16 Epsilon<     float16 >::value = static_cast<     float16 >(1e-3); // machine eps is 2^-10
+   // Note, depending on the kind of float16 <bfloat, float16> another Epsilon must be used.
+   #endif
    const       float Epsilon<       float >::value = static_cast<       float >(1e-4);
    const      double Epsilon<      double >::value = static_cast<      double >(1e-8);
    const long double Epsilon< long double >::value = static_cast< long double >(1e-10);
diff --git a/src/core/DataTypes.h b/src/core/DataTypes.h
index bae5b7651..4e7c019a8 100644
--- a/src/core/DataTypes.h
+++ b/src/core/DataTypes.h
@@ -175,21 +175,33 @@ using real_t = float;
 /// Only bandwidth bound code may therefore benefit. None of this is guaranteed, and may change in the future.
 ///
 #ifdef WALBERLA_BUILD_WITH_HALF_PRECISION_SUPPORT
-#   if defined(WALBERLA_CXX_COMPILER_IS_CLANG) || defined(WALBERLA_CXX_COMPILER_IS_GNU)
-/// Clang version must be 15 or higher for x86 half precision support.
-/// GCC version must be 12 or higher for x86 half precision support.
-/// Also support seems to require SSE, so ensure that respective instruction sets are enabled.
+/// FIXME: (not really right) Clang version must be 15 or higher for x86 half precision support.
+/// FIXME: (not really right) GCC version must be 12 or higher for x86 half precision support.
+/// FIXME: (I don't know) Also support seems to require SSE, so ensure that respective instruction sets are enabled.
 /// See
 ///   https://clang.llvm.org/docs/LanguageExtensions.html#half-precision-floating-point
 ///   https://gcc.gnu.org/onlinedocs/gcc/Half-Precision.html
 /// for more information.
+/// ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
+/// Compiler requirements:
+/// Within this project, there are several checks to ensure that the template parameter 'ValueType'
+/// is a floating point number. The check is_floating_point<ValueType> is done primarily in our MPI implementation.
+/// The IEE 754 floating type format _Float16, evaluates to true only if your compiler supports the
+/// open C++23 standard P1467R9 (Extended floating-point types and standard names).
+/// Compare:
+///  https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2022/p1467r9.html
+///
+/// Right now (18.12.2023) this is the case only for gcc13.
+/// For more information see:
+///   https://gcc.gnu.org/projects/cxx-status.html#:~:text=Extended%20floating%2Dpoint%20types%20and%20standard%20names
+///   https://clang.llvm.org/cxx_status.html#:~:text=Extended%20floating%2Dpoint%20types%20and%20standard%20names
+
 using half    = _Float16;
+// Note: there are two possible float16 formats.
+// The one used right now is the IEE 754 float16 standard, consisting of a 5 bit exponent and a 10 bit mantissa.
+// Another possible half precision format would be the one from Google Brain (bfloat16) with an 8 bit exponent and a 7 bit mantissa.
+// Compare https://i10git.cs.fau.de/ab04unyc/walberla/-/issues/23
 using float16 = half;
-#   else
-static_assert(false, "\n\n### Attempting to built walberla with half precision support.\n"
-                     "### However, the compiler you chose is not suited for that, or we simply have not implemented "
-                     "support for half precision and your compiler.\n");
-#   endif
 #endif
 using float32 = float;
 using float64 = double;
@@ -228,6 +240,10 @@ inline bool realIsIdentical( const real_t a, const real_t b )
 namespace real_comparison
 {
    template< class T > struct Epsilon;
+   #ifdef WALBERLA_BUILD_WITH_HALF_PRECISION_SUPPORT
+   using walberla::float16;
+   template<> struct Epsilon<     float16 > { static const     float16 value; };
+   #endif
    template<> struct Epsilon<       float > { static const       float value; };
    template<> struct Epsilon<      double > { static const      double value; };
    template<> struct Epsilon< long double > { static const long double value; };
@@ -254,6 +270,14 @@ inline bool floatIsEqual( float lhs, float rhs, const float epsilon = real_compa
    return std::fabs( lhs - rhs ) < epsilon;
 }
 
+#ifdef WALBERLA_BUILD_WITH_HALF_PRECISION_SUPPORT
+inline bool floatIsEqual( walberla::float16 lhs, walberla::float16 rhs, const walberla::float16 epsilon = real_comparison::Epsilon<walberla::float16>::value )
+{
+   const auto difference = lhs - rhs;
+   return ( (difference < 0) ? -difference : difference ) < epsilon;
+}
+#endif
+
 } // namespace walberla
 
 #define WALBERLA_UNUSED(x)  (void)(x)
diff --git a/src/core/mpi/MPIWrapper.h b/src/core/mpi/MPIWrapper.h
index 51ab22e26..eecee3136 100644
--- a/src/core/mpi/MPIWrapper.h
+++ b/src/core/mpi/MPIWrapper.h
@@ -353,6 +353,9 @@ WALBERLA_CREATE_MPITRAIT_SPECIALIZATION( unsigned short int , MPI_UNSIGNED_SHORT
 WALBERLA_CREATE_MPITRAIT_SPECIALIZATION( unsigned int       , MPI_UNSIGNED           );
 WALBERLA_CREATE_MPITRAIT_SPECIALIZATION( unsigned long int  , MPI_UNSIGNED_LONG      );
 WALBERLA_CREATE_MPITRAIT_SPECIALIZATION( unsigned long long , MPI_UNSIGNED_LONG_LONG );
+#ifdef WALBERLA_BUILD_WITH_HALF_PRECISION_SUPPORT
+   WALBERLA_CREATE_MPITRAIT_SPECIALIZATION( walberla::float16  , MPI_WCHAR              );
+#endif
 WALBERLA_CREATE_MPITRAIT_SPECIALIZATION( float              , MPI_FLOAT              );
 WALBERLA_CREATE_MPITRAIT_SPECIALIZATION( double             , MPI_DOUBLE             );
 WALBERLA_CREATE_MPITRAIT_SPECIALIZATION( long double        , MPI_LONG_DOUBLE        );
diff --git a/tests/core/CMakeLists.txt b/tests/core/CMakeLists.txt
index 46b98eb48..8d3f0298a 100644
--- a/tests/core/CMakeLists.txt
+++ b/tests/core/CMakeLists.txt
@@ -222,3 +222,17 @@ if( WALBERLA_BUILD_WITH_PARMETIS )
    waLBerla_compile_test( NAME PlainParMetisTest FILES load_balancing/PlainParMetisTest.cpp )
    waLBerla_execute_test( NAME PlainParMetisTest PROCESSES 3 )
 endif()
+
+###################
+# Mixed Precision #
+###################
+
+if ( WALBERLA_BUILD_WITH_HALF_PRECISION_SUPPORT )
+   waLBerla_compile_test( Name Float16SupportTest FILES Float16SupportTest.cpp DEPENDS core)
+   # Actual support for float16 is available only since C++23
+   #   before is_arithmetic and is_floating_point evaluated to false,
+   #   also many STL functions are compatible with float16 only since C++23.
+   # Which features are actually supported depend on the compiler
+   target_compile_features( Float16SupportTest PUBLIC cxx_std_23 )
+   waLBerla_execute_test(NAME Float16SupportTest)
+endif ()
\ No newline at end of file
diff --git a/tests/core/Float16SupportTest.cpp b/tests/core/Float16SupportTest.cpp
new file mode 100644
index 000000000..04ea9378f
--- /dev/null
+++ b/tests/core/Float16SupportTest.cpp
@@ -0,0 +1,163 @@
+//======================================================================================================================
+//
+//  This file is part of waLBerla. waLBerla is free software: you can
+//  redistribute it and/or modify it under the terms of the GNU General Public
+//  License as published by the Free Software Foundation, either version 3 of
+//  the License, or (at your option) any later version.
+//
+//  waLBerla is distributed in the hope that it will be useful, but WITHOUT
+//  ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+//  FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
+//  for more details.
+//
+//  You should have received a copy of the GNU General Public License along
+//  with waLBerla (see COPYING.txt). If not, see <http://www.gnu.org/licenses/>.
+//
+//! \file Float16SupportTest.cpp
+//! \ingroup core
+//! \author Michael Zikeli <michael.zikeli@fau.de>
+//
+//======================================================================================================================
+
+#include <memory>
+#include <numeric>
+
+#include "core/DataTypes.h"
+#include "core/Environment.h"
+#include "core/logging/Logging.h"
+
+namespace walberla::simple_Float16_test {
+using walberla::floatIsEqual;
+using walberla::real_t;
+using walberla::uint_c;
+using walberla::uint_t;
+
+// === Choosing Accuracy ===
+//+++ Precision : fp16 +++
+using walberla::float16;
+using walberla::float32;
+using walberla::float64;
+using dst_t                         = float16;
+using src_t                         = real_t;
+constexpr real_t     precisionLimit = walberla::float16( 1e-3 );
+const std::string    precisionType  = "float16";
+constexpr const auto maxLevel       = uint_t( 3 );
+
+void simple_array_test()
+{
+   auto fpSrc = std::make_shared< src_t[] >( 10 );
+   auto fpDst = std::make_shared< dst_t[] >( 10 );
+
+   std::fill_n( fpSrc.get(), 10, 17. );
+   std::fill_n( fpDst.get(), 10, (dst_t) 17. );
+
+   fpSrc[5] = 8.;
+   fpDst[5] = (dst_t) 8.;
+
+   // Test equality with custom compare
+   WALBERLA_CHECK_LESS( std::fabs( fpSrc[9] - (src_t) fpDst[9] ), precisionLimit );
+   WALBERLA_CHECK_LESS( std::fabs( fpSrc[5] - (src_t) fpDst[5] ), precisionLimit );
+   // Test specialized floatIsEqual
+   WALBERLA_CHECK( floatIsEqual( fpSrc[9], (src_t) fpDst[9], (src_t) precisionLimit ) );
+   WALBERLA_CHECK( floatIsEqual( (dst_t) fpSrc[9], fpDst[9], (dst_t) precisionLimit ) );
+   WALBERLA_CHECK_FLOAT_EQUAL( (dst_t) fpSrc[9], fpDst[9] );
+
+   // Test std::fill_n
+   auto other_fpDst = std::make_shared< dst_t[] >( 10 );
+   std::fill_n( other_fpDst.get(), 10, (dst_t) 2. );
+   WALBERLA_CHECK_FLOAT_EQUAL( (dst_t) 2., other_fpDst[9] );
+   WALBERLA_CHECK_FLOAT_EQUAL( (dst_t) 2., other_fpDst[5] );
+
+   // Test std::swap
+   std::swap( fpDst, other_fpDst );
+   fpDst[5] = (dst_t) 9.;
+
+   WALBERLA_CHECK_FLOAT_EQUAL( (dst_t) fpSrc[9], other_fpDst[9] );
+   WALBERLA_CHECK_FLOAT_EQUAL( (dst_t) fpSrc[5], other_fpDst[5] );
+   WALBERLA_CHECK_FLOAT_EQUAL( (dst_t) 2., fpDst[9] );
+   WALBERLA_CHECK_FLOAT_EQUAL( (dst_t) 9., fpDst[5] );
+
+} // simple_Float16_test::simple_array_test()
+
+void vector_test()
+{
+   auto fpSrc      = std::vector< src_t >( 10 );
+   auto fpDst_cast = std::vector< dst_t >( 10 );
+   auto fp32       = std::vector< walberla::float32 >( 10 );
+   auto fpDst      = std::vector< dst_t >( 10 );
+
+   fpSrc.assign( 10, 1.5 );
+   fpDst_cast.assign( 10, (dst_t) 1.5 );
+   fp32.assign( 10, 1.5f );
+   std::copy( fpSrc.begin(), fpSrc.end(), fpDst.begin() );
+   WALBERLA_LOG_WARNING_ON_ROOT(
+       " std::vector.assign is not able to assign "
+       << typeid( src_t ).name() << " values to container of type " << precisionType << ".\n"
+       << " Therefore, the floating point value for assign must be cast beforehand or std::copy must be used, since it uses a static_cast internally." );
+
+   fpSrc[5]      = 2.3;
+   fpDst_cast[5] = (dst_t) 2.3;
+   fp32[5]       = 2.3f;
+   fpDst[5]      = (dst_t) 2.3;
+
+   WALBERLA_CHECK_FLOAT_EQUAL( (walberla::float32) fpSrc[0], fp32[0] );
+   WALBERLA_CHECK_FLOAT_EQUAL( (walberla::float32) fpSrc[9], fp32[9] );
+   WALBERLA_CHECK_FLOAT_EQUAL( (walberla::float32) fpSrc[5], fp32[5] );
+   WALBERLA_CHECK_FLOAT_EQUAL( (dst_t) fpSrc[0], fpDst_cast[0] );
+   WALBERLA_CHECK_FLOAT_EQUAL( (dst_t) fpSrc[9], fpDst_cast[9] );
+   WALBERLA_CHECK_FLOAT_EQUAL( (dst_t) fpSrc[5], fpDst_cast[5] );
+   WALBERLA_CHECK_FLOAT_EQUAL( (dst_t) fpSrc[0], fpDst[0] );
+   WALBERLA_CHECK_FLOAT_EQUAL( (dst_t) fpSrc[9], fpDst[9] );
+   WALBERLA_CHECK_FLOAT_EQUAL( (dst_t) fpSrc[5], fpDst[5] );
+   WALBERLA_CHECK_EQUAL( typeid( fpDst ), typeid( fpDst_cast ) );
+
+   // Add up all elements of the vector to check whether the result is sufficiently correct.
+   {
+      const auto sumSrc = std::reduce(fpSrc.begin(), fpSrc.end());
+      const auto sumDst = std::reduce(fpDst.begin(), fpDst.end());
+      WALBERLA_CHECK_FLOAT_EQUAL( (dst_t)sumSrc, sumDst );
+   }
+   {
+      fpSrc.assign( 13, 1.3 );
+      std::copy( fpSrc.begin(), fpSrc.end(), fpDst.begin() );
+      const auto sumSrc = std::reduce(fpSrc.begin(), fpSrc.end());
+      const auto sumDst = std::reduce(fpDst.begin(), fpDst.end());
+      WALBERLA_CHECK_FLOAT_UNEQUAL( (dst_t)sumSrc, sumDst );
+   }
+} // simple_Float16_test::vector_test()
+
+int main( int argc, char** argv )
+{
+   // This check only works since C++23 and is used in many implementations, so it's important, that it works.
+   WALBERLA_CHECK( std::is_arithmetic< dst_t >::value );
+
+   walberla::Environment env( argc, argv );
+   walberla::logging::Logging::instance()->setLogLevel( walberla::logging::Logging::INFO );
+   walberla::MPIManager::instance()->useWorldComm();
+
+   WALBERLA_LOG_INFO_ON_ROOT( " This run is executed with " << precisionType );
+   WALBERLA_LOG_INFO_ON_ROOT( " machine precision limit is " << precisionLimit );
+   const std::string stringLine( 125, '=' );
+   WALBERLA_LOG_INFO_ON_ROOT( stringLine );
+
+   WALBERLA_LOG_INFO_ON_ROOT( " Start a test with shared_pointer<float16[]>." );
+   simple_array_test();
+
+   WALBERLA_LOG_INFO_ON_ROOT( " Start a test with std::vector<float16>." );
+   vector_test();
+
+   WALBERLA_LOG_INFO_ON_ROOT( " Start a where float32 is sufficient but float16 is not." );
+   WALBERLA_CHECK_FLOAT_UNEQUAL( dst_t(1.0)-dst_t(0.3), 1.0-0.3 );
+   WALBERLA_CHECK_FLOAT_EQUAL( 1.0f-0.3f, 1.0-0.3 );
+
+   return 0;
+} // simple_Float16_test::main()
+
+} // namespace walberla::simple_Float16_test
+
+int main( int argc, char** argv )
+{
+   walberla::simple_Float16_test::main( argc, argv );
+
+   return EXIT_SUCCESS;
+} // main()
-- 
GitLab