From 672d668c473a046d663f658b57f8cd5a4a337851 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Wed, 6 Mar 2024 14:51:43 +0100
Subject: [PATCH] various refactorings and quick-fixes

---
 src/pystencils/backend/jit/legacy_cpu.py      |   2 +-
 .../backend/kernelcreation/context.py         |  12 +-
 .../backend/kernelcreation/freeze.py          |  14 +-
 .../boundaries/boundaryconditions.py          |  18 +--
 src/pystencils/boundaries/boundaryhandling.py |  23 ++--
 src/pystencils/boundaries/inkernel.py         |   3 +-
 src/pystencils/field.py                       |  17 ++-
 src/pystencils/gpu/__init__.py                |   0
 .../{old => }/gpu/gpu_array_handler.py        |   0
 src/pystencils/integer_functions.py           |   3 +-
 src/pystencils/kernelcreation.py              |   4 +-
 src/pystencils/old/gpu/__init__.py            |   4 +-
 src/pystencils/rng.py                         |  10 +-
 src/pystencils/sympyextensions/__init__.py    | 125 +++++++++++++++---
 14 files changed, 175 insertions(+), 60 deletions(-)
 create mode 100644 src/pystencils/gpu/__init__.py
 rename src/pystencils/{old => }/gpu/gpu_array_handler.py (100%)

diff --git a/src/pystencils/backend/jit/legacy_cpu.py b/src/pystencils/backend/jit/legacy_cpu.py
index 1d773dbe6..ca5701160 100644
--- a/src/pystencils/backend/jit/legacy_cpu.py
+++ b/src/pystencils/backend/jit/legacy_cpu.py
@@ -428,7 +428,7 @@ def compile_and_load(kernel: KernelFunction, custom_backend=None):
 
     code.create_code_string(compiler_config["restrict_qualifier"], function_prefix)
     code_hash_str = code.get_hash_of_code()
-    
+
     compile_flags = []
     #   TODO: replace
     # if kernel.instruction_set and "compile_flags" in kernel.instruction_set:
diff --git a/src/pystencils/backend/kernelcreation/context.py b/src/pystencils/backend/kernelcreation/context.py
index 9496c3097..8c2b34fc6 100644
--- a/src/pystencils/backend/kernelcreation/context.py
+++ b/src/pystencils/backend/kernelcreation/context.py
@@ -145,9 +145,15 @@ class KernelCreationContext:
         Before adding the field to the collection, various sanity and constraint checks are applied.
         """
 
-        if field in self._fields_and_arrays:
-            #   Field was already added
-            return
+        if field.name in self._fields_and_arrays:
+            existing_field = self._fields_and_arrays[field.name].field
+            if existing_field != field:
+                raise KernelConstraintsError(
+                    "Encountered two fields with the same name, but different properties: "
+                    f"{field} and {existing_field}"
+                )
+            else:
+                return
 
         arr_shape: list[EllipsisType | int] | None = None
         arr_strides: list[EllipsisType | int] | None = None
diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py
index 2d10b0c0c..31d322aa1 100644
--- a/src/pystencils/backend/kernelcreation/freeze.py
+++ b/src/pystencils/backend/kernelcreation/freeze.py
@@ -1,4 +1,4 @@
-from typing import overload, cast
+from typing import overload, cast, Any
 from functools import reduce
 from operator import add, mul
 
@@ -77,6 +77,14 @@ class FreezeExpressions:
 
         raise FreezeError(f"Don't know how to freeze expression {node}")
 
+    def visit_expr_like(self, obj: Any) -> PsExpression:
+        if isinstance(obj, sp.Basic):
+            return self.visit_expr(obj)
+        elif isinstance(obj, (int, float, bool)):
+            return PsExpression.make(PsConstant(obj))
+        else:
+            raise FreezeError(f"Don't know how to freeze {obj}")
+
     def visit_expr(self, expr: sp.Basic):
         if not isinstance(expr, sp.Expr):
             raise FreezeError(f"Cannot freeze {expr} to an expression")
@@ -128,7 +136,7 @@ class FreezeExpressions:
         array = self._ctx.get_array(field)
         ptr = array.base_pointer
 
-        offsets: list[PsExpression] = [self.visit_expr(o) for o in access.offsets]
+        offsets: list[PsExpression] = [self.visit_expr_like(o) for o in access.offsets]
         indices: list[PsExpression]
 
         if not access.is_absolute_access:
@@ -174,7 +182,7 @@ class FreezeExpressions:
                 )
         else:
             struct_member_name = None
-            indices = [self.visit_expr(i) for i in access.index]
+            indices = [self.visit_expr_like(i) for i in access.index]
             if not indices:
                 # For canonical representation, there must always be at least one index dimension
                 indices = [PsExpression.make(PsConstant(0))]
diff --git a/src/pystencils/boundaries/boundaryconditions.py b/src/pystencils/boundaries/boundaryconditions.py
index 9aa4319fb..f52573bca 100644
--- a/src/pystencils/boundaries/boundaryconditions.py
+++ b/src/pystencils/boundaries/boundaryconditions.py
@@ -1,8 +1,8 @@
-from typing import Any, List, Tuple
+from typing import Any, List, Tuple, Sequence
 
-from pystencils.sympyextensions.astnodes import SympyAssignment
+from pystencils.sympyextensions import Assignment
 from pystencils.boundaries.boundaryhandling import BoundaryOffsetInfo
-from pystencils.typing import create_type
+from pystencils.types import create_type
 
 
 class Boundary:
@@ -14,7 +14,7 @@ class Boundary:
     def __init__(self, name=None):
         self._name = name
 
-    def __call__(self, field, direction_symbol, index_field) -> List[SympyAssignment]:
+    def __call__(self, field, direction_symbol, index_field) -> List[Assignment]:
         """Defines the boundary behavior and must therefore be implemented by all boundaries.
 
         Here the boundary is defined as a list of sympy assignments, from which a boundary kernel is generated.
@@ -30,7 +30,7 @@ class Boundary:
         raise NotImplementedError("Boundary class has to overwrite __call__")
 
     @property
-    def additional_data(self) -> Tuple[str, Any]:
+    def additional_data(self) -> Sequence[Tuple[str, Any]]:
         """Return a list of (name, type) tuples for additional data items required in this boundary
         These data items can either be initialized in separate kernel see additional_data_kernel_init or by
         Python callbacks - see additional_data_callback """
@@ -63,13 +63,13 @@ class Neumann(Boundary):
 
         neighbor = BoundaryOffsetInfo.offset_from_dir(direction_symbol, field.spatial_dimensions)
         if field.index_dimensions == 0:
-            return [SympyAssignment(field.center, field[neighbor])]
+            return [Assignment(field.center, field[neighbor])]
         else:
             from itertools import product
             if not field.has_fixed_index_shape:
                 raise NotImplementedError("Neumann boundary works only for fields with fixed index shape")
             index_iter = product(*(range(i) for i in field.index_shape))
-            return [SympyAssignment(field(*idx), field[neighbor](*idx)) for idx in index_iter]
+            return [Assignment(field(*idx), field[neighbor](*idx)) for idx in index_iter]
 
     def __hash__(self):
         # All boundaries of these class behave equal -> should also be equal
@@ -103,11 +103,11 @@ class Dirichlet(Boundary):
     def __call__(self, field, direction_symbol, index_field, **kwargs):
 
         if field.index_dimensions == 0:
-            return [SympyAssignment(field.center, index_field("value") if self.additional_data else self._value)]
+            return [Assignment(field.center, index_field("value") if self.additional_data else self._value)]
         elif field.index_dimensions == 1:
             assert not self.additional_data
             if not field.has_fixed_index_shape:
                 raise NotImplementedError("Field needs fixed index shape")
             assert len(self._value) == field.index_shape[0], "Dirichlet value does not match index shape of field"
-            return [SympyAssignment(field(i), self._value[i]) for i in range(field.index_shape[0])]
+            return [Assignment(field(i), self._value[i]) for i in range(field.index_shape[0])]
         raise NotImplementedError("Dirichlet boundary not implemented for fields with more than one index dimension")
diff --git a/src/pystencils/boundaries/boundaryhandling.py b/src/pystencils/boundaries/boundaryhandling.py
index 96c82e75d..a61f062be 100644
--- a/src/pystencils/boundaries/boundaryhandling.py
+++ b/src/pystencils/boundaries/boundaryhandling.py
@@ -4,14 +4,14 @@ import numpy as np
 import sympy as sp
 
 from pystencils import create_kernel, CreateKernelConfig, Target
-from pystencils.sympyextensions.astnodes import SympyAssignment
-from pystencils.backends.cbackend import CustomCodeNode
+from pystencils.sympyextensions import Assignment
 from pystencils.boundaries.createindexlist import (
     create_boundary_index_array, numpy_data_type_for_boundary_object)
-from pystencils.typing import TypedSymbol, create_type
+from pystencils.sympyextensions import TypedSymbol
+from pystencils.types import create_type
 from pystencils.gpu.gpu_array_handler import GPUArrayHandler
 from pystencils.field import Field
-from pystencils.typing.typed_sympy import FieldPointerSymbol
+from pystencils.backend.kernelfunction import FieldPointerParam
 
 try:
     # noinspection PyPep8Naming
@@ -246,9 +246,9 @@ class BoundaryHandling:
             for b_obj, idx_arr in b[self._index_array_name].boundary_object_to_index_list.items():
                 kwargs[self._field_name] = b[self._field_name]
                 kwargs['indexField'] = idx_arr
-                data_used_in_kernel = (p.fields[0].name
+                data_used_in_kernel = (p.field.name
                                        for p in self._boundary_object_to_boundary_info[b_obj].kernel.parameters
-                                       if isinstance(p.symbol, FieldPointerSymbol) and p.fields[0].name not in kwargs)
+                                       if isinstance(p, FieldPointerParam) and p.field.name not in kwargs)
                 kwargs.update({name: b[name] for name in data_used_in_kernel})
 
                 self._boundary_object_to_boundary_info[b_obj].kernel(**kwargs)
@@ -262,9 +262,9 @@ class BoundaryHandling:
                 arguments = kwargs.copy()
                 arguments[self._field_name] = b[self._field_name]
                 arguments['indexField'] = idx_arr
-                data_used_in_kernel = (p.fields[0].name
+                data_used_in_kernel = (p.field.name
                                        for p in self._boundary_object_to_boundary_info[b_obj].kernel.parameters
-                                       if isinstance(p.symbol, FieldPointerSymbol) and p.field_name not in arguments)
+                                       if isinstance(p, FieldPointerParam) and p.field.name not in arguments)
                 arguments.update({name: b[name] for name in data_used_in_kernel if name not in arguments})
 
                 kernel = self._boundary_object_to_boundary_info[b_obj].kernel
@@ -404,7 +404,8 @@ class BoundaryDataSetter:
         return self.index_array[item]
 
 
-class BoundaryOffsetInfo(CustomCodeNode):
+# class BoundaryOffsetInfo(CustomCodeNode): #   TODO nbackend: replace
+class BoundaryOffsetInfo:
 
     # --------------------------- Functions to be used by boundaries --------------------------
 
@@ -448,7 +449,7 @@ class BoundaryOffsetInfo(CustomCodeNode):
 def create_boundary_kernel(field, index_field, stencil, boundary_functor, target=Target.CPU, **kernel_creation_args):
     elements = [BoundaryOffsetInfo(stencil)]
     dir_symbol = TypedSymbol("dir", np.int32)
-    elements += [SympyAssignment(dir_symbol, index_field[0]('dir'))]
+    elements += [Assignment(dir_symbol, index_field[0]('dir'))]
     elements += boundary_functor(field, direction_symbol=dir_symbol, index_field=index_field)
-    config = CreateKernelConfig(index_fields=[index_field], target=target, **kernel_creation_args)
+    config = CreateKernelConfig(index_field=index_field, target=target, **kernel_creation_args)
     return create_kernel(elements, config=config)
diff --git a/src/pystencils/boundaries/inkernel.py b/src/pystencils/boundaries/inkernel.py
index 479f30d22..7cd9e628b 100644
--- a/src/pystencils/boundaries/inkernel.py
+++ b/src/pystencils/boundaries/inkernel.py
@@ -1,7 +1,8 @@
 import sympy as sp
 
 from pystencils.boundaries.boundaryhandling import DEFAULT_FLAG_TYPE
-from pystencils.typing import TypedSymbol, create_type
+from pystencils.sympyextensions import TypedSymbol
+from pystencils.types import create_type
 from pystencils.field import Field
 from pystencils.integer_functions import bitwise_and
 
diff --git a/src/pystencils/field.py b/src/pystencils/field.py
index f1b7cb376..646cecb5d 100644
--- a/src/pystencils/field.py
+++ b/src/pystencils/field.py
@@ -17,6 +17,7 @@ from pystencils.stencil import direction_string_to_offset, inverse_direction, of
 from pystencils.types import PsType, PsStructType, create_type
 from pystencils.sympyextensions.typed_sympy import (FieldShapeSymbol, FieldStrideSymbol, TypedSymbol)
 from pystencils.sympyextensions.math import is_integer_sequence
+from pystencils.types.quick import UserTypeSpec
 
 
 __all__ = ['Field', 'fields', 'FieldType', 'Field']
@@ -122,7 +123,7 @@ class Field:
     """
 
     @staticmethod
-    def create_generic(field_name, spatial_dimensions, dtype=np.float64, index_dimensions=0, layout='numpy',
+    def create_generic(field_name, spatial_dimensions, dtype: UserTypeSpec = np.float64, index_dimensions=0, layout='numpy',
                        index_shape=None, field_type=FieldType.GENERIC) -> 'Field':
         """
         Creates a generic field where the field size is not fixed i.e. can be called with arrays of different sizes
@@ -156,7 +157,9 @@ class Field:
 
         strides = tuple([FieldStrideSymbol(field_name, i) for i in range(total_dimensions)])
 
-        np_data_type = np.dtype(dtype)
+        dtype = create_type(dtype)
+        np_data_type = dtype.numpy_dtype
+        assert np_data_type is not None
         if np_data_type.fields is not None:
             if index_dimensions != 0:
                 raise ValueError("Structured arrays/fields are not allowed to have an index dimension")
@@ -245,7 +248,15 @@ class Field:
             spatial_layout.remove(i)
         return Field(field_name, field_type, dtype, tuple(spatial_layout), shape, strides)
 
-    def __init__(self, field_name, field_type, dtype, layout, shape, strides):
+    def __init__(
+        self,
+        field_name: str,
+        field_type: FieldType,
+        dtype: UserTypeSpec,
+        layout: tuple[int, ...],
+        shape,
+        strides
+    ):
         """Do not use directly. Use static create* methods"""
         self._field_name = field_name
         assert isinstance(field_type, FieldType)
diff --git a/src/pystencils/gpu/__init__.py b/src/pystencils/gpu/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/src/pystencils/old/gpu/gpu_array_handler.py b/src/pystencils/gpu/gpu_array_handler.py
similarity index 100%
rename from src/pystencils/old/gpu/gpu_array_handler.py
rename to src/pystencils/gpu/gpu_array_handler.py
diff --git a/src/pystencils/integer_functions.py b/src/pystencils/integer_functions.py
index cd0e6f231..90e79798d 100644
--- a/src/pystencils/integer_functions.py
+++ b/src/pystencils/integer_functions.py
@@ -2,7 +2,8 @@
 import numpy as np
 import sympy as sp
 
-from pystencils.typing import CastFunc, collate_types, create_type, get_type_of_expression
+from pystencils.sympyextensions import CastFunc
+from pystencils.types import create_type
 from pystencils.sympyextensions import is_integer_sequence
 
 
diff --git a/src/pystencils/kernelcreation.py b/src/pystencils/kernelcreation.py
index 1e365cfed..535586cd8 100644
--- a/src/pystencils/kernelcreation.py
+++ b/src/pystencils/kernelcreation.py
@@ -96,7 +96,7 @@ def create_kernel(
 def create_kernel_function(
     ctx: KernelCreationContext,
     body: PsBlock,
-    name: str,
+    function_name: str,
     target_spec: Target,
     jit: JitBase,
 ):
@@ -123,5 +123,5 @@ def create_kernel_function(
     req_headers |= ctx.required_headers
 
     return KernelFunction(
-        body, target_spec, name, params, req_headers, ctx.constraints, jit
+        body, target_spec, function_name, params, req_headers, ctx.constraints, jit
     )
diff --git a/src/pystencils/old/gpu/__init__.py b/src/pystencils/old/gpu/__init__.py
index 0ee6f02ae..d1ae203b7 100644
--- a/src/pystencils/old/gpu/__init__.py
+++ b/src/pystencils/old/gpu/__init__.py
@@ -1,9 +1,7 @@
-from .gpu_array_handler import GPUArrayHandler, GPUNotAvailableHandler
 from .gpujit import make_python_function
 from .kernelcreation import create_cuda_kernel, created_indexed_cuda_kernel
 
 from .indexing import AbstractIndexing, BlockIndexing, LineIndexing
 
-__all__ = ['GPUArrayHandler', 'GPUNotAvailableHandler',
-           'create_cuda_kernel', 'created_indexed_cuda_kernel', 'make_python_function',
+__all__ = ['create_cuda_kernel', 'created_indexed_cuda_kernel', 'make_python_function',
            'AbstractIndexing', 'BlockIndexing', 'LineIndexing']
diff --git a/src/pystencils/rng.py b/src/pystencils/rng.py
index 5f74b62a4..221505a10 100644
--- a/src/pystencils/rng.py
+++ b/src/pystencils/rng.py
@@ -2,13 +2,13 @@ import copy
 import numpy as np
 import sympy as sp
 
-from pystencils.typing import TypedSymbol, CastFunc
-from pystencils.sympyextensions.astnodes import LoopOverCoordinate
-from pystencils.backends.cbackend import CustomCodeNode
+from pystencils.sympyextensions import TypedSymbol, CastFunc
+# from pystencils.sympyextensions.astnodes import LoopOverCoordinate # TODO nbackend: replace
+# from pystencils.backends.cbackend import CustomCodeNode # TODO nbackend: replace
 from pystencils.sympyextensions import fast_subs
 
-
-class RNGBase(CustomCodeNode):
+# class RNGBase(CustomCodeNode): TODO nbackend: replace
+class RNGBase:
 
     id = 0
 
diff --git a/src/pystencils/sympyextensions/__init__.py b/src/pystencils/sympyextensions/__init__.py
index 36c66da89..1f0a4c5f4 100644
--- a/src/pystencils/sympyextensions/__init__.py
+++ b/src/pystencils/sympyextensions/__init__.py
@@ -1,22 +1,111 @@
-from .astnodes import Assignment, AugmentedAssignment, AddAugmentedAssignment, AssignmentCollection
+from .astnodes import (
+    Assignment,
+    AugmentedAssignment,
+    AddAugmentedAssignment,
+    AssignmentCollection,
+    SymbolGen
+)
+from .typed_sympy import TypedSymbol, CastFunc
 from .simplificationstrategy import SimplificationStrategy
-from .simplifications import (sympy_cse, sympy_cse_on_assignment_list, apply_to_all_assignments,
-                              apply_on_all_subexpressions, subexpression_substitution_in_existing_subexpressions,
-                              subexpression_substitution_in_main_assignments, add_subexpressions_for_constants,
-                              add_subexpressions_for_divisions, add_subexpressions_for_sums,
-                              add_subexpressions_for_field_reads)
+from .simplifications import (
+    sympy_cse,
+    sympy_cse_on_assignment_list,
+    apply_to_all_assignments,
+    apply_on_all_subexpressions,
+    subexpression_substitution_in_existing_subexpressions,
+    subexpression_substitution_in_main_assignments,
+    add_subexpressions_for_constants,
+    add_subexpressions_for_divisions,
+    add_subexpressions_for_sums,
+    add_subexpressions_for_field_reads
+)
 from .subexpression_insertion import (
-    insert_aliases, insert_zeros, insert_constants,
-    insert_constant_additions, insert_constant_multiples,
-    insert_squares, insert_symbol_times_minus_one)
+    insert_aliases,
+    insert_zeros,
+    insert_constants,
+    insert_constant_additions,
+    insert_constant_multiples,
+    insert_squares,
+    insert_symbol_times_minus_one,
+)
 
+from .math import (
+    prod,
+    remove_small_floats,
+    is_integer_sequence,
+    scalar_product,
+    kronecker_delta,
+    tanh_step_function_approximation,
+    multidimensional_sum,
+    normalize_product,
+    symmetric_product,
+    fast_subs,
+    is_constant,
+    subs_additive,
+    replace_second_order_products,
+    remove_higher_order_terms,
+    complete_the_square,
+    complete_the_squares_in_exp,
+    extract_most_common_factor,
+    recursive_collect,
+    summands,
+    simplify_by_equality,
+    count_operations,
+    count_operations_in_ast,
+    common_denominator,
+    get_symmetric_part,
+)
 
-__all__ = ['Assignment', 'AugmentedAssignment', 'AddAugmentedAssignment',
-           'AssignmentCollection', 'SimplificationStrategy',
-           'sympy_cse', 'sympy_cse_on_assignment_list', 'apply_to_all_assignments',
-           'apply_on_all_subexpressions', 'subexpression_substitution_in_existing_subexpressions',
-           'subexpression_substitution_in_main_assignments', 'add_subexpressions_for_constants',
-           'add_subexpressions_for_divisions', 'add_subexpressions_for_sums', 'add_subexpressions_for_field_reads',
-           'insert_aliases', 'insert_zeros', 'insert_constants',
-           'insert_constant_additions', 'insert_constant_multiples',
-           'insert_squares', 'insert_symbol_times_minus_one']
+
+__all__ = [
+    "Assignment",
+    "AugmentedAssignment",
+    "AddAugmentedAssignment",
+    "AssignmentCollection",
+    "SymbolGen",
+    "TypedSymbol",
+    "CastFunc",
+    "SimplificationStrategy",
+    "sympy_cse",
+    "sympy_cse_on_assignment_list",
+    "apply_to_all_assignments",
+    "apply_on_all_subexpressions",
+    "subexpression_substitution_in_existing_subexpressions",
+    "subexpression_substitution_in_main_assignments",
+    "add_subexpressions_for_constants",
+    "add_subexpressions_for_divisions",
+    "add_subexpressions_for_sums",
+    "add_subexpressions_for_field_reads",
+    "insert_aliases",
+    "insert_zeros",
+    "insert_constants",
+    "insert_constant_additions",
+    "insert_constant_multiples",
+    "insert_squares",
+    "insert_symbol_times_minus_one",
+    "remove_higher_order_terms",
+    "prod",
+    "remove_small_floats",
+    "is_integer_sequence",
+    "scalar_product",
+    "kronecker_delta",
+    "tanh_step_function_approximation",
+    "multidimensional_sum",
+    "normalize_product",
+    "symmetric_product",
+    "fast_subs",
+    "is_constant",
+    "subs_additive",
+    "replace_second_order_products",
+    "remove_higher_order_terms",
+    "complete_the_square",
+    "complete_the_squares_in_exp",
+    "extract_most_common_factor",
+    "recursive_collect",
+    "summands",
+    "simplify_by_equality",
+    "count_operations",
+    "count_operations_in_ast",
+    "common_denominator",
+    "get_symmetric_part",
+]
-- 
GitLab