From 96c04e6b453901188e0137af62a1b7c7f7527946 Mon Sep 17 00:00:00 2001 From: Markus Holzer <markus.holzer@fau.de> Date: Fri, 19 Nov 2021 17:27:55 +0100 Subject: [PATCH] Start with refactoring of the type system --- pystencils/data_types.py | 68 +++++++++++++++---- pystencils/gpucuda/periodicity.py | 1 + pystencils/kernelcreation.py | 4 ++ pystencils/kernelparameters.py | 5 ++ pystencils/transformations.py | 11 ++- .../test_subexpression_insertion.py | 1 - pystencils_tests/test_types.py | 26 +++++++ 7 files changed, 101 insertions(+), 15 deletions(-) diff --git a/pystencils/data_types.py b/pystencils/data_types.py index 331de62bd..84af2f7f5 100644 --- a/pystencils/data_types.py +++ b/pystencils/data_types.py @@ -29,11 +29,13 @@ def typed_symbols(names, dtype, *args): def type_all_numbers(expr, dtype): + # TODO: move to pystnecils_walberla substitutions = {a: cast_func(a, dtype) for a in expr.atoms(sp.Number)} return expr.subs(substitutions) def matrix_symbols(names, dtype, rows, cols): + # TODO: check if needed. (lbmpy, walberla) if isinstance(names, str): names = names.replace(' ', '').split(',') @@ -46,6 +48,7 @@ def matrix_symbols(names, dtype, rows, cols): def assumptions_from_dtype(dtype): + # TODO: type hints and if dtype is correct type form Numpy """Derives SymPy assumptions from :class:`BasicType` or a Numpy dtype Args: @@ -76,6 +79,9 @@ def assumptions_from_dtype(dtype): # noinspection PyPep8Naming class address_of(sp.Function): + # TODO: ask Martin + # TODO: documentation + # TODO: move function to `functions.py` is_Atom = True def __new__(cls, arg): @@ -103,6 +109,8 @@ class address_of(sp.Function): # noinspection PyPep8Naming class cast_func(sp.Function): + # TODO: documentation + # TODO: move function to `functions.py` is_Atom = True def __new__(cls, *args, **kwargs): @@ -190,22 +198,30 @@ class cast_func(sp.Function): # noinspection PyPep8Naming class boolean_cast_func(cast_func, Boolean): + # TODO: documentation + # TODO: move function to `functions.py` pass # noinspection PyPep8Naming class vector_memory_access(cast_func): + # TODO: documentation + # TODO: move function to `functions.py` # Arguments are: read/write expression, type, aligned, nontemporal, mask (or none), stride nargs = (6,) # noinspection PyPep8Naming class reinterpret_cast_func(cast_func): + # TODO: documentation + # TODO: move function to `functions.py` pass # noinspection PyPep8Naming class pointer_arithmetic_func(sp.Function, Boolean): + # TODO: documentation + # TODO: move function to `functions.py` @property def canonical(self): if hasattr(self.args[0], 'canonical'): @@ -272,6 +288,8 @@ class TypedSymbol(sp.Symbol): def create_type(specification): + # TODO: HERE + # TODO: type hint -> np.type """Creates a subclass of Type according to a string or an object of subclass Type. Args: @@ -292,6 +310,7 @@ def create_type(specification): @memorycache(maxsize=64) def create_composite_type_from_string(specification): + # TODO: can be removed after llvm removla and fix of kernelparameters """Creates a new Type object from a c-like string specification. Args: @@ -338,12 +357,15 @@ def create_composite_type_from_string(specification): def get_base_type(data_type): + # TODO: WTF is this?? DOCS!!! + # TODO: Can be removed after removal of kerncraft and fix in FieldPointer Symbol while data_type.base_type is not None: data_type = data_type.base_type return data_type def to_ctypes(data_type): + # TODO: can be removed with llvm """ Transforms a given Type into ctypes :param data_type: Subclass of Type @@ -356,7 +378,7 @@ def to_ctypes(data_type): else: return to_ctypes.map[data_type.numpy_dtype] - +# TODO: can be removed with llvm to_ctypes.map = { np.dtype(np.int8): ctypes.c_int8, np.dtype(np.int16): ctypes.c_int16, @@ -374,6 +396,7 @@ to_ctypes.map = { def ctypes_from_llvm(data_type): + # TODO can be removed with LLVM if not ir: raise _ir_importerror if isinstance(data_type, ir.PointerType): @@ -404,6 +427,7 @@ def ctypes_from_llvm(data_type): def to_llvm_type(data_type, nvvm_target=False): + # TODO: can be removed with LLVM """ Transforms a given type into ctypes :param data_type: Subclass of Type @@ -417,6 +441,7 @@ def to_llvm_type(data_type, nvvm_target=False): return to_llvm_type.map[data_type.numpy_dtype] +# TODO: can be removed with LLVM if ir: to_llvm_type.map = { np.dtype(np.int8): ir.IntType(8), @@ -435,16 +460,19 @@ if ir: def peel_off_type(dtype, type_to_peel_off): + # TODO: WTF is this??? DOCS!!! + # TODO: used only once.... can be a lambda there while type(dtype) is type_to_peel_off: dtype = dtype.base_type return dtype +############################# This is basically our type system ######################################################## def collate_types(types, - forbid_collation_to_complex=False, - forbid_collation_to_float=False, - default_float_type='float64', - default_int_type='int64'): + forbid_collation_to_complex=False, # TODO: type system shouldn't need this!!! + forbid_collation_to_float=False, # TODO: type system shouldn't need this!!! + default_float_type='float64', # TODO: AST leaves should be typed. Expressions should be able to find out correct type + default_int_type='int64'): # TODO: AST leaves should be typed. Expressions should be able to find out correct type """ Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double Uses the collation rules from numpy. @@ -495,9 +523,9 @@ def collate_types(types, @memorycache_if_hashable(maxsize=2048) def get_type_of_expression(expr, - default_float_type='double', - default_int_type='int', - symbol_type_dict=None): + default_float_type='double', # TODO: we shouldn't need to have default. AST leaves should have a type + default_int_type='int', # TODO: we shouldn't need to have default. AST leaves should have a type + symbol_type_dict=None): # TODO: we shouldn't need to have default. AST leaves should have a type from pystencils.astnodes import ResolvedFieldAccess from pystencils.cpu.vectorization import vec_all, vec_any @@ -582,6 +610,7 @@ def get_type_of_expression(expr, return create_type(default_float_type) raise NotImplementedError("Could not determine type for", expr, type(expr)) +############################# End This is basically our type system ################################################## sympy_version = sp.__version__.split('.') @@ -614,6 +643,8 @@ if int(sympy_version[0]) * 100 + int(sympy_version[1]) >= 109: class Type(sp.Atom): + # TODO: why is our type system dependent on sympy??? + # TODO: ask Martin def __new__(cls, *args, **kwargs): return sp.Basic.__new__(cls) @@ -622,8 +653,15 @@ class Type(sp.Atom): class BasicType(Type): + # TODO: check if Type inheritance is needed + # TODO: should be a sensible interface to np.dtype + # TODO: read numpy docs (Jan) @staticmethod def numpy_name_to_c(name): + # TODO: this should be a free function + # TODO: also check if numpy has this functionality + # TODO: docs!!! + # TODO: is this C? if name == 'float64': return 'double' elif name == 'float32': @@ -644,9 +682,10 @@ class BasicType(Type): raise NotImplementedError(f"Can map numpy to C name for {name}") def __init__(self, dtype, const=False): + # TODO: type hints self.const = const if isinstance(dtype, Type): - self._dtype = dtype.numpy_dtype + self._dtype = dtype.numpy_dtype # TODO: wtf? else: self._dtype = np.dtype(dtype) assert self._dtype.fields is None, "Tried to initialize NativeType with a structured type" @@ -660,7 +699,7 @@ class BasicType(Type): return (self.numpy_dtype, self.const), {} @property - def base_type(self): + def base_type(self): # TODO: what is base_type? return None @property @@ -672,7 +711,7 @@ class BasicType(Type): return getattr(sympy.codegen.ast, str(self.numpy_dtype)) @property - def item_size(self): + def item_size(self): # TODO: what is this? return 1 def is_int(self): @@ -691,7 +730,7 @@ class BasicType(Type): return self.numpy_dtype in np.sctypes['others'] @property - def base_name(self): + def base_name(self): # TODO: name of the function is highly confusing return BasicType.numpy_name_to_c(str(self._dtype)) def __str__(self): @@ -714,6 +753,7 @@ class BasicType(Type): class VectorType(Type): + # TODO: check with rest instruction_set = None def __init__(self, base_type, width=4): @@ -760,6 +800,7 @@ class VectorType(Type): class PointerType(Type): + # TODO: rename to FieldType def __init__(self, base_type, const=False, restrict=True): self._base_type = base_type self.const = const @@ -805,6 +846,7 @@ class PointerType(Type): class StructType: + # TODO: Docs. This is a struct. A list of types (with C offsets) def __init__(self, numpy_type, const=False): self.const = const self._dtype = np.dtype(numpy_type) @@ -858,6 +900,8 @@ class StructType: class TypedImaginaryUnit(TypedSymbol): + # TODO: why is this an extra class??? + # TODO: remove? def __new__(cls, *args, **kwds): obj = TypedImaginaryUnit.__xnew_cached_(cls, *args, **kwds) return obj diff --git a/pystencils/gpucuda/periodicity.py b/pystencils/gpucuda/periodicity.py index da62af0f6..e5083af4a 100644 --- a/pystencils/gpucuda/periodicity.py +++ b/pystencils/gpucuda/periodicity.py @@ -31,6 +31,7 @@ def create_copy_kernel(domain_size, from_slice, to_slice, index_dimensions=0, in return ast +# TODO: type flot is dangerous here def get_periodic_boundary_functor(stencil, domain_size, index_dimensions=0, index_dim_shape=1, ghost_layers=1, thickness=None, dtype=float, target=Target.GPU, opencl_queue=None, opencl_ctx=None): assert target in {Target.GPU, Target.OPENCL} diff --git a/pystencils/kernelcreation.py b/pystencils/kernelcreation.py index ac4412256..f106f7d2a 100644 --- a/pystencils/kernelcreation.py +++ b/pystencils/kernelcreation.py @@ -38,6 +38,8 @@ class CreateKernelConfig: """ Name of the generated function - only important if generated code is written out """ + # TODO: config should check that the datatype is a Numpy type + # TODO: check for the python types and issue warnings data_type: Union[str, dict] = 'double' """ Data type used for all untyped symbols (i.e. non-fields), can also be a dict from symbol name to type @@ -125,6 +127,7 @@ class CreateKernelConfig: def __post_init__(self): # ---- Legacy parameters + # TODO adapt here the types if isinstance(self.target, str): new_target = Target[self.target.upper()] warnings.warn(f'Target "{self.target}" as str is deprecated. Use {new_target} instead', @@ -249,6 +252,7 @@ def create_domain_kernel(assignments: List[Assignment], *, config: CreateKernelC if config.target == Target.CPU: if config.backend == Backend.C: from pystencils.cpu import add_openmp, create_kernel + # TODO: data type keyword should be unified to data_type ast = create_kernel(assignments, function_name=config.function_name, type_info=config.data_type, split_groups=split_groups, iteration_slice=config.iteration_slice, ghost_layers=config.ghost_layers, diff --git a/pystencils/kernelparameters.py b/pystencils/kernelparameters.py index 934c305cc..8bd4341be 100644 --- a/pystencils/kernelparameters.py +++ b/pystencils/kernelparameters.py @@ -18,6 +18,11 @@ from sympy.core.cache import cacheit from pystencils.data_types import ( PointerType, TypedSymbol, create_composite_type_from_string, get_base_type) + +# TODO: Why do we need extra classes? Why isn't TypedSymbol enough? +# TODO: Replace with a factory function + + SHAPE_DTYPE = create_composite_type_from_string("const int64") STRIDE_DTYPE = create_composite_type_from_string("const int64") diff --git a/pystencils/transformations.py b/pystencils/transformations.py index c2b6cf54b..1175580f7 100644 --- a/pystencils/transformations.py +++ b/pystencils/transformations.py @@ -1,6 +1,7 @@ import hashlib import pickle import warnings +from typing import List, Dict from collections import OrderedDict, defaultdict, namedtuple from copy import deepcopy from types import MappingProxyType @@ -424,7 +425,7 @@ def resolve_buffer_accesses(ast_node, base_buffer_index, read_only_field_names=s return visit_node(ast_node) -def resolve_field_accesses(ast_node, read_only_field_names=set(), +def resolve_field_accesses(ast_node, read_only_field_names=None, field_to_base_pointer_info=MappingProxyType({}), field_to_fixed_coordinates=MappingProxyType({})): """ @@ -441,6 +442,8 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(), Returns transformed AST """ + if read_only_field_names is None: + read_only_field_names = set() field_to_base_pointer_info = OrderedDict(sorted(field_to_base_pointer_info.items(), key=lambda pair: pair[0])) field_to_fixed_coordinates = OrderedDict(sorted(field_to_fixed_coordinates.items(), key=lambda pair: pair[0])) @@ -936,7 +939,8 @@ class KernelConstraintsCheck: self.scopes.access_symbol(rhs) -def add_types(eqs, type_for_symbol, check_independence_condition, check_double_write_condition=True): +def add_types(eqs: List[Assignment], type_for_symbol: Dict[sp.Symbol, np.dtype], check_independence_condition: bool, + check_double_write_condition: bool=True): """Traverses AST and replaces every :class:`sympy.Symbol` by a :class:`pystencils.typedsymbol.TypedSymbol`. Additionally returns sets of all fields which are read/written @@ -956,9 +960,12 @@ def add_types(eqs, type_for_symbol, check_independence_condition, check_double_w type_for_symbol = adjust_c_single_precision_type(type_for_symbol) + # TODO what does this do???? + # TODO: ask Martin check = KernelConstraintsCheck(type_for_symbol, check_independence_condition, check_double_write_condition=check_double_write_condition) + # TODO: check if this adds only types to leave nodes of AST, get type info def visit(obj): if isinstance(obj, (list, tuple)): return [visit(e) for e in obj] diff --git a/pystencils_tests/test_subexpression_insertion.py b/pystencils_tests/test_subexpression_insertion.py index 9ae64d9fe..790d97d76 100644 --- a/pystencils_tests/test_subexpression_insertion.py +++ b/pystencils_tests/test_subexpression_insertion.py @@ -1,4 +1,3 @@ -import sympy as sp from pystencils import fields, Assignment, AssignmentCollection from pystencils.simp.subexpression_insertion import * diff --git a/pystencils_tests/test_types.py b/pystencils_tests/test_types.py index c63ab6923..87124ec9e 100644 --- a/pystencils_tests/test_types.py +++ b/pystencils_tests/test_types.py @@ -180,3 +180,29 @@ def test_ctypes_from_llvm(): assert ctypes_from_llvm(ir.FloatType()) == ctypes.c_float assert ctypes_from_llvm(ir.DoubleType()) == ctypes.c_double + + +def test_division(): + f = ps.fields('f(10): float32[2D]') + m, tau = sp.symbols("m, tau") + + up = [ps.Assignment(tau, 1.0 / (0.5 + (3.0 * m))), + ps.Assignment(f.center, tau)] + + ast = ps.create_kernel(up, config=ps.CreateKernelConfig(data_type="float32")) + code = ps.get_code_str(ast) + + assert "1.0f" in code + + +def test_pow(): + f = ps.fields('f(10): float32[2D]') + m, tau = sp.symbols("m, tau") + + up = [ps.Assignment(tau, m ** 1.5), + ps.Assignment(f.center, tau)] + + ast = ps.create_kernel(up, config=ps.CreateKernelConfig(data_type="float32")) + code = ps.get_code_str(ast) + + assert "1.5f" in code -- GitLab