Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Markus Holzer
pystencils
Commits
c5f6c366
Commit
c5f6c366
authored
Dec 03, 2021
by
Markus Holzer
Browse files
Implemented piecewise
parent
11a81449
Pipeline
#36117
failed with stages
in 2 minutes and 19 seconds
Changes
8
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
pystencils/backends/cbackend.py
View file @
c5f6c366
...
...
@@ -491,10 +491,7 @@ class CustomSympyPrinter(CCodePrinter):
return
f
"&(
{
self
.
_print
(
expr
.
args
[
0
])
}
)"
elif
isinstance
(
expr
,
CastFunc
):
arg
,
data_type
=
expr
.
args
if
isinstance
(
arg
,
sp
.
Number
)
and
arg
.
is_finite
:
return
self
.
_typed_number
(
arg
,
data_type
)
else
:
return
f
"((
{
data_type
}
)(
{
self
.
_print
(
arg
)
}
))"
return
f
"((
{
data_type
}
)(
{
self
.
_print
(
arg
)
}
))"
elif
isinstance
(
expr
,
fast_division
):
return
f
"(
{
self
.
_print
(
expr
.
args
[
0
]
/
expr
.
args
[
1
])
}
)"
elif
isinstance
(
expr
,
fast_sqrt
):
...
...
pystencils/cpu/kernelcreation.py
View file @
c5f6c366
...
...
@@ -4,7 +4,7 @@ import sympy as sp
import
numpy
as
np
import
pystencils.astnodes
as
ast
from
pystencils.assignment
import
Assignment
from
pystencils.
simp.
assignment
_collection
import
Assignment
Collection
from
pystencils.config
import
CreateKernelConfig
from
pystencils.enums
import
Target
,
Backend
from
pystencils.astnodes
import
Block
,
KernelFunction
,
LoopOverCoordinate
,
SympyAssignment
...
...
@@ -17,12 +17,8 @@ from pystencils.transformations import (
move_constants_before_loop
,
parse_base_pointer_info
,
resolve_buffer_accesses
,
resolve_field_accesses
,
split_inner_loop
)
from
pystencils.kernel_contrains_check
import
KernelConstraintsCheck
AssignmentOrAstNodeList
=
List
[
Union
[
Assignment
,
ast
.
Node
]]
def
create_kernel
(
assignments
:
AssignmentOrAstNodeList
,
config
:
CreateKernelConfig
,
split_groups
)
->
KernelFunction
:
def
create_kernel
(
assignments
:
AssignmentCollection
,
config
:
CreateKernelConfig
)
->
KernelFunction
:
"""Creates an abstract syntax tree for a kernel function, by taking a list of update rules.
Loops are created according to the field accesses in the equations.
...
...
@@ -31,8 +27,6 @@ def create_kernel(assignments: AssignmentOrAstNodeList, config: CreateKernelConf
assignments: list of sympy equations, containing accesses to :class:`pystencils.field.Field`.
Defining the update rules of the kernel
config: create kernel config
split_groups: Specification on how to split up inner loop into multiple loops. For details see
transformation :func:`pystencils.transformation.split_inner_loop`
Returns:
AST node representing a function, that can be printed as C or CUDA code
...
...
@@ -41,8 +35,13 @@ def create_kernel(assignments: AssignmentOrAstNodeList, config: CreateKernelConf
type_info
=
config
.
data_type
iteration_slice
=
config
.
iteration_slice
ghost_layers
=
config
.
ghost_layers
skip_independence_check
=
config
.
skip_independence_check
allow_double_writes
=
config
.
allow_double_writes
fields_written
=
assignments
.
bound_fields
fields_read
=
assignments
.
free_fields
split_groups
=
()
if
'split_groups'
in
assignments
.
simplification_hints
:
split_groups
=
assignments
.
simplification_hints
[
'split_groups'
]
assignments
=
assignments
.
all_assignments
# TODO: try to delete
def
type_symbol
(
term
):
...
...
@@ -56,12 +55,7 @@ def create_kernel(assignments: AssignmentOrAstNodeList, config: CreateKernelConf
else
:
raise
ValueError
(
"Term has to be field access or symbol"
)
check
=
KernelConstraintsCheck
(
check_independence_condition
=
skip_independence_check
,
check_double_write_condition
=
allow_double_writes
)
check
.
visit
(
assignments
)
fields_read
=
check
.
fields_read
fields_written
=
check
.
fields_written
# TODO move add_types to create_domain_kernel or create_kernel
assignments
=
add_types
(
assignments
,
config
)
...
...
@@ -78,7 +72,6 @@ def create_kernel(assignments: AssignmentOrAstNodeList, config: CreateKernelConf
ast_node
=
KernelFunction
(
loop_node
,
Target
.
CPU
,
Backend
.
C
,
compile_function
=
make_python_function
,
ghost_layers
=
ghost_layer_info
,
function_name
=
function_name
,
assignments
=
assignments
)
# TODO move split groups here
if
split_groups
:
typed_split_groups
=
[[
type_symbol
(
s
)
for
s
in
split_group
]
for
split_group
in
split_groups
]
split_inner_loop
(
ast_node
,
typed_split_groups
)
...
...
@@ -100,7 +93,7 @@ def create_kernel(assignments: AssignmentOrAstNodeList, config: CreateKernelConf
return
ast_node
def
create_indexed_kernel
(
assignments
:
Assignment
OrAstNodeList
,
index_fields
,
function_name
=
"kernel"
,
def
create_indexed_kernel
(
assignments
:
Assignment
Collection
,
index_fields
,
function_name
=
"kernel"
,
type_info
=
None
,
coordinate_names
=
(
'x'
,
'y'
,
'z'
))
->
KernelFunction
:
"""
Similar to :func:`create_kernel`, but here not all cells of a field are updated but only cells with
...
...
pystencils/kernel_contrains_check.py
View file @
c5f6c366
...
...
@@ -4,6 +4,7 @@ from typing import Union
import
sympy
as
sp
from
sympy.codegen
import
Assignment
from
pystencils.simp
import
AssignmentCollection
from
pystencils
import
astnodes
as
ast
,
TypedSymbol
from
pystencils.field
import
Field
from
pystencils.transformations
import
NestedScopes
...
...
@@ -41,7 +42,9 @@ class KernelConstraintsCheck:
self
.
check_double_write_condition
=
check_double_write_condition
def
visit
(
self
,
obj
):
if
isinstance
(
obj
,
list
)
or
isinstance
(
obj
,
tuple
):
if
isinstance
(
obj
,
AssignmentCollection
):
[
self
.
visit
(
e
)
for
e
in
obj
.
all_assignments
]
elif
isinstance
(
obj
,
list
)
or
isinstance
(
obj
,
tuple
):
[
self
.
visit
(
e
)
for
e
in
obj
]
elif
isinstance
(
obj
,
(
sp
.
Eq
,
ast
.
SympyAssignment
,
Assignment
)):
self
.
process_assignment
(
obj
)
...
...
pystencils/kernelcreation.py
View file @
c5f6c366
...
...
@@ -12,7 +12,7 @@ from pystencils.enums import Target, Backend
from
pystencils.field
import
Field
,
FieldType
from
pystencils.gpucuda.indexing
import
indexing_creator_from_params
from
pystencils.simp.assignment_collection
import
AssignmentCollection
from
pystencils.
simp.simplifications
import
apply_sympy_optimisations
from
pystencils.
kernel_contrains_check
import
KernelConstraintsCheck
from
pystencils.simplificationfactory
import
create_simplification_strategy
from
pystencils.stencil
import
direction_string_to_offset
,
inverse_direction_string
from
pystencils.transformations
import
(
...
...
@@ -62,6 +62,8 @@ def create_kernel(assignments: Union[Assignment, List[Assignment], AssignmentCol
if
isinstance
(
assignments
,
Assignment
):
assignments
=
[
assignments
]
assert
assignments
,
"Assignments must not be empty!"
if
isinstance
(
assignments
,
list
):
assignments
=
AssignmentCollection
(
assignments
)
if
config
.
index_fields
:
return
create_indexed_kernel
(
assignments
,
config
=
config
)
...
...
@@ -69,7 +71,7 @@ def create_kernel(assignments: Union[Assignment, List[Assignment], AssignmentCol
return
create_domain_kernel
(
assignments
,
config
=
config
)
def
create_domain_kernel
(
assignments
:
List
[
Assignment
]
,
*
,
config
:
CreateKernelConfig
):
def
create_domain_kernel
(
assignments
:
Assignment
Collection
,
*
,
config
:
CreateKernelConfig
):
"""
Creates abstract syntax tree (AST) of kernel, using a list of update equations.
...
...
@@ -82,6 +84,7 @@ def create_domain_kernel(assignments: List[Assignment], *, config: CreateKernelC
can be compiled with through its 'compile()' member
Example:
# TODO change to assignment collection
>>> import pystencils as ps
>>> import numpy as np
>>> s, d = ps.fields('s, d: [2D]')
...
...
@@ -98,6 +101,7 @@ def create_domain_kernel(assignments: List[Assignment], *, config: CreateKernelC
[0., 4., 4., 4., 0.],
[0., 0., 0., 0., 0.]])
"""
# --- applying first default simplifications
try
:
if
config
.
default_assignment_simplifications
and
isinstance
(
assignments
,
AssignmentCollection
):
...
...
@@ -107,20 +111,18 @@ def create_domain_kernel(assignments: List[Assignment], *, config: CreateKernelC
warnings
.
warn
(
f
"It was not possible to apply the default pystencils optimisations to the "
f
"AssignmentCollection due to the following problem :
{
e
}
"
)
# TODO: shift to CPU
# ---- Normalizing parameters
split_groups
=
()
if
isinstance
(
assignments
,
AssignmentCollection
):
if
'split_groups'
in
assignments
.
simplification_hints
:
split_groups
=
assignments
.
simplification_hints
[
'split_groups'
]
assignments
=
assignments
.
all_assignments
assignments
.
evaluate_terms
()
try
:
if
config
.
default_assignment_simplifications
:
assignments
=
apply_sympy_optimisations
(
assignments
)
except
Exception
as
e
:
warnings
.
warn
(
f
"It was not possible to apply the default SymPy optimisations to the "
f
"Assignments due to the following problem :
{
e
}
"
)
# --- eval
# TODO split apply_sympy_optimisations and do the eval here
# FUTURE WORK from here we shouldn't NEED sympy
# --- check constrains
check
=
KernelConstraintsCheck
(
check_independence_condition
=
config
.
skip_independence_check
,
check_double_write_condition
=
config
.
allow_double_writes
)
check
.
visit
(
assignments
)
assert
assignments
.
bound_fields
==
check
.
fields_written
,
f
'WTF'
assert
assignments
.
rhs_fields
==
check
.
fields_read
,
f
'WTF'
# ---- Creating ast
ast
=
None
...
...
@@ -128,7 +130,7 @@ def create_domain_kernel(assignments: List[Assignment], *, config: CreateKernelC
if
config
.
backend
==
Backend
.
C
:
from
pystencils.cpu
import
add_openmp
,
create_kernel
# TODO: data type keyword should be unified to data_type
ast
=
create_kernel
(
assignments
,
config
=
config
,
split_groups
=
split_groups
)
ast
=
create_kernel
(
assignments
,
config
=
config
)
for
optimization
in
config
.
cpu_prepend_optimizations
:
optimization
(
ast
)
omp_collapse
=
None
...
...
@@ -170,7 +172,7 @@ def create_domain_kernel(assignments: List[Assignment], *, config: CreateKernelC
return
ast
def
create_indexed_kernel
(
assignments
:
List
[
Assignment
]
,
*
,
config
:
CreateKernelConfig
):
def
create_indexed_kernel
(
assignments
:
Assignment
Collection
,
*
,
config
:
CreateKernelConfig
):
"""
Similar to :func:`create_kernel`, but here not all cells of a field are updated but only cells with
coordinates which are stored in an index field. This traversal method can e.g. be used for boundary handling.
...
...
@@ -212,6 +214,8 @@ import pystencils.kernel_creation_config >>> import pystencils as ps
[0. , 0. , 0. , 4.3, 0. ],
[0. , 0. , 0. , 0. , 0. ]])
"""
# TODO do this in backends
assignments
=
assignments
.
all_assignments
ast
=
None
if
config
.
target
==
Target
.
CPU
and
config
.
backend
==
Backend
.
C
:
from
pystencils.cpu
import
add_openmp
,
create_indexed_kernel
...
...
pystencils/simp/assignment_collection.py
View file @
c5f6c366
...
...
@@ -3,6 +3,7 @@ from copy import copy
from
typing
import
Any
,
Dict
,
Iterable
,
Iterator
,
List
,
Optional
,
Sequence
,
Set
,
Union
import
sympy
as
sp
from
sympy.codegen.rewriting
import
ReplaceOptim
,
optimize
import
pystencils
from
pystencils.assignment
import
Assignment
...
...
@@ -107,16 +108,21 @@ class AssignmentCollection:
return
self
.
subexpressions
+
self
.
main_assignments
@
property
def
free
_symbols
(
self
)
->
Set
[
sp
.
Symbol
]:
"""All symbols used in the assignment collection, which
do not occur as left hand sides in
any assignment."""
free
_symbols
=
set
()
def
rhs
_symbols
(
self
)
->
Set
[
sp
.
Symbol
]:
"""All symbols used in the assignment collection, which
occur on the rhs of
any assignment."""
rhs
_symbols
=
set
()
for
eq
in
self
.
all_assignments
:
if
isinstance
(
eq
,
Assignment
):
free
_symbols
.
update
(
eq
.
rhs
.
atoms
(
sp
.
Symbol
))
rhs
_symbols
.
update
(
eq
.
rhs
.
atoms
(
sp
.
Symbol
))
elif
isinstance
(
eq
,
pystencils
.
astnodes
.
Node
):
free
_symbols
.
update
(
eq
.
undefined_symbols
)
rhs
_symbols
.
update
(
eq
.
undefined_symbols
)
return
free_symbols
-
self
.
bound_symbols
return
rhs_symbols
@
property
def
free_symbols
(
self
)
->
Set
[
sp
.
Symbol
]:
"""All symbols used in the assignment collection, which do not occur as left hand sides in any assignment."""
return
self
.
rhs_symbols
-
self
.
bound_symbols
@
property
def
bound_symbols
(
self
)
->
Set
[
sp
.
Symbol
]:
...
...
@@ -132,10 +138,15 @@ class AssignmentCollection:
assignment
.
symbols_defined
for
assignment
in
self
.
all_assignments
if
isinstance
(
assignment
,
pystencils
.
astnodes
.
Node
)
]
)
)
return
bound_symbols_set
@
property
def
rhs_fields
(
self
):
"""All fields accessed in the assignment collection, which do not occur as left hand sides in any assignment."""
return
{
s
.
field
for
s
in
self
.
rhs_symbols
if
hasattr
(
s
,
'field'
)}
@
property
def
free_fields
(
self
):
"""All fields accessed in the assignment collection, which do not occur as left hand sides in any assignment."""
...
...
@@ -152,7 +163,7 @@ class AssignmentCollection:
return
(
set
(
[
assignment
.
lhs
for
assignment
in
self
.
main_assignments
if
isinstance
(
assignment
,
Assignment
)]
).
union
(
*
[
assignment
.
symbols_defined
for
assignment
in
self
.
main_assignments
if
isinstance
(
assignment
,
pystencils
.
astnodes
.
Node
)]
assignment
,
pystencils
.
astnodes
.
Node
)]
))
@
property
...
...
@@ -214,6 +225,7 @@ class AssignmentCollection:
return
{
s
:
func
(
*
args
,
**
kwargs
)
for
s
,
func
in
lambdas
.
items
()}
return
f
# ---------------------------- Creating new modified collections ---------------------------------------------------
def
copy
(
self
,
...
...
@@ -353,10 +365,26 @@ class AssignmentCollection:
new_assignment
=
[
fast_subs
(
eq
,
substitution_dict
)
for
eq
in
self
.
main_assignments
]
return
self
.
copy
(
new_assignment
,
kept_subexpressions
)
def
evaluate_terms
(
self
):
evaluate_constant_terms
=
ReplaceOptim
(
lambda
e
:
hasattr
(
e
,
'is_constant'
)
and
e
.
is_constant
and
not
e
.
is_integer
,
lambda
p
:
p
.
evalf
())
sympy_optimisations
=
[
evaluate_constant_terms
]
self
.
subexpressions
=
[
Assignment
(
a
.
lhs
,
optimize
(
a
.
rhs
,
sympy_optimisations
))
if
hasattr
(
a
,
'lhs'
)
else
a
for
a
in
self
.
subexpressions
]
self
.
main_assignments
=
[
Assignment
(
a
.
lhs
,
optimize
(
a
.
rhs
,
sympy_optimisations
))
if
hasattr
(
a
,
'lhs'
)
else
a
for
a
in
self
.
main_assignments
]
# ----------------------------------------- Display and Printing -------------------------------------------------
def
_repr_html_
(
self
):
"""Interface to Jupyter notebook, to display as a nicely formatted HTML table"""
def
make_html_equation_table
(
equations
):
no_border
=
'style="border:none"'
html_table
=
'<table style="border:none; width: 100%; ">'
...
...
pystencils/simp/simplifications.py
View file @
c5f6c366
...
...
@@ -3,12 +3,10 @@ from typing import Callable, List, Sequence, Union
from
collections
import
defaultdict
import
sympy
as
sp
from
sympy.codegen.rewriting
import
optims_c99
,
optimize
from
sympy.codegen.rewriting
import
ReplaceOptim
from
pystencils.assignment
import
Assignment
from
pystencils.astnodes
import
Node
,
SympyAssignment
from
pystencils.field
import
Field
,
Field
from
pystencils.astnodes
import
Node
from
pystencils.field
import
Field
from
pystencils.sympyextensions
import
subs_additive
,
is_constant
,
recursive_collect
...
...
@@ -227,22 +225,29 @@ def apply_on_all_subexpressions(operation: Callable[[sp.Expr], sp.Expr]):
return
f
def
apply_sympy_optimisations
(
assignments
):
""" Evaluates constant expressions (e.g. :math:`
\\
sqrt{3}` will be replaced by its floating point representation)
and applies the default sympy optimisations. See sympy.codegen.rewriting
"""
# Evaluates all constant terms
evaluate_constant_terms
=
ReplaceOptim
(
lambda
e
:
hasattr
(
e
,
'is_constant'
)
and
e
.
is_constant
and
not
e
.
is_integer
,
lambda
p
:
p
.
evalf
())
sympy_optimisations
=
[
evaluate_constant_terms
]
+
list
(
optims_c99
)
assignments
=
[
Assignment
(
a
.
lhs
,
optimize
(
a
.
rhs
,
sympy_optimisations
))
if
hasattr
(
a
,
'lhs'
)
else
a
for
a
in
assignments
]
assignments_nodes
=
[
a
.
atoms
(
SympyAssignment
)
for
a
in
assignments
]
for
a
in
chain
.
from_iterable
(
assignments_nodes
):
a
.
optimize
(
sympy_optimisations
)
return
assignments
# TODO Markus
# TODO: make this really work for Assignmentcollections
# TODO: this function should ONLY evaluate
# TODO: do the optims_c99 elsewhere optionally
# def apply_sympy_optimisations(ac: AssignmentCollection):
# """ Evaluates constant expressions (e.g. :math:`\\sqrt{3}` will be replaced by its floating point representation)
# and applies the default sympy optimisations. See sympy.codegen.rewriting
# """
#
# # Evaluates all constant terms
#
# assignments = ac.all_assignments
#
# evaluate_constant_terms = ReplaceOptim(lambda e: hasattr(e, 'is_constant') and e.is_constant and not e.is_integer,
# lambda p: p.evalf())
#
# sympy_optimisations = [evaluate_constant_terms] + list(optims_c99)
#
# assignments = [Assignment(a.lhs, optimize(a.rhs, sympy_optimisations))
# if hasattr(a, 'lhs')
# else a for a in assignments]
# assignments_nodes = [a.atoms(SympyAssignment) for a in assignments]
# for a in chain.from_iterable(assignments_nodes):
# a.optimize(sympy_optimisations)
#
# return AssignmentCollection(assignments)
pystencils/typing/leaf_typing.py
View file @
c5f6c366
...
...
@@ -5,6 +5,11 @@ import logging
import
numpy
as
np
import
sympy
as
sp
from
sympy
import
Piecewise
from
sympy.functions.elementary.piecewise
import
ExprCondPair
from
sympy.codegen
import
Assignment
from
sympy.logic.boolalg
import
BooleanFunction
from
sympy.logic.boolalg
import
BooleanAtom
from
pystencils
import
astnodes
as
ast
from
pystencils.field
import
Field
...
...
@@ -13,8 +18,6 @@ from pystencils.typing.utilities import get_type_of_expression, collate_types
from
pystencils.typing.cast_functions
import
CastFunc
from
pystencils.typing.typed_sympy
import
TypedSymbol
from
pystencils.utils
import
ContextVar
from
sympy.codegen
import
Assignment
from
sympy.logic.boolalg
import
BooleanFunction
class
TypeAdder
:
...
...
@@ -93,6 +96,8 @@ class TypeAdder:
def
figure_out_type
(
self
,
expr
)
->
Tuple
[
Any
,
BasicType
]:
# TODO or abstract type? vector type?
# Trivial cases
from
pystencils.field
import
Field
import
pystencils.integer_functions
from
pystencils.bit_masks
import
flag_cond
if
isinstance
(
expr
,
Field
.
Access
):
return
expr
,
expr
.
dtype
...
...
@@ -104,24 +109,56 @@ class TypeAdder:
elif
isinstance
(
expr
,
np
.
generic
):
assert
False
,
f
'Why do we have a np.generic in rhs????
{
expr
}
'
elif
isinstance
(
expr
,
sp
.
Number
):
if
expr
.
is_Float
:
data_type
=
self
.
default_number_float
.
get
()
elif
expr
.
is_Integer
:
if
expr
.
is_Integer
:
data_type
=
self
.
default_number_int
.
get
()
elif
expr
.
is_Float
or
expr
.
is_Rational
:
data_type
=
self
.
default_number_float
.
get
()
else
:
assert
False
,
f
'
{
sp
.
Number
}
is neither Float nor Integer'
return
expr
,
data_type
# TODO add everything in between
elif
isinstance
(
expr
,
BooleanAtom
):
return
expr
,
BasicType
(
'bool'
)
elif
isinstance
(
expr
,
sp
.
Equality
):
new_args
=
[
self
.
figure_out_type
(
arg
)[
0
]
for
arg
in
expr
.
args
]
new_eq
=
sp
.
Equality
(
*
new_args
)
return
new_eq
,
BasicType
(
'bool'
)
elif
isinstance
(
expr
,
CastFunc
):
raise
NotImplementedError
(
'CastFunc'
)
elif
isinstance
(
expr
,
BooleanFunction
)
or
\
type
(
expr
,
)
in
pystencils
.
integer_functions
.
__dict__
.
values
():
raise
NotImplementedError
(
'BooleanFunction'
)
elif
isinstance
(
expr
,
flag_cond
):
# do not process the arguments to the bit shift - they must remain integers
raise
NotImplementedError
(
'flag_cond'
)
elif
isinstance
(
expr
,
sp
.
Mul
):
raise
NotImplementedError
(
'sp.Mul'
)
# TODO can we ignore this and move it to general expr handling, i.e. removing Mul?
args_types
=
[
self
.
figure_out_type
(
arg
)
for
arg
in
expr
.
args
if
arg
not
in
(
-
1
,
1
)]
return
None
# TODO collate types
# args_types = [self.figure_out_type(arg) for arg in expr.args if arg not in (-1, 1)]
elif
isinstance
(
expr
,
sp
.
Indexed
):
self
.
apply_type
(
expr
,
BasicType
(
'uintp'
))
# TODO double check
return
None
raise
NotImplementedError
(
'sp.Indexed'
)
elif
isinstance
(
expr
,
sp
.
Pow
):
# TODO sp.Pow should know a type
return
None
# TODO
args_types
=
[
self
.
figure_out_type
(
arg
)
for
arg
in
expr
.
args
]
collated_type
=
collate_types
([
t
for
_
,
t
in
args_types
])
return
expr
,
collated_type
elif
isinstance
(
expr
,
ExprCondPair
):
expr_expr
,
expr_type
=
self
.
figure_out_type
(
expr
.
expr
)
condition
,
condition_type
=
self
.
figure_out_type
(
expr
.
cond
)
if
condition_type
!=
BasicType
(
'bool'
):
logging
.
warning
(
f
'Condition "
{
condition
}
" is of type "
{
condition_type
}
" and not "bool"'
)
return
expr
.
func
(
expr_expr
,
condition
),
expr_type
elif
isinstance
(
expr
,
Piecewise
):
args_types
=
[
self
.
figure_out_type
(
arg
)
for
arg
in
expr
.
args
]
collated_type
=
collate_types
([
t
for
_
,
t
in
args_types
])
new_args
=
[]
for
a
,
t
in
args_types
:
if
t
!=
collated_type
:
if
isinstance
(
a
,
ExprCondPair
):
new_args
.
append
(
a
.
func
(
CastFunc
(
a
.
expr
,
collated_type
),
a
.
cond
))
else
:
new_args
.
append
(
CastFunc
(
a
,
collated_type
))
else
:
new_args
.
append
(
a
)
return
expr
.
func
(
*
new_args
)
if
new_args
else
expr
,
collated_type
else
:
args_types
=
[
self
.
figure_out_type
(
arg
)
for
arg
in
expr
.
args
]
collated_type
=
collate_types
([
t
for
_
,
t
in
args_types
])
...
...
@@ -190,6 +227,6 @@ class TypeAdder:
def
process_lhs
(
self
,
lhs
:
Union
[
Field
.
Access
,
TypedSymbol
,
sp
.
Symbol
]):
if
not
isinstance
(
lhs
,
(
Field
.
Access
,
TypedSymbol
)):
return
TypedSymbol
(
lhs
.
name
,
self
.
_
type_for_symbol
[
lhs
.
name
])
return
TypedSymbol
(
lhs
.
name
,
self
.
type_for_symbol
[
lhs
.
name
])
else
:
return
lhs
pystencils_tests/test_types.py
View file @
c5f6c366
...
...
@@ -84,7 +84,7 @@ def test_mixed_add(dtype1, dtype2):
assert
test_f
[
0
]
==
constant
+
constant
# TODO
redo following tests
# TODO
vector
def
test_collation
():
double_type
=
BasicType
(
'float64'
)
float_type
=
BasicType
(
'float32'
)
...
...
@@ -95,8 +95,9 @@ def test_collation():
assert
collate_types
([
double4_type
,
float4_type
])
==
double4_type
# TODO this
def
test_vector_type
():
double_type
=
BasicType
(
"double"
)
double_type
=
BasicType
(
'float64'
)
float_type
=
BasicType
(
'float32'
)
double4_type
=
VectorType
(
double_type
,
4
)
float4_type
=
VectorType
(
float_type
,
4
)
...
...
@@ -147,36 +148,33 @@ def test_assumptions():
assert
(
x
.
shape
[
0
]
+
1
).
is_real
def
test_sqrt_of_integer
():
@
pytest
.
mark
.
parametrize
(
'dtype'
,
(
'float64'
,
'float32'
))
def
test_sqrt_of_integer
(
dtype
):
"""Regression test for bug where sqrt(3) was classified as integer"""
f
=
ps
.
fields
(
"f: [1D]"
)
tmp
=
sp
.
symbols
(
"tmp"
)
assignments
=
[
ps
.
Assignment
(
tmp
,
sp
.
sqrt
(
3
)),
ps
.
Assignment
(
f
[
0
],
tmp
)]
arr_double
=
np
.
array
([
1
],
dtype
=
np
.
float64
)
kernel
=
ps
.
create_kernel
(
assignments
).
compile
()
kernel
(
f
=
arr_double
)
assert
1.7
<
arr_double
[
0
]
<
1.8
f
=
ps
.
fields
(
"f: float32[1D]"
)
tmp
=
sp
.
symbols
(
"tmp"
)
f
=
ps
.
fields
(
f
'f:
{
dtype
}
[1D]'
)
tmp
=
sp
.
symbols
(
'tmp'
)
assignments
=
[
ps
.
Assignment
(
tmp
,
sp
.
sqrt
(
3
)),
ps
.
Assignment
(
f
[
0
],
tmp
)]
arr_single
=
np
.
array
([
1
],
dtype
=
np
.
float32
)
config
=
pystencils
.
config
.
CreateKernelConfig
(
data_type
=
"float32"
)
kernel
=
ps
.
create_kernel
(
assignments
,
config
=
config
).
compile
()
kernel
(
f
=
arr_single
)
arr
=
np
.
array
([
1
],
dtype
=
dtype
)
config
=
pystencils
.
config
.
CreateKernelConfig
(
data_type
=
dtype
)
code
=
ps
.
get_code_str
(
kernel
.
ast
)
ast
=
ps
.
create_kernel
(
assignments
,
config
=
config
)
kernel
=
ast
.
compile
()
kernel
(
f
=
arr
)
assert
1.7
<
arr
[
0
]
<
1.8
assert
"1.7320508075688772f"
in
code
assert
1.7
<
arr_single
[
0
]
<
1.8
code
=
ps
.
get_code_str
(
ast
)
constant
=
'1.7320508075688772f'
if
dtype
==
'float32'
:
assert
constant
in
code
else
:
assert
constant
not
in
code
def
test_integer_comparision
():
f
=
ps
.
fields
(
"f [2D]"
)
@
pytest
.
mark
.
parametrize
(
'dtype'
,
(
'float64'
,
'float32'
))
def
test_integer_comparision
(
dtype
):
f
=
ps
.
fields
(
f
"f:
{
dtype
}
[2D]"
)
d
=
sp
.
Symbol
(
"dir"
)
ur
=
ps
.
Assignment
(
f
[
0
,
0
],
sp
.
Piecewise
((
0
,
sp
.
Equality
(
d
,
1
)),
(
f
[
0
,
0
],
True
)))
...
...
@@ -184,9 +182,17 @@ def test_integer_comparision():
ast
=
ps
.
create_kernel
(
ur
)
code
=
ps
.
get_code_str
(
ast
)
assert
"_data_f_00[_stride_f_1*ctr_1] = ((((dir) == (1))) ? (0.0): (_data_f_00[_stride_f_1*ctr_1]));"
in
code
print
(
code
)
# There should be an explicit cast for the integer zero to the type of the field on the rhs
if
dtype
==
'float64'
:
t
=
"_data_f_00[_stride_f_1*ctr_1] = ((((dir) == (1))) ? (((double)(0))): (_data_f_00[_stride_f_1*ctr_1]));"
else
:
t
=
"_data_f_00[_stride_f_1*ctr_1] = ((((dir) == (1))) ? (((float)(0))): (_data_f_00[_stride_f_1*ctr_1]));"
assert
t
in
code
# TODO this
def
test_Basic_data_type
():
assert
typed_symbols
((
"s"
,
"f"
),
np
.
uint
)
==
typed_symbols
(
"s, f"
,
np
.
uint
)
t_symbols
=
typed_symbols
((
"s"
,
"f"
),
np
.
uint
)
...
...
@@ -223,6 +229,7 @@ def test_Basic_data_type():
assert
TypedSymbol
(
"s"
,
np
.
uint
).
reversed
==
TypedSymbol
(
"s"
,
np
.
uint
)
# TODO this
def
test_cast_func
():
assert
CastFunc
(
TypedSymbol
(
"s"
,
np
.
uint
),
np
.
int64
).
canonical
==
TypedSymbol
(
"s"
,
np
.
uint
).
canonical
...
...
@@ -235,6 +242,7 @@ def test_pointer_arithmetic_func():
assert
PointerArithmeticFunc
(
TypedSymbol
(
"s"
,
np
.
uint
),
1
).
canonical
==
TypedSymbol
(
"s"
,
np
.
uint
).
canonical
# TODO this
def
test_division
():
f
=
ps
.
fields
(
'f(10): float32[2D]'
)
m
,
tau
=
sp
.
symbols
(
"m, tau"
)
...
...
@@ -248,6 +256,7 @@ def test_division():
assert
"1.0f"
in
code
# TODO this
def
test_pow
():
f
=
ps
.
fields
(
'f(10): float32[2D]'
)
m
,
tau
=
sp
.
symbols
(
"m, tau"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment