Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Jonas Plewinski
pystencils
Commits
0d12a2ac
Commit
0d12a2ac
authored
Mar 07, 2019
by
Martin Bauer
Browse files
pystencils: support for approximate divisions and sqrt's (CUDA)
parent
61800b73
Changes
3
Hide whitespace changes
Inline
Side-by-side
backends/cbackend.py
View file @
0d12a2ac
...
...
@@ -3,6 +3,9 @@ from collections import namedtuple
from
sympy.core
import
S
from
typing
import
Set
from
sympy.printing.ccode
import
C89CodePrinter
from
pystencils.fast_approximation
import
fast_division
,
fast_sqrt
,
fast_inv_sqrt
try
:
from
sympy.printing.ccode
import
C99CodePrinter
as
CCodePrinter
except
ImportError
:
...
...
@@ -98,9 +101,9 @@ class CBackend:
signature_only
=
False
,
vector_instruction_set
=
None
,
dialect
=
'c'
):
if
sympy_printer
is
None
:
if
vector_instruction_set
is
not
None
:
self
.
sympy_printer
=
VectorizedCustomSympyPrinter
(
vector_instruction_set
)
self
.
sympy_printer
=
VectorizedCustomSympyPrinter
(
vector_instruction_set
,
dialect
)
else
:
self
.
sympy_printer
=
CustomSympyPrinter
()
self
.
sympy_printer
=
CustomSympyPrinter
(
dialect
)
else
:
self
.
sympy_printer
=
sympy_printer
...
...
@@ -210,9 +213,10 @@ class CBackend:
# noinspection PyPep8Naming
class
CustomSympyPrinter
(
CCodePrinter
):
def
__init__
(
self
):
def
__init__
(
self
,
dialect
):
super
(
CustomSympyPrinter
,
self
).
__init__
()
self
.
_float_type
=
create_type
(
"float32"
)
self
.
_dialect
=
dialect
if
'Min'
in
self
.
known_functions
:
del
self
.
known_functions
[
'Min'
]
if
'Max'
in
self
.
known_functions
:
...
...
@@ -259,7 +263,22 @@ class CustomSympyPrinter(CCodePrinter):
if
isinstance
(
arg
,
sp
.
Number
):
return
self
.
_typed_number
(
arg
,
data_type
)
else
:
return
"*((%s)(& %s))"
%
(
PointerType
(
data_type
,
restrict
=
False
),
self
.
_print
(
arg
))
return
"((%s)(%s))"
%
(
data_type
,
self
.
_print
(
arg
))
elif
isinstance
(
expr
,
fast_division
):
if
self
.
_dialect
==
"cuda"
:
return
"__fdividef(%s, %s)"
%
tuple
(
self
.
_print
(
a
)
for
a
in
expr
.
args
)
else
:
return
"({})"
.
format
(
self
.
_print
(
expr
.
args
[
0
]
/
expr
.
args
[
1
]))
elif
isinstance
(
expr
,
fast_sqrt
):
if
self
.
_dialect
==
"cuda"
:
return
"__fsqrt_rn(%s)"
%
tuple
(
self
.
_print
(
a
)
for
a
in
expr
.
args
)
else
:
return
"({})"
.
format
(
self
.
_print
(
sp
.
sqrt
(
expr
.
args
[
0
])))
elif
isinstance
(
expr
,
fast_inv_sqrt
):
if
self
.
_dialect
==
"cuda"
:
return
"__frsqrt_rn(%s)"
%
tuple
(
self
.
_print
(
a
)
for
a
in
expr
.
args
)
else
:
return
"({})"
.
format
(
self
.
_print
(
1
/
sp
.
sqrt
(
expr
.
args
[
0
])))
elif
expr
.
func
in
infix_functions
:
return
"(%s %s %s)"
%
(
self
.
_print
(
expr
.
args
[
0
]),
infix_functions
[
expr
.
func
],
self
.
_print
(
expr
.
args
[
1
]))
else
:
...
...
@@ -285,8 +304,8 @@ class CustomSympyPrinter(CCodePrinter):
class
VectorizedCustomSympyPrinter
(
CustomSympyPrinter
):
SummandInfo
=
namedtuple
(
"SummandInfo"
,
[
'sign'
,
'term'
])
def
__init__
(
self
,
instruction_set
):
super
(
VectorizedCustomSympyPrinter
,
self
).
__init__
()
def
__init__
(
self
,
instruction_set
,
dialect
):
super
(
VectorizedCustomSympyPrinter
,
self
).
__init__
(
dialect
=
dialect
)
self
.
instruction_set
=
instruction_set
def
_scalarFallback
(
self
,
func_name
,
expr
,
*
args
,
**
kwargs
):
...
...
@@ -306,7 +325,12 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
arg
,
data_type
=
expr
.
args
if
type
(
data_type
)
is
VectorType
:
return
self
.
instruction_set
[
'makeVec'
].
format
(
self
.
_print
(
arg
))
elif
expr
.
func
==
fast_division
:
return
self
.
instruction_set
[
'/'
].
format
(
self
.
_print
(
expr
.
args
[
0
]),
self
.
_print
(
expr
.
args
[
1
]))
elif
expr
.
func
==
fast_sqrt
:
return
"({})"
.
format
(
self
.
_print
(
sp
.
sqrt
(
expr
.
args
[
0
])))
elif
expr
.
func
==
fast_inv_sqrt
:
return
"({})"
.
format
(
self
.
_print
(
1
/
sp
.
sqrt
(
expr
.
args
[
0
])))
return
super
(
VectorizedCustomSympyPrinter
,
self
).
_print_Function
(
expr
)
def
_print_And
(
self
,
expr
):
...
...
display_utils.py
View file @
0d12a2ac
...
...
@@ -38,17 +38,18 @@ def show_code(ast: KernelFunction):
Can either be displayed as HTML in Jupyter notebooks or printed as normal string.
"""
from
pystencils.backends.cbackend
import
generate_c
dialect
=
'cuda'
if
ast
.
backend
==
'gpucuda'
else
'c'
class
CodeDisplay
:
def
__init__
(
self
,
ast_input
):
self
.
ast
=
ast_input
def
_repr_html_
(
self
):
return
highlight_cpp
(
generate_c
(
self
.
ast
)).
__html__
()
return
highlight_cpp
(
generate_c
(
self
.
ast
,
dialect
=
dialect
)).
__html__
()
def
__str__
(
self
):
return
generate_c
(
self
.
ast
)
return
generate_c
(
self
.
ast
,
dialect
=
dialect
)
def
__repr__
(
self
):
return
generate_c
(
self
.
ast
)
return
generate_c
(
self
.
ast
,
dialect
=
dialect
)
return
CodeDisplay
(
ast
)
fast_approximation.py
0 → 100644
View file @
0d12a2ac
import
sympy
as
sp
from
typing
import
List
,
Union
from
pystencils.astnodes
import
Node
from
pystencils.simp
import
AssignmentCollection
# noinspection PyPep8Naming
class
fast_division
(
sp
.
Function
):
nargs
=
(
2
,)
# noinspection PyPep8Naming
class
fast_sqrt
(
sp
.
Function
):
nargs
=
(
1
,
)
# noinspection PyPep8Naming
class
fast_inv_sqrt
(
sp
.
Function
):
nargs
=
(
1
,
)
def
insert_fast_sqrts
(
term
:
Union
[
sp
.
Expr
,
List
[
sp
.
Expr
],
AssignmentCollection
]):
def
visit
(
expr
):
if
isinstance
(
expr
,
Node
):
return
expr
if
expr
.
func
==
sp
.
Pow
and
isinstance
(
expr
.
exp
,
sp
.
Rational
)
and
expr
.
exp
.
q
==
2
:
power
=
expr
.
exp
.
p
if
power
<
0
:
return
fast_inv_sqrt
(
expr
.
args
[
0
])
**
(
-
power
)
else
:
return
fast_sqrt
(
expr
.
args
[
0
])
**
power
else
:
new_args
=
[
visit
(
a
)
for
a
in
expr
.
args
]
return
expr
.
func
(
*
new_args
)
if
new_args
else
expr
if
isinstance
(
term
,
AssignmentCollection
):
new_main_assignments
=
insert_fast_sqrts
(
term
.
main_assignments
)
new_subexpressions
=
insert_fast_sqrts
(
term
.
subexpressions
)
return
term
.
copy
(
new_main_assignments
,
new_subexpressions
)
elif
isinstance
(
term
,
list
):
return
[
insert_fast_sqrts
(
e
)
for
e
in
term
]
else
:
return
visit
(
term
)
def
insert_fast_divisions
(
term
:
Union
[
sp
.
Expr
,
List
[
sp
.
Expr
],
AssignmentCollection
]):
def
visit
(
expr
):
if
isinstance
(
expr
,
Node
):
return
expr
if
expr
.
func
==
sp
.
Mul
:
div_args
=
[]
other_args
=
[]
for
a
in
expr
.
args
:
if
a
.
func
==
sp
.
Pow
and
a
.
exp
.
is_integer
and
a
.
exp
<
0
:
div_args
.
append
(
visit
(
a
.
base
)
**
(
-
a
.
exp
))
else
:
other_args
.
append
(
visit
(
a
))
if
div_args
:
return
fast_division
(
sp
.
Mul
(
*
other_args
),
sp
.
Mul
(
*
div_args
))
else
:
return
sp
.
Mul
(
*
other_args
)
elif
expr
.
func
==
sp
.
Pow
and
expr
.
exp
.
is_integer
and
expr
.
exp
<
0
:
return
fast_division
(
1
,
visit
(
expr
.
base
)
**
(
-
expr
.
exp
))
else
:
new_args
=
[
visit
(
a
)
for
a
in
expr
.
args
]
return
expr
.
func
(
*
new_args
)
if
new_args
else
expr
if
isinstance
(
term
,
AssignmentCollection
):
new_main_assignments
=
insert_fast_divisions
(
term
.
main_assignments
)
new_subexpressions
=
insert_fast_divisions
(
term
.
subexpressions
)
return
term
.
copy
(
new_main_assignments
,
new_subexpressions
)
elif
isinstance
(
term
,
list
):
return
[
insert_fast_divisions
(
e
)
for
e
in
term
]
else
:
return
visit
(
term
)
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment