Skip to content
Snippets Groups Projects
Commit ede1dc5c authored by Michael Kuron's avatar Michael Kuron :mortar_board:
Browse files

Move boundary generation into pystencils_walberla

parent 16df5598
Branches
Tags
No related merge requests found
......@@ -2,13 +2,11 @@ import numpy as np
from jinja2 import Environment, PackageLoader, StrictUndefined
from lbmpy.boundaries.boundaryhandling import create_lattice_boltzmann_boundary_kernel
from lbmpy_walberla.walberla_lbm_generation import KernelInfo
from pystencils import Field, FieldType
from pystencils.boundaries.createindexlist import (
boundary_index_array_coordinate_names, direction_member_name,
numpy_data_type_for_boundary_object)
from pystencils.boundaries.createindexlist import numpy_data_type_for_boundary_object
from pystencils.data_types import TypedSymbol, create_type
from pystencils_walberla.codegen import default_create_kernel_parameters
from pystencils_walberla.boundary import struct_from_numpy_dtype
from pystencils_walberla.codegen import default_create_kernel_parameters, KernelInfo
from pystencils_walberla.jinja_filters import add_pystencils_filters_to_jinja_env
......@@ -54,7 +52,7 @@ def generate_boundary(generation_context, class_name, boundary_object, lb_method
'namespace': 'lbm',
}
env = Environment(loader=PackageLoader('lbmpy_walberla'), undefined=StrictUndefined)
env = Environment(loader=PackageLoader('pystencils_walberla'), undefined=StrictUndefined)
add_pystencils_filters_to_jinja_env(env)
header = env.get_template('Boundary.tmpl.h').render(**context)
......@@ -63,30 +61,3 @@ def generate_boundary(generation_context, class_name, boundary_object, lb_method
source_extension = "cpp" if create_kernel_params.get("target", "cpu") == "cpu" else "cu"
generation_context.write_file("{}.h".format(class_name), header)
generation_context.write_file("{}.{}".format(class_name, source_extension), source)
def struct_from_numpy_dtype(struct_name, numpy_dtype):
result = "struct %s { \n" % (struct_name,)
equality_compare = []
constructor_params = []
constructor_initializer_list = []
for name, (sub_type, offset) in numpy_dtype.fields.items():
pystencils_type = create_type(sub_type)
result += " %s %s;\n" % (pystencils_type, name)
if name in boundary_index_array_coordinate_names or name == direction_member_name:
constructor_params.append("%s %s_" % (pystencils_type, name))
constructor_initializer_list.append("%s(%s_)" % (name, name))
else:
constructor_initializer_list.append("%s()" % name)
if pystencils_type.is_float():
equality_compare.append("floatIsEqual(%s, o.%s)" % (name, name))
else:
equality_compare.append("%s == o.%s" % (name, name))
result += " %s(%s) : %s {}\n" % \
(struct_name, ", ".join(constructor_params), ", ".join(constructor_initializer_list))
result += " bool operator==(const %s & o) const {\n return %s;\n }\n" % \
(struct_name, " && ".join(equality_compare))
result += "};\n"
return result
import numpy as np
from jinja2 import Environment, PackageLoader, StrictUndefined
from pystencils_walberla.codegen import KernelInfo
from pystencils import Field, FieldType
from pystencils.boundaries.createindexlist import (
boundary_index_array_coordinate_names, direction_member_name,
numpy_data_type_for_boundary_object)
from pystencils.data_types import TypedSymbol, create_type
from pystencils_walberla.codegen import default_create_kernel_parameters
from pystencils_walberla.jinja_filters import add_pystencils_filters_to_jinja_env
def struct_from_numpy_dtype(struct_name, numpy_dtype):
result = "struct %s { \n" % (struct_name,)
equality_compare = []
constructor_params = []
constructor_initializer_list = []
for name, (sub_type, offset) in numpy_dtype.fields.items():
pystencils_type = create_type(sub_type)
result += " %s %s;\n" % (pystencils_type, name)
if name in boundary_index_array_coordinate_names or name == direction_member_name:
constructor_params.append("%s %s_" % (pystencils_type, name))
constructor_initializer_list.append("%s(%s_)" % (name, name))
else:
constructor_initializer_list.append("%s()" % name)
if pystencils_type.is_float():
equality_compare.append("floatIsEqual(%s, o.%s)" % (name, name))
else:
equality_compare.append("%s == o.%s" % (name, name))
result += " %s(%s) : %s {}\n" % \
(struct_name, ", ".join(constructor_params), ", ".join(constructor_initializer_list))
result += " bool operator==(const %s & o) const {\n return %s;\n }\n" % \
(struct_name, " && ".join(equality_compare))
result += "};\n"
return result
......@@ -14,8 +14,7 @@
// with waLBerla (see COPYING.txt). If not, see <http://www.gnu.org/licenses/>.
//
//! \\file {{class_name}}.cpp
//! \\ingroup lbm
//! \\author lbmpy
//! \\author pystencils
//======================================================================================================================
#include <cmath>
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment