From 96c04e6b453901188e0137af62a1b7c7f7527946 Mon Sep 17 00:00:00 2001
From: Markus Holzer <markus.holzer@fau.de>
Date: Fri, 19 Nov 2021 17:27:55 +0100
Subject: [PATCH] Start with refactoring of the type system

---
 pystencils/data_types.py                      | 68 +++++++++++++++----
 pystencils/gpucuda/periodicity.py             |  1 +
 pystencils/kernelcreation.py                  |  4 ++
 pystencils/kernelparameters.py                |  5 ++
 pystencils/transformations.py                 | 11 ++-
 .../test_subexpression_insertion.py           |  1 -
 pystencils_tests/test_types.py                | 26 +++++++
 7 files changed, 101 insertions(+), 15 deletions(-)

diff --git a/pystencils/data_types.py b/pystencils/data_types.py
index 331de62bd..84af2f7f5 100644
--- a/pystencils/data_types.py
+++ b/pystencils/data_types.py
@@ -29,11 +29,13 @@ def typed_symbols(names, dtype, *args):
 
 
 def type_all_numbers(expr, dtype):
+    # TODO: move to pystnecils_walberla
     substitutions = {a: cast_func(a, dtype) for a in expr.atoms(sp.Number)}
     return expr.subs(substitutions)
 
 
 def matrix_symbols(names, dtype, rows, cols):
+    # TODO: check if needed. (lbmpy, walberla)
     if isinstance(names, str):
         names = names.replace(' ', '').split(',')
 
@@ -46,6 +48,7 @@ def matrix_symbols(names, dtype, rows, cols):
 
 
 def assumptions_from_dtype(dtype):
+    # TODO: type hints and if dtype is correct type form Numpy
     """Derives SymPy assumptions from :class:`BasicType` or a Numpy dtype
 
     Args:
@@ -76,6 +79,9 @@ def assumptions_from_dtype(dtype):
 
 # noinspection PyPep8Naming
 class address_of(sp.Function):
+    # TODO: ask Martin
+    # TODO: documentation
+    # TODO: move function to `functions.py`
     is_Atom = True
 
     def __new__(cls, arg):
@@ -103,6 +109,8 @@ class address_of(sp.Function):
 
 # noinspection PyPep8Naming
 class cast_func(sp.Function):
+    # TODO: documentation
+    # TODO: move function to `functions.py`
     is_Atom = True
 
     def __new__(cls, *args, **kwargs):
@@ -190,22 +198,30 @@ class cast_func(sp.Function):
 
 # noinspection PyPep8Naming
 class boolean_cast_func(cast_func, Boolean):
+    # TODO: documentation
+    # TODO: move function to `functions.py`
     pass
 
 
 # noinspection PyPep8Naming
 class vector_memory_access(cast_func):
+    # TODO: documentation
+    # TODO: move function to `functions.py`
     # Arguments are: read/write expression, type, aligned, nontemporal, mask (or none), stride
     nargs = (6,)
 
 
 # noinspection PyPep8Naming
 class reinterpret_cast_func(cast_func):
+    # TODO: documentation
+    # TODO: move function to `functions.py`
     pass
 
 
 # noinspection PyPep8Naming
 class pointer_arithmetic_func(sp.Function, Boolean):
+    # TODO: documentation
+    # TODO: move function to `functions.py`
     @property
     def canonical(self):
         if hasattr(self.args[0], 'canonical'):
@@ -272,6 +288,8 @@ class TypedSymbol(sp.Symbol):
 
 
 def create_type(specification):
+    # TODO: HERE
+    # TODO: type hint -> np.type
     """Creates a subclass of Type according to a string or an object of subclass Type.
 
     Args:
@@ -292,6 +310,7 @@ def create_type(specification):
 
 @memorycache(maxsize=64)
 def create_composite_type_from_string(specification):
+    # TODO: can be removed after llvm removla and fix of kernelparameters
     """Creates a new Type object from a c-like string specification.
 
     Args:
@@ -338,12 +357,15 @@ def create_composite_type_from_string(specification):
 
 
 def get_base_type(data_type):
+    # TODO: WTF is this?? DOCS!!!
+    # TODO: Can be removed after removal of kerncraft and fix in FieldPointer Symbol
     while data_type.base_type is not None:
         data_type = data_type.base_type
     return data_type
 
 
 def to_ctypes(data_type):
+    # TODO: can be removed with llvm
     """
     Transforms a given Type into ctypes
     :param data_type: Subclass of Type
@@ -356,7 +378,7 @@ def to_ctypes(data_type):
     else:
         return to_ctypes.map[data_type.numpy_dtype]
 
-
+# TODO: can be removed with llvm
 to_ctypes.map = {
     np.dtype(np.int8): ctypes.c_int8,
     np.dtype(np.int16): ctypes.c_int16,
@@ -374,6 +396,7 @@ to_ctypes.map = {
 
 
 def ctypes_from_llvm(data_type):
+    # TODO can be removed with LLVM
     if not ir:
         raise _ir_importerror
     if isinstance(data_type, ir.PointerType):
@@ -404,6 +427,7 @@ def ctypes_from_llvm(data_type):
 
 
 def to_llvm_type(data_type, nvvm_target=False):
+    # TODO: can be removed with LLVM
     """
     Transforms a given type into ctypes
     :param data_type: Subclass of Type
@@ -417,6 +441,7 @@ def to_llvm_type(data_type, nvvm_target=False):
         return to_llvm_type.map[data_type.numpy_dtype]
 
 
+# TODO: can be removed with LLVM
 if ir:
     to_llvm_type.map = {
         np.dtype(np.int8): ir.IntType(8),
@@ -435,16 +460,19 @@ if ir:
 
 
 def peel_off_type(dtype, type_to_peel_off):
+    # TODO: WTF is this??? DOCS!!!
+    # TODO: used only once.... can be a lambda there
     while type(dtype) is type_to_peel_off:
         dtype = dtype.base_type
     return dtype
 
 
+############################# This is basically our type system ########################################################
 def collate_types(types,
-                  forbid_collation_to_complex=False,
-                  forbid_collation_to_float=False,
-                  default_float_type='float64',
-                  default_int_type='int64'):
+                  forbid_collation_to_complex=False,  # TODO: type system shouldn't need this!!!
+                  forbid_collation_to_float=False,    # TODO: type system shouldn't need this!!!
+                  default_float_type='float64',       # TODO: AST leaves should be typed. Expressions should be able to find out correct type
+                  default_int_type='int64'):          # TODO: AST leaves should be typed. Expressions should be able to find out correct type
     """
     Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double
     Uses the collation rules from numpy.
@@ -495,9 +523,9 @@ def collate_types(types,
 
 @memorycache_if_hashable(maxsize=2048)
 def get_type_of_expression(expr,
-                           default_float_type='double',
-                           default_int_type='int',
-                           symbol_type_dict=None):
+                           default_float_type='double', # TODO: we shouldn't need to have default. AST leaves should have a type
+                           default_int_type='int', # TODO: we shouldn't need to have default. AST leaves should have a type
+                           symbol_type_dict=None): # TODO: we shouldn't need to have default. AST leaves should have a type
     from pystencils.astnodes import ResolvedFieldAccess
     from pystencils.cpu.vectorization import vec_all, vec_any
 
@@ -582,6 +610,7 @@ def get_type_of_expression(expr,
                 return create_type(default_float_type)
 
     raise NotImplementedError("Could not determine type for", expr, type(expr))
+############################# End This is basically our type system ##################################################
 
 
 sympy_version = sp.__version__.split('.')
@@ -614,6 +643,8 @@ if int(sympy_version[0]) * 100 + int(sympy_version[1]) >= 109:
 
 
 class Type(sp.Atom):
+    # TODO: why is our type system dependent on sympy???
+    # TODO: ask Martin
     def __new__(cls, *args, **kwargs):
         return sp.Basic.__new__(cls)
 
@@ -622,8 +653,15 @@ class Type(sp.Atom):
 
 
 class BasicType(Type):
+    # TODO: check if Type inheritance is needed
+    # TODO: should be a sensible interface to np.dtype
+    # TODO: read numpy docs (Jan)
     @staticmethod
     def numpy_name_to_c(name):
+        # TODO: this should be a free function
+        # TODO: also check if numpy has this functionality
+        # TODO: docs!!!
+        # TODO: is this C?
         if name == 'float64':
             return 'double'
         elif name == 'float32':
@@ -644,9 +682,10 @@ class BasicType(Type):
             raise NotImplementedError(f"Can map numpy to C name for {name}")
 
     def __init__(self, dtype, const=False):
+        # TODO: type hints
         self.const = const
         if isinstance(dtype, Type):
-            self._dtype = dtype.numpy_dtype
+            self._dtype = dtype.numpy_dtype  # TODO: wtf?
         else:
             self._dtype = np.dtype(dtype)
         assert self._dtype.fields is None, "Tried to initialize NativeType with a structured type"
@@ -660,7 +699,7 @@ class BasicType(Type):
         return (self.numpy_dtype, self.const), {}
 
     @property
-    def base_type(self):
+    def base_type(self): # TODO: what is base_type?
         return None
 
     @property
@@ -672,7 +711,7 @@ class BasicType(Type):
         return getattr(sympy.codegen.ast, str(self.numpy_dtype))
 
     @property
-    def item_size(self):
+    def item_size(self):  # TODO: what is this?
         return 1
 
     def is_int(self):
@@ -691,7 +730,7 @@ class BasicType(Type):
         return self.numpy_dtype in np.sctypes['others']
 
     @property
-    def base_name(self):
+    def base_name(self):  # TODO: name of the function is highly confusing
         return BasicType.numpy_name_to_c(str(self._dtype))
 
     def __str__(self):
@@ -714,6 +753,7 @@ class BasicType(Type):
 
 
 class VectorType(Type):
+    # TODO: check with rest
     instruction_set = None
 
     def __init__(self, base_type, width=4):
@@ -760,6 +800,7 @@ class VectorType(Type):
 
 
 class PointerType(Type):
+    # TODO: rename to FieldType
     def __init__(self, base_type, const=False, restrict=True):
         self._base_type = base_type
         self.const = const
@@ -805,6 +846,7 @@ class PointerType(Type):
 
 
 class StructType:
+    # TODO: Docs. This is a struct. A list of types (with C offsets)
     def __init__(self, numpy_type, const=False):
         self.const = const
         self._dtype = np.dtype(numpy_type)
@@ -858,6 +900,8 @@ class StructType:
 
 
 class TypedImaginaryUnit(TypedSymbol):
+    # TODO: why is this an extra class???
+    # TODO: remove?
     def __new__(cls, *args, **kwds):
         obj = TypedImaginaryUnit.__xnew_cached_(cls, *args, **kwds)
         return obj
diff --git a/pystencils/gpucuda/periodicity.py b/pystencils/gpucuda/periodicity.py
index da62af0f6..e5083af4a 100644
--- a/pystencils/gpucuda/periodicity.py
+++ b/pystencils/gpucuda/periodicity.py
@@ -31,6 +31,7 @@ def create_copy_kernel(domain_size, from_slice, to_slice, index_dimensions=0, in
     return ast
 
 
+# TODO: type flot is dangerous here
 def get_periodic_boundary_functor(stencil, domain_size, index_dimensions=0, index_dim_shape=1, ghost_layers=1,
                                   thickness=None, dtype=float, target=Target.GPU, opencl_queue=None, opencl_ctx=None):
     assert target in {Target.GPU, Target.OPENCL}
diff --git a/pystencils/kernelcreation.py b/pystencils/kernelcreation.py
index ac4412256..f106f7d2a 100644
--- a/pystencils/kernelcreation.py
+++ b/pystencils/kernelcreation.py
@@ -38,6 +38,8 @@ class CreateKernelConfig:
     """
     Name of the generated function - only important if generated code is written out
     """
+    # TODO: config should check that the datatype is a Numpy type
+    # TODO: check for the python types and issue warnings
     data_type: Union[str, dict] = 'double'
     """
     Data type used for all untyped symbols (i.e. non-fields), can also be a dict from symbol name to type
@@ -125,6 +127,7 @@ class CreateKernelConfig:
 
     def __post_init__(self):
         # ----  Legacy parameters
+        # TODO adapt here the types
         if isinstance(self.target, str):
             new_target = Target[self.target.upper()]
             warnings.warn(f'Target "{self.target}" as str is deprecated. Use {new_target} instead',
@@ -249,6 +252,7 @@ def create_domain_kernel(assignments: List[Assignment], *, config: CreateKernelC
     if config.target == Target.CPU:
         if config.backend == Backend.C:
             from pystencils.cpu import add_openmp, create_kernel
+            # TODO: data type keyword should be unified to data_type
             ast = create_kernel(assignments, function_name=config.function_name, type_info=config.data_type,
                                 split_groups=split_groups,
                                 iteration_slice=config.iteration_slice, ghost_layers=config.ghost_layers,
diff --git a/pystencils/kernelparameters.py b/pystencils/kernelparameters.py
index 934c305cc..8bd4341be 100644
--- a/pystencils/kernelparameters.py
+++ b/pystencils/kernelparameters.py
@@ -18,6 +18,11 @@ from sympy.core.cache import cacheit
 from pystencils.data_types import (
     PointerType, TypedSymbol, create_composite_type_from_string, get_base_type)
 
+
+# TODO: Why do we need extra classes? Why isn't TypedSymbol enough?
+# TODO: Replace with a factory function
+
+
 SHAPE_DTYPE = create_composite_type_from_string("const int64")
 STRIDE_DTYPE = create_composite_type_from_string("const int64")
 
diff --git a/pystencils/transformations.py b/pystencils/transformations.py
index c2b6cf54b..1175580f7 100644
--- a/pystencils/transformations.py
+++ b/pystencils/transformations.py
@@ -1,6 +1,7 @@
 import hashlib
 import pickle
 import warnings
+from typing import List, Dict
 from collections import OrderedDict, defaultdict, namedtuple
 from copy import deepcopy
 from types import MappingProxyType
@@ -424,7 +425,7 @@ def resolve_buffer_accesses(ast_node, base_buffer_index, read_only_field_names=s
     return visit_node(ast_node)
 
 
-def resolve_field_accesses(ast_node, read_only_field_names=set(),
+def resolve_field_accesses(ast_node, read_only_field_names=None,
                            field_to_base_pointer_info=MappingProxyType({}),
                            field_to_fixed_coordinates=MappingProxyType({})):
     """
@@ -441,6 +442,8 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(),
     Returns
         transformed AST
     """
+    if read_only_field_names is None:
+        read_only_field_names = set()
     field_to_base_pointer_info = OrderedDict(sorted(field_to_base_pointer_info.items(), key=lambda pair: pair[0]))
     field_to_fixed_coordinates = OrderedDict(sorted(field_to_fixed_coordinates.items(), key=lambda pair: pair[0]))
 
@@ -936,7 +939,8 @@ class KernelConstraintsCheck:
             self.scopes.access_symbol(rhs)
 
 
-def add_types(eqs, type_for_symbol, check_independence_condition, check_double_write_condition=True):
+def add_types(eqs: List[Assignment], type_for_symbol: Dict[sp.Symbol, np.dtype], check_independence_condition: bool,
+              check_double_write_condition: bool=True):
     """Traverses AST and replaces every :class:`sympy.Symbol` by a :class:`pystencils.typedsymbol.TypedSymbol`.
 
     Additionally returns sets of all fields which are read/written
@@ -956,9 +960,12 @@ def add_types(eqs, type_for_symbol, check_independence_condition, check_double_w
 
     type_for_symbol = adjust_c_single_precision_type(type_for_symbol)
 
+    # TODO what does this do????
+    # TODO: ask Martin
     check = KernelConstraintsCheck(type_for_symbol, check_independence_condition,
                                    check_double_write_condition=check_double_write_condition)
 
+    # TODO: check if this adds only types to leave nodes of AST, get type info
     def visit(obj):
         if isinstance(obj, (list, tuple)):
             return [visit(e) for e in obj]
diff --git a/pystencils_tests/test_subexpression_insertion.py b/pystencils_tests/test_subexpression_insertion.py
index 9ae64d9fe..790d97d76 100644
--- a/pystencils_tests/test_subexpression_insertion.py
+++ b/pystencils_tests/test_subexpression_insertion.py
@@ -1,4 +1,3 @@
-import sympy as sp
 from pystencils import fields, Assignment, AssignmentCollection
 from pystencils.simp.subexpression_insertion import *
 
diff --git a/pystencils_tests/test_types.py b/pystencils_tests/test_types.py
index c63ab6923..87124ec9e 100644
--- a/pystencils_tests/test_types.py
+++ b/pystencils_tests/test_types.py
@@ -180,3 +180,29 @@ def test_ctypes_from_llvm():
 
     assert ctypes_from_llvm(ir.FloatType()) == ctypes.c_float
     assert ctypes_from_llvm(ir.DoubleType()) == ctypes.c_double
+
+
+def test_division():
+    f = ps.fields('f(10): float32[2D]')
+    m, tau = sp.symbols("m, tau")
+
+    up = [ps.Assignment(tau, 1.0 / (0.5 + (3.0 * m))),
+          ps.Assignment(f.center, tau)]
+
+    ast = ps.create_kernel(up, config=ps.CreateKernelConfig(data_type="float32"))
+    code = ps.get_code_str(ast)
+
+    assert "1.0f" in code
+
+
+def test_pow():
+    f = ps.fields('f(10): float32[2D]')
+    m, tau = sp.symbols("m, tau")
+
+    up = [ps.Assignment(tau, m ** 1.5),
+          ps.Assignment(f.center, tau)]
+
+    ast = ps.create_kernel(up, config=ps.CreateKernelConfig(data_type="float32"))
+    code = ps.get_code_str(ast)
+
+    assert "1.5f" in code
-- 
GitLab