Commit d8f49def authored by Martin Bauer's avatar Martin Bauer
Browse files

First try to get a better const treatment

parent 443527ae
Pipeline #18689 failed with stage
in 3 minutes and 7 seconds
......@@ -506,7 +506,6 @@ class SympyAssignment(Node):
super(SympyAssignment, self).__init__(parent=None)
self._lhs_symbol = lhs_symbol
self.rhs = sp.sympify(rhs_expr)
self._is_const = is_const
self._is_declaration = self.__is_declaration()
def __is_declaration(self):
......@@ -563,10 +562,6 @@ class SympyAssignment(Node):
def is_declaration(self):
return self._is_declaration
@property
def is_const(self):
return self._is_const
def replace(self, child, replacement):
if child == self.lhs:
replacement.parent = self
......
......@@ -225,13 +225,9 @@ class CBackend:
def _print_SympyAssignment(self, node):
if node.is_declaration:
if node.is_const:
prefix = 'const '
else:
prefix = ''
data_type = prefix + self._print(node.lhs.dtype).replace(' const', '') + " "
return "%s%s = %s;" % (data_type, self.sympy_printer.doprint(node.lhs),
self.sympy_printer.doprint(node.rhs))
data_type = self._print(node.lhs.dtype)
return "%s %s = %s;" % (data_type, self.sympy_printer.doprint(node.lhs),
self.sympy_printer.doprint(node.rhs))
else:
lhs_type = get_type_of_expression(node.lhs)
if type(lhs_type) is VectorType and isinstance(node.lhs, cast_func):
......
......@@ -63,11 +63,11 @@ def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'avx',
if assume_inner_stride_one:
replace_inner_stride_with_one(kernel_ast)
field_float_dtypes = set(f.dtype for f in all_fields if f.dtype.is_float())
field_float_dtypes = set(f.dtype.numpy_dtype for f in all_fields if f.dtype.is_float())
if len(field_float_dtypes) != 1:
raise NotImplementedError("Cannot vectorize kernels that contain accesses "
"to differently typed floating point fields")
float_size = field_float_dtypes.pop().numpy_dtype.itemsize
float_size = field_float_dtypes.pop().itemsize
assert float_size in (8, 4)
vector_is = get_vector_instruction_set('double' if float_size == 8 else 'float',
instruction_set=instruction_set)
......@@ -148,7 +148,7 @@ def insert_vector_casts(ast_node):
return expr
else:
target_type = collate_types(arg_types)
casted_args = [cast_func(a, target_type) if t != target_type else a
casted_args = [cast_func(a, target_type) if not t.equal_ignoring_const(target_type) else a
for a, t in zip(new_args, arg_types)]
return expr.func(*casted_args)
elif expr.func is sp.Pow:
......@@ -167,11 +167,11 @@ def insert_vector_casts(ast_node):
if type(condition_target_type) is not VectorType and type(result_target_type) is VectorType:
condition_target_type = VectorType(condition_target_type, width=result_target_type.width)
casted_results = [cast_func(a, result_target_type) if t != result_target_type else a
casted_results = [cast_func(a, result_target_type) if not t.equal_ignoring_const(result_target_type) else a
for a, t in zip(new_results, types_of_results)]
casted_conditions = [cast_func(a, condition_target_type)
if t != condition_target_type and a is not True else a
if not t.equal_ignoring_const(condition_target_type) and a is not True else a
for a, t in zip(new_conditions, types_of_conditions)]
return sp.Piecewise(*[(r, c) for r, c in zip(casted_results, casted_conditions)])
......
......@@ -453,7 +453,7 @@ def collate_types(types, forbid_collation_to_float=False):
types = tuple(t for t in types if t.is_float())
# use numpy collation -> create type from numpy type -> and, put vector type around if necessary
result_numpy_type = np.result_type(*(t.numpy_dtype for t in types))
result = BasicType(result_numpy_type)
result = BasicType(result_numpy_type, const=any(t.const for t in types))
if vector_type:
result = VectorType(result, vector_type[0].width)
return result
......@@ -618,6 +618,12 @@ class BasicType(Type):
else:
return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
def equal_ignoring_const(self, other):
if not isinstance(other, BasicType):
return False
else:
return self.numpy_dtype == other.numpy_dtype
def __hash__(self):
return hash(str(self))
......@@ -643,17 +649,23 @@ class VectorType(Type):
else:
return (self.base_type, self.width) == (other.base_type, other.width)
def equal_ignoring_const(self, other):
if not isinstance(other, VectorType):
return False
else:
return self.base_type.equal_ignoring_const(other.base_type)
def __str__(self):
if self.instruction_set is None:
return "%s[%d]" % (self.base_type, self.width)
else:
if self.base_type == create_type("int64"):
if self.base_type.numpy_dtype == np.int64:
return self.instruction_set['int']
elif self.base_type == create_type("float64"):
elif self.base_type.numpy_dtype == np.float64:
return self.instruction_set['double']
elif self.base_type == create_type("float32"):
elif self.base_type.numpy_dtype == np.float32:
return self.instruction_set['float']
elif self.base_type == create_type("bool"):
elif self.base_type.numpy_dtype == np.bool:
return self.instruction_set['bool']
else:
raise NotImplementedError()
......@@ -692,6 +704,12 @@ class PointerType(Type):
else:
return (self.base_type, self.const, self.restrict) == (other.base_type, other.const, other.restrict)
def equal_ignoring_const(self, other):
if not isinstance(other, PointerType):
return False
else:
return self.base_type.equal_ignoring_const(other.base_type)
def __str__(self):
components = [str(self.base_type), '*']
if self.restrict:
......@@ -743,6 +761,12 @@ class StructType:
else:
return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
def equal_ignoring_const(self, other):
if not isinstance(other, StructType):
return False
else:
return self.numpy_dtype == other.numpy_dtype
def __str__(self):
# structs are handled byte-wise
result = "uint8_t"
......
......@@ -16,7 +16,7 @@ would reference back to the field.
from sympy.core.cache import cacheit
from pystencils.data_types import (
PointerType, TypedSymbol, create_composite_type_from_string, get_base_type)
BasicType, PointerType, TypedSymbol, create_composite_type_from_string, get_base_type)
SHAPE_DTYPE = create_composite_type_from_string("const int64")
STRIDE_DTYPE = create_composite_type_from_string("const int64")
......@@ -78,7 +78,8 @@ class FieldPointerSymbol(TypedSymbol):
def __new_stage2__(cls, field_name, field_dtype, const):
name = "_data_{name}".format(name=field_name)
dtype = PointerType(get_base_type(field_dtype), const=const, restrict=True)
base_type = BasicType(get_base_type(field_dtype), const=const)
dtype = PointerType(base_type, const=True, restrict=True)
obj = super(FieldPointerSymbol, cls).__xnew__(cls, name, dtype)
obj.field_name = field_name
return obj
......
......@@ -878,7 +878,9 @@ class KernelConstraintsCheck:
assert isinstance(lhs, sp.Symbol)
self._update_accesses_lhs(lhs)
if not isinstance(lhs, (AbstractField.AbstractAccess, TypedSymbol)):
return TypedSymbol(lhs.name, self._type_for_symbol[lhs.name])
dtype = create_type(self._type_for_symbol[lhs.name])
dtype.const = True
return TypedSymbol(lhs.name, dtype)
else:
return lhs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment