Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Stephan Seitz
pystencils
Commits
684ef359
Commit
684ef359
authored
Sep 02, 2019
by
Stephan Seitz
Browse files
Correctly determine complex dtype of symbols and imaginary unit
parent
b9423f80
Pipeline
#17710
failed with stage
in 5 minutes and 1 second
Changes
3
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
pystencils/data_types.py
View file @
684ef359
...
...
@@ -431,11 +431,15 @@ def collate_types(types,
def
get_type_of_expression
(
expr
,
default_float_type
=
'double'
,
default_int_type
=
'int'
,
default_complex_type
=
'complex128'
,
symbol_type_dict
=
None
):
from
pystencils.astnodes
import
ResolvedFieldAccess
from
pystencils.cpu.vectorization
import
vec_all
,
vec_any
# TODO: determine more general
if
default_float_type
==
'double'
or
default_float_type
==
'float64'
:
default_complex_type
=
'complex128'
else
:
default_complex_type
=
'complex64'
if
not
symbol_type_dict
:
symbol_type_dict
=
defaultdict
(
lambda
:
create_type
(
'double'
))
...
...
@@ -443,7 +447,6 @@ def get_type_of_expression(expr,
get_type
=
partial
(
get_type_of_expression
,
default_float_type
=
default_float_type
,
default_int_type
=
default_int_type
,
default_complex_type
=
default_complex_type
,
symbol_type_dict
=
symbol_type_dict
)
expr
=
sp
.
sympify
(
expr
)
...
...
pystencils/transformations.py
View file @
684ef359
...
...
@@ -12,8 +12,8 @@ from sympy.logic.boolalg import Boolean
import
pystencils.astnodes
as
ast
from
pystencils.assignment
import
Assignment
from
pystencils.data_types
import
(
PointerType
,
StructType
,
TypedSymbol
,
cast_func
,
collate_types
,
create_type
,
get_base_type
,
get_type_of_expression
,
pointer_arithmetic_func
,
reinterpret_cast_func
)
PointerType
,
StructType
,
TypedImaginaryUnit
,
TypedSymbol
,
cast_func
,
collate_types
,
create_type
,
get_base_type
,
get_type_of_expression
,
pointer_arithmetic_func
,
reinterpret_cast_func
)
from
pystencils.field
import
AbstractField
,
Field
,
FieldType
from
pystencils.kernelparameters
import
FieldPointerSymbol
from
pystencils.simp.assignment_collection
import
AssignmentCollection
...
...
@@ -898,6 +898,11 @@ class KernelConstraintsCheck:
return
rhs
elif
isinstance
(
rhs
,
TypedSymbol
):
return
rhs
elif
isinstance
(
rhs
,
sp
.
numbers
.
ImaginaryUnit
):
return
TypedImaginaryUnit
(
self
.
_type_for_symbol
[
'_ImaginaryUnit'
])
elif
isinstance
(
rhs
,
sp
.
Symbol
):
return
TypedSymbol
(
rhs
.
name
,
self
.
_type_for_symbol
[
rhs
.
name
])
return
TypedSymbol
(
rhs
.
name
,
self
.
_type_for_symbol
[
rhs
.
name
])
elif
isinstance
(
rhs
,
sp
.
Symbol
):
return
TypedSymbol
(
rhs
.
name
,
self
.
_type_for_symbol
[
rhs
.
name
])
elif
type_constants
and
isinstance
(
rhs
,
np
.
generic
):
...
...
@@ -1167,6 +1172,11 @@ def typing_from_sympy_inspection(eqs, default_type="double", default_int_type='i
dictionary, mapping symbol name to type
"""
result
=
defaultdict
(
lambda
:
default_type
)
if
default_type
==
'double'
or
default_type
==
'float64'
:
# todo: fix
result
[
'_ImaginaryUnit'
]
=
create_type
(
'complex128'
)
else
:
result
[
'_ImaginaryUnit'
]
=
create_type
(
'complex64'
)
for
eq
in
eqs
:
if
isinstance
(
eq
,
ast
.
Conditional
):
result
.
update
(
typing_from_sympy_inspection
(
eq
.
true_block
.
args
))
...
...
pystencils_tests/test_complex_numbers.py
View file @
684ef359
...
...
@@ -20,7 +20,8 @@ from pystencils.data_types import TypedImaginaryUnit, TypedSymbol, create_type
X
,
Y
=
pystencils
.
fields
(
'x, y: complex64[2d]'
)
A
,
B
=
pystencils
.
fields
(
'a, b: float32[2d]'
)
S1
,
S2
=
sympy
.
symbols
(
'S1, S2'
)
T64
=
TypedSymbol
(
't'
,
create_type
(
'complex64'
))
# T64 = TypedSymbol('t', create_type('complex64'))
T64
=
sympy
.
Symbol
(
't'
)
TEST_ASSIGNMENTS
=
[
AssignmentCollection
({
X
[
0
,
0
]:
1j
}),
...
...
@@ -48,11 +49,9 @@ SCALAR_DTYPES = ['float32', 'float64']
@
pytest
.
mark
.
parametrize
(
"assignment, scalar_dtypes"
,
itertools
.
product
(
TEST_ASSIGNMENTS
,
SCALAR_DTYPES
))
def
test_complex_numbers
(
assignment
,
scalar_dtypes
):
ast
=
pystencils
.
create_kernel
(
assignment
.
subs
(
{
sympy
.
sympify
(
1j
).
args
[
1
]:
TypedImaginaryUnit
(
create_type
(
'complex64'
))}),
target
=
'cpu'
,
data_type
=
scalar_dtypes
)
ast
=
pystencils
.
create_kernel
(
assignment
,
target
=
'cpu'
,
data_type
=
'float32'
)
code
=
str
(
pystencils
.
show_code
(
ast
))
print
(
code
)
...
...
@@ -94,11 +93,9 @@ SCALAR_DTYPES = ['float32', 'float64']
@
pytest
.
mark
.
parametrize
(
"assignment, scalar_dtypes"
,
itertools
.
product
(
TEST_ASSIGNMENTS
,
SCALAR_DTYPES
))
def
test_complex_numbers_64
(
assignment
,
scalar_dtypes
):
ast
=
pystencils
.
create_kernel
(
assignment
.
subs
(
{
sympy
.
sympify
(
1j
).
args
[
1
]:
TypedImaginaryUnit
(
create_type
(
'complex128'
))}),
target
=
'cpu'
,
data_type
=
scalar_dtypes
)
ast
=
pystencils
.
create_kernel
(
assignment
,
target
=
'cpu'
,
data_type
=
'double'
)
code
=
str
(
pystencils
.
show_code
(
ast
))
print
(
code
)
...
...
@@ -113,5 +110,8 @@ def test_get_data_type():
from
pystencils.data_types
import
get_type_of_expression
i
=
TypedImaginaryUnit
(
create_type
(
'complex128'
))
#
assert get_type_of_expression(i+3).numpy_dtype == np.complex128
assert
get_type_of_expression
(
i
+
3
).
numpy_dtype
==
np
.
complex128
assert
get_type_of_expression
(
i
+
3.
).
numpy_dtype
==
np
.
complex128
i
=
TypedImaginaryUnit
(
create_type
(
'complex64'
))
assert
get_type_of_expression
(
i
+
3
,
default_float_type
=
'float32'
).
numpy_dtype
==
np
.
complex64
assert
get_type_of_expression
(
i
+
3.
,
default_float_type
=
'float32'
).
numpy_dtype
==
np
.
complex64
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment