diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py index 4d9a77e479ca1602d67157c56f7cad197334281e..7b819db3cb3b3097a32f22657a85b02ae6b981e5 100644 --- a/pystencils/astnodes.py +++ b/pystencils/astnodes.py @@ -6,10 +6,10 @@ from typing import Any, List, Optional, Sequence, Set, Union import sympy as sp import pystencils -from pystencils.typing import TypedSymbol, CastFunc, create_type, get_next_parent_of_type +from pystencils.typing import create_type, get_next_parent_of_type, CastFunc from pystencils.enums import Target, Backend from pystencils.field import Field -from pystencils.typing.typed_sympy import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol +from pystencils.typing.typed_sympy import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol, TypedSymbol from pystencils.sympyextensions import fast_subs NodeOrExpr = Union['Node', sp.Expr] diff --git a/pystencils/boundaries/boundaryhandling.py b/pystencils/boundaries/boundaryhandling.py index 4ad3ab3ffba2d4d8e774121ab3123a4683d9ccf9..52a314d75e5783085d73772439d91d0c44fb4d02 100644 --- a/pystencils/boundaries/boundaryhandling.py +++ b/pystencils/boundaries/boundaryhandling.py @@ -10,7 +10,7 @@ from pystencils.cache import memorycache from pystencils.typing import TypedSymbol, create_type from pystencils.datahandling.pycuda import PyCudaArrayHandler from pystencils.field import Field -from pystencils.kernelparameters import FieldPointerSymbol +from pystencils.typing.typed_sympy import FieldPointerSymbol try: # noinspection PyPep8Naming diff --git a/pystencils/cpu/kernelcreation.py b/pystencils/cpu/kernelcreation.py index f2dc0ff933a1e51d6bfc67cc0b90963b478f010d..77a4a7d7920b2a64fb5129d86e9e342c9bd63906 100644 --- a/pystencils/cpu/kernelcreation.py +++ b/pystencils/cpu/kernelcreation.py @@ -59,7 +59,7 @@ def create_kernel(assignments: AssignmentOrAstNodeList, function_name: str = "ke else: raise ValueError("Term has to be field access or symbol") - # TODO 1) check kernel + # TODO 1) check kernel -> do general checks elsewhere # TODO 2) add leaf types fields_read, fields_written, assignments = add_types( assignments, type_info, not skip_independence_check, check_double_write_condition=not allow_double_writes) diff --git a/pystencils/field.py b/pystencils/field.py index 91b33eed39c45998256b9e58e1705404e1f7d437..4a29a1be201cf34cc396e5da355dc77fdba23446 100644 --- a/pystencils/field.py +++ b/pystencils/field.py @@ -319,7 +319,7 @@ class Field: assert isinstance(field_type, FieldType) assert len(shape) == len(strides) self.field_type = field_type - self._dtype = create_type(dtype) + self._dtype = create_type(dtype) # TODO do we have AoS??? self._layout = normalize_layout(layout) self.shape = shape self.strides = strides @@ -619,7 +619,7 @@ class Field: self.coordinate_origin = -sp.Matrix([i / 2 for i in self.spatial_shape]) # noinspection PyAttributeOutsideInit,PyUnresolvedReferences - class Access(TypedSymbol, Field.Access): + class Access(TypedSymbol): """Class representing a relative access into a `Field`. This class behaves like a normal sympy Symbol, it is actually derived from it. One can built up diff --git a/pystencils/gpucuda/cudajit.py b/pystencils/gpucuda/cudajit.py index a13297e0d7a222f40af25ccefb2623304a9f2f62..b6fb901750895b341d44fde26040ff3b91d0e9e9 100644 --- a/pystencils/gpucuda/cudajit.py +++ b/pystencils/gpucuda/cudajit.py @@ -6,7 +6,7 @@ from pystencils.typing import StructType from pystencils.field import FieldType from pystencils.include import get_pycuda_include_path, get_pystencils_include_path from pystencils.kernel_wrapper import KernelWrapper -from pystencils.kernelparameters import FieldPointerSymbol +from pystencils.typing.typed_sympy import FieldPointerSymbol USE_FAST_MATH = True diff --git a/pystencils/kernel_contrains_check.py b/pystencils/kernel_contrains_check.py index 842e70ad93cbdc3cd16710c212ecfd51b71b4456..a2b0740c987098a74984789f4c76d7bf35445f83 100644 --- a/pystencils/kernel_contrains_check.py +++ b/pystencils/kernel_contrains_check.py @@ -9,8 +9,15 @@ from pystencils.field import Field from pystencils.transformations import NestedScopes +accepted_functions = [ + sp.Pow, + sp.sqrt, # TODO why not a class?? + # TODO trigonometric functions +] + + class KernelConstraintsCheck: - # TODO: specification + # TODO: proper specification # TODO: More checks :) """Checks if the input to create_kernel is valid. @@ -26,28 +33,52 @@ class KernelConstraintsCheck: """ FieldAndIndex = namedtuple('FieldAndIndex', ['field', 'index']) - def __init__(self, type_for_symbol, check_independence_condition, check_double_write_condition=True): - self._type_for_symbol = type_for_symbol - + def __init__(self, check_independence_condition, check_double_write_condition=True): self.scopes = NestedScopes() self.field_writes = defaultdict(set) self.fields_read = set() self.check_independence_condition = check_independence_condition self.check_double_write_condition = check_double_write_condition + def visit(self, obj): + if isinstance(obj, (list, tuple)): + [self.visit(e) for e in obj] + if isinstance(obj, (sp.Eq, ast.SympyAssignment, Assignment)): + self.process_assignment(obj) + elif isinstance(obj, ast.Conditional): + self.scopes.push() + # Disable double write check inside conditionals + # would be triggered by e.g. in-kernel boundaries + old_double_write = self.check_double_write_condition + self.check_double_write_condition = False + if obj.false_block: + self.visit(obj.false_block) + self.process_expression(obj.condition_expr) + self.check_double_write_condition = old_double_write + self.scopes.pop() + elif isinstance(obj, ast.Block): + self.scopes.push() + [self.visit(e) for e in obj.args] + self.scopes.pop() + elif isinstance(obj, ast.Node) and not isinstance(obj, ast.LoopOverCoordinate): + pass + else: + raise ValueError(f'Invalid object in kernel {type(obj)}') + def process_assignment(self, assignment: Union[sp.Eq, ast.SympyAssignment, Assignment]): # for checks it is crucial to process rhs before lhs to catch e.g. a = a + 1 self.process_expression(assignment.rhs) self.process_lhs(assignment.lhs) - def process_expression(self, rhs, type_constants=True): + def process_expression(self, rhs): + # TODO constraint for accepted functions self.update_accesses_rhs(rhs) if isinstance(rhs, Field.Access): self.fields_read.add(rhs.field) self.fields_read.update(rhs.indirect_addressing_fields) else: for arg in rhs.args: - self.process_expression(arg, type_constants) + self.process_expression(arg) @property def fields_written(self): diff --git a/pystencils/sympyextensions.py b/pystencils/sympyextensions.py index 1746a8b9994292bee2b74aeaa4aacefb4931f5f7..0a9aea653fc716e2ce2f5c129cfb62f30e7c44f9 100644 --- a/pystencils/sympyextensions.py +++ b/pystencils/sympyextensions.py @@ -11,7 +11,7 @@ from sympy.core.numbers import Zero from pystencils.assignment import Assignment from pystencils.typing import CastFunc, get_type_of_expression, PointerType, VectorType -from pystencils.kernelparameters import FieldPointerSymbol +from pystencils.typing.typed_sympy import FieldPointerSymbol T = TypeVar('T') diff --git a/pystencils/typing/__init__.py b/pystencils/typing/__init__.py index 55fb731c0bd45c006d38c72cd20558fdf2dd6d17..2221b812b82e08976e2b4bdc73a1181605a0fcad 100644 --- a/pystencils/typing/__init__.py +++ b/pystencils/typing/__init__.py @@ -1,4 +1,6 @@ -from pystencils.typing.utilities import * + + from pystencils.typing.types import * from pystencils.typing.typed_sympy import * from pystencils.typing.cast_functions import * +from pystencils.typing.utilities import * diff --git a/pystencils/leaf_typing.py b/pystencils/typing/leaf_typing.py similarity index 51% rename from pystencils/leaf_typing.py rename to pystencils/typing/leaf_typing.py index 789bb4a8d8601e6a8cbabb5c87277c9e3ddc15c9..b6ef0362f17b8da213ea2157b11ebb19ead784c5 100644 --- a/pystencils/leaf_typing.py +++ b/pystencils/typing/leaf_typing.py @@ -1,5 +1,5 @@ -from collections import namedtuple, defaultdict -from typing import List, Union +from collections import namedtuple +from typing import Union, Dict, Tuple, Any import numpy as np @@ -9,13 +9,13 @@ import sympy as sp from pystencils import astnodes as ast, TypedSymbol from pystencils.bit_masks import flag_cond from pystencils.field import Field -from pystencils.transformations import NestedScopes -from pystencils.typing import CastFunc, create_type, get_type_of_expression, collate_types +from pystencils.typing import AbstractType, BasicType, CastFunc, create_type, get_type_of_expression, collate_types +from pystencils.utils import ContextVar from sympy.codegen import Assignment from sympy.logic.boolalg import BooleanFunction -class KernelConstraintsCheck: # TODO rename +class TypeAdder: # TODO: Logs # TODO: specification # TODO: split this into checker and leaf typing @@ -33,33 +33,95 @@ class KernelConstraintsCheck: # TODO rename """ FieldAndIndex = namedtuple('FieldAndIndex', ['field', 'index']) - def __init__(self, type_for_symbol, check_independence_condition, check_double_write_condition=True): - self._type_for_symbol = type_for_symbol - - self.scopes = NestedScopes() - self.field_writes = defaultdict(set) - self.fields_read = set() - self.check_independence_condition = check_independence_condition - self.check_double_write_condition = check_double_write_condition + def __init__(self, default_symbol_type: BasicType, type_for_symbol: Dict[str, BasicType], + default_number_float: BasicType, default_number_int: BasicType): + self.type_for_symbol = type_for_symbol + self.default_symbol_type = ContextVar(default_symbol_type) + self.default_number_float = ContextVar(default_number_float) + self.default_number_int = ContextVar(default_number_int) + + def get_symbol_type(self, symbol: str) -> BasicType: + return self.type_for_symbol.get(symbol, self.default_symbol_type.get()) + + # TODO: check if this adds only types to leave nodes of AST, get type info + def visit(self, obj): + if isinstance(obj, (list, tuple)): + return [self.visit(e) for e in obj] + if isinstance(obj, (sp.Eq, ast.SympyAssignment, Assignment)): + return self.process_assignment(obj) + elif isinstance(obj, ast.Conditional): + false_block = None if obj.false_block is None else self.visit( + obj.false_block) + result = ast.Conditional(self.process_expression( + obj.condition_expr, type_constants=False), + true_block=self.visit(obj.true_block), + false_block=false_block) + return result + elif isinstance(obj, ast.Block): + result = ast.Block([self.visit(e) for e in obj.args]) + return result + elif isinstance(obj, ast.Node) and not isinstance(obj, ast.LoopOverCoordinate): + return obj + else: + raise ValueError("Invalid object in kernel " + str(type(obj))) def process_assignment(self, assignment: Union[sp.Eq, ast.SympyAssignment, Assignment]) -> ast.SympyAssignment: # for checks it is crucial to process rhs before lhs to catch e.g. a = a + 1 new_rhs = self.process_expression(assignment.rhs) + # TODO check type rhs lhs new_lhs = self.process_lhs(assignment.lhs) return ast.SympyAssignment(new_lhs, new_rhs) + # Type System Specification + # - Defined Types: TypedSymbol, Field, Field.Access, ...? + # - Indexed: always unsigned_integer64 + # - Undefined Types: Symbol - Is specified in Config in the dict or as 'default_type' + # - Constants/Numbers: Are either integer or floating. The precision and sign is specified via config + # - Example: 1.4 config:float32 -> float32 + # - Expressions deduce types from arguments + # - Functions deduce types from arguments + # - default_type and default_float and default_int can be given for a list of assignment, or + # individually as a list for assignment + + # Possible Problems - Do we need to support this? + # - Mixture in expression with int and float + # - Mixture in expression with uint64 and sint64 + + def figure_out_type(self, expr) -> Tuple[Any, BasicType]: #TODO or abstract type? + # Trivial cases + if isinstance(expr, Field.Access): + return expr, expr.dtype + elif isinstance(expr, TypedSymbol): + return expr, expr.dtype + elif isinstance(expr, sp.Symbol): + t = TypedSymbol(expr.name, self.get_symbol_type(expr.name)) # TODO with or without name + return t, t.dtype + elif isinstance(expr, np.generic): + assert False, f'Why do we have a np.generic in rhs???? {expr}' + elif isinstance(expr, sp.Number): + if expr.is_Float: + data_type = self.default_number_float.get() + elif expr.is_Integer: + data_type = self.default_number_int.get() + return expr, data_type + # TODO add everything in between + elif isinstance(expr, sp.Mul): + # TODO can we ignore this and move it to general expr handling, i.e. removing Mul? + types = [self.figure_out_type(arg) for arg in expr.args if arg not in (-1, 1)] + return None # TODO collate types + elif isinstance(expr, sp.Indexed): + self.apply_type(expr, BasicType('uintp')) # TODO double check + return None + elif isinstance(expr, sp.Pow): + # TODO sp.Pow should know a type + return None # TODO + else: + types = [self.figure_out_type(arg) for arg in expr.args] + # TODO collate + return None # TODO - # Expression - # 1) ask children if they are cocksure about a type - # 1b) Postpone clueless children (see 5) - # cocksure: Children have somewhere type from Field.Access, TypedSymbol, CastFunction or Function^TM - # clueless: Children without Field.Access,... - # 1c) none child is cocksure -> do nothing a return None, wait for recall from parent - # 2) collate_type of children - # 3) apply collated type on children - # 4) issue warnings of casts on cocksure children - # 5a) resume on clueless children with the collated type as default datatype, issue warning - # 5b) or apply special circumstances + def apply_type(self, expr, data_type: AbstractType): + pass def process_expression(self, rhs, type_constants=True): # TODO default_type as parameter if isinstance(rhs, Field.Access): @@ -115,13 +177,6 @@ class KernelConstraintsCheck: # TODO rename new_args = [self.process_expression(arg, type_constants) for arg in rhs.args] return rhs.func(*new_args) if new_args else rhs - @property - def fields_written(self): - """ - Return all rhs fields - """ - return set(k.field for k, v in self.field_writes.items() if len(v)) - def process_lhs(self, lhs: Union[Field.Access, TypedSymbol, sp.Symbol]): if not isinstance(lhs, (Field.Access, TypedSymbol)): return TypedSymbol(lhs.name, self._type_for_symbol[lhs.name]) diff --git a/pystencils/typing/typed_sympy.py b/pystencils/typing/typed_sympy.py index 0a253f748082aa7baf3359a716dcd0a873cb02fb..dffffe9e26763e0474a3d5ec3a5d59c28c3a1270 100644 --- a/pystencils/typing/typed_sympy.py +++ b/pystencils/typing/typed_sympy.py @@ -5,7 +5,6 @@ import sympy as sp from sympy.core.cache import cacheit from pystencils.typing.types import BasicType, create_type, PointerType -from pystencils.typing.utilities import get_base_type def assumptions_from_dtype(dtype: Union[BasicType, np.dtype]): @@ -44,6 +43,7 @@ class TypedSymbol(sp.Symbol): return obj def __new_stage2__(cls, name, dtype, **kwargs): # TODO does not match signature of sp.Symbol??? + # TODO: also Symbol should be allowed ---> see sympy Variable assumptions = assumptions_from_dtype(dtype) # TODO should by dtype a np.dtype or our Type??? assumptions.update(kwargs) obj = super(TypedSymbol, cls).__xnew__(cls, name, **assumptions) @@ -59,10 +59,10 @@ class TypedSymbol(sp.Symbol): @property def dtype(self): - return self._dtype + return self.numpy_dtype def _hashable_content(self): - return super()._hashable_content(), hash(self._dtype) + return super()._hashable_content(), hash(self.numpy_dtype) def __getnewargs__(self): return self.name, self.dtype @@ -160,6 +160,8 @@ class FieldPointerSymbol(TypedSymbol): return obj def __new_stage2__(cls, field_name, field_dtype, const): + from pystencils.typing.utilities import get_base_type + name = f"_data_{field_name}" dtype = PointerType(get_base_type(field_dtype), const=const, restrict=True) obj = super(FieldPointerSymbol, cls).__xnew__(cls, name, dtype) diff --git a/pystencils/typing/types.py b/pystencils/typing/types.py index eabe87dbd837940f0ec48edb33962befb482d354..318b0932ac1733c20f474721c26a6246b78d874a 100644 --- a/pystencils/typing/types.py +++ b/pystencils/typing/types.py @@ -38,7 +38,7 @@ def numpy_name_to_c(name: str) -> str: raise NotImplementedError(f"Can't map numpy to C name for {name}") -class AbstractType(sp.Atom, ABC): +class AbstractType(sp.Atom): # TODO: inherits from sp.Atom because of cast function (and maybe others) # TODO: is this necessary? def __new__(cls, *args, **kwargs): diff --git a/pystencils/typing/utilities.py b/pystencils/typing/utilities.py index a7a506f0cbfcd724d0d9f6b0f6e900568f94dda1..2f3d175daa3aa80d8138977d69a8a9c06a6543ef 100644 --- a/pystencils/typing/utilities.py +++ b/pystencils/typing/utilities.py @@ -1,18 +1,16 @@ from collections import defaultdict from functools import partial -from typing import Tuple, Union, List, Dict +from typing import Tuple, List, Dict import numpy as np import sympy as sp -from pystencils import astnodes as ast -from pystencils.kernel_contrains_check import KernelConstraintsCheck +# from pystencils.typing.leaf_typing import TypeAdder # TODO this should be leaf_typing from sympy.codegen import Assignment from sympy.logic.boolalg import Boolean, BooleanFunction import pystencils -from pystencils.cache import memorycache, memorycache_if_hashable -from pystencils.utils import all_equal -from pystencils.typing.types import AbstractType, BasicType, VectorType, PointerType, StructType, create_type +from pystencils.cache import memorycache_if_hashable +from pystencils.typing.types import BasicType, VectorType, PointerType, create_type from pystencils.typing.cast_functions import CastFunc, PointerArithmeticFunc from pystencils.typing.typed_sympy import TypedSymbol @@ -74,49 +72,53 @@ def peel_off_type(dtype, type_to_peel_off): return dtype - ############################# This is basically our type system ######################################################## -def collate_types(types, - forbid_collation_to_complex=False, # TODO: type system shouldn't need this!!! - forbid_collation_to_float=False, # TODO: type system shouldn't need this!!! - default_float_type='float64', - # TODO: AST leaves should be typed. Expressions should be able to find out correct type - default_int_type='int64'): # TODO: AST leaves should be typed. Expressions should be able to find out correct type + +def result_type(*args: np.dtype): + s = sorted(args, key=lambda x: x.itemsize) + + def kind_to_value(kind: str) -> int: + if kind == 'f': + return 3 + elif kind == 'i': + return 2 + elif kind == 'u': + return 1 + elif kind == 'b': + return 0 + else: + raise NotImplementedError(f'{kind=} is not a supported kind of a type. See "numpy.dtype.kind" for options') + s = sorted(s, key=lambda x: kind_to_value(x.kind)) + return s[-1] + + +def collate_types(types): """ Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double Uses the collation rules from numpy. """ # TODO: use np.can_cast and np.promote_types and np.result_type and np.find_common_type - if forbid_collation_to_complex: - types = [t for t in types if not np.issubdtype(t.numpy_dtype, np.complexfloating)] - if not types: - return create_type(default_float_type) - - if forbid_collation_to_float: - types = [t for t in types if not np.issubdtype(t.numpy_dtype, np.floating)] - if not types: - return create_type(default_int_type) - - # Pointer arithmetic case i.e. pointer + integer is allowed - if any(type(t) is PointerType for t in types): - pointer_type = None - for t in types: - if type(t) is PointerType: - if pointer_type is not None: - raise ValueError("Cannot collate the combination of two pointer types") - pointer_type = t - elif type(t) is BasicType: - if not (t.is_int() or t.is_uint()): - raise ValueError("Invalid pointer arithmetic") - else: - raise ValueError("Invalid pointer arithmetic") - return pointer_type - # peel of vector types, if at least one vector type occurred the result will also be the vector type - vector_type = [t for t in types if type(t) is VectorType] - if not all_equal(t.width for t in vector_type): - raise ValueError("Collation failed because of vector types with different width") - types = [peel_off_type(t, VectorType) for t in types] + # # Pointer arithmetic case i.e. pointer + integer is allowed + # if any(type(t) is PointerType for t in types): + # pointer_type = None + # for t in types: + # if type(t) is PointerType: + # if pointer_type is not None: + # raise ValueError("Cannot collate the combination of two pointer types") + # pointer_type = t + # elif type(t) is BasicType: + # if not (t.is_int() or t.is_uint()): + # raise ValueError("Invalid pointer arithmetic") + # else: + # raise ValueError("Invalid pointer arithmetic") + # return pointer_type + # + # # peel of vector types, if at least one vector type occurred the result will also be the vector type + # vector_type = [t for t in types if type(t) is VectorType] + # if not all_equal(t.width for t in vector_type): + # raise ValueError("Collation failed because of vector types with different width") + # types = [peel_off_type(t, VectorType) for t in types] # now we should have a list of basic types - struct types are not yet supported assert all(type(t) is BasicType for t in types) @@ -126,8 +128,8 @@ def collate_types(types, # use numpy collation -> create type from numpy type -> and, put vector type around if necessary result_numpy_type = np.result_type(*(t.numpy_dtype for t in types)) result = BasicType(result_numpy_type) - if vector_type: - result = VectorType(result, vector_type[0].width) + # if vector_type: + # result = VectorType(result, vector_type[0].width) return result @@ -166,6 +168,7 @@ def get_type_of_expression(expr, elif isinstance(expr, TypedSymbol): return expr.dtype elif isinstance(expr, sp.Symbol): + # TODO delete if case if symbol_type_dict: return symbol_type_dict[expr.name] else: @@ -288,36 +291,7 @@ def add_types(eqs: List[Assignment], type_for_symbol: Dict[sp.Symbol, np.dtype], check = KernelConstraintsCheck(type_for_symbol, check_independence_condition, check_double_write_condition=check_double_write_condition) - # TODO: check if this adds only types to leave nodes of AST, get type info - def visit(obj): - if isinstance(obj, (list, tuple)): - return [visit(e) for e in obj] - if isinstance(obj, (sp.Eq, ast.SympyAssignment, Assignment)): - return check.process_assignment(obj) - elif isinstance(obj, ast.Conditional): - check.scopes.push() - # Disable double write check inside conditionals - # would be triggered by e.g. in-kernel boundaries - old_double_write = check.check_double_write_condition - check.check_double_write_condition = False - false_block = None if obj.false_block is None else visit( - obj.false_block) - result = ast.Conditional(check.process_expression( - obj.condition_expr, type_constants=False), - true_block=visit(obj.true_block), - false_block=false_block) - check.check_double_write_condition = old_double_write - check.scopes.pop() - return result - elif isinstance(obj, ast.Block): - check.scopes.push() - result = ast.Block([visit(e) for e in obj.args]) - check.scopes.pop() - return result - elif isinstance(obj, ast.Node) and not isinstance(obj, ast.LoopOverCoordinate): - return obj - else: - raise ValueError("Invalid object in kernel " + str(type(obj))) + typed_equations = visit(eqs) @@ -333,6 +307,8 @@ def insert_casts(node): Returns: modified AST """ + from pystencils.astnodes import SympyAssignment, ResolvedFieldAccess, LoopOverCoordinate, Block + def cast(zipped_args_types, target_dtype): """ Adds casts to the arguments if their type differs from the target type @@ -385,7 +361,7 @@ def insert_casts(node): return pointer_arithmetic(zipped) else: return node.func(*cast(zipped, target)) - elif node.func is ast.SympyAssignment: + elif node.func is SympyAssignment: lhs = args[0] rhs = args[1] target = get_type_of_expression(lhs) @@ -393,13 +369,13 @@ def insert_casts(node): return node.func(*args) # TODO fix, not complete else: return node.func(lhs, *cast([(rhs, get_type_of_expression(rhs))], target)) - elif node.func is ast.ResolvedFieldAccess: + elif node.func is ResolvedFieldAccess: return node - elif node.func is ast.Block: + elif node.func is Block: for old_arg, new_arg in zip(node.args, args): node.replace(old_arg, new_arg) return node - elif node.func is ast.LoopOverCoordinate: + elif node.func is LoopOverCoordinate: for old_arg, new_arg in zip(node.args, args): node.replace(old_arg, new_arg) return node @@ -464,18 +440,19 @@ def typing_from_sympy_inspection(eqs, default_type="double", default_int_type='i Returns: dictionary, mapping symbol name to type """ + from pystencils.astnodes import SympyAssignment, Conditional, Node result = defaultdict(lambda: default_type) if hasattr(default_type, 'numpy_dtype'): result['_complex_type'] = (np.zeros((1,), default_type.numpy_dtype) * 1j).dtype else: result['_complex_type'] = (np.zeros((1,), default_type) * 1j).dtype for eq in eqs: - if isinstance(eq, ast.Conditional): + if isinstance(eq, Conditional): result.update(typing_from_sympy_inspection(eq.true_block.args)) if eq.false_block: result.update(typing_from_sympy_inspection( eq.false_block.args)) - elif isinstance(eq, ast.Node) and not isinstance(eq, ast.SympyAssignment): + elif isinstance(eq, Node) and not isinstance(eq, SympyAssignment): continue else: from pystencils.cpu.vectorization import vec_all, vec_any diff --git a/pystencils/utils.py b/pystencils/utils.py index 3afdbc582ef7dece1933dbaf5b00be149f9cbd30..dc8d35ee64dcfdb0ef6f9f687526fe3379ce8fbd 100644 --- a/pystencils/utils.py +++ b/pystencils/utils.py @@ -220,3 +220,17 @@ class LinearEquationSystem: break result -= 1 self.next_zero_row = result + + +class ContextVar: + def __init__(self, value): + self.stack = [value] + + @contextmanager + def __call__(self, new_value): + self.stack.append(new_value) + yield self + self.stack.pop() + + def get(self): + return self.stack[-1] diff --git a/pystencils_tests/test_types.py b/pystencils_tests/test_types.py index 5c2b008e4ba4b0bd1fbf28e96fbef8affeef0e4c..774306d8d7c36ab2ff4a026e89f4ed4e78785c4e 100644 --- a/pystencils_tests/test_types.py +++ b/pystencils_tests/test_types.py @@ -3,7 +3,40 @@ import numpy as np import pystencils as ps from pystencils.typing import TypedSymbol, get_type_of_expression, VectorType, collate_types, create_type, \ - typed_symbols, type_all_numbers, matrix_symbols, CastFunc, PointerArithmeticFunc, PointerType + typed_symbols, CastFunc, PointerArithmeticFunc, PointerType, result_type + + +def test_result_type(): + i = np.dtype('int32') + l = np.dtype('int64') + ui = np.dtype('uint32') + ul = np.dtype('uint64') + f = np.dtype('float32') + d = np.dtype('float64') + b = np.dtype('bool') + + assert result_type(i, l) == l + assert result_type(l, i) == l + assert result_type(ui, i) == i + assert result_type(ui, l) == l + assert result_type(ul, i) == i + assert result_type(ul, l) == l + assert result_type(d, f) == d + assert result_type(f, d) == d + assert result_type(i, f) == f + assert result_type(l, f) == f + assert result_type(ui, f) == f + assert result_type(ul, f) == f + assert result_type(i, d) == d + assert result_type(l, d) == d + assert result_type(ui, d) == d + assert result_type(ul, d) == d + assert result_type(b, i) == i + assert result_type(b, l) == l + assert result_type(b, ui) == ui + assert result_type(b, ul) == ul + assert result_type(b, f) == f + assert result_type(b, d) == d def test_collation():