Commit 820ad582 authored by Jan Hönig's avatar Jan Hönig
Browse files

Documentation. Todo cleanup. Removed unnecessary code

parent 3302c69d
......@@ -7,7 +7,10 @@ from pystencils.typing.typed_sympy import TypedSymbol
class CastFunc(sp.Function):
# TODO: documentation
CastFunc is used in order to introduce static casts. They are especially useful as a way to signal what type
a certain node should have, if it is impossible to add a type to a node, e.g. a sp.Number.
is_Atom = True
def __new__(cls, *args, **kwargs):
......@@ -29,7 +32,6 @@ class CastFunc(sp.Function):
# rhs = cast_func(0, 'int')
# print( sp.Ne(lhs, rhs) ) # would give true if all cast_funcs are booleans
# -> thus a separate class boolean_cast_func is introduced
# TODO check this
if isinstance(expr, Boolean) and (not isinstance(expr, TypedSymbol) or expr.dtype == BasicType('bool')):
cls = BooleanCastFunc
......@@ -105,19 +107,22 @@ class BooleanCastFunc(CastFunc, Boolean):
class VectorMemoryAccess(CastFunc):
# TODO: documentation
# Arguments are: read/write expression, type, aligned, nontemporal, mask (or none), stride
Special memory access for vectorized kernel.
Arguments: read/write expression, type, aligned, non-temporal, mask (or none), stride
nargs = (6,)
class ReinterpretCastFunc(CastFunc):
# TODO: documentation
Reinterpret cast is necessary for the StructType
class PointerArithmeticFunc(sp.Function, Boolean):
# TODO: documentation
# TODO wtf is this????
# TODO: documentation, or deprecate!
def canonical(self):
if hasattr(self.args[0], 'canonical'):
......@@ -27,9 +27,7 @@ from pystencils.utils import ContextVar
class TypeAdder:
# TODO: Logs
# TODO: specification
# TODO: split this into checker and leaf typing
# TODO: specification -> jupyter notebook
"""Checks if the input to create_kernel is valid.
Test the following conditions:
......@@ -50,7 +48,6 @@ class TypeAdder:
self.default_number_float = ContextVar(default_number_float)
self.default_number_int = ContextVar(default_number_int)
# TODO: check if this adds only types to leave nodes of AST, get type info
def visit(self, obj):
if isinstance(obj, (list, tuple)):
return [self.visit(e) for e in obj]
......@@ -244,60 +241,3 @@ class TypeAdder:
return expr.func(*new_args) if new_args else expr, collated_type
raise NotImplementedError(f'expr {type(expr)}: {expr} unknown to typing')
def process_expression(self, rhs, type_constants=True): # TODO DELETE
import pystencils.integer_functions
from pystencils.bit_masks import flag_cond
if isinstance(rhs, Field.Access):
return rhs
elif isinstance(rhs, TypedSymbol):
return rhs
elif isinstance(rhs, sp.Symbol):
return TypedSymbol(, self._type_for_symbol[])
elif type_constants and isinstance(rhs, np.generic):
assert False, f'Why do we have a np.generic in rhs???? {rhs}'
# return CastFunc(rhs, create_type(rhs.dtype))
elif type_constants and isinstance(rhs, sp.Number):
return CastFunc(rhs, create_type(self._type_for_symbol['_constant']))
# Very important that this clause comes before BooleanFunction
elif isinstance(rhs, sp.Equality):
if isinstance(rhs.args[1], sp.Number):
return sp.Equality(
self.process_expression(rhs.args[0], type_constants),
rhs.args[1]) # TODO: process args[1] as number with a good type
return sp.Equality(
self.process_expression(rhs.args[0], type_constants),
self.process_expression(rhs.args[1], type_constants))
elif isinstance(rhs, CastFunc):
return CastFunc(
self.process_expression(rhs.args[0], type_constants=False), # TODO: recommend type
elif isinstance(rhs, BooleanFunction) or \
type(rhs) in pystencils.integer_functions.__dict__.values():
new_args = [self.process_expression(a, type_constants) for a in rhs.args]
types_of_expressions = [get_type_of_expression(a) for a in new_args]
arg_type = collate_types(types_of_expressions)
new_args = [a if not hasattr(a, 'dtype') or a.dtype == arg_type
else CastFunc(a, arg_type)
for a in new_args]
return rhs.func(*new_args)
elif isinstance(rhs, flag_cond): # TODO
# do not process the arguments to the bit shift - they must remain integers
processed_args = (self.process_expression(a) for a in rhs.args[2:])
return flag_cond(rhs.args[0], rhs.args[1], *processed_args)
elif isinstance(rhs, sp.Mul):
new_args = [
self.process_expression(arg, type_constants)
if arg not in (-1, 1) else arg for arg in rhs.args
return rhs.func(*new_args) if new_args else rhs
elif isinstance(rhs, sp.Indexed):
return rhs
elif isinstance(rhs, sp.Pow):
# don't process exponents -> they should remain integers # TODO
return sp.Pow(self.process_expression(rhs.args[0], type_constants), rhs.args[1])
new_args = [self.process_expression(arg, type_constants) for arg in rhs.args]
return rhs.func(*new_args) if new_args else rhs
......@@ -8,7 +8,6 @@ from pystencils.typing.types import BasicType, create_type, PointerType
def assumptions_from_dtype(dtype: Union[BasicType, np.dtype]):
# TODO: type hints and if dtype is correct type form Numpy
"""Derives SymPy assumptions from :class:`BasicType` or a Numpy dtype
......@@ -44,7 +43,7 @@ class TypedSymbol(sp.Symbol):
def __new_stage2__(cls, name, dtype, **kwargs): # TODO does not match signature of sp.Symbol???
# TODO: also Symbol should be allowed ---> see sympy Variable
assumptions = assumptions_from_dtype(dtype) # TODO should by dtype a np.dtype or our Type???
assumptions = assumptions_from_dtype(dtype)
obj = super(TypedSymbol, cls).__xnew__(cls, name, **assumptions)
......@@ -99,7 +98,6 @@ SHAPE_DTYPE = BasicType('int64', const=True)
STRIDE_DTYPE = BasicType('int64', const=True)
# TODO: is it really necessary to have special symbols for that????
class FieldStrideSymbol(TypedSymbol):
"""Sympy symbol representing the stride value of a field in a specific coordinate."""
def __new__(cls, *args, **kwds):
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment