diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py index 47f1fd7d1d1715bf85e326e485aad8231dadcdfe..b4db7f274b5bfbb1d99767a3f35ac2a32bd2cac1 100644 --- a/pystencils/astnodes.py +++ b/pystencils/astnodes.py @@ -506,7 +506,6 @@ class SympyAssignment(Node): super(SympyAssignment, self).__init__(parent=None) self._lhs_symbol = lhs_symbol self.rhs = sp.sympify(rhs_expr) - self._is_const = is_const self._is_declaration = self.__is_declaration() def __is_declaration(self): @@ -563,10 +562,6 @@ class SympyAssignment(Node): def is_declaration(self): return self._is_declaration - @property - def is_const(self): - return self._is_const - def replace(self, child, replacement): if child == self.lhs: replacement.parent = self diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index 7248846b312c920768e1f5b68af65aa21cfb46b8..ee6811f93f6b602602a1be1b8705b967142bd95c 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -225,13 +225,9 @@ class CBackend: def _print_SympyAssignment(self, node): if node.is_declaration: - if node.is_const: - prefix = 'const ' - else: - prefix = '' - data_type = prefix + self._print(node.lhs.dtype).replace(' const', '') + " " - return "%s%s = %s;" % (data_type, self.sympy_printer.doprint(node.lhs), - self.sympy_printer.doprint(node.rhs)) + data_type = self._print(node.lhs.dtype) + return "%s %s = %s;" % (data_type, self.sympy_printer.doprint(node.lhs), + self.sympy_printer.doprint(node.rhs)) else: lhs_type = get_type_of_expression(node.lhs) if type(lhs_type) is VectorType and isinstance(node.lhs, cast_func): diff --git a/pystencils/cpu/vectorization.py b/pystencils/cpu/vectorization.py index 6bf3a26def2dba238eeb4c64e325a41d58fd2afa..ac0f149475e1fa0ba2d7110bbbbab8d8bb62e790 100644 --- a/pystencils/cpu/vectorization.py +++ b/pystencils/cpu/vectorization.py @@ -63,11 +63,11 @@ def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'avx', if assume_inner_stride_one: replace_inner_stride_with_one(kernel_ast) - field_float_dtypes = set(f.dtype for f in all_fields if f.dtype.is_float()) + field_float_dtypes = set(f.dtype.numpy_dtype for f in all_fields if f.dtype.is_float()) if len(field_float_dtypes) != 1: raise NotImplementedError("Cannot vectorize kernels that contain accesses " "to differently typed floating point fields") - float_size = field_float_dtypes.pop().numpy_dtype.itemsize + float_size = field_float_dtypes.pop().itemsize assert float_size in (8, 4) vector_is = get_vector_instruction_set('double' if float_size == 8 else 'float', instruction_set=instruction_set) @@ -148,7 +148,7 @@ def insert_vector_casts(ast_node): return expr else: target_type = collate_types(arg_types) - casted_args = [cast_func(a, target_type) if t != target_type else a + casted_args = [cast_func(a, target_type) if not t.equal_ignoring_const(target_type) else a for a, t in zip(new_args, arg_types)] return expr.func(*casted_args) elif expr.func is sp.Pow: @@ -167,11 +167,11 @@ def insert_vector_casts(ast_node): if type(condition_target_type) is not VectorType and type(result_target_type) is VectorType: condition_target_type = VectorType(condition_target_type, width=result_target_type.width) - casted_results = [cast_func(a, result_target_type) if t != result_target_type else a + casted_results = [cast_func(a, result_target_type) if not t.equal_ignoring_const(result_target_type) else a for a, t in zip(new_results, types_of_results)] casted_conditions = [cast_func(a, condition_target_type) - if t != condition_target_type and a is not True else a + if not t.equal_ignoring_const(condition_target_type) and a is not True else a for a, t in zip(new_conditions, types_of_conditions)] return sp.Piecewise(*[(r, c) for r, c in zip(casted_results, casted_conditions)]) diff --git a/pystencils/data_types.py b/pystencils/data_types.py index 20eb94d6b83474bfab46dc491a05b3b1ed2191d9..2c07f02b3b90f0077cc75660d59e51ccffab6666 100644 --- a/pystencils/data_types.py +++ b/pystencils/data_types.py @@ -453,7 +453,7 @@ def collate_types(types, forbid_collation_to_float=False): types = tuple(t for t in types if t.is_float()) # use numpy collation -> create type from numpy type -> and, put vector type around if necessary result_numpy_type = np.result_type(*(t.numpy_dtype for t in types)) - result = BasicType(result_numpy_type) + result = BasicType(result_numpy_type, const=any(t.const for t in types)) if vector_type: result = VectorType(result, vector_type[0].width) return result @@ -618,6 +618,12 @@ class BasicType(Type): else: return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const) + def equal_ignoring_const(self, other): + if not isinstance(other, BasicType): + return False + else: + return self.numpy_dtype == other.numpy_dtype + def __hash__(self): return hash(str(self)) @@ -643,17 +649,23 @@ class VectorType(Type): else: return (self.base_type, self.width) == (other.base_type, other.width) + def equal_ignoring_const(self, other): + if not isinstance(other, VectorType): + return False + else: + return self.base_type.equal_ignoring_const(other.base_type) + def __str__(self): if self.instruction_set is None: return "%s[%d]" % (self.base_type, self.width) else: - if self.base_type == create_type("int64"): + if self.base_type.numpy_dtype == np.int64: return self.instruction_set['int'] - elif self.base_type == create_type("float64"): + elif self.base_type.numpy_dtype == np.float64: return self.instruction_set['double'] - elif self.base_type == create_type("float32"): + elif self.base_type.numpy_dtype == np.float32: return self.instruction_set['float'] - elif self.base_type == create_type("bool"): + elif self.base_type.numpy_dtype == np.bool: return self.instruction_set['bool'] else: raise NotImplementedError() @@ -692,6 +704,12 @@ class PointerType(Type): else: return (self.base_type, self.const, self.restrict) == (other.base_type, other.const, other.restrict) + def equal_ignoring_const(self, other): + if not isinstance(other, PointerType): + return False + else: + return self.base_type.equal_ignoring_const(other.base_type) + def __str__(self): components = [str(self.base_type), '*'] if self.restrict: @@ -743,6 +761,12 @@ class StructType: else: return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const) + def equal_ignoring_const(self, other): + if not isinstance(other, StructType): + return False + else: + return self.numpy_dtype == other.numpy_dtype + def __str__(self): # structs are handled byte-wise result = "uint8_t" diff --git a/pystencils/kernelparameters.py b/pystencils/kernelparameters.py index 9284a1ee5d14dd97c63f44574e020f1bff24d4e5..adfd58d7fa23b259fa05ab6d169fba0999e030de 100644 --- a/pystencils/kernelparameters.py +++ b/pystencils/kernelparameters.py @@ -16,7 +16,7 @@ would reference back to the field. from sympy.core.cache import cacheit from pystencils.data_types import ( - PointerType, TypedSymbol, create_composite_type_from_string, get_base_type) + BasicType, PointerType, TypedSymbol, create_composite_type_from_string, get_base_type) SHAPE_DTYPE = create_composite_type_from_string("const int64") STRIDE_DTYPE = create_composite_type_from_string("const int64") @@ -78,7 +78,8 @@ class FieldPointerSymbol(TypedSymbol): def __new_stage2__(cls, field_name, field_dtype, const): name = "_data_{name}".format(name=field_name) - dtype = PointerType(get_base_type(field_dtype), const=const, restrict=True) + base_type = BasicType(get_base_type(field_dtype), const=const) + dtype = PointerType(base_type, const=True, restrict=True) obj = super(FieldPointerSymbol, cls).__xnew__(cls, name, dtype) obj.field_name = field_name return obj diff --git a/pystencils/transformations.py b/pystencils/transformations.py index 1bfb0511ac9a6a8afe11b8275e4ec8d8d75cb9f6..133f48f45968f786115db29756157fd784494bda 100644 --- a/pystencils/transformations.py +++ b/pystencils/transformations.py @@ -878,7 +878,9 @@ class KernelConstraintsCheck: assert isinstance(lhs, sp.Symbol) self._update_accesses_lhs(lhs) if not isinstance(lhs, (AbstractField.AbstractAccess, TypedSymbol)): - return TypedSymbol(lhs.name, self._type_for_symbol[lhs.name]) + dtype = create_type(self._type_for_symbol[lhs.name]) + dtype.const = True + return TypedSymbol(lhs.name, dtype) else: return lhs