Skip to content
Snippets Groups Projects
Commit 0a8c16b0 authored by Martin Bauer's avatar Martin Bauer
Browse files

Merge branch 'infer-symbol-types-from-definition' into 'master'

Use get_type_of_expression in typing_form_sympy_inspection to infer types

See merge request pycodegen/pystencils!43
parents 6d875789 8770acd7
No related merge requests found
......@@ -101,6 +101,41 @@ minimal-sympy-master:
tags:
- docker
pycodegen-integration:
image: i10git.cs.fau.de:5005/pycodegen/pycodegen/full
stage: test
when: manual
script:
- git clone https://gitlab-ci-token:${CI_JOB_TOKEN}@i10git.cs.fau.de/pycodegen/pycodegen.git
- cd pycodegen
- git submodule sync --recursive
- git submodule update --init --recursive
- git submodule foreach git fetch origin # compare the latest master version!
- git submodule foreach git reset --hard origin/master
- cd pystencils
- git remote add test $CI_REPOSITORY_URL
- git fetch test
- git reset --hard $CI_COMMIT_SHA
- cd ..
- export PYTHONPATH=`pwd`/pystencils:`pwd`/lbmpy:`pwd`/pygrandchem:`pwd`/pystencils_walberla:`pwd`/lbmpy_walberla
- ./install_walberla.sh
- export NUM_CORES=$(nproc --all)
- mkdir -p ~/.config/matplotlib
- echo "backend:template" > ~/.config/matplotlib/matplotlibrc
- cd pystencils
- py.test -v -n $NUM_CORES .
- cd ../lbmpy
- py.test -v -n $NUM_CORES .
- cd ../pygrandchem
- py.test -v -n $NUM_CORES .
- cd ../walberla/build/
- make CodegenJacobiCPU CodegenJacobiGPU MicroBenchmarkGpuLbm LbCodeGenerationExample
tags:
- docker
- cuda
- AVX
# -------------------- Linter & Documentation --------------------------------------------------------------------------
......
import os
from collections import Hashable
from functools import partial
from itertools import chain
try:
from functools import lru_cache as memorycache
except ImportError:
from backports.functools_lru_cache import lru_cache as memorycache
try:
from joblib import Memory
from appdirs import user_cache_dir
......@@ -22,6 +26,20 @@ except ImportError:
return o
def _wrapper(wrapped_func, cached_func, *args, **kwargs):
if all(isinstance(a, Hashable) for a in chain(args, kwargs.values())):
return cached_func(*args, **kwargs)
else:
return wrapped_func(*args, **kwargs)
def memorycache_if_hashable(maxsize=128, typed=False):
def wrapper(func):
return partial(_wrapper, func, memorycache(maxsize, typed)(func))
return wrapper
# Disable memory cache:
# disk_cache = lambda o: o
# disk_cache_no_fallback = lambda o: o
import ctypes
from collections import defaultdict
from functools import partial
import numpy as np
import sympy as sp
from sympy.core.cache import cacheit
from sympy.logic.boolalg import Boolean
from pystencils.cache import memorycache
from pystencils.cache import memorycache, memorycache_if_hashable
from pystencils.utils import all_equal
try:
......@@ -408,11 +410,22 @@ def collate_types(types):
return result
@memorycache(maxsize=2048)
def get_type_of_expression(expr, default_float_type='double', default_int_type='int'):
@memorycache_if_hashable(maxsize=2048)
def get_type_of_expression(expr,
default_float_type='double',
default_int_type='int',
symbol_type_dict=None):
from pystencils.astnodes import ResolvedFieldAccess
from pystencils.cpu.vectorization import vec_all, vec_any
if not symbol_type_dict:
symbol_type_dict = defaultdict(lambda: create_type('double'))
get_type = partial(get_type_of_expression,
default_float_type=default_float_type,
default_int_type=default_int_type,
symbol_type_dict=symbol_type_dict)
expr = sp.sympify(expr)
if isinstance(expr, sp.Integer):
return create_type(default_int_type)
......@@ -423,14 +436,17 @@ def get_type_of_expression(expr, default_float_type='double', default_int_type='
elif isinstance(expr, TypedSymbol):
return expr.dtype
elif isinstance(expr, sp.Symbol):
raise ValueError("All symbols inside this expression have to be typed! ", str(expr))
if symbol_type_dict:
return symbol_type_dict[expr.name]
else:
raise ValueError("All symbols inside this expression have to be typed! ", str(expr))
elif isinstance(expr, cast_func):
return expr.args[1]
elif isinstance(expr, vec_any) or isinstance(expr, vec_all):
elif isinstance(expr, (vec_any, vec_all)):
return create_type("bool")
elif hasattr(expr, 'func') and expr.func == sp.Piecewise:
collated_result_type = collate_types(tuple(get_type_of_expression(a[0]) for a in expr.args))
collated_condition_type = collate_types(tuple(get_type_of_expression(a[1]) for a in expr.args))
collated_result_type = collate_types(tuple(get_type(a[0]) for a in expr.args))
collated_condition_type = collate_types(tuple(get_type(a[1]) for a in expr.args))
if type(collated_condition_type) is VectorType and type(collated_result_type) is not VectorType:
collated_result_type = VectorType(collated_result_type, width=collated_condition_type.width)
return collated_result_type
......@@ -440,16 +456,16 @@ def get_type_of_expression(expr, default_float_type='double', default_int_type='
elif isinstance(expr, sp.boolalg.Boolean) or isinstance(expr, sp.boolalg.BooleanFunction):
# if any arg is of vector type return a vector boolean, else return a normal scalar boolean
result = create_type("bool")
vec_args = [get_type_of_expression(a) for a in expr.args if isinstance(get_type_of_expression(a), VectorType)]
vec_args = [get_type(a) for a in expr.args if isinstance(get_type(a), VectorType)]
if vec_args:
result = VectorType(result, width=vec_args[0].width)
return result
elif isinstance(expr, sp.Pow):
return get_type_of_expression(expr.args[0])
elif isinstance(expr, (sp.Pow, sp.Sum, sp.Product)):
return get_type(expr.args[0])
elif isinstance(expr, sp.Expr):
expr: sp.Expr
if expr.args:
types = tuple(get_type_of_expression(a) for a in expr.args)
types = tuple(get_type(a) for a in expr.args)
return collate_types(types)
else:
if expr.is_integer:
......
from sympy.abc import a, b, c, d, e, f
import pystencils
from pystencils.data_types import cast_func, create_type
def test_type_interference():
x = pystencils.fields('x: float32[3d]')
assignments = pystencils.AssignmentCollection({
a: cast_func(10, create_type('float64')),
b: cast_func(10, create_type('uint16')),
e: 11,
c: b,
f: c + b,
d: c + b + x.center + e,
x.center: c + b + x.center
})
ast = pystencils.create_kernel(assignments)
code = str(pystencils.show_code(ast))
print(code)
assert 'double a' in code
assert 'uint16_t b' in code
assert 'uint16_t f' in code
assert 'int64_t e' in code
This diff is collapsed.
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment