From 942c7d965b05c83b8f16fd699253c694b5089007 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20H=C3=B6nig?= <jan.hoenig@fau.de> Date: Wed, 24 Nov 2021 19:38:22 +0100 Subject: [PATCH] Created typing as own module --- pystencils/__init__.py | 2 +- pystencils/alignedarray.py | 4 +- pystencils/astnodes.py | 6 +- pystencils/backends/cbackend.py | 36 +- pystencils/bit_masks.py | 2 +- pystencils/boundaries/boundaryconditions.py | 2 +- pystencils/boundaries/boundaryhandling.py | 2 +- pystencils/boundaries/inkernel.py | 2 +- pystencils/cache.py | 2 +- pystencils/cpu/cpujit.py | 6 +- pystencils/cpu/kernelcreation.py | 4 +- pystencils/cpu/vectorization.py | 38 +- pystencils/data_types.py | 927 ------------------ pystencils/field.py | 26 +- pystencils/gpucuda/cudajit.py | 2 +- pystencils/gpucuda/indexing.py | 2 +- pystencils/gpucuda/kernelcreation.py | 4 +- pystencils/integer_functions.py | 7 +- pystencils/integer_set_analysis.py | 2 +- .../kerncraft_coupling/generate_benchmark.py | 2 +- .../kerncraft_coupling/kerncraft_interface.py | 2 +- pystencils/kernel_contrains_check.py | 150 +++ pystencils/rng.py | 6 +- pystencils/sympyextensions.py | 4 +- pystencils/transformations.py | 394 +------- pystencils/typing/__init__.py | 4 + pystencils/typing/cast_functions.py | 120 +++ .../typed_sympy.py} | 114 ++- pystencils/typing/types.py | 297 ++++++ pystencils/typing/utilities.py | 494 ++++++++++ pystencils_tests/test_abs.py | 4 +- pystencils_tests/test_address_of.py | 10 +- pystencils_tests/test_complex_numbers.py | 2 +- pystencils_tests/test_cuda_known_functions.py | 2 +- pystencils_tests/test_field.py | 2 +- .../test_floor_ceil_int_optimization.py | 2 +- pystencils_tests/test_global_definitions.py | 2 +- pystencils_tests/test_kernel_data_type.py | 2 +- ...st_match_subs_for_assignment_collection.py | 4 +- pystencils_tests/test_pickle_support.py | 2 +- pystencils_tests/test_random.py | 2 +- pystencils_tests/test_sum_prod.py | 2 +- pystencils_tests/test_transformations.py | 2 +- pystencils_tests/test_type_interference.py | 6 +- pystencils_tests/test_types.py | 24 +- 45 files changed, 1266 insertions(+), 1465 deletions(-) delete mode 100644 pystencils/data_types.py create mode 100644 pystencils/kernel_contrains_check.py create mode 100644 pystencils/typing/__init__.py create mode 100644 pystencils/typing/cast_functions.py rename pystencils/{kernelparameters.py => typing/typed_sympy.py} (52%) create mode 100644 pystencils/typing/types.py create mode 100644 pystencils/typing/utilities.py diff --git a/pystencils/__init__.py b/pystencils/__init__.py index a10acb8f6..4d97202bd 100644 --- a/pystencils/__init__.py +++ b/pystencils/__init__.py @@ -3,7 +3,7 @@ from .enums import Backend, Target from . import fd from . import stencil as stencil from .assignment import Assignment, assignment_from_stencil -from .data_types import TypedSymbol +from pystencils.typing.typed_sympy import TypedSymbol from .datahandling import create_data_handling from .display_utils import get_code_obj, get_code_str, show_code, to_dot from .field import Field, FieldType, fields diff --git a/pystencils/alignedarray.py b/pystencils/alignedarray.py index da20a778e..26c3aa5ba 100644 --- a/pystencils/alignedarray.py +++ b/pystencils/alignedarray.py @@ -1,5 +1,5 @@ import numpy as np -from pystencils.data_types import BasicType +from pystencils.typing import numpy_name_to_c def aligned_empty(shape, byte_alignment=True, dtype=np.float64, byte_offset=0, order='C', align_inner_coordinate=True): @@ -21,7 +21,7 @@ def aligned_empty(shape, byte_alignment=True, dtype=np.float64, byte_offset=0, o from pystencils.backends.simd_instruction_sets import (get_supported_instruction_sets, get_cacheline_size, get_vector_instruction_set) - type_name = BasicType.numpy_name_to_c(np.dtype(dtype).name) + type_name = numpy_name_to_c(np.dtype(dtype).name) instruction_sets = get_supported_instruction_sets() if instruction_sets is None: byte_alignment = 64 diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py index 689a18b02..b9f13ae64 100644 --- a/pystencils/astnodes.py +++ b/pystencils/astnodes.py @@ -6,7 +6,7 @@ from typing import Any, List, Optional, Sequence, Set, Union import sympy as sp import pystencils -from pystencils.data_types import TypedImaginaryUnit, TypedSymbol, cast_func, create_type +from pystencils.typing import TypedSymbol, CastFunc, create_type, get_next_parent_of_type from pystencils.enums import Target, Backend from pystencils.field import Field from pystencils.kernelparameters import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol @@ -542,7 +542,6 @@ class LoopOverCoordinate(Node): @property def is_outermost_loop(self): - from pystencils.transformations import get_next_parent_of_type return get_next_parent_of_type(self, LoopOverCoordinate) is None @property @@ -571,7 +570,7 @@ class SympyAssignment(Node): self.use_auto = use_auto def __is_declaration(self): - if isinstance(self._lhs_symbol, cast_func): + if isinstance(self._lhs_symbol, CastFunc): return False if any(isinstance(self._lhs_symbol, c) for c in (Field.Access, sp.Indexed, TemporaryMemoryAllocation)): return False @@ -616,7 +615,6 @@ class SympyAssignment(Node): if isinstance(symbol, Field.Access): for i in range(len(symbol.offsets)): loop_counters.add(LoopOverCoordinate.get_loop_counter_symbol(i)) - result = {r for r in result if not isinstance(r, TypedImaginaryUnit)} result.update(loop_counters) result.update(self._lhs_symbol.atoms(sp.Symbol)) return result diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index a17ee7269..fa3079e32 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -11,9 +11,9 @@ from sympy.logic.boolalg import BooleanFalse, BooleanTrue from pystencils.astnodes import KernelFunction, LoopOverCoordinate, Node from pystencils.cpu.vectorization import vec_all, vec_any, CachelineSize -from pystencils.data_types import ( - PointerType, VectorType, address_of, cast_func, create_type, get_type_of_expression, - reinterpret_cast_func, vector_memory_access, BasicType, TypedSymbol) +from pystencils.typing import ( + PointerType, VectorType, address_of, CastFunc, create_type, get_type_of_expression, + ReinterpretCastFunc, VectorMemoryAccess, BasicType, TypedSymbol) from pystencils.enums import Backend from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt from pystencils.integer_functions import ( @@ -276,7 +276,7 @@ class CBackend: else: lhs_type = get_type_of_expression(node.lhs) printed_mask = "" - if type(lhs_type) is VectorType and isinstance(node.lhs, cast_func): + if type(lhs_type) is VectorType and isinstance(node.lhs, CastFunc): arg, data_type, aligned, nontemporal, mask, stride = node.lhs.args instr = 'storeU' if aligned: @@ -289,12 +289,12 @@ class CBackend: self._vector_instruction_set['load' + instr[-1]].format('{0}', **self._kwargs), '{1}', '{2}', **self._kwargs), **self._kwargs) printed_mask = self.sympy_printer.doprint(mask) - if data_type.base_type.base_name == 'double': + if data_type.base_type.c_name == 'double': if self._vector_instruction_set['double'] == '__m256d': printed_mask = f"_mm256_castpd_si256({printed_mask})" elif self._vector_instruction_set['double'] == '__m128d': printed_mask = f"_mm_castpd_si128({printed_mask})" - elif data_type.base_type.base_name == 'float': + elif data_type.base_type.c_name == 'float': if self._vector_instruction_set['float'] == '__m256': printed_mask = f"_mm256_castps_si256({printed_mask})" elif self._vector_instruction_set['float'] == '__m128': @@ -302,7 +302,7 @@ class CBackend: rhs_type = get_type_of_expression(node.rhs) if type(rhs_type) is not VectorType: - rhs = cast_func(node.rhs, VectorType(rhs_type)) + rhs = CastFunc(node.rhs, VectorType(rhs_type)) else: rhs = node.rhs @@ -322,7 +322,7 @@ class CBackend: if stride == 1: offset = offset.subs({node.lhs.args[0].field.spatial_strides[0]: 1}) size = sp.Mul(*node.lhs.args[0].field.spatial_shape) - element_size = 8 if data_type.base_type.base_name == 'double' else 4 + element_size = 8 if data_type.base_type.c_name == 'double' else 4 size_cond = f"({offset} + {CachelineSize.symbol/element_size}) < {size}" pre_code = f"if ({first_cond} && {size_cond}) " + "{\n\t" + \ self._vector_instruction_set['cachelineZero'].format(ptr, **self._kwargs) + ';\n}\n' @@ -483,13 +483,13 @@ class CustomSympyPrinter(CCodePrinter): } if hasattr(expr, 'to_c'): return expr.to_c(self._print) - if isinstance(expr, reinterpret_cast_func): + if isinstance(expr, ReinterpretCastFunc): arg, data_type = expr.args return f"*(({self._print(PointerType(data_type, restrict=False))})(& {self._print(arg)}))" elif isinstance(expr, address_of): assert len(expr.args) == 1, "address_of must only have one argument" return f"&({self._print(expr.args[0])})" - elif isinstance(expr, cast_func): + elif isinstance(expr, CastFunc): arg, data_type = expr.args if isinstance(arg, sp.Number) and arg.is_finite: return self._typed_number(arg, data_type) @@ -648,22 +648,22 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): return None def _print_Abs(self, expr): - if 'abs' in self.instruction_set and isinstance(expr.args[0], vector_memory_access): + if 'abs' in self.instruction_set and isinstance(expr.args[0], VectorMemoryAccess): return self.instruction_set['abs'].format(self._print(expr.args[0]), **self._kwargs) return super()._print_Abs(expr) def _print_Function(self, expr): - if isinstance(expr, vector_memory_access): + if isinstance(expr, VectorMemoryAccess): arg, data_type, aligned, _, mask, stride = expr.args if stride != 1: return self.instruction_set['loadS'].format(f"& {self._print(arg)}", stride, **self._kwargs) instruction = self.instruction_set['loadA'] if aligned else self.instruction_set['loadU'] return instruction.format(f"& {self._print(arg)}", **self._kwargs) - elif isinstance(expr, cast_func): + elif isinstance(expr, CastFunc): arg, data_type = expr.args if type(data_type) is VectorType: # vector_memory_access is a cast_func itself so it should't be directly inside a cast_func - assert not isinstance(arg, vector_memory_access) + assert not isinstance(arg, VectorMemoryAccess) if isinstance(arg, sp.Tuple): is_boolean = get_type_of_expression(arg[0]) == create_type("bool") is_integer = get_type_of_expression(arg[0]) == create_type("int") @@ -747,12 +747,12 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): # special treatment for all-integer args, for loop index arithmetic until we have proper int vectorization suffix = "" - if all([(type(e) is cast_func and str(e.dtype) == self.instruction_set['int']) or isinstance(e, sp.Integer) + if all([(type(e) is CastFunc and str(e.dtype) == self.instruction_set['int']) or isinstance(e, sp.Integer) or (type(e) is TypedSymbol and isinstance(e.dtype, BasicType) and e.dtype.is_int()) for e in args]): - dtype = set([e.dtype for e in args if type(e) is cast_func]) + dtype = set([e.dtype for e in args if type(e) is CastFunc]) assert len(dtype) == 1 dtype = dtype.pop() - args = [cast_func(e, dtype) if (isinstance(e, sp.Integer) or isinstance(e, TypedSymbol)) else e + args = [CastFunc(e, dtype) if (isinstance(e, sp.Integer) or isinstance(e, TypedSymbol)) else e for e in args] suffix = "int" @@ -880,7 +880,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): result = self._print(expr.args[-1][0]) for true_expr, condition in reversed(expr.args[:-1]): - if isinstance(condition, cast_func) and get_type_of_expression(condition.args[0]) == create_type("bool"): + if isinstance(condition, CastFunc) and get_type_of_expression(condition.args[0]) == create_type("bool"): if not KERNCRAFT_NO_TERNARY_MODE: result = "(({}) ? ({}) : ({}))".format(self._print(condition.args[0]), self._print(true_expr), result, **self._kwargs) diff --git a/pystencils/bit_masks.py b/pystencils/bit_masks.py index 0fab63b25..73c18688c 100644 --- a/pystencils/bit_masks.py +++ b/pystencils/bit_masks.py @@ -1,5 +1,5 @@ import sympy as sp -from pystencils.data_types import get_type_of_expression +from pystencils.typing import get_type_of_expression # noinspection PyPep8Naming diff --git a/pystencils/boundaries/boundaryconditions.py b/pystencils/boundaries/boundaryconditions.py index dc01224d0..c53d248ac 100644 --- a/pystencils/boundaries/boundaryconditions.py +++ b/pystencils/boundaries/boundaryconditions.py @@ -2,7 +2,7 @@ from typing import Any, List, Tuple from pystencils import Assignment from pystencils.boundaries.boundaryhandling import BoundaryOffsetInfo -from pystencils.data_types import create_type +from pystencils.typing import create_type class Boundary: diff --git a/pystencils/boundaries/boundaryhandling.py b/pystencils/boundaries/boundaryhandling.py index 5705d3d53..4ad3ab3ff 100644 --- a/pystencils/boundaries/boundaryhandling.py +++ b/pystencils/boundaries/boundaryhandling.py @@ -7,7 +7,7 @@ from pystencils.backends.cbackend import CustomCodeNode from pystencils.boundaries.createindexlist import ( create_boundary_index_array, numpy_data_type_for_boundary_object) from pystencils.cache import memorycache -from pystencils.data_types import TypedSymbol, create_type +from pystencils.typing import TypedSymbol, create_type from pystencils.datahandling.pycuda import PyCudaArrayHandler from pystencils.field import Field from pystencils.kernelparameters import FieldPointerSymbol diff --git a/pystencils/boundaries/inkernel.py b/pystencils/boundaries/inkernel.py index 1d78814db..479f30d22 100644 --- a/pystencils/boundaries/inkernel.py +++ b/pystencils/boundaries/inkernel.py @@ -1,7 +1,7 @@ import sympy as sp from pystencils.boundaries.boundaryhandling import DEFAULT_FLAG_TYPE -from pystencils.data_types import TypedSymbol, create_type +from pystencils.typing import TypedSymbol, create_type from pystencils.field import Field from pystencils.integer_functions import bitwise_and diff --git a/pystencils/cache.py b/pystencils/cache.py index f29678920..15274ccb8 100644 --- a/pystencils/cache.py +++ b/pystencils/cache.py @@ -5,7 +5,7 @@ from itertools import chain try: from functools import lru_cache as memorycache -except ImportError: +except ImportError: # TODO what python version is this??? from backports.functools_lru_cache import lru_cache as memorycache from joblib import Memory diff --git a/pystencils/cpu/cpujit.py b/pystencils/cpu/cpujit.py index 240cddd49..2861d671f 100644 --- a/pystencils/cpu/cpujit.py +++ b/pystencils/cpu/cpujit.py @@ -60,7 +60,7 @@ from appdirs import user_cache_dir, user_config_dir from pystencils import FieldType from pystencils.astnodes import LoopOverCoordinate from pystencils.backends.cbackend import generate_c, get_headers, CFunction -from pystencils.data_types import cast_func, VectorType, vector_memory_access +from pystencils.typing import CastFunc, VectorType, VectorMemoryAccess from pystencils.include import get_pystencils_include_path from pystencils.kernel_wrapper import KernelWrapper from pystencils.utils import atomic_file_write, recursive_dict_update @@ -388,7 +388,7 @@ def create_function_boilerplate_code(parameter_info, name, ast_node, insert_chec aligned = False if ast_node.assignments: aligned = any([a.lhs.args[2] for a in ast_node.assignments - if hasattr(a, 'lhs') and isinstance(a.lhs, cast_func) + if hasattr(a, 'lhs') and isinstance(a.lhs, CastFunc) and hasattr(a.lhs, 'dtype') and isinstance(a.lhs.dtype, VectorType)]) if ast_node.instruction_set and aligned: @@ -398,7 +398,7 @@ def create_function_boilerplate_code(parameter_info, name, ast_node, insert_chec for loop in ast_node.atoms(LoopOverCoordinate): has_openmp = has_openmp or any(['#pragma omp' in p for p in loop.prefix_lines]) has_nontemporal = has_nontemporal or any([a.args[0].field == field and a.args[3] for a in - loop.atoms(vector_memory_access)]) + loop.atoms(VectorMemoryAccess)]) if has_openmp and has_nontemporal: byte_width = ast_node.instruction_set['cachelineSize'] offset = max(max(ast_node.ghost_layers)) * item_size diff --git a/pystencils/cpu/kernelcreation.py b/pystencils/cpu/kernelcreation.py index 865beefa9..7b2719fd7 100644 --- a/pystencils/cpu/kernelcreation.py +++ b/pystencils/cpu/kernelcreation.py @@ -8,10 +8,10 @@ from pystencils.assignment import Assignment from pystencils.enums import Target, Backend from pystencils.astnodes import Block, KernelFunction, LoopOverCoordinate, SympyAssignment from pystencils.cpu.cpujit import make_python_function -from pystencils.data_types import StructType, TypedSymbol, create_type +from pystencils.typing import StructType, TypedSymbol, create_type, add_types from pystencils.field import Field, FieldType from pystencils.transformations import ( - add_types, filtered_tree_iteration, get_base_buffer_index, get_optimal_loop_ordering, make_loop_over_domain, + filtered_tree_iteration, get_base_buffer_index, get_optimal_loop_ordering, make_loop_over_domain, move_constants_before_loop, parse_base_pointer_info, resolve_buffer_accesses, resolve_field_accesses, split_inner_loop) diff --git a/pystencils/cpu/vectorization.py b/pystencils/cpu/vectorization.py index c0511aa16..a161d5879 100644 --- a/pystencils/cpu/vectorization.py +++ b/pystencils/cpu/vectorization.py @@ -7,8 +7,8 @@ from sympy.logic.boolalg import BooleanFunction import pystencils.astnodes as ast from pystencils.backends.simd_instruction_sets import get_supported_instruction_sets, get_vector_instruction_set -from pystencils.data_types import ( - PointerType, TypedSymbol, VectorType, cast_func, collate_types, get_type_of_expression, vector_memory_access) +from pystencils.typing import ( + PointerType, TypedSymbol, VectorType, CastFunc, collate_types, get_type_of_expression, VectorMemoryAccess) from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt from pystencils.field import Field from pystencils.integer_functions import modulo_ceil, modulo_floor @@ -180,8 +180,8 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a nontemporal = False if hasattr(indexed, 'field'): nontemporal = (indexed.field in nontemporal_fields) or (indexed.field.name in nontemporal_fields) - substitutions[indexed] = vector_memory_access(indexed, vec_type, use_aligned_access, nontemporal, True, - stride if strided else 1) + substitutions[indexed] = VectorMemoryAccess(indexed, vec_type, use_aligned_access, nontemporal, True, + stride if strided else 1) if nontemporal: # insert NontemporalFence after the outermost loop parent = loop_node.parent @@ -197,12 +197,12 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a loop_node.step = vector_width loop_node.subs(substitutions) vector_int_width = ast_node.instruction_set['intwidth'] - vector_loop_counter = cast_func(loop_counter_symbol, VectorType(loop_counter_symbol.dtype, vector_int_width)) \ - + cast_func(tuple(range(vector_int_width if type(vector_int_width) is int else 2)), - VectorType(loop_counter_symbol.dtype, vector_int_width)) + vector_loop_counter = CastFunc(loop_counter_symbol, VectorType(loop_counter_symbol.dtype, vector_int_width)) \ + + CastFunc(tuple(range(vector_int_width if type(vector_int_width) is int else 2)), + VectorType(loop_counter_symbol.dtype, vector_int_width)) fast_subs(loop_node, {loop_counter_symbol: vector_loop_counter}, - skip=lambda e: isinstance(e, ast.ResolvedFieldAccess) or isinstance(e, vector_memory_access)) + skip=lambda e: isinstance(e, ast.ResolvedFieldAccess) or isinstance(e, VectorMemoryAccess)) mask_conditionals(loop_node) @@ -232,8 +232,8 @@ def mask_conditionals(loop_body): node.condition_expr = vec_any(node.condition_expr) elif isinstance(node, ast.SympyAssignment): if mask is not True: - s = {ma: vector_memory_access(*ma.args[0:4], sp.And(mask, ma.args[4]), *ma.args[5:]) - for ma in node.atoms(vector_memory_access)} + s = {ma: VectorMemoryAccess(*ma.args[0:4], sp.And(mask, ma.args[4]), *ma.args[5:]) + for ma in node.atoms(VectorMemoryAccess)} node.subs(s) else: for arg in node.args: @@ -248,13 +248,13 @@ def insert_vector_casts(ast_node, default_float_type='double'): handled_functions = (sp.Add, sp.Mul, fast_division, fast_sqrt, fast_inv_sqrt, vec_any, vec_all) def visit_expr(expr, default_type='double'): - if isinstance(expr, vector_memory_access): - return vector_memory_access(*expr.args[0:4], visit_expr(expr.args[4], default_type), *expr.args[5:]) - elif isinstance(expr, cast_func): + if isinstance(expr, VectorMemoryAccess): + return VectorMemoryAccess(*expr.args[0:4], visit_expr(expr.args[4], default_type), *expr.args[5:]) + elif isinstance(expr, CastFunc): return expr elif expr.func is sp.Abs and 'abs' not in ast_node.instruction_set: new_arg = visit_expr(expr.args[0], default_type) - base_type = get_type_of_expression(expr.args[0]).base_type if type(expr.args[0]) is vector_memory_access \ + base_type = get_type_of_expression(expr.args[0]).base_type if type(expr.args[0]) is VectorMemoryAccess \ else get_type_of_expression(expr.args[0]) pw = sp.Piecewise((-new_arg, new_arg < base_type.numpy_dtype.type(0)), (new_arg, True)) @@ -263,7 +263,7 @@ def insert_vector_casts(ast_node, default_float_type='double'): if expr.func is sp.Mul and expr.args[0] == -1: # special treatment for the unary minus: make sure that the -1 has the same type as the argument dtype = int - for arg in expr.atoms(vector_memory_access): + for arg in expr.atoms(VectorMemoryAccess): if arg.dtype.base_type.is_float(): dtype = arg.dtype.base_type.numpy_dtype.type for arg in expr.atoms(TypedSymbol): @@ -280,7 +280,7 @@ def insert_vector_casts(ast_node, default_float_type='double'): else: target_type = collate_types(arg_types) casted_args = [ - cast_func(a, target_type) if t != target_type and not isinstance(a, vector_memory_access) else a + CastFunc(a, target_type) if t != target_type and not isinstance(a, VectorMemoryAccess) else a for a, t in zip(new_args, arg_types)] return expr.func(*casted_args) elif expr.func is sp.Pow: @@ -299,10 +299,10 @@ def insert_vector_casts(ast_node, default_float_type='double'): if type(condition_target_type) is not VectorType and type(result_target_type) is VectorType: condition_target_type = VectorType(condition_target_type, width=result_target_type.width) - casted_results = [cast_func(a, result_target_type) if t != result_target_type else a + casted_results = [CastFunc(a, result_target_type) if t != result_target_type else a for a, t in zip(new_results, types_of_results)] - casted_conditions = [cast_func(a, condition_target_type) + casted_conditions = [CastFunc(a, condition_target_type) if t != condition_target_type and a is not True else a for a, t in zip(new_conditions, types_of_conditions)] @@ -326,7 +326,7 @@ def insert_vector_casts(ast_node, default_float_type='double'): new_lhs = TypedSymbol(assignment.lhs.name, new_lhs_type) substitution_dict[assignment.lhs] = new_lhs assignment.lhs = new_lhs - elif isinstance(assignment.lhs, vector_memory_access): + elif isinstance(assignment.lhs, VectorMemoryAccess): assignment.lhs = visit_expr(assignment.lhs, default_type) elif isinstance(arg, ast.Conditional): arg.condition_expr = fast_subs(arg.condition_expr, substitution_dict, diff --git a/pystencils/data_types.py b/pystencils/data_types.py deleted file mode 100644 index 9bf8375bf..000000000 --- a/pystencils/data_types.py +++ /dev/null @@ -1,927 +0,0 @@ -import ctypes -from collections import defaultdict -from functools import partial -from typing import Tuple - -import numpy as np -import sympy as sp -import sympy.codegen.ast -from sympy.core.cache import cacheit -from sympy.logic.boolalg import Boolean, BooleanFunction - -import pystencils -from pystencils.cache import memorycache, memorycache_if_hashable -from pystencils.utils import all_equal - -try: - import llvmlite.ir as ir -except ImportError as e: - ir = None - _ir_importerror = e - - -def typed_symbols(names, dtype, *args): - symbols = sp.symbols(names, *args) - if isinstance(symbols, Tuple): - return tuple(TypedSymbol(str(s), dtype) for s in symbols) - else: - return TypedSymbol(str(symbols), dtype) - - -def type_all_numbers(expr, dtype): - # TODO: move to pystnecils_walberla - substitutions = {a: cast_func(a, dtype) for a in expr.atoms(sp.Number)} - return expr.subs(substitutions) - - -def matrix_symbols(names, dtype, rows, cols): - # TODO: check if needed. (lbmpy, walberla) - if isinstance(names, str): - names = names.replace(' ', '').split(',') - - matrices = [] - for n in names: - symbols = typed_symbols(f"{n}:{rows * cols}", dtype) - matrices.append(sp.Matrix(rows, cols, lambda i, j: symbols[i * cols + j])) - - return tuple(matrices) - - -def assumptions_from_dtype(dtype): - # TODO: type hints and if dtype is correct type form Numpy - """Derives SymPy assumptions from :class:`BasicType` or a Numpy dtype - - Args: - dtype (BasicType, np.dtype): a Numpy data type - Returns: - A dict of SymPy assumptions - """ - if hasattr(dtype, 'numpy_dtype'): - dtype = dtype.numpy_dtype - - assumptions = dict() - - try: - if np.issubdtype(dtype, np.integer): - assumptions.update({'integer': True}) - - if np.issubdtype(dtype, np.unsignedinteger): - assumptions.update({'negative': False}) - - if np.issubdtype(dtype, np.integer) or \ - np.issubdtype(dtype, np.floating): - assumptions.update({'real': True}) - except Exception: - pass - - return assumptions - - -# noinspection PyPep8Naming -class address_of(sp.Function): - # DONE: ask Martin - # TODO: documentation - # TODO: move function to `functions.py` - # this is '&' in C - is_Atom = True - - def __new__(cls, arg): - obj = sp.Function.__new__(cls, arg) - return obj - - @property - def canonical(self): - if hasattr(self.args[0], 'canonical'): - return self.args[0].canonical - else: - raise NotImplementedError() - - @property - def is_commutative(self): - return self.args[0].is_commutative - - @property - def dtype(self): - if hasattr(self.args[0], 'dtype'): - return PointerType(self.args[0].dtype, restrict=True) - else: - return PointerType('void', restrict=True) - - -# noinspection PyPep8Naming -class cast_func(sp.Function): - # TODO: documentation - # TODO: move function to `functions.py` - is_Atom = True - - def __new__(cls, *args, **kwargs): - if len(args) != 2: - pass - expr, dtype, *other_args = args - if not isinstance(dtype, Type): - dtype = create_type(dtype) - # to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well - # however, a cast_function should only be a boolean if its argument is a boolean, otherwise this leads - # to problems when for example comparing cast_func's for equality - # - # lhs = bitwise_and(a, cast_func(1, 'int')) - # rhs = cast_func(0, 'int') - # print( sp.Ne(lhs, rhs) ) # would give true if all cast_funcs are booleans - # -> thus a separate class boolean_cast_func is introduced - if isinstance(expr, Boolean) and (not isinstance(expr, TypedSymbol) or expr.dtype == BasicType(bool)): - cls = boolean_cast_func - - return sp.Function.__new__(cls, expr, dtype, *other_args, **kwargs) - - @property - def canonical(self): - if hasattr(self.args[0], 'canonical'): - return self.args[0].canonical - else: - raise NotImplementedError() - - @property - def is_commutative(self): - return self.args[0].is_commutative - - def _eval_evalf(self, *args, **kwargs): - return self.args[0].evalf() - - @property - def dtype(self): - return self.args[1] - - @property - def is_integer(self): - """ - Uses Numpy type hierarchy to determine :func:`sympy.Expr.is_integer` predicate - - For reference: Numpy type hierarchy https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html - """ - if hasattr(self.dtype, 'numpy_dtype'): - return np.issubdtype(self.dtype.numpy_dtype, np.integer) or super().is_integer - else: - return super().is_integer - - @property - def is_negative(self): - """ - See :func:`.TypedSymbol.is_integer` - """ - if hasattr(self.dtype, 'numpy_dtype'): - if np.issubdtype(self.dtype.numpy_dtype, np.unsignedinteger): - return False - - return super().is_negative - - @property - def is_nonnegative(self): - """ - See :func:`.TypedSymbol.is_integer` - """ - if self.is_negative is False: - return True - else: - return super().is_nonnegative - - @property - def is_real(self): - """ - See :func:`.TypedSymbol.is_integer` - """ - if hasattr(self.dtype, 'numpy_dtype'): - return np.issubdtype(self.dtype.numpy_dtype, np.integer) or \ - np.issubdtype(self.dtype.numpy_dtype, np.floating) or \ - super().is_real - else: - return super().is_real - - -# noinspection PyPep8Naming -class boolean_cast_func(cast_func, Boolean): - # TODO: documentation - # TODO: move function to `functions.py` - pass - - -# noinspection PyPep8Naming -class vector_memory_access(cast_func): - # TODO: documentation - # TODO: move function to `functions.py` - # Arguments are: read/write expression, type, aligned, nontemporal, mask (or none), stride - nargs = (6,) - - -# noinspection PyPep8Naming -class reinterpret_cast_func(cast_func): - # TODO: documentation - # TODO: move function to `functions.py` - pass - - -# noinspection PyPep8Naming -class pointer_arithmetic_func(sp.Function, Boolean): - # TODO: documentation - # TODO: move function to `functions.py` - @property - def canonical(self): - if hasattr(self.args[0], 'canonical'): - return self.args[0].canonical - else: - raise NotImplementedError() - - -class TypedSymbol(sp.Symbol): - def __new__(cls, *args, **kwds): - obj = TypedSymbol.__xnew_cached_(cls, *args, **kwds) - return obj - - def __new_stage2__(cls, name, dtype, **kwargs): - assumptions = assumptions_from_dtype(dtype) - assumptions.update(kwargs) - obj = super(TypedSymbol, cls).__xnew__(cls, name, **assumptions) - try: - obj._dtype = create_type(dtype) - except (TypeError, ValueError): - # on error keep the string - obj._dtype = dtype - return obj - - __xnew__ = staticmethod(__new_stage2__) - __xnew_cached_ = staticmethod(cacheit(__new_stage2__)) - - @property - def dtype(self): - return self._dtype - - def _hashable_content(self): - return super()._hashable_content(), hash(self._dtype) - - def __getnewargs__(self): - return self.name, self.dtype - - def __getnewargs_ex__(self): - return (self.name, self.dtype), self.assumptions0 - - @property - def canonical(self): - return self - - @property - def reversed(self): - return self - - @property - def headers(self): - headers = [] - try: - if np.issubdtype(self.dtype.numpy_dtype, np.complexfloating): - headers.append('"cuda_complex.hpp"') - except Exception: - pass - try: - if np.issubdtype(self.dtype.base_type.numpy_dtype, np.complexfloating): - headers.append('"cuda_complex.hpp"') - except Exception: - pass - - return headers - - -def create_type(specification): - # TODO: HERE - # TODO: type hint -> np.type - """Creates a subclass of Type according to a string or an object of subclass Type. - - Args: - specification: Type object, or a string - - Returns: - Type object, or a new Type object parsed from the string - """ - if isinstance(specification, Type): - return specification - else: - numpy_dtype = np.dtype(specification) - if numpy_dtype.fields is None: - return BasicType(numpy_dtype, const=False) - else: - return StructType(numpy_dtype, const=False) - - -@memorycache(maxsize=64) -def create_composite_type_from_string(specification): - # TODO: can be removed after llvm removla and fix of kernelparameters - """Creates a new Type object from a c-like string specification. - - Args: - specification: Specification string - - Returns: - Type object - """ - specification = specification.lower().split() - parts = [] - current = [] - for s in specification: - if s == '*': - parts.append(current) - current = [s] - else: - current.append(s) - if len(current) > 0: - parts.append(current) - # Parse native part - base_part = parts.pop(0) - const = False - if 'const' in base_part: - const = True - base_part.remove('const') - assert len(base_part) == 1 - if base_part[0][-1] == "*": - base_part[0] = base_part[0][:-1] - parts.append('*') - current_type = BasicType(np.dtype(base_part[0]), const) - # Parse pointer parts - for part in parts: - restrict = False - const = False - if 'restrict' in part: - restrict = True - part.remove('restrict') - if 'const' in part: - const = True - part.remove("const") - assert len(part) == 1 and part[0] == '*' - current_type = PointerType(current_type, const, restrict) - return current_type - - -def get_base_type(data_type): - # TODO: WTF is this?? DOCS!!! - # TODO: Can be removed after removal of kerncraft and fix in FieldPointer Symbol - while data_type.base_type is not None: - data_type = data_type.base_type - return data_type - - -def to_ctypes(data_type): - # TODO: can be removed with llvm - """ - Transforms a given Type into ctypes - :param data_type: Subclass of Type - :return: ctypes type object - """ - if isinstance(data_type, PointerType): - return ctypes.POINTER(to_ctypes(data_type.base_type)) - elif isinstance(data_type, StructType): - return ctypes.POINTER(ctypes.c_uint8) - else: - return to_ctypes.map[data_type.numpy_dtype] - -# TODO: can be removed with llvm -to_ctypes.map = { - np.dtype(np.int8): ctypes.c_int8, - np.dtype(np.int16): ctypes.c_int16, - np.dtype(np.int32): ctypes.c_int32, - np.dtype(np.int64): ctypes.c_int64, - - np.dtype(np.uint8): ctypes.c_uint8, - np.dtype(np.uint16): ctypes.c_uint16, - np.dtype(np.uint32): ctypes.c_uint32, - np.dtype(np.uint64): ctypes.c_uint64, - - np.dtype(np.float32): ctypes.c_float, - np.dtype(np.float64): ctypes.c_double, -} - - -def ctypes_from_llvm(data_type): - # TODO can be removed with LLVM - if not ir: - raise _ir_importerror - if isinstance(data_type, ir.PointerType): - ctype = ctypes_from_llvm(data_type.pointee) - if ctype is None: - return ctypes.c_void_p - else: - return ctypes.POINTER(ctype) - elif isinstance(data_type, ir.IntType): - if data_type.width == 8: - return ctypes.c_int8 - elif data_type.width == 16: - return ctypes.c_int16 - elif data_type.width == 32: - return ctypes.c_int32 - elif data_type.width == 64: - return ctypes.c_int64 - else: - raise ValueError("Int width %d is not supported" % data_type.width) - elif isinstance(data_type, ir.FloatType): - return ctypes.c_float - elif isinstance(data_type, ir.DoubleType): - return ctypes.c_double - elif isinstance(data_type, ir.VoidType): - return None # Void type is not supported by ctypes - else: - raise NotImplementedError(f'Data type {type(data_type)} of {data_type} is not supported yet') - - -def to_llvm_type(data_type, nvvm_target=False): - # TODO: can be removed with LLVM - """ - Transforms a given type into ctypes - :param data_type: Subclass of Type - :return: llvmlite type object - """ - if not ir: - raise _ir_importerror - if isinstance(data_type, PointerType): - return to_llvm_type(data_type.base_type).as_pointer(1 if nvvm_target else 0) - else: - return to_llvm_type.map[data_type.numpy_dtype] - - -# TODO: can be removed with LLVM -if ir: - to_llvm_type.map = { - np.dtype(np.int8): ir.IntType(8), - np.dtype(np.int16): ir.IntType(16), - np.dtype(np.int32): ir.IntType(32), - np.dtype(np.int64): ir.IntType(64), - - np.dtype(np.uint8): ir.IntType(8), - np.dtype(np.uint16): ir.IntType(16), - np.dtype(np.uint32): ir.IntType(32), - np.dtype(np.uint64): ir.IntType(64), - - np.dtype(np.float32): ir.FloatType(), - np.dtype(np.float64): ir.DoubleType(), - } - - -def peel_off_type(dtype, type_to_peel_off): - # TODO: WTF is this??? DOCS!!! - # TODO: used only once.... can be a lambda there - while type(dtype) is type_to_peel_off: - dtype = dtype.base_type - return dtype - - -############################# This is basically our type system ######################################################## -def collate_types(types, - forbid_collation_to_complex=False, # TODO: type system shouldn't need this!!! - forbid_collation_to_float=False, # TODO: type system shouldn't need this!!! - default_float_type='float64', # TODO: AST leaves should be typed. Expressions should be able to find out correct type - default_int_type='int64'): # TODO: AST leaves should be typed. Expressions should be able to find out correct type - """ - Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double - Uses the collation rules from numpy. - """ - if forbid_collation_to_complex: - types = [t for t in types if not np.issubdtype(t.numpy_dtype, np.complexfloating)] - if not types: - return create_type(default_float_type) - - if forbid_collation_to_float: - types = [t for t in types if not np.issubdtype(t.numpy_dtype, np.floating)] - if not types: - return create_type(default_int_type) - - # Pointer arithmetic case i.e. pointer + integer is allowed - if any(type(t) is PointerType for t in types): - pointer_type = None - for t in types: - if type(t) is PointerType: - if pointer_type is not None: - raise ValueError("Cannot collate the combination of two pointer types") - pointer_type = t - elif type(t) is BasicType: - if not (t.is_int() or t.is_uint()): - raise ValueError("Invalid pointer arithmetic") - else: - raise ValueError("Invalid pointer arithmetic") - return pointer_type - - # peel of vector types, if at least one vector type occurred the result will also be the vector type - vector_type = [t for t in types if type(t) is VectorType] - if not all_equal(t.width for t in vector_type): - raise ValueError("Collation failed because of vector types with different width") - types = [peel_off_type(t, VectorType) for t in types] - - # now we should have a list of basic types - struct types are not yet supported - assert all(type(t) is BasicType for t in types) - - if any(t.is_float() for t in types): - types = tuple(t for t in types if t.is_float()) - # use numpy collation -> create type from numpy type -> and, put vector type around if necessary - result_numpy_type = np.result_type(*(t.numpy_dtype for t in types)) - result = BasicType(result_numpy_type) - if vector_type: - result = VectorType(result, vector_type[0].width) - return result - - -@memorycache_if_hashable(maxsize=2048) -def get_type_of_expression(expr, - default_float_type='double', # TODO: we shouldn't need to have default. AST leaves should have a type - default_int_type='int', # TODO: we shouldn't need to have default. AST leaves should have a type - symbol_type_dict=None): # TODO: we shouldn't need to have default. AST leaves should have a type - from pystencils.astnodes import ResolvedFieldAccess - from pystencils.cpu.vectorization import vec_all, vec_any - - if default_float_type == 'float': - default_float_type = 'float32' - - if not symbol_type_dict: - symbol_type_dict = defaultdict(lambda: create_type('double')) - - get_type = partial(get_type_of_expression, - default_float_type=default_float_type, - default_int_type=default_int_type, - symbol_type_dict=symbol_type_dict) - - expr = sp.sympify(expr) - if isinstance(expr, sp.Integer): - return create_type(default_int_type) - elif expr.is_real is False: - return create_type((np.zeros((1,), default_float_type) * 1j).dtype) - elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float): - return create_type(default_float_type) - elif isinstance(expr, ResolvedFieldAccess): - return expr.field.dtype - elif isinstance(expr, pystencils.field.Field.AbstractAccess): - return expr.field.dtype - elif isinstance(expr, TypedSymbol): - return expr.dtype - elif isinstance(expr, sp.Symbol): - if symbol_type_dict: - return symbol_type_dict[expr.name] - else: - raise ValueError("All symbols inside this expression have to be typed! ", str(expr)) - elif isinstance(expr, cast_func): - return expr.args[1] - elif isinstance(expr, (vec_any, vec_all)): - return create_type("bool") - elif hasattr(expr, 'func') and expr.func == sp.Piecewise: - collated_result_type = collate_types(tuple(get_type(a[0]) for a in expr.args)) - collated_condition_type = collate_types(tuple(get_type(a[1]) for a in expr.args)) - if type(collated_condition_type) is VectorType and type(collated_result_type) is not VectorType: - collated_result_type = VectorType(collated_result_type, width=collated_condition_type.width) - return collated_result_type - elif isinstance(expr, sp.Indexed): - typed_symbol = expr.base.label - return typed_symbol.dtype.base_type - elif isinstance(expr, (Boolean, BooleanFunction)): - # if any arg is of vector type return a vector boolean, else return a normal scalar boolean - result = create_type("bool") - vec_args = [get_type(a) for a in expr.args if isinstance(get_type(a), VectorType)] - if vec_args: - result = VectorType(result, width=vec_args[0].width) - return result - elif isinstance(expr, sp.Pow): - base_type = get_type(expr.args[0]) - if expr.exp.is_integer: - return base_type - else: - return collate_types([create_type(default_float_type), base_type]) - elif isinstance(expr, (sp.Sum, sp.Product)): - return get_type(expr.args[0]) - elif isinstance(expr, sp.Expr): - expr: sp.Expr - if expr.args: - types = tuple(get_type(a) for a in expr.args) - # collate_types checks numpy_dtype in the special cases - if any(not hasattr(t, 'numpy_dtype') for t in types): - forbid_collation_to_complex = False - forbid_collation_to_float = False - else: - forbid_collation_to_complex = expr.is_real is True - forbid_collation_to_float = expr.is_integer is True - return collate_types( - types, - forbid_collation_to_complex=forbid_collation_to_complex, - forbid_collation_to_float=forbid_collation_to_float, - default_float_type=default_float_type, - default_int_type=default_int_type) - else: - if expr.is_integer: - return create_type(default_int_type) - else: - return create_type(default_float_type) - - raise NotImplementedError("Could not determine type for", expr, type(expr)) -############################# End This is basically our type system ################################################## - - -sympy_version = sp.__version__.split('.') -if int(sympy_version[0]) * 100 + int(sympy_version[1]) >= 109: - # __setstate__ would bypass the contructor, so we remove it - sp.Number.__getstate__ = sp.Basic.__getstate__ - del sp.Basic.__getstate__ - - class FunctorWithStoredKwargs: - def __init__(self, func, **kwargs): - self.func = func - self.kwargs = kwargs - - def __call__(self, *args): - return self.func(*args, **self.kwargs) - - # __reduce_ex__ would strip kwargs, so we override it - def basic_reduce_ex(self, protocol): - if hasattr(self, '__getnewargs_ex__'): - args, kwargs = self.__getnewargs_ex__() - else: - args, kwargs = self.__getnewargs__(), {} - if hasattr(self, '__getstate__'): - state = self.__getstate__() - else: - state = None - return FunctorWithStoredKwargs(type(self), **kwargs), args, state - sp.Number.__reduce_ex__ = sp.Basic.__reduce_ex__ - sp.Basic.__reduce_ex__ = basic_reduce_ex - - -class Type(sp.Atom): - # TODO: why is our type system dependent on sympy??? - # DONE: ask Martin - # TODO: inherits from sp.Atom because of cast function (and maybe others) - def __new__(cls, *args, **kwargs): - return sp.Basic.__new__(cls) - - def _sympystr(self, *args, **kwargs): - return str(self) - - -class BasicType(Type): - # TODO: check if Type inheritance is needed - # TODO: should be a sensible interface to np.dtype - # TODO: read numpy docs (Jan) - @staticmethod - def numpy_name_to_c(name): - # TODO: this should be a free function - # TODO: also check if numpy has this functionality - # TODO: docs!!! - # TODO: is this C? - if name == 'float64': - return 'double' - elif name == 'float32': - return 'float' - elif name == 'complex64': - return 'ComplexFloat' - elif name == 'complex128': - return 'ComplexDouble' - elif name.startswith('int'): - width = int(name[len("int"):]) - return f"int{width}_t" - elif name.startswith('uint'): - width = int(name[len("uint"):]) - return f"uint{width}_t" - elif name == 'bool': - return 'bool' - else: - raise NotImplementedError(f"Can map numpy to C name for {name}") - - def __init__(self, dtype, const=False): - # TODO: type hints - self.const = const - if isinstance(dtype, Type): - self._dtype = dtype.numpy_dtype # TODO: wtf? - else: - self._dtype = np.dtype(dtype) - assert self._dtype.fields is None, "Tried to initialize NativeType with a structured type" - assert self._dtype.hasobject is False - assert self._dtype.subdtype is None - - def __getnewargs__(self): - return self.numpy_dtype, self.const - - def __getnewargs_ex__(self): - return (self.numpy_dtype, self.const), {} - - @property - def base_type(self): # TODO: what is base_type? - return None - - @property - def numpy_dtype(self): - return self._dtype - - @property - def sympy_dtype(self): - return getattr(sympy.codegen.ast, str(self.numpy_dtype)) - - @property - def item_size(self): # TODO: what is this? - return 1 - - def is_int(self): - return self.numpy_dtype in np.sctypes['int'] or self.numpy_dtype in np.sctypes['uint'] - - def is_float(self): - return self.numpy_dtype in np.sctypes['float'] - - def is_uint(self): - return self.numpy_dtype in np.sctypes['uint'] - - def is_complex(self): - return self.numpy_dtype in np.sctypes['complex'] - - def is_other(self): - return self.numpy_dtype in np.sctypes['others'] - - @property - def base_name(self): # TODO: name of the function is highly confusing - return BasicType.numpy_name_to_c(str(self._dtype)) - - def __str__(self): - result = BasicType.numpy_name_to_c(str(self._dtype)) - if self.const: - result += " const" - return result - - def __repr__(self): - return str(self) - - def __eq__(self, other): - if not isinstance(other, BasicType): - return False - else: - return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const) - - def __hash__(self): - return hash(str(self)) - - -class VectorType(Type): - # TODO: check with rest - instruction_set = None - - def __init__(self, base_type, width=4): - self._base_type = base_type - self.width = width - - @property - def base_type(self): - return self._base_type - - @property - def item_size(self): - return self.width * self.base_type.item_size - - def __eq__(self, other): - if not isinstance(other, VectorType): - return False - else: - return (self.base_type, self.width) == (other.base_type, other.width) - - def __str__(self): - if self.instruction_set is None: - return f"{self.base_type}[{self.width}]" - else: - if self.base_type == create_type("int64") or self.base_type == create_type("int32"): - return self.instruction_set['int'] - elif self.base_type == create_type("float64"): - return self.instruction_set['double'] - elif self.base_type == create_type("float32"): - return self.instruction_set['float'] - elif self.base_type == create_type("bool"): - return self.instruction_set['bool'] - else: - raise NotImplementedError() - - def __hash__(self): - return hash((self.base_type, self.width)) - - def __getnewargs__(self): - return self._base_type, self.width - - def __getnewargs_ex__(self): - return (self._base_type, self.width), {} - - -class PointerType(Type): - # TODO: rename to FieldType - def __init__(self, base_type, const=False, restrict=True): - self._base_type = base_type - self.const = const - self.restrict = restrict - - def __getnewargs__(self): - return self.base_type, self.const, self.restrict - - def __getnewargs_ex__(self): - return (self.base_type, self.const, self.restrict), {} - - @property - def alias(self): - return not self.restrict - - @property - def base_type(self): - return self._base_type - - @property - def item_size(self): - return self.base_type.item_size - - def __eq__(self, other): - if not isinstance(other, PointerType): - return False - else: - return (self.base_type, self.const, self.restrict) == (other.base_type, other.const, other.restrict) - - def __str__(self): - components = [str(self.base_type), '*'] - if self.restrict: - components.append('RESTRICT') - if self.const: - components.append("const") - return " ".join(components) - - def __repr__(self): - return str(self) - - def __hash__(self): - return hash((self._base_type, self.const, self.restrict)) - - -class StructType: - # TODO: Docs. This is a struct. A list of types (with C offsets) - def __init__(self, numpy_type, const=False): - self.const = const - self._dtype = np.dtype(numpy_type) - - def __getnewargs__(self): - return self.numpy_dtype, self.const - - def __getnewargs_ex__(self): - return (self.numpy_dtype, self.const), {} - - @property - def base_type(self): - return None - - @property - def numpy_dtype(self): - return self._dtype - - @property - def item_size(self): - return self.numpy_dtype.itemsize - - def get_element_offset(self, element_name): - return self.numpy_dtype.fields[element_name][1] - - def get_element_type(self, element_name): - np_element_type = self.numpy_dtype.fields[element_name][0] - return BasicType(np_element_type, self.const) - - def has_element(self, element_name): - return element_name in self.numpy_dtype.fields - - def __eq__(self, other): - if not isinstance(other, StructType): - return False - else: - return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const) - - def __str__(self): - # structs are handled byte-wise - result = "uint8_t" - if self.const: - result += " const" - return result - - def __repr__(self): - return str(self) - - def __hash__(self): - return hash((self.numpy_dtype, self.const)) - - -class TypedImaginaryUnit(TypedSymbol): - # TODO: why is this an extra class??? - # TODO: remove? - def __new__(cls, *args, **kwds): - obj = TypedImaginaryUnit.__xnew_cached_(cls, *args, **kwds) - return obj - - def __new_stage2__(cls, dtype): - obj = super(TypedImaginaryUnit, cls).__xnew__(cls, - "_i", - dtype, - imaginary=True) - return obj - - headers = ['"cuda_complex.hpp"'] - - __xnew__ = staticmethod(__new_stage2__) - __xnew_cached_ = staticmethod(cacheit(__new_stage2__)) - - def __getnewargs__(self): - return (self.dtype,) - - def __getnewargs_ex__(self): - return (self.dtype,), {} diff --git a/pystencils/field.py b/pystencils/field.py index dcb33ca99..146a1cacb 100644 --- a/pystencils/field.py +++ b/pystencils/field.py @@ -13,8 +13,8 @@ from sympy.core.cache import cacheit import pystencils from pystencils.alignedarray import aligned_empty -from pystencils.data_types import StructType, TypedSymbol, create_type -from pystencils.kernelparameters import FieldShapeSymbol, FieldStrideSymbol +from pystencils.typing import StructType, TypedSymbol, create_type +from pystencils.typing.typed_sympy import FieldShapeSymbol, FieldStrideSymbol from pystencils.stencil import ( direction_string_to_offset, inverse_direction, offset_to_direction_string) from pystencils.sympyextensions import is_integer_sequence @@ -137,6 +137,7 @@ def fields(description=None, index_dimensions=0, layout=None, field_type=FieldTy return result +# TODO why this??? Why abstarct? class AbstractField: class AbstractAccess: pass @@ -472,27 +473,6 @@ class Field(AbstractField): assert FieldType.is_custom(self) return Field.Access(self, offset, index, is_absolute_access=True) - def interpolated_access(self, - offset: Tuple, - interpolation_mode='linear', - address_mode='BORDER', - allow_textures=True): - """Provides access to field values at non-integer positions - - ``interpolated_access`` is similar to :func:`Field.absolute_access` except that - it allows non-integer offsets and automatic handling of out-of-bound accesses. - - :param offset: Tuple of spatial coordinates (can be floats) - :param interpolation_mode: One of :class:`pystencils.interpolation_astnodes.InterpolationMode` - :param address_mode: How boundaries are handled can be 'border', 'wrap', 'mirror', 'clamp' - :param allow_textures: Allow implementation by texture accesses on GPUs - """ - from pystencils.interpolation_astnodes import Interpolator - return Interpolator(self, - interpolation_mode, - address_mode, - allow_textures=allow_textures).at(offset) - def staggered_access(self, offset, index=None): """If this field is a staggered field, it can be accessed using half-integer offsets. For example, an offset of ``(0, sp.Rational(1,2))`` or ``"E"`` corresponds to the staggered point to the east diff --git a/pystencils/gpucuda/cudajit.py b/pystencils/gpucuda/cudajit.py index 67adac657..a13297e0d 100644 --- a/pystencils/gpucuda/cudajit.py +++ b/pystencils/gpucuda/cudajit.py @@ -2,7 +2,7 @@ import numpy as np from pystencils.backends.cbackend import get_headers from pystencils.backends.cuda_backend import generate_cuda -from pystencils.data_types import StructType +from pystencils.typing import StructType from pystencils.field import FieldType from pystencils.include import get_pycuda_include_path, get_pystencils_include_path from pystencils.kernel_wrapper import KernelWrapper diff --git a/pystencils/gpucuda/indexing.py b/pystencils/gpucuda/indexing.py index ae5db1b98..6f30b0a1c 100644 --- a/pystencils/gpucuda/indexing.py +++ b/pystencils/gpucuda/indexing.py @@ -5,7 +5,7 @@ import sympy as sp from sympy.core.cache import cacheit from pystencils.astnodes import Block, Conditional -from pystencils.data_types import TypedSymbol, create_type +from pystencils.typing import TypedSymbol, create_type from pystencils.integer_functions import div_ceil, div_floor from pystencils.slicing import normalize_slice from pystencils.sympyextensions import is_integer_sequence, prod diff --git a/pystencils/gpucuda/kernelcreation.py b/pystencils/gpucuda/kernelcreation.py index 39808eab0..96399ae1c 100644 --- a/pystencils/gpucuda/kernelcreation.py +++ b/pystencils/gpucuda/kernelcreation.py @@ -1,13 +1,13 @@ import numpy as np from pystencils.astnodes import Block, KernelFunction, LoopOverCoordinate, SympyAssignment -from pystencils.data_types import StructType, TypedSymbol +from pystencils.typing import StructType, TypedSymbol, add_types from pystencils.field import Field, FieldType from pystencils.enums import Target, Backend from pystencils.gpucuda.cudajit import make_python_function from pystencils.gpucuda.indexing import BlockIndexing from pystencils.transformations import ( - add_types, get_base_buffer_index, get_common_shape, parse_base_pointer_info, + get_base_buffer_index, get_common_shape, parse_base_pointer_info, resolve_buffer_accesses, resolve_field_accesses, unify_shape_symbols) diff --git a/pystencils/integer_functions.py b/pystencils/integer_functions.py index efdaaaecf..1975a877e 100644 --- a/pystencils/integer_functions.py +++ b/pystencils/integer_functions.py @@ -1,7 +1,8 @@ +# TODO move to a module functions import numpy as np import sympy as sp -from pystencils.data_types import cast_func, collate_types, create_type, get_type_of_expression +from pystencils.typing import CastFunc, collate_types, create_type, get_type_of_expression from pystencils.sympyextensions import is_integer_sequence @@ -12,9 +13,9 @@ class IntegerFunctionTwoArgsMixIn(sp.Function): args = [] for a in (arg1, arg2): if isinstance(a, sp.Number) or isinstance(a, int): - args.append(cast_func(a, create_type("int"))) + args.append(CastFunc(a, create_type("int"))) elif isinstance(a, np.generic): - args.append(cast_func(a, a.dtype)) + args.append(CastFunc(a, a.dtype)) else: args.append(a) diff --git a/pystencils/integer_set_analysis.py b/pystencils/integer_set_analysis.py index 82af791ca..2e37c643f 100644 --- a/pystencils/integer_set_analysis.py +++ b/pystencils/integer_set_analysis.py @@ -4,7 +4,7 @@ import islpy as isl import sympy as sp import pystencils.astnodes as ast -from pystencils.transformations import parents_of_type +from pystencils.typing import parents_of_type def remove_brackets(s): diff --git a/pystencils/kerncraft_coupling/generate_benchmark.py b/pystencils/kerncraft_coupling/generate_benchmark.py index 955098d2c..8d8d7d1da 100644 --- a/pystencils/kerncraft_coupling/generate_benchmark.py +++ b/pystencils/kerncraft_coupling/generate_benchmark.py @@ -8,7 +8,7 @@ from jinja2 import Environment, PackageLoader, StrictUndefined from pystencils.astnodes import PragmaBlock from pystencils.backends.cbackend import generate_c, get_headers from pystencils.cpu.cpujit import get_compiler_config, run_compile_step -from pystencils.data_types import get_base_type +from pystencils.typing import get_base_type from pystencils.enums import Backend from pystencils.include import get_pystencils_include_path from pystencils.integer_functions import modulo_ceil diff --git a/pystencils/kerncraft_coupling/kerncraft_interface.py b/pystencils/kerncraft_coupling/kerncraft_interface.py index 61867e518..bfb5a2d6a 100644 --- a/pystencils/kerncraft_coupling/kerncraft_interface.py +++ b/pystencils/kerncraft_coupling/kerncraft_interface.py @@ -21,7 +21,7 @@ from pystencils.sympyextensions import count_operations_in_ast from pystencils.transformations import filtered_tree_iteration from pystencils.utils import DotDict from pystencils.cpu.kernelcreation import add_openmp -from pystencils.data_types import get_base_type +from pystencils.typing.utilities import get_base_type from pystencils.sympyextensions import prod diff --git a/pystencils/kernel_contrains_check.py b/pystencils/kernel_contrains_check.py new file mode 100644 index 000000000..55f141201 --- /dev/null +++ b/pystencils/kernel_contrains_check.py @@ -0,0 +1,150 @@ +from collections import namedtuple, defaultdict + +import numpy as np + +import pystencils.integer_functions +import sympy as sp +from pystencils import astnodes as ast, TypedSymbol +from pystencils.bit_masks import flag_cond +from pystencils.field import AbstractField +from pystencils.transformations import NestedScopes +from pystencils.typing import CastFunc, create_type, get_type_of_expression, collate_types +from sympy.logic.boolalg import BooleanFunction + + +class KernelConstraintsCheck: + # TODO: Logs + # TODO: specification + """Checks if the input to create_kernel is valid. + + Test the following conditions: + + - SSA Form for pure symbols: + - Every pure symbol may occur only once as left-hand-side of an assignment + - Every pure symbol that is read, may not be written to later + - Independence / Parallelization condition: + - a field that is written may only be read at exact the same spatial position + + (Pure symbols are symbols that are not Field.Accesses) + """ + FieldAndIndex = namedtuple('FieldAndIndex', ['field', 'index']) + + def __init__(self, type_for_symbol, check_independence_condition, check_double_write_condition=True): + self._type_for_symbol = type_for_symbol + + self.scopes = NestedScopes() + self._field_writes = defaultdict(set) + self.fields_read = set() + self.check_independence_condition = check_independence_condition + self.check_double_write_condition = check_double_write_condition + + def process_assignment(self, assignment): + # for checks it is crucial to process rhs before lhs to catch e.g. a = a + 1 + new_rhs = self.process_expression(assignment.rhs) + new_lhs = self._process_lhs(assignment.lhs) + return ast.SympyAssignment(new_lhs, new_rhs) + + def process_expression(self, rhs, type_constants=True): + + self._update_accesses_rhs(rhs) + if isinstance(rhs, AbstractField.AbstractAccess): + self.fields_read.add(rhs.field) + self.fields_read.update(rhs.indirect_addressing_fields) + return rhs + # TODO remove this + #elif isinstance(rhs, ImaginaryUnit): + # return TypedImaginaryUnit(create_type(self._type_for_symbol['_complex_type'])) + elif isinstance(rhs, TypedSymbol): + return rhs + elif isinstance(rhs, sp.Symbol): + return TypedSymbol(rhs.name, self._type_for_symbol[rhs.name]) + elif type_constants and isinstance(rhs, np.generic): + return CastFunc(rhs, create_type(rhs.dtype)) + elif type_constants and isinstance(rhs, sp.Number): + return CastFunc(rhs, create_type(self._type_for_symbol['_constant'])) + # Very important that this clause comes before BooleanFunction + elif isinstance(rhs, sp.Equality): + if isinstance(rhs.args[1], sp.Number): + return sp.Equality( + self.process_expression(rhs.args[0], type_constants), + rhs.args[1]) + else: + return sp.Equality( + self.process_expression(rhs.args[0], type_constants), + self.process_expression(rhs.args[1], type_constants)) + elif isinstance(rhs, CastFunc): + return CastFunc( + self.process_expression(rhs.args[0], type_constants=False), + rhs.dtype) + elif isinstance(rhs, BooleanFunction) or \ + type(rhs) in pystencils.integer_functions.__dict__.values(): + new_args = [self.process_expression(a, type_constants) for a in rhs.args] + types_of_expressions = [get_type_of_expression(a) for a in new_args] + arg_type = collate_types(types_of_expressions, forbid_collation_to_float=True) + new_args = [a if not hasattr(a, 'dtype') or a.dtype == arg_type + else CastFunc(a, arg_type) + for a in new_args] + return rhs.func(*new_args) + elif isinstance(rhs, flag_cond): + # do not process the arguments to the bit shift - they must remain integers + processed_args = (self.process_expression(a) for a in rhs.args[2:]) + return flag_cond(rhs.args[0], rhs.args[1], *processed_args) + elif isinstance(rhs, sp.Mul): + new_args = [ + self.process_expression(arg, type_constants) + if arg not in (-1, 1) else arg for arg in rhs.args + ] + return rhs.func(*new_args) if new_args else rhs + elif isinstance(rhs, sp.Indexed): + return rhs + else: + if isinstance(rhs, sp.Pow): + # don't process exponents -> they should remain integers + return sp.Pow( + self.process_expression(rhs.args[0], type_constants), + rhs.args[1]) + else: + new_args = [ + self.process_expression(arg, type_constants) + for arg in rhs.args + ] + return rhs.func(*new_args) if new_args else rhs + + @property + def fields_written(self): + return set(k.field for k, v in self._field_writes.items() if len(v)) + + def _process_lhs(self, lhs): + assert isinstance(lhs, sp.Symbol) + self._update_accesses_lhs(lhs) + if not isinstance(lhs, (AbstractField.AbstractAccess, TypedSymbol)): + return TypedSymbol(lhs.name, self._type_for_symbol[lhs.name]) + else: + return lhs + + def _update_accesses_lhs(self, lhs): + if isinstance(lhs, AbstractField.AbstractAccess): + fai = self.FieldAndIndex(lhs.field, lhs.index) + self._field_writes[fai].add(lhs.offsets) + if self.check_double_write_condition and len(self._field_writes[fai]) > 1: + raise ValueError( + f"Field {lhs.field.name} is written at two different locations") + elif isinstance(lhs, sp.Symbol): + if self.scopes.is_defined_locally(lhs): + raise ValueError(f"Assignments not in SSA form, multiple assignments to {lhs.name}") + if lhs in self.scopes.free_parameters: + raise ValueError(f"Symbol {lhs.name} is written, after it has been read") + self.scopes.define_symbol(lhs) + + def _update_accesses_rhs(self, rhs): + if isinstance(rhs, AbstractField.AbstractAccess) and self.check_independence_condition: + writes = self._field_writes[self.FieldAndIndex( + rhs.field, rhs.index)] + for write_offset in writes: + assert len(writes) == 1 + if write_offset != rhs.offsets: + raise ValueError("Violation of loop independence condition. Field " + "{} is read at {} and written at {}".format(rhs.field, rhs.offsets, write_offset)) + self.fields_read.add(rhs.field) + elif isinstance(rhs, sp.Symbol): + self.scopes.access_symbol(rhs) \ No newline at end of file diff --git a/pystencils/rng.py b/pystencils/rng.py index 7c4f894f9..c75c3f972 100644 --- a/pystencils/rng.py +++ b/pystencils/rng.py @@ -2,7 +2,7 @@ import copy import numpy as np import sympy as sp -from pystencils.data_types import TypedSymbol, cast_func +from pystencils.typing import TypedSymbol, CastFunc from pystencils.astnodes import LoopOverCoordinate from pystencils.backends.cbackend import CustomCodeNode from pystencils.sympyextensions import fast_subs @@ -47,11 +47,11 @@ class RNGBase(CustomCodeNode): def get_code(self, dialect, vector_instruction_set, print_arg): code = "\n" for r in self.result_symbols: - if vector_instruction_set and not self.args[1].atoms(cast_func): + if vector_instruction_set and not self.args[1].atoms(CastFunc): # this vector RNG has become scalar through substitution code += f"{r.dtype} {r.name};\n" else: - code += f"{vector_instruction_set[r.dtype.base_name] if vector_instruction_set else r.dtype} " + \ + code += f"{vector_instruction_set[r.dtype.c_name] if vector_instruction_set else r.dtype} " + \ f"{r.name};\n" args = [print_arg(a) for a in self.args] + ['' + r.name for r in self.result_symbols] code += (self._name + "(" + ", ".join(args) + ");\n") diff --git a/pystencils/sympyextensions.py b/pystencils/sympyextensions.py index f63328d81..1746a8b99 100644 --- a/pystencils/sympyextensions.py +++ b/pystencils/sympyextensions.py @@ -10,7 +10,7 @@ from sympy.functions import Abs from sympy.core.numbers import Zero from pystencils.assignment import Assignment -from pystencils.data_types import cast_func, get_type_of_expression, PointerType, VectorType +from pystencils.typing import CastFunc, get_type_of_expression, PointerType, VectorType from pystencils.kernelparameters import FieldPointerSymbol T = TypeVar('T') @@ -519,7 +519,7 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr], List[Assignment]], visit_children = False elif t.is_integer: pass - elif isinstance(t, cast_func): + elif isinstance(t, CastFunc): visit_children = False visit(t.args[0]) elif t.func is fast_sqrt: diff --git a/pystencils/transformations.py b/pystencils/transformations.py index 0c6d00658..beb5d287e 100644 --- a/pystencils/transformations.py +++ b/pystencils/transformations.py @@ -1,28 +1,21 @@ import hashlib import pickle import warnings -from typing import List, Dict -from collections import OrderedDict, defaultdict, namedtuple +from collections import OrderedDict from copy import deepcopy from types import MappingProxyType -import numpy as np import sympy as sp -from sympy.core.numbers import ImaginaryUnit -from sympy.logic.boolalg import Boolean, BooleanFunction import pystencils.astnodes as ast -import pystencils.integer_functions from pystencils.assignment import Assignment -from pystencils.data_types import ( - PointerType, StructType, TypedImaginaryUnit, TypedSymbol, cast_func, collate_types, create_type, - get_base_type, get_type_of_expression, pointer_arithmetic_func, reinterpret_cast_func) +from pystencils.typing import ( + PointerType, StructType, TypedSymbol, get_base_type, ReinterpretCastFunc, get_next_parent_of_type, parents_of_type) from pystencils.field import AbstractField, Field, FieldType -from pystencils.kernelparameters import FieldPointerSymbol +from pystencils.typing import FieldPointerSymbol from pystencils.simp.assignment_collection import AssignmentCollection from pystencils.slicing import normalize_slice from pystencils.integer_functions import int_div -from pystencils.bit_masks import flag_cond class NestedScopes: @@ -379,7 +372,10 @@ def get_base_buffer_index(ast_node, loop_counters=None, loop_iterations=None): return base_buffer_index * buffer_index_size -def resolve_buffer_accesses(ast_node, base_buffer_index, read_only_field_names=set()): +def resolve_buffer_accesses(ast_node, base_buffer_index, read_only_field_names=None): + + if read_only_field_names is None: + read_only_field_names = set() def visit_sympy_expr(expr, enclosing_block, sympy_assignment): if isinstance(expr, AbstractField.AbstractAccess): @@ -522,7 +518,7 @@ def resolve_field_accesses(ast_node, read_only_field_names=None, if isinstance(accessed_field_name, sp.Symbol): accessed_field_name = accessed_field_name.name new_type = field_access.field.dtype.get_element_type(accessed_field_name) - result = reinterpret_cast_func(result, new_type) + result = ReinterpretCastFunc(result, new_type) return visit_sympy_expr(result, enclosing_block, sympy_assignment) else: @@ -804,298 +800,6 @@ def cleanup_blocks(node: ast.Node) -> None: cleanup_blocks(a) -class KernelConstraintsCheck: - # TODO: Logs - # TODO: specification - """Checks if the input to create_kernel is valid. - - Test the following conditions: - - - SSA Form for pure symbols: - - Every pure symbol may occur only once as left-hand-side of an assignment - - Every pure symbol that is read, may not be written to later - - Independence / Parallelization condition: - - a field that is written may only be read at exact the same spatial position - - (Pure symbols are symbols that are not Field.Accesses) - """ - FieldAndIndex = namedtuple('FieldAndIndex', ['field', 'index']) - - def __init__(self, type_for_symbol, check_independence_condition, check_double_write_condition=True): - self._type_for_symbol = type_for_symbol - - self.scopes = NestedScopes() - self._field_writes = defaultdict(set) - self.fields_read = set() - self.check_independence_condition = check_independence_condition - self.check_double_write_condition = check_double_write_condition - - def process_assignment(self, assignment): - # for checks it is crucial to process rhs before lhs to catch e.g. a = a + 1 - new_rhs = self.process_expression(assignment.rhs) - new_lhs = self._process_lhs(assignment.lhs) - return ast.SympyAssignment(new_lhs, new_rhs) - - def process_expression(self, rhs, type_constants=True): - - self._update_accesses_rhs(rhs) - if isinstance(rhs, AbstractField.AbstractAccess): - self.fields_read.add(rhs.field) - self.fields_read.update(rhs.indirect_addressing_fields) - return rhs - elif isinstance(rhs, ImaginaryUnit): - return TypedImaginaryUnit(create_type(self._type_for_symbol['_complex_type'])) - elif isinstance(rhs, TypedSymbol): - return rhs - elif isinstance(rhs, sp.Symbol): - return TypedSymbol(rhs.name, self._type_for_symbol[rhs.name]) - elif type_constants and isinstance(rhs, np.generic): - return cast_func(rhs, create_type(rhs.dtype)) - elif type_constants and isinstance(rhs, sp.Number): - return cast_func(rhs, create_type(self._type_for_symbol['_constant'])) - # Very important that this clause comes before BooleanFunction - elif isinstance(rhs, sp.Equality): - if isinstance(rhs.args[1], sp.Number): - return sp.Equality( - self.process_expression(rhs.args[0], type_constants), - rhs.args[1]) - else: - return sp.Equality( - self.process_expression(rhs.args[0], type_constants), - self.process_expression(rhs.args[1], type_constants)) - elif isinstance(rhs, cast_func): - return cast_func( - self.process_expression(rhs.args[0], type_constants=False), - rhs.dtype) - elif isinstance(rhs, BooleanFunction) or \ - type(rhs) in pystencils.integer_functions.__dict__.values(): - new_args = [self.process_expression(a, type_constants) for a in rhs.args] - types_of_expressions = [get_type_of_expression(a) for a in new_args] - arg_type = collate_types(types_of_expressions, forbid_collation_to_float=True) - new_args = [a if not hasattr(a, 'dtype') or a.dtype == arg_type - else cast_func(a, arg_type) - for a in new_args] - return rhs.func(*new_args) - elif isinstance(rhs, flag_cond): - # do not process the arguments to the bit shift - they must remain integers - processed_args = (self.process_expression(a) for a in rhs.args[2:]) - return flag_cond(rhs.args[0], rhs.args[1], *processed_args) - elif isinstance(rhs, sp.Mul): - new_args = [ - self.process_expression(arg, type_constants) - if arg not in (-1, 1) else arg for arg in rhs.args - ] - return rhs.func(*new_args) if new_args else rhs - elif isinstance(rhs, sp.Indexed): - return rhs - else: - if isinstance(rhs, sp.Pow): - # don't process exponents -> they should remain integers - return sp.Pow( - self.process_expression(rhs.args[0], type_constants), - rhs.args[1]) - else: - new_args = [ - self.process_expression(arg, type_constants) - for arg in rhs.args - ] - return rhs.func(*new_args) if new_args else rhs - - @property - def fields_written(self): - return set(k.field for k, v in self._field_writes.items() if len(v)) - - def _process_lhs(self, lhs): - assert isinstance(lhs, sp.Symbol) - self._update_accesses_lhs(lhs) - if not isinstance(lhs, (AbstractField.AbstractAccess, TypedSymbol)): - return TypedSymbol(lhs.name, self._type_for_symbol[lhs.name]) - else: - return lhs - - def _update_accesses_lhs(self, lhs): - if isinstance(lhs, AbstractField.AbstractAccess): - fai = self.FieldAndIndex(lhs.field, lhs.index) - self._field_writes[fai].add(lhs.offsets) - if self.check_double_write_condition and len(self._field_writes[fai]) > 1: - raise ValueError( - f"Field {lhs.field.name} is written at two different locations") - elif isinstance(lhs, sp.Symbol): - if self.scopes.is_defined_locally(lhs): - raise ValueError(f"Assignments not in SSA form, multiple assignments to {lhs.name}") - if lhs in self.scopes.free_parameters: - raise ValueError(f"Symbol {lhs.name} is written, after it has been read") - self.scopes.define_symbol(lhs) - - def _update_accesses_rhs(self, rhs): - if isinstance(rhs, AbstractField.AbstractAccess) and self.check_independence_condition: - writes = self._field_writes[self.FieldAndIndex( - rhs.field, rhs.index)] - for write_offset in writes: - assert len(writes) == 1 - if write_offset != rhs.offsets: - raise ValueError("Violation of loop independence condition. Field " - "{} is read at {} and written at {}".format(rhs.field, rhs.offsets, write_offset)) - self.fields_read.add(rhs.field) - elif isinstance(rhs, sp.Symbol): - self.scopes.access_symbol(rhs) - - -def add_types(eqs: List[Assignment], type_for_symbol: Dict[sp.Symbol, np.dtype], check_independence_condition: bool, - check_double_write_condition: bool=True): - """Traverses AST and replaces every :class:`sympy.Symbol` by a :class:`pystencils.typedsymbol.TypedSymbol`. - - Additionally returns sets of all fields which are read/written - - Args: - eqs: list of equations - type_for_symbol: dict mapping symbol names to types. Types are strings of C types like 'int' or 'double' - check_independence_condition: check that loop iterations are independent - this has to be skipped for indexed - kernels - - Returns: - ``fields_read, fields_written, typed_equations`` set of read fields, set of written fields, - list of equations where symbols have been replaced by typed symbols - """ - if isinstance(type_for_symbol, (str, type)) or not hasattr(type_for_symbol, '__getitem__'): - type_for_symbol = typing_from_sympy_inspection(eqs, type_for_symbol) - - type_for_symbol = adjust_c_single_precision_type(type_for_symbol) - - # TODO what does this do???? - # TODO: ask Martin - check = KernelConstraintsCheck(type_for_symbol, check_independence_condition, - check_double_write_condition=check_double_write_condition) - - # TODO: check if this adds only types to leave nodes of AST, get type info - def visit(obj): - if isinstance(obj, (list, tuple)): - return [visit(e) for e in obj] - if isinstance(obj, (sp.Eq, ast.SympyAssignment, Assignment)): - return check.process_assignment(obj) - elif isinstance(obj, ast.Conditional): - check.scopes.push() - # Disable double write check inside conditionals - # would be triggered by e.g. in-kernel boundaries - old_double_write = check.check_double_write_condition - check.check_double_write_condition = False - false_block = None if obj.false_block is None else visit( - obj.false_block) - result = ast.Conditional(check.process_expression( - obj.condition_expr, type_constants=False), - true_block=visit(obj.true_block), - false_block=false_block) - check.check_double_write_condition = old_double_write - check.scopes.pop() - return result - elif isinstance(obj, ast.Block): - check.scopes.push() - result = ast.Block([visit(e) for e in obj.args]) - check.scopes.pop() - return result - elif isinstance(obj, ast.Node) and not isinstance(obj, ast.LoopOverCoordinate): - return obj - else: - raise ValueError("Invalid object in kernel " + str(type(obj))) - - typed_equations = visit(eqs) - - return check.fields_read, check.fields_written, typed_equations - - -def insert_casts(node): - """Checks the types and inserts casts and pointer arithmetic where necessary. - - Args: - node: the head node of the ast - - Returns: - modified AST - """ - def cast(zipped_args_types, target_dtype): - """ - Adds casts to the arguments if their type differs from the target type - :param zipped_args_types: a zipped list of args and types - :param target_dtype: The target data type - :return: args with possible casts - """ - casted_args = [] - for argument, data_type in zipped_args_types: - if data_type.numpy_dtype != target_dtype.numpy_dtype: # ignoring const - casted_args.append(cast_func(argument, target_dtype)) - else: - casted_args.append(argument) - return casted_args - - def pointer_arithmetic(expr_args): - """ - Creates a valid pointer arithmetic function - :param expr_args: Arguments of the add expression - :return: pointer_arithmetic_func - """ - pointer = None - new_args = [] - for arg, data_type in expr_args: - if data_type.func is PointerType: - assert pointer is None - pointer = arg - for arg, data_type in expr_args: - if arg != pointer: - assert data_type.is_int() or data_type.is_uint() - new_args.append(arg) - new_args = sp.Add(*new_args) if len(new_args) > 0 else new_args - return pointer_arithmetic_func(pointer, new_args) - - if isinstance(node, sp.AtomicExpr) or isinstance(node, cast_func): - return node - args = [] - for arg in node.args: - args.append(insert_casts(arg)) - # TODO indexed, LoopOverCoordinate - if node.func in (sp.Add, sp.Mul, sp.Or, sp.And, sp.Pow, sp.Eq, sp.Ne, sp.Lt, sp.Le, sp.Gt, sp.Ge): - # TODO optimize pow, don't cast integer on double - types = [get_type_of_expression(arg) for arg in args] - assert len(types) > 0 - # Never ever, ever collate to float type for boolean functions! - target = collate_types(types, forbid_collation_to_float=isinstance(node.func, BooleanFunction)) - zipped = list(zip(args, types)) - if target.func is PointerType: - assert node.func is sp.Add - return pointer_arithmetic(zipped) - else: - return node.func(*cast(zipped, target)) - elif node.func is ast.SympyAssignment: - lhs = args[0] - rhs = args[1] - target = get_type_of_expression(lhs) - if target.func is PointerType: - return node.func(*args) # TODO fix, not complete - else: - return node.func(lhs, *cast([(rhs, get_type_of_expression(rhs))], target)) - elif node.func is ast.ResolvedFieldAccess: - return node - elif node.func is ast.Block: - for old_arg, new_arg in zip(node.args, args): - node.replace(old_arg, new_arg) - return node - elif node.func is ast.LoopOverCoordinate: - for old_arg, new_arg in zip(node.args, args): - node.replace(old_arg, new_arg) - return node - elif node.func is sp.Piecewise: - expressions = [expr for (expr, _) in args] - types = [get_type_of_expression(expr) for expr in expressions] - target = collate_types(types) - zipped = list(zip(expressions, types)) - casted_expressions = cast(zipped, target) - args = [ - arg.func(*[expr, arg.cond]) - for (arg, expr) in zip(args, casted_expressions) - ] - - return node.func(*args) - - def remove_conditionals_in_staggered_kernel(function_node: ast.KernelFunction, include_first=True) -> None: """Removes conditionals of a kernel that iterates over staggered positions by splitting the loops at last or first and last element""" @@ -1118,73 +822,6 @@ def remove_conditionals_in_staggered_kernel(function_node: ast.KernelFunction, i # --------------------------------------- Helper Functions ------------------------------------------------------------- - - -def typing_from_sympy_inspection(eqs, default_type="double", default_int_type='int64'): - """ - Creates a default symbol name to type mapping. - If a sympy Boolean is assigned to a symbol it is assumed to be 'bool' otherwise the default type, usually ('double') - - Args: - eqs: list of equations - default_type: the type for non-boolean symbols - Returns: - dictionary, mapping symbol name to type - """ - result = defaultdict(lambda: default_type) - if hasattr(default_type, 'numpy_dtype'): - result['_complex_type'] = (np.zeros((1,), default_type.numpy_dtype) * 1j).dtype - else: - result['_complex_type'] = (np.zeros((1,), default_type) * 1j).dtype - for eq in eqs: - if isinstance(eq, ast.Conditional): - result.update(typing_from_sympy_inspection(eq.true_block.args)) - if eq.false_block: - result.update(typing_from_sympy_inspection( - eq.false_block.args)) - elif isinstance(eq, ast.Node) and not isinstance(eq, ast.SympyAssignment): - continue - else: - from pystencils.cpu.vectorization import vec_all, vec_any - if isinstance(eq.rhs, (vec_all, vec_any)): - result[eq.lhs.name] = "bool" - # problematic case here is when rhs is a symbol: then it is impossible to decide here without - # further information what type the left hand side is - default fallback is the dict value then - if isinstance(eq.rhs, Boolean) and not isinstance(eq.rhs, sp.Symbol): - result[eq.lhs.name] = "bool" - try: - result[eq.lhs.name] = get_type_of_expression(eq.rhs, - default_float_type=default_type, - default_int_type=default_int_type, - symbol_type_dict=result) - except Exception: - pass # gracefully fail in case get_type_of_expression cannot determine type - return result - - -def get_next_parent_of_type(node, parent_type): - """Returns the next parent node of given type or None, if root is reached. - - Traverses the AST nodes parents until a parent of given type was found. - If no such parent is found, None is returned - """ - parent = node.parent - while parent is not None: - if isinstance(parent, parent_type): - return parent - parent = parent.parent - return None - - -def parents_of_type(node, parent_type, include_current=False): - """Generator for all parent nodes of given type""" - parent = node if include_current else node.parent - while parent is not None: - if isinstance(parent, parent_type): - yield parent - parent = parent.parent - - def get_optimal_loop_ordering(fields): """ Determines the optimal loop order for a given set of fields. @@ -1340,16 +977,3 @@ def loop_blocking(ast_node: ast.KernelFunction, block_size) -> int: inner_loop.start = block_ctr inner_loop.stop = stop return coordinates_taken_into_account - - -def adjust_c_single_precision_type(type_for_symbol): - """Replaces every occurrence of 'float' with 'single' to enforce the numpy single precision type.""" - def single_factory(): - return "single" - - for symbol in type_for_symbol: - if type_for_symbol[symbol] == "float": - type_for_symbol[symbol] = single_factory() - if hasattr(type_for_symbol, "default_factory") and type_for_symbol.default_factory() == "float": - type_for_symbol.default_factory = single_factory - return type_for_symbol diff --git a/pystencils/typing/__init__.py b/pystencils/typing/__init__.py new file mode 100644 index 000000000..55fb731c0 --- /dev/null +++ b/pystencils/typing/__init__.py @@ -0,0 +1,4 @@ +from pystencils.typing.utilities import * +from pystencils.typing.types import * +from pystencils.typing.typed_sympy import * +from pystencils.typing.cast_functions import * diff --git a/pystencils/typing/cast_functions.py b/pystencils/typing/cast_functions.py new file mode 100644 index 000000000..0c2da8d20 --- /dev/null +++ b/pystencils/typing/cast_functions.py @@ -0,0 +1,120 @@ +import numpy as np +import sympy as sp +from sympy.logic.boolalg import Boolean + +from pystencils.typing.types import AbstractType, BasicType, create_type +from pystencils.typing.typed_sympy import TypedSymbol + + +class CastFunc(sp.Function): + # TODO: documentation + # TODO: move function to `functions.py` + is_Atom = True + + def __new__(cls, *args, **kwargs): + if len(args) != 2: + pass + expr, dtype, *other_args = args + if not isinstance(dtype, AbstractType): + dtype = create_type(dtype) + # to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well + # however, a cast_function should only be a boolean if its argument is a boolean, otherwise this leads + # to problems when for example comparing cast_func's for equality + # + # lhs = bitwise_and(a, cast_func(1, 'int')) + # rhs = cast_func(0, 'int') + # print( sp.Ne(lhs, rhs) ) # would give true if all cast_funcs are booleans + # -> thus a separate class boolean_cast_func is introduced + if isinstance(expr, Boolean) and (not isinstance(expr, TypedSymbol) or expr.dtype == BasicType(bool)): + cls = BooleanCastFunc + + return sp.Function.__new__(cls, expr, dtype, *other_args, **kwargs) + + @property + def canonical(self): + if hasattr(self.args[0], 'canonical'): + return self.args[0].canonical + else: + raise NotImplementedError() + + @property + def is_commutative(self): + return self.args[0].is_commutative + + def _eval_evalf(self, *args, **kwargs): + return self.args[0].evalf() + + @property + def dtype(self): + return self.args[1] + + @property + def is_integer(self): + """ + Uses Numpy type hierarchy to determine :func:`sympy.Expr.is_integer` predicate + + For reference: Numpy type hierarchy https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html + """ + if hasattr(self.dtype, 'numpy_dtype'): + return np.issubdtype(self.dtype.numpy_dtype, np.integer) or super().is_integer + else: + return super().is_integer + + @property + def is_negative(self): + """ + See :func:`.TypedSymbol.is_integer` + """ + if hasattr(self.dtype, 'numpy_dtype'): + if np.issubdtype(self.dtype.numpy_dtype, np.unsignedinteger): + return False + + return super().is_negative + + @property + def is_nonnegative(self): + """ + See :func:`.TypedSymbol.is_integer` + """ + if self.is_negative is False: + return True + else: + return super().is_nonnegative + + @property + def is_real(self): + """ + See :func:`.TypedSymbol.is_integer` + """ + if hasattr(self.dtype, 'numpy_dtype'): + return np.issubdtype(self.dtype.numpy_dtype, np.integer) or \ + np.issubdtype(self.dtype.numpy_dtype, np.floating) or \ + super().is_real + else: + return super().is_real + + +class BooleanCastFunc(CastFunc, Boolean): + # TODO: documentation + pass + + +class VectorMemoryAccess(CastFunc): + # TODO: documentation + # Arguments are: read/write expression, type, aligned, nontemporal, mask (or none), stride + nargs = (6,) + + +class ReinterpretCastFunc(CastFunc): + # TODO: documentation + pass + + +class PointerArithmeticFunc(sp.Function, Boolean): + # TODO: documentation + @property + def canonical(self): + if hasattr(self.args[0], 'canonical'): + return self.args[0].canonical + else: + raise NotImplementedError() diff --git a/pystencils/kernelparameters.py b/pystencils/typing/typed_sympy.py similarity index 52% rename from pystencils/kernelparameters.py rename to pystencils/typing/typed_sympy.py index 8bd4341be..0a253f748 100644 --- a/pystencils/kernelparameters.py +++ b/pystencils/typing/typed_sympy.py @@ -1,30 +1,102 @@ -"""Special symbols representing kernel parameters related to fields/arrays. - -A `KernelFunction` node determines parameters that have to be passed to the function by searching for all undefined -symbols. Some symbols are not directly defined by the user, but are related to the `Field`s used in the kernel: -For each field a `FieldPointerSymbol` needs to be passed in, which is the pointer to the memory region where -the field is stored. This pointer is represented by the `FieldPointerSymbol` class that additionally stores the -name of the corresponding field. For fields where the size is not known at compile time, additionally shape and stride -information has to be passed in at runtime. These values are represented by `FieldShapeSymbol` -and `FieldPointerSymbol`. - -The special symbols in this module store only the field name instead of a field reference. Storing a field reference -directly leads to problems with copying and pickling behaviour due to the circular dependency of `Field` and -e.g. `FieldShapeSymbol`, since a Field contains `FieldShapeSymbol`s in its shape, and a `FieldShapeSymbol` -would reference back to the field. -""" +from typing import Union + +import numpy as np +import sympy as sp from sympy.core.cache import cacheit -from pystencils.data_types import ( - PointerType, TypedSymbol, create_composite_type_from_string, get_base_type) +from pystencils.typing.types import BasicType, create_type, PointerType +from pystencils.typing.utilities import get_base_type + + +def assumptions_from_dtype(dtype: Union[BasicType, np.dtype]): + # TODO: type hints and if dtype is correct type form Numpy + """Derives SymPy assumptions from :class:`BasicType` or a Numpy dtype + + Args: + dtype (BasicType, np.dtype): a Numpy data type + Returns: + A dict of SymPy assumptions + """ + if hasattr(dtype, 'numpy_dtype'): + dtype = dtype.numpy_dtype + + assumptions = dict() + + try: + if np.issubdtype(dtype, np.integer): + assumptions.update({'integer': True}) + if np.issubdtype(dtype, np.unsignedinteger): + assumptions.update({'negative': False}) -# TODO: Why do we need extra classes? Why isn't TypedSymbol enough? -# TODO: Replace with a factory function + if np.issubdtype(dtype, np.integer) or \ + np.issubdtype(dtype, np.floating): + assumptions.update({'real': True}) + except Exception: # TODO this is dirty + pass + return assumptions -SHAPE_DTYPE = create_composite_type_from_string("const int64") -STRIDE_DTYPE = create_composite_type_from_string("const int64") + +class TypedSymbol(sp.Symbol): + def __new__(cls, *args, **kwds): + obj = TypedSymbol.__xnew_cached_(cls, *args, **kwds) + return obj + + def __new_stage2__(cls, name, dtype, **kwargs): # TODO does not match signature of sp.Symbol??? + assumptions = assumptions_from_dtype(dtype) # TODO should by dtype a np.dtype or our Type??? + assumptions.update(kwargs) + obj = super(TypedSymbol, cls).__xnew__(cls, name, **assumptions) + try: + obj.numpy_dtype = create_type(dtype) + except (TypeError, ValueError): + # on error keep the string + obj.numpy_dtype = dtype + return obj + + __xnew__ = staticmethod(__new_stage2__) + __xnew_cached_ = staticmethod(cacheit(__new_stage2__)) + + @property + def dtype(self): + return self._dtype + + def _hashable_content(self): + return super()._hashable_content(), hash(self._dtype) + + def __getnewargs__(self): + return self.name, self.dtype + + def __getnewargs_ex__(self): + return (self.name, self.dtype), self.assumptions0 + + @property + def canonical(self): + return self + + @property + def reversed(self): + return self + + @property + def headers(self): + headers = [] + try: + if np.issubdtype(self.dtype.numpy_dtype, np.complexfloating): + headers.append('"cuda_complex.hpp"') + except Exception: + pass + try: + if np.issubdtype(self.dtype.base_type.numpy_dtype, np.complexfloating): + headers.append('"cuda_complex.hpp"') + except Exception: + pass + + return headers + + +SHAPE_DTYPE = BasicType('int64', const=True) +STRIDE_DTYPE = BasicType('int64', const=True) class FieldStrideSymbol(TypedSymbol): diff --git a/pystencils/typing/types.py b/pystencils/typing/types.py new file mode 100644 index 000000000..eabe87dbd --- /dev/null +++ b/pystencils/typing/types.py @@ -0,0 +1,297 @@ +from abc import ABC, abstractmethod +from typing import Union + +import numpy as np +import sympy as sp +import sympy.codegen.ast + + +def is_supported_type(dtype: np.dtype): + scalar = dtype.type + c = np.issctype(dtype) + subclass = issubclass(scalar, np.floating) or issubclass(scalar, np.integer) or issubclass(scalar, np.bool) + additional_checks = dtype.fields is None and dtype.hasobject is False and dtype.subdtype is None + return c and subclass and additional_checks + + +def numpy_name_to_c(name: str) -> str: + """ + Converts a np.dtype.name into a C type + Args: + name: np.dtype.name string + Returns: + type as a C string + """ + if name == 'float64': + return 'double' + elif name == 'float32': + return 'float' + elif name.startswith('int'): + width = int(name[len("int"):]) + return f"int{width}_t" + elif name.startswith('uint'): + width = int(name[len("uint"):]) + return f"uint{width}_t" + elif name == 'bool': + return 'bool' + else: + raise NotImplementedError(f"Can't map numpy to C name for {name}") + + +class AbstractType(sp.Atom, ABC): + # TODO: inherits from sp.Atom because of cast function (and maybe others) + # TODO: is this necessary? + def __new__(cls, *args, **kwargs): + return sp.Basic.__new__(cls) + + def _sympystr(self, *args, **kwargs): + return str(self) + + @property + @abstractmethod + def base_type(self) -> Union[None, 'BasicType']: + """ + Returns: Returns BasicType of a Vector or Pointer type, None otherwise + """ + pass + + @property + @abstractmethod + def item_size(self) -> int: + """ + Returns: WHO THE FUCK KNOWS!??!!? + """ + pass + + +class BasicType(AbstractType): + # TODO: should be a sensible interface to np.dtype + + def __init__(self, dtype: Union[np.dtype, 'BasicType', str], const: bool = False): + self.const = const + if isinstance(dtype, BasicType): + self.numpy_dtype = dtype.numpy_dtype # TODO copy const as well?? + else: + self.numpy_dtype = np.dtype(dtype) + assert is_supported_type(self.numpy_dtype), f'Type {self.numpy_dtype} is currently not supported!' + + def __getnewargs__(self): + return self.numpy_dtype, self.const + + def __getnewargs_ex__(self): + return (self.numpy_dtype, self.const), {} + + @property + def base_type(self): + return None + + @property + def sympy_dtype(self): + return getattr(sympy.codegen.ast, str(self.numpy_dtype)) + + @property + def item_size(self): # TODO: what is this? Do we want self.numpy_type.itemsize???? + return 1 + + def is_float(self): + return issubclass(self.numpy_dtype.type, np.floating) + + def is_int(self): + return issubclass(self.numpy_dtype.type, np.integer) + + def is_uint(self): + return issubclass(self.numpy_dtype.type, np.unsignedinteger) + + def is_sint(self): + return issubclass(self.numpy_dtype.type, np.signedinteger) + + def is_bool(self): + return issubclass(self.numpy_dtype.type, np.bool) + + @property + def c_name(self) -> str: + return numpy_name_to_c(self.numpy_dtype.name) + + def __str__(self): + return f'{self.c_name}{" const" if self.const else ""}' + + def __repr__(self): + return str(self) + + def __eq__(self, other): + if not isinstance(other, BasicType): + return False + else: + return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const) + + def __hash__(self): + return hash(str(self)) + + +class VectorType(AbstractType): + # TODO: check with rest + instruction_set = None + + def __init__(self, base_type: BasicType, width: int = 4): # TODO default vector length is dangerous + self._base_type = base_type + self.width = width + + @property + def base_type(self): + return self._base_type + + @property + def item_size(self): + return self.width * self.base_type.item_size + + def __eq__(self, other): + if not isinstance(other, VectorType): + return False + else: + return (self.base_type, self.width) == (other.base_type, other.width) + + def __str__(self): + if self.instruction_set is None: + return f"{self.base_type}[{self.width}]" + else: + # TODO this seems super weird. the instruction_set should know how to print a type out!!! + # TODO this is error prone. base_type could be cons=True. Use dtype instead + if self.base_type == create_type("int64") or self.base_type == create_type("int32"): + return self.instruction_set['int'] + elif self.base_type == create_type("float64"): + return self.instruction_set['double'] + elif self.base_type == create_type("float32"): + return self.instruction_set['float'] + elif self.base_type == create_type("bool"): + return self.instruction_set['bool'] + else: + raise NotImplementedError() + + def __hash__(self): + return hash((self.base_type, self.width)) + + def __getnewargs__(self): + return self._base_type, self.width + + def __getnewargs_ex__(self): + return (self._base_type, self.width), {} + + +class PointerType(AbstractType): + def __init__(self, base_type: BasicType, const: bool = False, restrict: bool = True): + self._base_type = base_type + self.const = const + self.restrict = restrict + + def __getnewargs__(self): + return self.base_type, self.const, self.restrict + + def __getnewargs_ex__(self): + return (self.base_type, self.const, self.restrict), {} + + @property + def alias(self): + return not self.restrict + + @property + def base_type(self): + return self._base_type + + @property + def item_size(self): + return self.base_type.item_size + + def __eq__(self, other): + if not isinstance(other, PointerType): + return False + else: + return (self.base_type, self.const, self.restrict) == (other.base_type, other.const, other.restrict) + + def __str__(self): + return f'{str(self.base_type)} * {"RESTRICT " if self.restrict else "" }{"const" if self.const else ""}' + + def __repr__(self): + return str(self) + + def __hash__(self): + return hash((self._base_type, self.const, self.restrict)) + + +class StructType(AbstractType): + # TODO: Docs. This is a struct. A list of types (with C offsets) + # TODO StructType didn't inherit from AbstractType..... + # TODO: This is basically like a BasicType... only as struct + def __init__(self, numpy_type, const=False): + self.const = const + self._dtype = np.dtype(numpy_type) + + def __getnewargs__(self): + return self.numpy_dtype, self.const + + def __getnewargs_ex__(self): + return (self.numpy_dtype, self.const), {} + + @property + def base_type(self): + return None + + @property + def numpy_dtype(self): + return self._dtype + + @property + def item_size(self): + return self.numpy_dtype.itemsize + + def get_element_offset(self, element_name): + return self.numpy_dtype.fields[element_name][1] + + def get_element_type(self, element_name): + np_element_type = self.numpy_dtype.fields[element_name][0] + return BasicType(np_element_type, self.const) + + def has_element(self, element_name): + return element_name in self.numpy_dtype.fields + + def __eq__(self, other): + if not isinstance(other, StructType): + return False + else: + return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const) + + def __str__(self): + # structs are handled byte-wise + # TODO structs are weird + result = "uint8_t" + if self.const: + result += " const" + return result + + def __repr__(self): + return str(self) + + def __hash__(self): + return hash((self.numpy_dtype, self.const)) + + +def create_type(specification: Union[np.dtype, AbstractType, str]) -> AbstractType: + # TODO: Ok, this is basically useless. Except for it can differentiate between BasicType and StructType + # TODO: Everything else is already implemented inside BasicType + # TODO: Also why don't we support Vector and Pointer types??? + """Creates a subclass of Type according to a string or an object of subclass Type. + + Args: + specification: Type object, or a string + + Returns: + Type object, or a new Type object parsed from the string + """ + if isinstance(specification, AbstractType): + return specification + else: + numpy_dtype = np.dtype(specification) + if numpy_dtype.fields is None: + return BasicType(numpy_dtype, const=False) + else: + return StructType(numpy_dtype, const=False) + diff --git a/pystencils/typing/utilities.py b/pystencils/typing/utilities.py new file mode 100644 index 000000000..8187d929e --- /dev/null +++ b/pystencils/typing/utilities.py @@ -0,0 +1,494 @@ +from collections import defaultdict +from functools import partial +from typing import Tuple, Union, List, Dict + +import numpy as np +import sympy as sp +from pystencils import astnodes as ast +from pystencils.kernel_contrains_check import KernelConstraintsCheck +from sympy.codegen import Assignment +from sympy.logic.boolalg import Boolean, BooleanFunction + +import pystencils +from pystencils.cache import memorycache, memorycache_if_hashable +from pystencils.utils import all_equal +from pystencils.typing.types import AbstractType, BasicType, VectorType, PointerType, StructType, create_type +from pystencils.typing.cast_functions import CastFunc, PointerArithmeticFunc +from pystencils.typing.typed_sympy import TypedSymbol + + +def typed_symbols(names, dtype, *args): + # TODO docs, type hints + symbols = sp.symbols(names, *args) + if isinstance(symbols, Tuple): + return tuple(TypedSymbol(str(s), dtype) for s in symbols) + else: + return TypedSymbol(str(symbols), dtype) + + +# noinspection PyPep8Naming +class address_of(sp.Function): + # DONE: ask Martin + # TODO: docstring + # this is '&' in C + is_Atom = True + + def __new__(cls, arg): + obj = sp.Function.__new__(cls, arg) + return obj + + @property + def canonical(self): + if hasattr(self.args[0], 'canonical'): + return self.args[0].canonical + else: + raise NotImplementedError() + + @property + def is_commutative(self): + return self.args[0].is_commutative + + @property + def dtype(self): + if hasattr(self.args[0], 'dtype'): + return PointerType(self.args[0].dtype, restrict=True) + else: + return PointerType('void', restrict=True) # TODO this shouldn't work??? FIX: Allow BasicType to be Void and use that. Or raise exception + + +def get_base_type(data_type): + # TODO: WTF is this?? DOCS!!! + # TODO: This is unsafe. + # TODO: remove + # Pointer(Pointer(int)) + while data_type.base_type is not None: + data_type = data_type.base_type + return data_type + + +def peel_off_type(dtype, type_to_peel_off): + # TODO: WTF is this??? DOCS!!! + # TODO: used only once.... can be a lambda there + while type(dtype) is type_to_peel_off: + dtype = dtype.base_type + return dtype + + + +############################# This is basically our type system ######################################################## +def collate_types(types, + forbid_collation_to_complex=False, # TODO: type system shouldn't need this!!! + forbid_collation_to_float=False, # TODO: type system shouldn't need this!!! + default_float_type='float64', + # TODO: AST leaves should be typed. Expressions should be able to find out correct type + default_int_type='int64'): # TODO: AST leaves should be typed. Expressions should be able to find out correct type + """ + Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double + Uses the collation rules from numpy. + """ + # TODO: use np.can_cast and np.promote_types and np.result_type and np.find_common_type + if forbid_collation_to_complex: + types = [t for t in types if not np.issubdtype(t.numpy_dtype, np.complexfloating)] + if not types: + return create_type(default_float_type) + + if forbid_collation_to_float: + types = [t for t in types if not np.issubdtype(t.numpy_dtype, np.floating)] + if not types: + return create_type(default_int_type) + + # Pointer arithmetic case i.e. pointer + integer is allowed + if any(type(t) is PointerType for t in types): + pointer_type = None + for t in types: + if type(t) is PointerType: + if pointer_type is not None: + raise ValueError("Cannot collate the combination of two pointer types") + pointer_type = t + elif type(t) is BasicType: + if not (t.is_int() or t.is_uint()): + raise ValueError("Invalid pointer arithmetic") + else: + raise ValueError("Invalid pointer arithmetic") + return pointer_type + + # peel of vector types, if at least one vector type occurred the result will also be the vector type + vector_type = [t for t in types if type(t) is VectorType] + if not all_equal(t.width for t in vector_type): + raise ValueError("Collation failed because of vector types with different width") + types = [peel_off_type(t, VectorType) for t in types] + + # now we should have a list of basic types - struct types are not yet supported + assert all(type(t) is BasicType for t in types) + + if any(t.is_float() for t in types): + types = tuple(t for t in types if t.is_float()) + # use numpy collation -> create type from numpy type -> and, put vector type around if necessary + result_numpy_type = np.result_type(*(t.numpy_dtype for t in types)) + result = BasicType(result_numpy_type) + if vector_type: + result = VectorType(result, vector_type[0].width) + return result + + +@memorycache_if_hashable(maxsize=2048) +def get_type_of_expression(expr, + default_float_type='double', + # TODO: we shouldn't need to have default. AST leaves should have a type + default_int_type='int', + # TODO: we shouldn't need to have default. AST leaves should have a type + symbol_type_dict=None): # TODO: we shouldn't need to have default. AST leaves should have a type + from pystencils.astnodes import ResolvedFieldAccess + from pystencils.cpu.vectorization import vec_all, vec_any + + if default_float_type == 'float': + default_float_type = 'float32' + + if not symbol_type_dict: + symbol_type_dict = defaultdict(lambda: create_type('double')) + + get_type = partial(get_type_of_expression, + default_float_type=default_float_type, + default_int_type=default_int_type, + symbol_type_dict=symbol_type_dict) + + expr = sp.sympify(expr) + if isinstance(expr, sp.Integer): + return create_type(default_int_type) + elif expr.is_real is False: + return create_type((np.zeros((1,), default_float_type) * 1j).dtype) + elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float): + return create_type(default_float_type) + elif isinstance(expr, ResolvedFieldAccess): + return expr.field.dtype + elif isinstance(expr, pystencils.field.Field.AbstractAccess): + return expr.field.dtype + elif isinstance(expr, TypedSymbol): + return expr.dtype + elif isinstance(expr, sp.Symbol): + if symbol_type_dict: + return symbol_type_dict[expr.name] + else: + raise ValueError("All symbols inside this expression have to be typed! ", str(expr)) + elif isinstance(expr, CastFunc): + return expr.args[1] + elif isinstance(expr, (vec_any, vec_all)): + return create_type("bool") + elif hasattr(expr, 'func') and expr.func == sp.Piecewise: + collated_result_type = collate_types(tuple(get_type(a[0]) for a in expr.args)) + collated_condition_type = collate_types(tuple(get_type(a[1]) for a in expr.args)) + if type(collated_condition_type) is VectorType and type(collated_result_type) is not VectorType: + collated_result_type = VectorType(collated_result_type, width=collated_condition_type.width) + return collated_result_type + elif isinstance(expr, sp.Indexed): + typed_symbol = expr.base.label + return typed_symbol.dtype.base_type + elif isinstance(expr, (Boolean, BooleanFunction)): + # if any arg is of vector type return a vector boolean, else return a normal scalar boolean + result = create_type("bool") + vec_args = [get_type(a) for a in expr.args if isinstance(get_type(a), VectorType)] + if vec_args: + result = VectorType(result, width=vec_args[0].width) + return result + elif isinstance(expr, sp.Pow): + base_type = get_type(expr.args[0]) + if expr.exp.is_integer: + return base_type + else: + return collate_types([create_type(default_float_type), base_type]) + elif isinstance(expr, (sp.Sum, sp.Product)): + return get_type(expr.args[0]) + elif isinstance(expr, sp.Expr): + expr: sp.Expr + if expr.args: + types = tuple(get_type(a) for a in expr.args) + # collate_types checks numpy_dtype in the special cases + if any(not hasattr(t, 'numpy_dtype') for t in types): + forbid_collation_to_complex = False + forbid_collation_to_float = False + else: + forbid_collation_to_complex = expr.is_real is True + forbid_collation_to_float = expr.is_integer is True + return collate_types( + types, + forbid_collation_to_complex=forbid_collation_to_complex, + forbid_collation_to_float=forbid_collation_to_float, + default_float_type=default_float_type, + default_int_type=default_int_type) + else: + if expr.is_integer: + return create_type(default_int_type) + else: + return create_type(default_float_type) + + raise NotImplementedError("Could not determine type for", expr, type(expr)) + + +############################# End This is basically our type system ################################################## + + +# TODO this seems quite wrong... +sympy_version = sp.__version__.split('.') +if int(sympy_version[0]) * 100 + int(sympy_version[1]) >= 109: + # __setstate__ would bypass the contructor, so we remove it + sp.Number.__getstate__ = sp.Basic.__getstate__ + del sp.Basic.__getstate__ + + + class FunctorWithStoredKwargs: + def __init__(self, func, **kwargs): + self.func = func + self.kwargs = kwargs + + def __call__(self, *args): + return self.func(*args, **self.kwargs) + + + # __reduce_ex__ would strip kwargs, so we override it + def basic_reduce_ex(self, protocol): + if hasattr(self, '__getnewargs_ex__'): + args, kwargs = self.__getnewargs_ex__() + else: + args, kwargs = self.__getnewargs__(), {} + if hasattr(self, '__getstate__'): + state = self.__getstate__() + else: + state = None + return FunctorWithStoredKwargs(type(self), **kwargs), args, state + + + sp.Number.__reduce_ex__ = sp.Basic.__reduce_ex__ + sp.Basic.__reduce_ex__ = basic_reduce_ex + + +def add_types(eqs: List[Assignment], type_for_symbol: Dict[sp.Symbol, np.dtype], check_independence_condition: bool, + check_double_write_condition: bool=True): + """Traverses AST and replaces every :class:`sympy.Symbol` by a :class:`pystencils.typedsymbol.TypedSymbol`. + + Additionally returns sets of all fields which are read/written + + Args: + eqs: list of equations + type_for_symbol: dict mapping symbol names to types. Types are strings of C types like 'int' or 'double' + check_independence_condition: check that loop iterations are independent - this has to be skipped for indexed + kernels + + Returns: + ``fields_read, fields_written, typed_equations`` set of read fields, set of written fields, + list of equations where symbols have been replaced by typed symbols + """ + if isinstance(type_for_symbol, (str, type)) or not hasattr(type_for_symbol, '__getitem__'): + type_for_symbol = typing_from_sympy_inspection(eqs, type_for_symbol) + + type_for_symbol = adjust_c_single_precision_type(type_for_symbol) + + # TODO what does this do???? + # TODO: ask Martin + check = KernelConstraintsCheck(type_for_symbol, check_independence_condition, + check_double_write_condition=check_double_write_condition) + + # TODO: check if this adds only types to leave nodes of AST, get type info + def visit(obj): + if isinstance(obj, (list, tuple)): + return [visit(e) for e in obj] + if isinstance(obj, (sp.Eq, ast.SympyAssignment, Assignment)): + return check.process_assignment(obj) + elif isinstance(obj, ast.Conditional): + check.scopes.push() + # Disable double write check inside conditionals + # would be triggered by e.g. in-kernel boundaries + old_double_write = check.check_double_write_condition + check.check_double_write_condition = False + false_block = None if obj.false_block is None else visit( + obj.false_block) + result = ast.Conditional(check.process_expression( + obj.condition_expr, type_constants=False), + true_block=visit(obj.true_block), + false_block=false_block) + check.check_double_write_condition = old_double_write + check.scopes.pop() + return result + elif isinstance(obj, ast.Block): + check.scopes.push() + result = ast.Block([visit(e) for e in obj.args]) + check.scopes.pop() + return result + elif isinstance(obj, ast.Node) and not isinstance(obj, ast.LoopOverCoordinate): + return obj + else: + raise ValueError("Invalid object in kernel " + str(type(obj))) + + typed_equations = visit(eqs) + + return check.fields_read, check.fields_written, typed_equations + + +def insert_casts(node): + """Checks the types and inserts casts and pointer arithmetic where necessary. + + Args: + node: the head node of the ast + + Returns: + modified AST + """ + def cast(zipped_args_types, target_dtype): + """ + Adds casts to the arguments if their type differs from the target type + :param zipped_args_types: a zipped list of args and types + :param target_dtype: The target data type + :return: args with possible casts + """ + casted_args = [] + for argument, data_type in zipped_args_types: + if data_type.numpy_dtype != target_dtype.numpy_dtype: # ignoring const + casted_args.append(CastFunc(argument, target_dtype)) + else: + casted_args.append(argument) + return casted_args + + def pointer_arithmetic(expr_args): + """ + Creates a valid pointer arithmetic function + :param expr_args: Arguments of the add expression + :return: pointer_arithmetic_func + """ + pointer = None + new_args = [] + for arg, data_type in expr_args: + if data_type.func is PointerType: + assert pointer is None + pointer = arg + for arg, data_type in expr_args: + if arg != pointer: + assert data_type.is_int() or data_type.is_uint() + new_args.append(arg) + new_args = sp.Add(*new_args) if len(new_args) > 0 else new_args + return PointerArithmeticFunc(pointer, new_args) + + if isinstance(node, sp.AtomicExpr) or isinstance(node, CastFunc): + return node + args = [] + for arg in node.args: + args.append(insert_casts(arg)) + # TODO indexed, LoopOverCoordinate + if node.func in (sp.Add, sp.Mul, sp.Or, sp.And, sp.Pow, sp.Eq, sp.Ne, sp.Lt, sp.Le, sp.Gt, sp.Ge): + # TODO optimize pow, don't cast integer on double + types = [get_type_of_expression(arg) for arg in args] + assert len(types) > 0 + # Never ever, ever collate to float type for boolean functions! + target = collate_types(types, forbid_collation_to_float=isinstance(node.func, BooleanFunction)) + zipped = list(zip(args, types)) + if target.func is PointerType: + assert node.func is sp.Add + return pointer_arithmetic(zipped) + else: + return node.func(*cast(zipped, target)) + elif node.func is ast.SympyAssignment: + lhs = args[0] + rhs = args[1] + target = get_type_of_expression(lhs) + if target.func is PointerType: + return node.func(*args) # TODO fix, not complete + else: + return node.func(lhs, *cast([(rhs, get_type_of_expression(rhs))], target)) + elif node.func is ast.ResolvedFieldAccess: + return node + elif node.func is ast.Block: + for old_arg, new_arg in zip(node.args, args): + node.replace(old_arg, new_arg) + return node + elif node.func is ast.LoopOverCoordinate: + for old_arg, new_arg in zip(node.args, args): + node.replace(old_arg, new_arg) + return node + elif node.func is sp.Piecewise: + expressions = [expr for (expr, _) in args] + types = [get_type_of_expression(expr) for expr in expressions] + target = collate_types(types) + zipped = list(zip(expressions, types)) + casted_expressions = cast(zipped, target) + args = [ + arg.func(*[expr, arg.cond]) + for (arg, expr) in zip(args, casted_expressions) + ] + + return node.func(*args) + + +def adjust_c_single_precision_type(type_for_symbol): + """Replaces every occurrence of 'float' with 'single' to enforce the numpy single precision type.""" + def single_factory(): + return "single" + + for symbol in type_for_symbol: + if type_for_symbol[symbol] == "float": + type_for_symbol[symbol] = single_factory() + if hasattr(type_for_symbol, "default_factory") and type_for_symbol.default_factory() == "float": + type_for_symbol.default_factory = single_factory + return type_for_symbol + + +def get_next_parent_of_type(node, parent_type): + """Returns the next parent node of given type or None, if root is reached. + + Traverses the AST nodes parents until a parent of given type was found. + If no such parent is found, None is returned + """ + parent = node.parent + while parent is not None: + if isinstance(parent, parent_type): + return parent + parent = parent.parent + return None + + +def parents_of_type(node, parent_type, include_current=False): + """Generator for all parent nodes of given type""" + parent = node if include_current else node.parent + while parent is not None: + if isinstance(parent, parent_type): + yield parent + parent = parent.parent + + +def typing_from_sympy_inspection(eqs, default_type="double", default_int_type='int64'): + """ + Creates a default symbol name to type mapping. + If a sympy Boolean is assigned to a symbol it is assumed to be 'bool' otherwise the default type, usually ('double') + + Args: + eqs: list of equations + default_type: the type for non-boolean symbols + Returns: + dictionary, mapping symbol name to type + """ + result = defaultdict(lambda: default_type) + if hasattr(default_type, 'numpy_dtype'): + result['_complex_type'] = (np.zeros((1,), default_type.numpy_dtype) * 1j).dtype + else: + result['_complex_type'] = (np.zeros((1,), default_type) * 1j).dtype + for eq in eqs: + if isinstance(eq, ast.Conditional): + result.update(typing_from_sympy_inspection(eq.true_block.args)) + if eq.false_block: + result.update(typing_from_sympy_inspection( + eq.false_block.args)) + elif isinstance(eq, ast.Node) and not isinstance(eq, ast.SympyAssignment): + continue + else: + from pystencils.cpu.vectorization import vec_all, vec_any + if isinstance(eq.rhs, (vec_all, vec_any)): + result[eq.lhs.name] = "bool" + # problematic case here is when rhs is a symbol: then it is impossible to decide here without + # further information what type the left hand side is - default fallback is the dict value then + if isinstance(eq.rhs, Boolean) and not isinstance(eq.rhs, sp.Symbol): + result[eq.lhs.name] = "bool" + try: + result[eq.lhs.name] = get_type_of_expression(eq.rhs, + default_float_type=default_type, + default_int_type=default_int_type, + symbol_type_dict=result) + except Exception: + pass # gracefully fail in case get_type_of_expression cannot determine type + return result \ No newline at end of file diff --git a/pystencils_tests/test_abs.py b/pystencils_tests/test_abs.py index cf71bc04c..7bf7a1a45 100644 --- a/pystencils_tests/test_abs.py +++ b/pystencils_tests/test_abs.py @@ -1,7 +1,7 @@ import sympy import pystencils as ps -from pystencils.data_types import cast_func, create_type +from pystencils.typing import CastFunc, create_type def test_abs(): @@ -10,7 +10,7 @@ def test_abs(): default_int_type = create_type('int64') assignments = ps.AssignmentCollection({ - x[0, 0]: sympy.Abs(cast_func(y[0, 0], default_int_type)) + x[0, 0]: sympy.Abs(CastFunc(y[0, 0], default_int_type)) }) config = ps.CreateKernelConfig(target=ps.Target.GPU) diff --git a/pystencils_tests/test_address_of.py b/pystencils_tests/test_address_of.py index 659f5d92f..1cb9c8ed1 100644 --- a/pystencils_tests/test_address_of.py +++ b/pystencils_tests/test_address_of.py @@ -3,7 +3,7 @@ Test of pystencils.data_types.address_of """ import sympy as sp import pystencils -from pystencils.data_types import PointerType, address_of, cast_func, create_type +from pystencils.typing import PointerType, address_of, CastFunc, create_type from pystencils.simp.simplifications import sympy_cse @@ -17,14 +17,14 @@ def test_address_of(): assignments = pystencils.AssignmentCollection({ s: address_of(x[0, 0]), - y[0, 0]: cast_func(s, create_type('int64')) + y[0, 0]: CastFunc(s, create_type('int64')) }, {}) ast = pystencils.create_kernel(assignments) pystencils.show_code(ast) assignments = pystencils.AssignmentCollection({ - y[0, 0]: cast_func(address_of(x[0, 0]), create_type('int64')) + y[0, 0]: CastFunc(address_of(x[0, 0]), create_type('int64')) }, {}) ast = pystencils.create_kernel(assignments) @@ -36,8 +36,8 @@ def test_address_of_with_cse(): s = pystencils.TypedSymbol('s', PointerType(create_type('int64'))) assignments = pystencils.AssignmentCollection({ - y[0, 0]: cast_func(address_of(x[0, 0]), create_type('int64')) + s, - x[0, 0]: cast_func(address_of(x[0, 0]), create_type('int64')) + 1 + y[0, 0]: CastFunc(address_of(x[0, 0]), create_type('int64')) + s, + x[0, 0]: CastFunc(address_of(x[0, 0]), create_type('int64')) + 1 }, {}) ast = pystencils.create_kernel(assignments) diff --git a/pystencils_tests/test_complex_numbers.py b/pystencils_tests/test_complex_numbers.py index 9d9f71952..7f3894825 100644 --- a/pystencils_tests/test_complex_numbers.py +++ b/pystencils_tests/test_complex_numbers.py @@ -16,7 +16,7 @@ from sympy.functions import im, re import pystencils from pystencils import AssignmentCollection -from pystencils.data_types import TypedSymbol, create_type +from pystencils.typing import TypedSymbol, create_type X, Y = pystencils.fields('x, y: complex64[2d]') A, B = pystencils.fields('a, b: float32[2d]') diff --git a/pystencils_tests/test_cuda_known_functions.py b/pystencils_tests/test_cuda_known_functions.py index 32b7d9b76..7e465da9f 100644 --- a/pystencils_tests/test_cuda_known_functions.py +++ b/pystencils_tests/test_cuda_known_functions.py @@ -5,7 +5,7 @@ import pytest import pystencils from pystencils.astnodes import get_dummy_symbol from pystencils.backends.cuda_backend import CudaSympyPrinter -from pystencils.data_types import address_of +from pystencils.typing import address_of from pystencils.enums import Target diff --git a/pystencils_tests/test_field.py b/pystencils_tests/test_field.py index 596f9f4da..14c751336 100644 --- a/pystencils_tests/test_field.py +++ b/pystencils_tests/test_field.py @@ -4,7 +4,7 @@ import sympy as sp import pystencils as ps from pystencils import TypedSymbol -from pystencils.data_types import create_type +from pystencils.typing import create_type from pystencils.field import Field, FieldType, layout_string_to_tuple diff --git a/pystencils_tests/test_floor_ceil_int_optimization.py b/pystencils_tests/test_floor_ceil_int_optimization.py index 7ec81b05b..ce06f0559 100644 --- a/pystencils_tests/test_floor_ceil_int_optimization.py +++ b/pystencils_tests/test_floor_ceil_int_optimization.py @@ -11,7 +11,7 @@ import sympy as sp import pystencils -from pystencils.data_types import create_type +from pystencils.typing import create_type def test_floor_ceil_int_optimization(): diff --git a/pystencils_tests/test_global_definitions.py b/pystencils_tests/test_global_definitions.py index c08557018..fa51ccc9d 100644 --- a/pystencils_tests/test_global_definitions.py +++ b/pystencils_tests/test_global_definitions.py @@ -2,7 +2,7 @@ import sympy import pystencils.astnodes from pystencils.backends.cbackend import CBackend -from pystencils.data_types import TypedSymbol +from pystencils.typing import TypedSymbol class BogusDeclaration(pystencils.astnodes.Node): diff --git a/pystencils_tests/test_kernel_data_type.py b/pystencils_tests/test_kernel_data_type.py index 2fbab3ff1..25ca56c2b 100644 --- a/pystencils_tests/test_kernel_data_type.py +++ b/pystencils_tests/test_kernel_data_type.py @@ -5,7 +5,7 @@ import pytest from sympy.abc import x, y from pystencils import Assignment, create_kernel, fields, CreateKernelConfig -from pystencils.transformations import adjust_c_single_precision_type +from pystencils.typing import adjust_c_single_precision_type @pytest.mark.parametrize("data_type", ("float", "double")) diff --git a/pystencils_tests/test_match_subs_for_assignment_collection.py b/pystencils_tests/test_match_subs_for_assignment_collection.py index 9bcc5ad6b..7bb0ec509 100644 --- a/pystencils_tests/test_match_subs_for_assignment_collection.py +++ b/pystencils_tests/test_match_subs_for_assignment_collection.py @@ -11,12 +11,12 @@ import sympy as sp import pystencils -from pystencils.data_types import create_type +from pystencils.typing import create_type def test_wild_typed_symbol(): x = pystencils.fields('x: float32[3d]') - typed_symbol = pystencils.data_types.TypedSymbol('a', create_type('float64')) + typed_symbol = pystencils.typing.data_types.TypedSymbol('a', create_type('float64')) assert x.center().match(sp.Wild('w1')) assert typed_symbol.match(sp.Wild('w1')) diff --git a/pystencils_tests/test_pickle_support.py b/pystencils_tests/test_pickle_support.py index 462645198..87268a777 100644 --- a/pystencils_tests/test_pickle_support.py +++ b/pystencils_tests/test_pickle_support.py @@ -1,7 +1,7 @@ from copy import copy, deepcopy from pystencils.field import Field -from pystencils.data_types import TypedSymbol +from pystencils.typing import TypedSymbol def test_field_access(): diff --git a/pystencils_tests/test_random.py b/pystencils_tests/test_random.py index d1f509e65..b29f15eb7 100644 --- a/pystencils_tests/test_random.py +++ b/pystencils_tests/test_random.py @@ -6,7 +6,7 @@ import pystencils as ps from pystencils.rng import PhiloxFourFloats, PhiloxTwoDoubles, AESNIFourFloats, AESNITwoDoubles, random_symbol from pystencils.backends.simd_instruction_sets import get_supported_instruction_sets from pystencils.cpu.cpujit import get_compiler_config -from pystencils.data_types import TypedSymbol +from pystencils.typing import TypedSymbol from pystencils.enums import Target RNGs = {('philox', 'float'): PhiloxFourFloats, ('philox', 'double'): PhiloxTwoDoubles, diff --git a/pystencils_tests/test_sum_prod.py b/pystencils_tests/test_sum_prod.py index 2f6bf7359..235644db2 100644 --- a/pystencils_tests/test_sum_prod.py +++ b/pystencils_tests/test_sum_prod.py @@ -13,7 +13,7 @@ import sympy as sp import sympy.abc import pystencils as ps -from pystencils.data_types import create_type +from pystencils.typing import create_type @pytest.mark.parametrize('default_assignment_simplifications', [False, True]) diff --git a/pystencils_tests/test_transformations.py b/pystencils_tests/test_transformations.py index 9b0024980..3ede70a85 100644 --- a/pystencils_tests/test_transformations.py +++ b/pystencils_tests/test_transformations.py @@ -1,7 +1,7 @@ import pystencils as ps from pystencils import TypedSymbol from pystencils.astnodes import LoopOverCoordinate, SympyAssignment -from pystencils.data_types import create_type +from pystencils.typing import create_type from pystencils.transformations import filtered_tree_iteration, get_loop_hierarchy, get_loop_counter_symbol_hierarchy diff --git a/pystencils_tests/test_type_interference.py b/pystencils_tests/test_type_interference.py index 953b87742..179fa2836 100644 --- a/pystencils_tests/test_type_interference.py +++ b/pystencils_tests/test_type_interference.py @@ -1,14 +1,14 @@ from sympy.abc import a, b, c, d, e, f import pystencils -from pystencils.data_types import cast_func, create_type +from pystencils.typing import CastFunc, create_type def test_type_interference(): x = pystencils.fields('x: float32[3d]') assignments = pystencils.AssignmentCollection({ - a: cast_func(10, create_type('float64')), - b: cast_func(10, create_type('uint16')), + a: CastFunc(10, create_type('float64')), + b: CastFunc(10, create_type('uint16')), e: 11, c: b, f: c + b, diff --git a/pystencils_tests/test_types.py b/pystencils_tests/test_types.py index 75ba2c5e3..5c2b008e4 100644 --- a/pystencils_tests/test_types.py +++ b/pystencils_tests/test_types.py @@ -1,21 +1,9 @@ import sympy as sp import numpy as np -import pytest -import ctypes import pystencils as ps -from pystencils import data_types -from pystencils.data_types import TypedSymbol, get_type_of_expression, VectorType, collate_types, create_type, \ - typed_symbols, type_all_numbers, matrix_symbols, cast_func, pointer_arithmetic_func, ctypes_from_llvm, PointerType - - -def test_parsing(): - assert str(data_types.create_composite_type_from_string("const double *")) == "double const *" - assert str(data_types.create_composite_type_from_string("double const *")) == "double const *" - - t1 = data_types.create_composite_type_from_string("const double * const * const restrict") - t2 = data_types.create_composite_type_from_string(str(t1)) - assert t1 == t2 +from pystencils.typing import TypedSymbol, get_type_of_expression, VectorType, collate_types, create_type, \ + typed_symbols, type_all_numbers, matrix_symbols, CastFunc, PointerArithmeticFunc, PointerType def test_collation(): @@ -133,7 +121,7 @@ def test_Basic_data_type(): assert typed_symbols("s", bool).dtype.is_other() assert typed_symbols("s", np.void).dtype.is_other() - assert typed_symbols("s", np.float64).dtype.base_name == 'double' + assert typed_symbols("s", np.float64).dtype.c_name == 'double' # removed for old sympy version # assert typed_symbols(("s"), np.float64).dtype.sympy_dtype == typed_symbols(("s"), float).dtype.sympy_dtype @@ -157,15 +145,15 @@ def test_Basic_data_type(): def test_cast_func(): - assert cast_func(TypedSymbol("s", np.uint), np.int64).canonical == TypedSymbol("s", np.uint).canonical + assert CastFunc(TypedSymbol("s", np.uint), np.int64).canonical == TypedSymbol("s", np.uint).canonical - a = cast_func(5, np.uint) + a = CastFunc(5, np.uint) assert a.is_negative is False assert a.is_nonnegative def test_pointer_arithmetic_func(): - assert pointer_arithmetic_func(TypedSymbol("s", np.uint), 1).canonical == TypedSymbol("s", np.uint).canonical + assert PointerArithmeticFunc(TypedSymbol("s", np.uint), 1).canonical == TypedSymbol("s", np.uint).canonical def test_division(): -- GitLab