Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Frederik Hennig
pystencils
Commits
89fe68a6
Commit
89fe68a6
authored
Feb 16, 2022
by
Jan Hönig
Browse files
Removed inster_casts
parent
32c4acb5
Changes
2
Hide whitespace changes
Inline
Side-by-side
pystencils/typing/__init__.py
View file @
89fe68a6
...
...
@@ -5,7 +5,7 @@ from pystencils.typing.types import (is_supported_type, numpy_name_to_c, Abstrac
from
pystencils.typing.typed_sympy
import
(
assumptions_from_dtype
,
TypedSymbol
,
FieldStrideSymbol
,
FieldShapeSymbol
,
FieldPointerSymbol
)
from
pystencils.typing.utilities
import
(
typed_symbols
,
get_base_type
,
result_type
,
collate_types
,
get_type_of_expression
,
insert_casts
,
get_next_parent_of_type
,
parents_of_type
)
get_type_of_expression
,
get_next_parent_of_type
,
parents_of_type
)
__all__
=
[
'CastFunc'
,
'BooleanCastFunc'
,
'VectorMemoryAccess'
,
'ReinterpretCastFunc'
,
'PointerArithmeticFunc'
,
...
...
@@ -13,4 +13,4 @@ __all__ = ['CastFunc', 'BooleanCastFunc', 'VectorMemoryAccess', 'ReinterpretCast
'VectorType'
,
'PointerType'
,
'StructType'
,
'create_type'
,
'assumptions_from_dtype'
,
'TypedSymbol'
,
'FieldStrideSymbol'
,
'FieldShapeSymbol'
,
'FieldPointerSymbol'
,
'typed_symbols'
,
'get_base_type'
,
'result_type'
,
'collate_types'
,
'get_type_of_expression'
,
'insert_casts'
,
'get_next_parent_of_type'
,
'parents_of_type'
]
'get_type_of_expression'
,
'get_next_parent_of_type'
,
'parents_of_type'
]
pystencils/typing/utilities.py
View file @
89fe68a6
...
...
@@ -211,101 +211,6 @@ if int(sympy_version[0]) * 100 + int(sympy_version[1]) >= 109:
sp
.
Basic
.
__reduce_ex__
=
basic_reduce_ex
def
insert_casts
(
node
):
"""Checks the types and inserts casts and pointer arithmetic where necessary.
Args:
node: the head node of the ast
Returns:
modified AST
"""
from
pystencils.astnodes
import
SympyAssignment
,
ResolvedFieldAccess
,
LoopOverCoordinate
,
Block
def
cast
(
zipped_args_types
,
target_dtype
):
"""
Adds casts to the arguments if their type differs from the target type
:param zipped_args_types: a zipped list of args and types
:param target_dtype: The target data type
:return: args with possible casts
"""
casted_args
=
[]
for
argument
,
data_type
in
zipped_args_types
:
if
data_type
.
numpy_dtype
!=
target_dtype
.
numpy_dtype
:
# ignoring const
casted_args
.
append
(
CastFunc
(
argument
,
target_dtype
))
else
:
casted_args
.
append
(
argument
)
return
casted_args
def
pointer_arithmetic
(
expr_args
):
"""
Creates a valid pointer arithmetic function
:param expr_args: Arguments of the add expression
:return: pointer_arithmetic_func
"""
pointer
=
None
new_args
=
[]
for
arg
,
data_type
in
expr_args
:
if
data_type
.
func
is
PointerType
:
assert
pointer
is
None
pointer
=
arg
for
arg
,
data_type
in
expr_args
:
if
arg
!=
pointer
:
assert
data_type
.
is_int
()
or
data_type
.
is_uint
()
new_args
.
append
(
arg
)
new_args
=
sp
.
Add
(
*
new_args
)
if
len
(
new_args
)
>
0
else
new_args
return
PointerArithmeticFunc
(
pointer
,
new_args
)
if
isinstance
(
node
,
sp
.
AtomicExpr
)
or
isinstance
(
node
,
CastFunc
):
return
node
args
=
[]
for
arg
in
node
.
args
:
args
.
append
(
insert_casts
(
arg
))
# TODO indexed, LoopOverCoordinate
if
node
.
func
in
(
sp
.
Add
,
sp
.
Mul
,
sp
.
Or
,
sp
.
And
,
sp
.
Pow
,
sp
.
Eq
,
sp
.
Ne
,
sp
.
Lt
,
sp
.
Le
,
sp
.
Gt
,
sp
.
Ge
):
# TODO optimize pow, don't cast integer on double
types
=
[
get_type_of_expression
(
arg
)
for
arg
in
args
]
assert
len
(
types
)
>
0
# Never ever, ever collate to float type for boolean functions!
target
=
collate_types
(
types
,
forbid_collation_to_float
=
isinstance
(
node
.
func
,
BooleanFunction
))
zipped
=
list
(
zip
(
args
,
types
))
if
target
.
func
is
PointerType
:
assert
node
.
func
is
sp
.
Add
return
pointer_arithmetic
(
zipped
)
else
:
return
node
.
func
(
*
cast
(
zipped
,
target
))
elif
node
.
func
is
SympyAssignment
:
lhs
=
args
[
0
]
rhs
=
args
[
1
]
target
=
get_type_of_expression
(
lhs
)
if
target
.
func
is
PointerType
:
return
node
.
func
(
*
args
)
# TODO fix, not complete
else
:
return
node
.
func
(
lhs
,
*
cast
([(
rhs
,
get_type_of_expression
(
rhs
))],
target
))
elif
node
.
func
is
ResolvedFieldAccess
:
return
node
elif
node
.
func
is
Block
:
for
old_arg
,
new_arg
in
zip
(
node
.
args
,
args
):
node
.
replace
(
old_arg
,
new_arg
)
return
node
elif
node
.
func
is
LoopOverCoordinate
:
for
old_arg
,
new_arg
in
zip
(
node
.
args
,
args
):
node
.
replace
(
old_arg
,
new_arg
)
return
node
elif
node
.
func
is
sp
.
Piecewise
:
expressions
=
[
expr
for
(
expr
,
_
)
in
args
]
types
=
[
get_type_of_expression
(
expr
)
for
expr
in
expressions
]
target
=
collate_types
(
types
)
zipped
=
list
(
zip
(
expressions
,
types
))
casted_expressions
=
cast
(
zipped
,
target
)
args
=
[
arg
.
func
(
*
[
expr
,
arg
.
cond
])
for
(
arg
,
expr
)
in
zip
(
args
,
casted_expressions
)
]
return
node
.
func
(
*
args
)
def
get_next_parent_of_type
(
node
,
parent_type
):
"""Returns the next parent node of given type or None, if root is reached.
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment