Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Jan Hönig
pystencils
Commits
9c93da4a
Commit
9c93da4a
authored
Sep 14, 2021
by
Christoph Alt
Committed by
Jan Hönig
Sep 14, 2021
Browse files
fixed create_kernel parameter data_type="float" to procucde single precision
parent
52775e94
Changes
2
Hide whitespace changes
Inline
Side-by-side
pystencils/transformations.py
View file @
9c93da4a
...
...
@@ -960,6 +960,8 @@ def add_types(eqs, type_for_symbol, check_independence_condition, check_double_w
if
isinstance
(
type_for_symbol
,
(
str
,
type
))
or
not
hasattr
(
type_for_symbol
,
'__getitem__'
):
type_for_symbol
=
typing_from_sympy_inspection
(
eqs
,
type_for_symbol
)
type_for_symbol
=
adjust_c_single_precision_type
(
type_for_symbol
)
check
=
KernelConstraintsCheck
(
type_for_symbol
,
check_independence_condition
,
check_double_write_condition
=
check_double_write_condition
)
...
...
@@ -1397,3 +1399,16 @@ def implement_interpolations(ast_node: ast.Node,
ast_node
.
subs
(
substitutions
)
return
ast_node
def
adjust_c_single_precision_type
(
type_for_symbol
):
"""Replaces every occurrence of 'float' with 'single' to enforce the numpy single precision type."""
def
single_factory
():
return
"single"
for
symbol
in
type_for_symbol
:
if
type_for_symbol
[
symbol
]
==
"float"
:
type_for_symbol
[
symbol
]
=
single_factory
()
if
hasattr
(
type_for_symbol
,
"default_factory"
)
and
type_for_symbol
.
default_factory
()
==
"float"
:
type_for_symbol
.
default_factory
=
single_factory
return
type_for_symbol
pystencils_tests/test_kernel_data_type.py
0 → 100644
View file @
9c93da4a
from
collections
import
defaultdict
import
numpy
as
np
import
pytest
from
sympy.abc
import
x
,
y
from
pystencils
import
Assignment
,
create_kernel
,
fields
,
CreateKernelConfig
from
pystencils.transformations
import
adjust_c_single_precision_type
@
pytest
.
mark
.
parametrize
(
"data_type"
,
(
"float"
,
"double"
))
def
test_single_precision
(
data_type
):
dtype
=
f
"float
{
64
if
data_type
==
'double'
else
32
}
"
s
=
fields
(
f
"s:
{
dtype
}
[1D]"
)
assignments
=
[
Assignment
(
x
,
y
),
Assignment
(
s
[
0
],
x
)]
ast
=
create_kernel
(
assignments
,
config
=
CreateKernelConfig
(
data_type
=
data_type
))
assert
ast
.
body
.
args
[
0
].
lhs
.
dtype
.
numpy_dtype
==
np
.
dtype
(
dtype
)
assert
ast
.
body
.
args
[
0
].
rhs
.
dtype
.
numpy_dtype
==
np
.
dtype
(
dtype
)
assert
ast
.
body
.
args
[
1
].
body
.
args
[
0
].
rhs
.
dtype
.
numpy_dtype
==
np
.
dtype
(
dtype
)
def
test_adjustment_dict
():
d
=
dict
({
"x"
:
"float"
,
"y"
:
"double"
})
adjust_c_single_precision_type
(
d
)
assert
np
.
dtype
(
d
[
"x"
])
==
np
.
dtype
(
"float32"
)
assert
np
.
dtype
(
d
[
"y"
])
==
np
.
dtype
(
"float64"
)
def
test_adjustement_default_dict
():
dd
=
defaultdict
(
lambda
:
"float"
)
dd
[
"x"
]
adjust_c_single_precision_type
(
dd
)
dd
[
"y"
]
assert
np
.
dtype
(
dd
[
"x"
])
==
np
.
dtype
(
"float32"
)
assert
np
.
dtype
(
dd
[
"y"
])
==
np
.
dtype
(
"float32"
)
assert
np
.
dtype
(
dd
[
"z"
])
==
np
.
dtype
(
"float32"
)
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment