Skip to content
Snippets Groups Projects
Commit 43f6d5de authored by Stephan Seitz's avatar Stephan Seitz Committed by Stephan Seitz
Browse files

Use get_type_of_expression in typing_form_sympy_inspection to infer types

parent d6301eea
1 merge request!43Use get_type_of_expression in typing_form_sympy_inspection to infer types
import os
from collections import Hashable
from functools import partial
from itertools import chain
try:
from functools import lru_cache as memorycache
except ImportError:
from backports.functools_lru_cache import lru_cache as memorycache
try:
from joblib import Memory
from appdirs import user_cache_dir
......@@ -22,6 +26,20 @@ except ImportError:
return o
def _wrapper(wrapped_func, cached_func, *args, **kwargs):
if all(isinstance(a, Hashable) for a in chain(args, kwargs.values())):
return cached_func(*args, **kwargs)
else:
return wrapped_func(*args, **kwargs)
def memorycache_if_hashable(maxsize=128, typed=False):
def wrapper(func):
return partial(_wrapper, func, memorycache(maxsize, typed)(func))
return wrapper
# Disable memory cache:
# disk_cache = lambda o: o
# disk_cache_no_fallback = lambda o: o
import ctypes
from collections import defaultdict
from functools import partial
import numpy as np
import sympy as sp
from sympy.core.cache import cacheit
from sympy.logic.boolalg import Boolean
from pystencils.cache import memorycache
from pystencils.cache import memorycache, memorycache_if_hashable
from pystencils.utils import all_equal
try:
......@@ -408,11 +410,22 @@ def collate_types(types):
return result
@memorycache(maxsize=2048)
def get_type_of_expression(expr, default_float_type='double', default_int_type='int'):
@memorycache_if_hashable(maxsize=2048)
def get_type_of_expression(expr,
default_float_type='double',
default_int_type='int',
symbol_type_dict=None):
from pystencils.astnodes import ResolvedFieldAccess
from pystencils.cpu.vectorization import vec_all, vec_any
if not symbol_type_dict:
symbol_type_dict = defaultdict(lambda: create_type('double'))
get_type = partial(get_type_of_expression,
default_float_type=default_float_type,
default_int_type=default_int_type,
symbol_type_dict=symbol_type_dict)
expr = sp.sympify(expr)
if isinstance(expr, sp.Integer):
return create_type(default_int_type)
......@@ -423,14 +436,15 @@ def get_type_of_expression(expr, default_float_type='double', default_int_type='
elif isinstance(expr, TypedSymbol):
return expr.dtype
elif isinstance(expr, sp.Symbol):
raise ValueError("All symbols inside this expression have to be typed! ", str(expr))
return symbol_type_dict[expr.name]
# raise ValueError("All symbols iside this expression have to be typed! ", str(expr))
elif isinstance(expr, cast_func):
return expr.args[1]
elif isinstance(expr, vec_any) or isinstance(expr, vec_all):
elif isinstance(expr, (vec_any, vec_all)):
return create_type("bool")
elif hasattr(expr, 'func') and expr.func == sp.Piecewise:
collated_result_type = collate_types(tuple(get_type_of_expression(a[0]) for a in expr.args))
collated_condition_type = collate_types(tuple(get_type_of_expression(a[1]) for a in expr.args))
collated_result_type = collate_types(tuple(get_type(a[0]) for a in expr.args))
collated_condition_type = collate_types(tuple(get_type(a[1]) for a in expr.args))
if type(collated_condition_type) is VectorType and type(collated_result_type) is not VectorType:
collated_result_type = VectorType(collated_result_type, width=collated_condition_type.width)
return collated_result_type
......@@ -440,16 +454,16 @@ def get_type_of_expression(expr, default_float_type='double', default_int_type='
elif isinstance(expr, sp.boolalg.Boolean) or isinstance(expr, sp.boolalg.BooleanFunction):
# if any arg is of vector type return a vector boolean, else return a normal scalar boolean
result = create_type("bool")
vec_args = [get_type_of_expression(a) for a in expr.args if isinstance(get_type_of_expression(a), VectorType)]
vec_args = [get_type(a) for a in expr.args if isinstance(get_type(a), VectorType)]
if vec_args:
result = VectorType(result, width=vec_args[0].width)
return result
elif isinstance(expr, sp.Pow):
return get_type_of_expression(expr.args[0])
elif isinstance(expr, (sp.Pow, sp.Sum, sp.Product)):
return get_type(expr.args[0])
elif isinstance(expr, sp.Expr):
expr: sp.Expr
if expr.args:
types = tuple(get_type_of_expression(a) for a in expr.args)
types = tuple(get_type(a) for a in expr.args)
return collate_types(types)
else:
if expr.is_integer:
......
from sympy.abc import a, b, c, d, e, f
import pystencils
from pystencils.data_types import cast_func, create_type
def test_type_interference():
x = pystencils.fields('x: float32[3d]')
assignments = pystencils.AssignmentCollection({
a: cast_func(10, create_type('float64')),
b: cast_func(10, create_type('uint16')),
e: 11,
c: b,
f: c + b,
d: c + b + x.center + e,
x.center: c + b + x.center
})
ast = pystencils.create_kernel(assignments)
code = str(pystencils.show_code(ast))
print(code)
assert 'double a' in code
assert 'uint16_t b' in code
assert 'uint16_t f' in code
assert 'int64_t e' in code
This diff is collapsed.
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment