Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Jonas Plewinski
pystencils
Commits
43f6d5de
Commit
43f6d5de
authored
Aug 28, 2019
by
Stephan Seitz
Committed by
Stephan Seitz
Sep 23, 2019
Browse files
Use get_type_of_expression in typing_form_sympy_inspection to infer types
parent
d6301eea
Changes
4
Expand all
Hide whitespace changes
Inline
Side-by-side
pystencils/cache.py
View file @
43f6d5de
import
os
from
collections
import
Hashable
from
functools
import
partial
from
itertools
import
chain
try
:
from
functools
import
lru_cache
as
memorycache
except
ImportError
:
from
backports.functools_lru_cache
import
lru_cache
as
memorycache
try
:
from
joblib
import
Memory
from
appdirs
import
user_cache_dir
...
...
@@ -22,6 +26,20 @@ except ImportError:
return
o
def
_wrapper
(
wrapped_func
,
cached_func
,
*
args
,
**
kwargs
):
if
all
(
isinstance
(
a
,
Hashable
)
for
a
in
chain
(
args
,
kwargs
.
values
())):
return
cached_func
(
*
args
,
**
kwargs
)
else
:
return
wrapped_func
(
*
args
,
**
kwargs
)
def
memorycache_if_hashable
(
maxsize
=
128
,
typed
=
False
):
def
wrapper
(
func
):
return
partial
(
_wrapper
,
func
,
memorycache
(
maxsize
,
typed
)(
func
))
return
wrapper
# Disable memory cache:
# disk_cache = lambda o: o
# disk_cache_no_fallback = lambda o: o
pystencils/data_types.py
View file @
43f6d5de
import
ctypes
from
collections
import
defaultdict
from
functools
import
partial
import
numpy
as
np
import
sympy
as
sp
from
sympy.core.cache
import
cacheit
from
sympy.logic.boolalg
import
Boolean
from
pystencils.cache
import
memorycache
from
pystencils.cache
import
memorycache
,
memorycache_if_hashable
from
pystencils.utils
import
all_equal
try
:
...
...
@@ -408,11 +410,22 @@ def collate_types(types):
return
result
@
memorycache
(
maxsize
=
2048
)
def
get_type_of_expression
(
expr
,
default_float_type
=
'double'
,
default_int_type
=
'int'
):
@
memorycache_if_hashable
(
maxsize
=
2048
)
def
get_type_of_expression
(
expr
,
default_float_type
=
'double'
,
default_int_type
=
'int'
,
symbol_type_dict
=
None
):
from
pystencils.astnodes
import
ResolvedFieldAccess
from
pystencils.cpu.vectorization
import
vec_all
,
vec_any
if
not
symbol_type_dict
:
symbol_type_dict
=
defaultdict
(
lambda
:
create_type
(
'double'
))
get_type
=
partial
(
get_type_of_expression
,
default_float_type
=
default_float_type
,
default_int_type
=
default_int_type
,
symbol_type_dict
=
symbol_type_dict
)
expr
=
sp
.
sympify
(
expr
)
if
isinstance
(
expr
,
sp
.
Integer
):
return
create_type
(
default_int_type
)
...
...
@@ -423,14 +436,15 @@ def get_type_of_expression(expr, default_float_type='double', default_int_type='
elif
isinstance
(
expr
,
TypedSymbol
):
return
expr
.
dtype
elif
isinstance
(
expr
,
sp
.
Symbol
):
raise
ValueError
(
"All symbols inside this expression have to be typed! "
,
str
(
expr
))
return
symbol_type_dict
[
expr
.
name
]
# raise ValueError("All symbols iside this expression have to be typed! ", str(expr))
elif
isinstance
(
expr
,
cast_func
):
return
expr
.
args
[
1
]
elif
isinstance
(
expr
,
vec_any
)
or
isinstance
(
expr
,
vec_all
):
elif
isinstance
(
expr
,
(
vec_any
,
vec_all
)
)
:
return
create_type
(
"bool"
)
elif
hasattr
(
expr
,
'func'
)
and
expr
.
func
==
sp
.
Piecewise
:
collated_result_type
=
collate_types
(
tuple
(
get_type
_of_expression
(
a
[
0
])
for
a
in
expr
.
args
))
collated_condition_type
=
collate_types
(
tuple
(
get_type
_of_expression
(
a
[
1
])
for
a
in
expr
.
args
))
collated_result_type
=
collate_types
(
tuple
(
get_type
(
a
[
0
])
for
a
in
expr
.
args
))
collated_condition_type
=
collate_types
(
tuple
(
get_type
(
a
[
1
])
for
a
in
expr
.
args
))
if
type
(
collated_condition_type
)
is
VectorType
and
type
(
collated_result_type
)
is
not
VectorType
:
collated_result_type
=
VectorType
(
collated_result_type
,
width
=
collated_condition_type
.
width
)
return
collated_result_type
...
...
@@ -440,16 +454,16 @@ def get_type_of_expression(expr, default_float_type='double', default_int_type='
elif
isinstance
(
expr
,
sp
.
boolalg
.
Boolean
)
or
isinstance
(
expr
,
sp
.
boolalg
.
BooleanFunction
):
# if any arg is of vector type return a vector boolean, else return a normal scalar boolean
result
=
create_type
(
"bool"
)
vec_args
=
[
get_type
_of_expression
(
a
)
for
a
in
expr
.
args
if
isinstance
(
get_type
_of_expression
(
a
),
VectorType
)]
vec_args
=
[
get_type
(
a
)
for
a
in
expr
.
args
if
isinstance
(
get_type
(
a
),
VectorType
)]
if
vec_args
:
result
=
VectorType
(
result
,
width
=
vec_args
[
0
].
width
)
return
result
elif
isinstance
(
expr
,
sp
.
Pow
):
return
get_type
_of_expression
(
expr
.
args
[
0
])
elif
isinstance
(
expr
,
(
sp
.
Pow
,
sp
.
Sum
,
sp
.
Product
)
):
return
get_type
(
expr
.
args
[
0
])
elif
isinstance
(
expr
,
sp
.
Expr
):
expr
:
sp
.
Expr
if
expr
.
args
:
types
=
tuple
(
get_type
_of_expression
(
a
)
for
a
in
expr
.
args
)
types
=
tuple
(
get_type
(
a
)
for
a
in
expr
.
args
)
return
collate_types
(
types
)
else
:
if
expr
.
is_integer
:
...
...
pystencils/test_type_interference.py
0 → 100644
View file @
43f6d5de
from
sympy.abc
import
a
,
b
,
c
,
d
,
e
,
f
import
pystencils
from
pystencils.data_types
import
cast_func
,
create_type
def
test_type_interference
():
x
=
pystencils
.
fields
(
'x: float32[3d]'
)
assignments
=
pystencils
.
AssignmentCollection
({
a
:
cast_func
(
10
,
create_type
(
'float64'
)),
b
:
cast_func
(
10
,
create_type
(
'uint16'
)),
e
:
11
,
c
:
b
,
f
:
c
+
b
,
d
:
c
+
b
+
x
.
center
+
e
,
x
.
center
:
c
+
b
+
x
.
center
})
ast
=
pystencils
.
create_kernel
(
assignments
)
code
=
str
(
pystencils
.
show_code
(
ast
))
print
(
code
)
assert
'double a'
in
code
assert
'uint16_t b'
in
code
assert
'uint16_t f'
in
code
assert
'int64_t e'
in
code
pystencils/transformations.py
View file @
43f6d5de
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment