Commit d8f49def authored by Martin Bauer's avatar Martin Bauer
Browse files

First try to get a better const treatment

parent 443527ae
...@@ -506,7 +506,6 @@ class SympyAssignment(Node): ...@@ -506,7 +506,6 @@ class SympyAssignment(Node):
super(SympyAssignment, self).__init__(parent=None) super(SympyAssignment, self).__init__(parent=None)
self._lhs_symbol = lhs_symbol self._lhs_symbol = lhs_symbol
self.rhs = sp.sympify(rhs_expr) self.rhs = sp.sympify(rhs_expr)
self._is_const = is_const
self._is_declaration = self.__is_declaration() self._is_declaration = self.__is_declaration()
def __is_declaration(self): def __is_declaration(self):
...@@ -563,10 +562,6 @@ class SympyAssignment(Node): ...@@ -563,10 +562,6 @@ class SympyAssignment(Node):
def is_declaration(self): def is_declaration(self):
return self._is_declaration return self._is_declaration
@property
def is_const(self):
return self._is_const
def replace(self, child, replacement): def replace(self, child, replacement):
if child == self.lhs: if child == self.lhs:
replacement.parent = self replacement.parent = self
......
...@@ -225,13 +225,9 @@ class CBackend: ...@@ -225,13 +225,9 @@ class CBackend:
def _print_SympyAssignment(self, node): def _print_SympyAssignment(self, node):
if node.is_declaration: if node.is_declaration:
if node.is_const: data_type = self._print(node.lhs.dtype)
prefix = 'const ' return "%s %s = %s;" % (data_type, self.sympy_printer.doprint(node.lhs),
else: self.sympy_printer.doprint(node.rhs))
prefix = ''
data_type = prefix + self._print(node.lhs.dtype).replace(' const', '') + " "
return "%s%s = %s;" % (data_type, self.sympy_printer.doprint(node.lhs),
self.sympy_printer.doprint(node.rhs))
else: else:
lhs_type = get_type_of_expression(node.lhs) lhs_type = get_type_of_expression(node.lhs)
if type(lhs_type) is VectorType and isinstance(node.lhs, cast_func): if type(lhs_type) is VectorType and isinstance(node.lhs, cast_func):
......
...@@ -63,11 +63,11 @@ def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'avx', ...@@ -63,11 +63,11 @@ def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'avx',
if assume_inner_stride_one: if assume_inner_stride_one:
replace_inner_stride_with_one(kernel_ast) replace_inner_stride_with_one(kernel_ast)
field_float_dtypes = set(f.dtype for f in all_fields if f.dtype.is_float()) field_float_dtypes = set(f.dtype.numpy_dtype for f in all_fields if f.dtype.is_float())
if len(field_float_dtypes) != 1: if len(field_float_dtypes) != 1:
raise NotImplementedError("Cannot vectorize kernels that contain accesses " raise NotImplementedError("Cannot vectorize kernels that contain accesses "
"to differently typed floating point fields") "to differently typed floating point fields")
float_size = field_float_dtypes.pop().numpy_dtype.itemsize float_size = field_float_dtypes.pop().itemsize
assert float_size in (8, 4) assert float_size in (8, 4)
vector_is = get_vector_instruction_set('double' if float_size == 8 else 'float', vector_is = get_vector_instruction_set('double' if float_size == 8 else 'float',
instruction_set=instruction_set) instruction_set=instruction_set)
...@@ -148,7 +148,7 @@ def insert_vector_casts(ast_node): ...@@ -148,7 +148,7 @@ def insert_vector_casts(ast_node):
return expr return expr
else: else:
target_type = collate_types(arg_types) target_type = collate_types(arg_types)
casted_args = [cast_func(a, target_type) if t != target_type else a casted_args = [cast_func(a, target_type) if not t.equal_ignoring_const(target_type) else a
for a, t in zip(new_args, arg_types)] for a, t in zip(new_args, arg_types)]
return expr.func(*casted_args) return expr.func(*casted_args)
elif expr.func is sp.Pow: elif expr.func is sp.Pow:
...@@ -167,11 +167,11 @@ def insert_vector_casts(ast_node): ...@@ -167,11 +167,11 @@ def insert_vector_casts(ast_node):
if type(condition_target_type) is not VectorType and type(result_target_type) is VectorType: if type(condition_target_type) is not VectorType and type(result_target_type) is VectorType:
condition_target_type = VectorType(condition_target_type, width=result_target_type.width) condition_target_type = VectorType(condition_target_type, width=result_target_type.width)
casted_results = [cast_func(a, result_target_type) if t != result_target_type else a casted_results = [cast_func(a, result_target_type) if not t.equal_ignoring_const(result_target_type) else a
for a, t in zip(new_results, types_of_results)] for a, t in zip(new_results, types_of_results)]
casted_conditions = [cast_func(a, condition_target_type) casted_conditions = [cast_func(a, condition_target_type)
if t != condition_target_type and a is not True else a if not t.equal_ignoring_const(condition_target_type) and a is not True else a
for a, t in zip(new_conditions, types_of_conditions)] for a, t in zip(new_conditions, types_of_conditions)]
return sp.Piecewise(*[(r, c) for r, c in zip(casted_results, casted_conditions)]) return sp.Piecewise(*[(r, c) for r, c in zip(casted_results, casted_conditions)])
......
...@@ -453,7 +453,7 @@ def collate_types(types, forbid_collation_to_float=False): ...@@ -453,7 +453,7 @@ def collate_types(types, forbid_collation_to_float=False):
types = tuple(t for t in types if t.is_float()) types = tuple(t for t in types if t.is_float())
# use numpy collation -> create type from numpy type -> and, put vector type around if necessary # use numpy collation -> create type from numpy type -> and, put vector type around if necessary
result_numpy_type = np.result_type(*(t.numpy_dtype for t in types)) result_numpy_type = np.result_type(*(t.numpy_dtype for t in types))
result = BasicType(result_numpy_type) result = BasicType(result_numpy_type, const=any(t.const for t in types))
if vector_type: if vector_type:
result = VectorType(result, vector_type[0].width) result = VectorType(result, vector_type[0].width)
return result return result
...@@ -618,6 +618,12 @@ class BasicType(Type): ...@@ -618,6 +618,12 @@ class BasicType(Type):
else: else:
return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const) return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
def equal_ignoring_const(self, other):
if not isinstance(other, BasicType):
return False
else:
return self.numpy_dtype == other.numpy_dtype
def __hash__(self): def __hash__(self):
return hash(str(self)) return hash(str(self))
...@@ -643,17 +649,23 @@ class VectorType(Type): ...@@ -643,17 +649,23 @@ class VectorType(Type):
else: else:
return (self.base_type, self.width) == (other.base_type, other.width) return (self.base_type, self.width) == (other.base_type, other.width)
def equal_ignoring_const(self, other):
if not isinstance(other, VectorType):
return False
else:
return self.base_type.equal_ignoring_const(other.base_type)
def __str__(self): def __str__(self):
if self.instruction_set is None: if self.instruction_set is None:
return "%s[%d]" % (self.base_type, self.width) return "%s[%d]" % (self.base_type, self.width)
else: else:
if self.base_type == create_type("int64"): if self.base_type.numpy_dtype == np.int64:
return self.instruction_set['int'] return self.instruction_set['int']
elif self.base_type == create_type("float64"): elif self.base_type.numpy_dtype == np.float64:
return self.instruction_set['double'] return self.instruction_set['double']
elif self.base_type == create_type("float32"): elif self.base_type.numpy_dtype == np.float32:
return self.instruction_set['float'] return self.instruction_set['float']
elif self.base_type == create_type("bool"): elif self.base_type.numpy_dtype == np.bool:
return self.instruction_set['bool'] return self.instruction_set['bool']
else: else:
raise NotImplementedError() raise NotImplementedError()
...@@ -692,6 +704,12 @@ class PointerType(Type): ...@@ -692,6 +704,12 @@ class PointerType(Type):
else: else:
return (self.base_type, self.const, self.restrict) == (other.base_type, other.const, other.restrict) return (self.base_type, self.const, self.restrict) == (other.base_type, other.const, other.restrict)
def equal_ignoring_const(self, other):
if not isinstance(other, PointerType):
return False
else:
return self.base_type.equal_ignoring_const(other.base_type)
def __str__(self): def __str__(self):
components = [str(self.base_type), '*'] components = [str(self.base_type), '*']
if self.restrict: if self.restrict:
...@@ -743,6 +761,12 @@ class StructType: ...@@ -743,6 +761,12 @@ class StructType:
else: else:
return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const) return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
def equal_ignoring_const(self, other):
if not isinstance(other, StructType):
return False
else:
return self.numpy_dtype == other.numpy_dtype
def __str__(self): def __str__(self):
# structs are handled byte-wise # structs are handled byte-wise
result = "uint8_t" result = "uint8_t"
......
...@@ -16,7 +16,7 @@ would reference back to the field. ...@@ -16,7 +16,7 @@ would reference back to the field.
from sympy.core.cache import cacheit from sympy.core.cache import cacheit
from pystencils.data_types import ( from pystencils.data_types import (
PointerType, TypedSymbol, create_composite_type_from_string, get_base_type) BasicType, PointerType, TypedSymbol, create_composite_type_from_string, get_base_type)
SHAPE_DTYPE = create_composite_type_from_string("const int64") SHAPE_DTYPE = create_composite_type_from_string("const int64")
STRIDE_DTYPE = create_composite_type_from_string("const int64") STRIDE_DTYPE = create_composite_type_from_string("const int64")
...@@ -78,7 +78,8 @@ class FieldPointerSymbol(TypedSymbol): ...@@ -78,7 +78,8 @@ class FieldPointerSymbol(TypedSymbol):
def __new_stage2__(cls, field_name, field_dtype, const): def __new_stage2__(cls, field_name, field_dtype, const):
name = "_data_{name}".format(name=field_name) name = "_data_{name}".format(name=field_name)
dtype = PointerType(get_base_type(field_dtype), const=const, restrict=True) base_type = BasicType(get_base_type(field_dtype), const=const)
dtype = PointerType(base_type, const=True, restrict=True)
obj = super(FieldPointerSymbol, cls).__xnew__(cls, name, dtype) obj = super(FieldPointerSymbol, cls).__xnew__(cls, name, dtype)
obj.field_name = field_name obj.field_name = field_name
return obj return obj
......
...@@ -878,7 +878,9 @@ class KernelConstraintsCheck: ...@@ -878,7 +878,9 @@ class KernelConstraintsCheck:
assert isinstance(lhs, sp.Symbol) assert isinstance(lhs, sp.Symbol)
self._update_accesses_lhs(lhs) self._update_accesses_lhs(lhs)
if not isinstance(lhs, (AbstractField.AbstractAccess, TypedSymbol)): if not isinstance(lhs, (AbstractField.AbstractAccess, TypedSymbol)):
return TypedSymbol(lhs.name, self._type_for_symbol[lhs.name]) dtype = create_type(self._type_for_symbol[lhs.name])
dtype.const = True
return TypedSymbol(lhs.name, dtype)
else: else:
return lhs return lhs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment