diff --git a/pystencils/typing/cast_functions.py b/pystencils/typing/cast_functions.py index 76686c21110e21a8454c81369764135b70182db9..1b83d223cbff2ce08c1fc0516d2ce53dc2ec350a 100644 --- a/pystencils/typing/cast_functions.py +++ b/pystencils/typing/cast_functions.py @@ -7,7 +7,10 @@ from pystencils.typing.typed_sympy import TypedSymbol class CastFunc(sp.Function): - # TODO: documentation + """ + CastFunc is used in order to introduce static casts. They are especially useful as a way to signal what type + a certain node should have, if it is impossible to add a type to a node, e.g. a sp.Number. + """ is_Atom = True def __new__(cls, *args, **kwargs): @@ -29,7 +32,6 @@ class CastFunc(sp.Function): # rhs = cast_func(0, 'int') # print( sp.Ne(lhs, rhs) ) # would give true if all cast_funcs are booleans # -> thus a separate class boolean_cast_func is introduced - # TODO check this if isinstance(expr, Boolean) and (not isinstance(expr, TypedSymbol) or expr.dtype == BasicType('bool')): cls = BooleanCastFunc @@ -105,19 +107,22 @@ class BooleanCastFunc(CastFunc, Boolean): class VectorMemoryAccess(CastFunc): - # TODO: documentation - # Arguments are: read/write expression, type, aligned, nontemporal, mask (or none), stride + """ + Special memory access for vectorized kernel. + Arguments: read/write expression, type, aligned, non-temporal, mask (or none), stride + """ nargs = (6,) class ReinterpretCastFunc(CastFunc): - # TODO: documentation + """ + Reinterpret cast is necessary for the StructType + """ pass class PointerArithmeticFunc(sp.Function, Boolean): - # TODO: documentation - # TODO wtf is this???? + # TODO: documentation, or deprecate! @property def canonical(self): if hasattr(self.args[0], 'canonical'): diff --git a/pystencils/typing/leaf_typing.py b/pystencils/typing/leaf_typing.py index 0d133038688f722d7d428c01442db2d8fb2458a9..560b94143f8b35182d2be4d02033a8ade652839b 100644 --- a/pystencils/typing/leaf_typing.py +++ b/pystencils/typing/leaf_typing.py @@ -27,9 +27,7 @@ from pystencils.utils import ContextVar class TypeAdder: - # TODO: Logs - # TODO: specification - # TODO: split this into checker and leaf typing + # TODO: specification -> jupyter notebook """Checks if the input to create_kernel is valid. Test the following conditions: @@ -50,7 +48,6 @@ class TypeAdder: self.default_number_float = ContextVar(default_number_float) self.default_number_int = ContextVar(default_number_int) - # TODO: check if this adds only types to leave nodes of AST, get type info def visit(self, obj): if isinstance(obj, (list, tuple)): return [self.visit(e) for e in obj] @@ -244,60 +241,3 @@ class TypeAdder: return expr.func(*new_args) if new_args else expr, collated_type else: raise NotImplementedError(f'expr {type(expr)}: {expr} unknown to typing') - - def process_expression(self, rhs, type_constants=True): # TODO DELETE - import pystencils.integer_functions - from pystencils.bit_masks import flag_cond - - if isinstance(rhs, Field.Access): - return rhs - elif isinstance(rhs, TypedSymbol): - return rhs - elif isinstance(rhs, sp.Symbol): - return TypedSymbol(rhs.name, self._type_for_symbol[rhs.name]) - elif type_constants and isinstance(rhs, np.generic): - assert False, f'Why do we have a np.generic in rhs???? {rhs}' - # return CastFunc(rhs, create_type(rhs.dtype)) - elif type_constants and isinstance(rhs, sp.Number): - return CastFunc(rhs, create_type(self._type_for_symbol['_constant'])) - # Very important that this clause comes before BooleanFunction - elif isinstance(rhs, sp.Equality): - if isinstance(rhs.args[1], sp.Number): - return sp.Equality( - self.process_expression(rhs.args[0], type_constants), - rhs.args[1]) # TODO: process args[1] as number with a good type - else: - return sp.Equality( - self.process_expression(rhs.args[0], type_constants), - self.process_expression(rhs.args[1], type_constants)) - elif isinstance(rhs, CastFunc): - return CastFunc( - self.process_expression(rhs.args[0], type_constants=False), # TODO: recommend type - rhs.dtype) - elif isinstance(rhs, BooleanFunction) or \ - type(rhs) in pystencils.integer_functions.__dict__.values(): - new_args = [self.process_expression(a, type_constants) for a in rhs.args] - types_of_expressions = [get_type_of_expression(a) for a in new_args] - arg_type = collate_types(types_of_expressions) - new_args = [a if not hasattr(a, 'dtype') or a.dtype == arg_type - else CastFunc(a, arg_type) - for a in new_args] - return rhs.func(*new_args) - elif isinstance(rhs, flag_cond): # TODO - # do not process the arguments to the bit shift - they must remain integers - processed_args = (self.process_expression(a) for a in rhs.args[2:]) - return flag_cond(rhs.args[0], rhs.args[1], *processed_args) - elif isinstance(rhs, sp.Mul): - new_args = [ - self.process_expression(arg, type_constants) - if arg not in (-1, 1) else arg for arg in rhs.args - ] - return rhs.func(*new_args) if new_args else rhs - elif isinstance(rhs, sp.Indexed): - return rhs - elif isinstance(rhs, sp.Pow): - # don't process exponents -> they should remain integers # TODO - return sp.Pow(self.process_expression(rhs.args[0], type_constants), rhs.args[1]) - else: - new_args = [self.process_expression(arg, type_constants) for arg in rhs.args] - return rhs.func(*new_args) if new_args else rhs diff --git a/pystencils/typing/typed_sympy.py b/pystencils/typing/typed_sympy.py index e99227c520c798faece3929f5c5dc1f2143db8ea..302c2f9987b2db1a907710678ddbb7234668cfc6 100644 --- a/pystencils/typing/typed_sympy.py +++ b/pystencils/typing/typed_sympy.py @@ -8,7 +8,6 @@ from pystencils.typing.types import BasicType, create_type, PointerType def assumptions_from_dtype(dtype: Union[BasicType, np.dtype]): - # TODO: type hints and if dtype is correct type form Numpy """Derives SymPy assumptions from :class:`BasicType` or a Numpy dtype Args: @@ -44,7 +43,7 @@ class TypedSymbol(sp.Symbol): def __new_stage2__(cls, name, dtype, **kwargs): # TODO does not match signature of sp.Symbol??? # TODO: also Symbol should be allowed ---> see sympy Variable - assumptions = assumptions_from_dtype(dtype) # TODO should by dtype a np.dtype or our Type??? + assumptions = assumptions_from_dtype(dtype) assumptions.update(kwargs) obj = super(TypedSymbol, cls).__xnew__(cls, name, **assumptions) try: @@ -99,7 +98,6 @@ SHAPE_DTYPE = BasicType('int64', const=True) STRIDE_DTYPE = BasicType('int64', const=True) -# TODO: is it really necessary to have special symbols for that???? class FieldStrideSymbol(TypedSymbol): """Sympy symbol representing the stride value of a field in a specific coordinate.""" def __new__(cls, *args, **kwds):