Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Jonas Plewinski
pystencils
Commits
fc198a17
Commit
fc198a17
authored
Oct 11, 2017
by
Martin Bauer
Browse files
Further vectorization tests & bugfixes
- phasefield phi sweep vectorizes successfully
parent
9d1e022d
Changes
3
Hide whitespace changes
Inline
Side-by-side
backends/cbackend.py
View file @
fc198a17
...
...
@@ -259,6 +259,24 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
processed
=
func
.
format
(
processed
,
summand
.
term
)
return
processed
def
_print_Pow
(
self
,
expr
):
"""Don't use std::pow function, for small integer exponents, write as multiplication"""
exprType
=
getTypeOfExpression
(
expr
)
if
type
(
exprType
)
is
not
VectorType
:
return
super
(
VectorizedCustomSympyPrinter
,
self
).
_print_Pow
(
expr
)
assert
self
.
instructionSet
[
'width'
]
==
exprType
.
width
if
expr
.
exp
.
is_integer
and
expr
.
exp
.
is_number
and
0
<
expr
.
exp
<
8
:
return
"("
+
self
.
_print
(
sp
.
Mul
(
*
[
expr
.
base
]
*
expr
.
exp
,
evaluate
=
False
))
+
")"
else
:
if
expr
.
exp
==
-
1
:
one
=
self
.
instructionSet
[
'makeVec'
].
format
(
1.0
)
return
self
.
instructionSet
[
'/'
].
format
(
one
,
self
.
_print
(
expr
.
base
))
elif
expr
.
exp
==
0.5
:
return
self
.
instructionSet
[
'sqrt'
].
format
(
self
.
_print
(
expr
.
base
))
else
:
raise
ValueError
(
"Generic exponential not supported"
)
def
_print_Mul
(
self
,
expr
,
insideAdd
=
False
):
exprType
=
getTypeOfExpression
(
expr
)
if
type
(
exprType
)
is
not
VectorType
:
...
...
@@ -286,6 +304,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
a
.
append
(
item
)
a
=
a
or
[
S
.
One
]
# a = a or [castFunc(S.One, VectorType(createTypeFromString("double"), exprType.width))]
a_str
=
[
self
.
_print
(
x
)
for
x
in
a
]
b_str
=
[
self
.
_print
(
x
)
for
x
in
b
]
...
...
data_types.py
View file @
fc198a17
...
...
@@ -10,7 +10,6 @@ from pystencils.utils import allEqual
# to work in conditions of sp.Piecewise castFunc has to be of type Relational as well
class
castFunc
(
sp
.
Function
,
sp
.
Rel
):
@
property
def
canonical
(
self
):
if
hasattr
(
self
.
args
[
0
],
'canonical'
):
...
...
@@ -18,6 +17,10 @@ class castFunc(sp.Function, sp.Rel):
else
:
raise
NotImplementedError
()
@
property
def
is_commutative
(
self
):
return
self
.
args
[
0
].
is_commutative
class
pointerArithmeticFunc
(
sp
.
Function
,
sp
.
Rel
):
...
...
@@ -281,8 +284,11 @@ def getTypeOfExpression(expr):
elif
hasattr
(
expr
,
'func'
)
and
expr
.
func
==
castFunc
:
return
expr
.
args
[
1
]
elif
hasattr
(
expr
,
'func'
)
and
expr
.
func
==
sp
.
Piecewise
:
branchResults
=
[
a
[
0
]
for
a
in
expr
.
args
]
return
collateTypes
(
tuple
(
getTypeOfExpression
(
a
)
for
a
in
branchResults
))
collatedResultType
=
collateTypes
(
tuple
(
getTypeOfExpression
(
a
[
0
])
for
a
in
expr
.
args
))
collatedConditionType
=
collateTypes
(
tuple
(
getTypeOfExpression
(
a
[
1
])
for
a
in
expr
.
args
))
if
type
(
collatedConditionType
)
is
VectorType
and
type
(
collatedResultType
)
is
not
VectorType
:
collatedResultType
=
VectorType
(
collatedResultType
,
width
=
collatedConditionType
.
width
)
return
collatedResultType
elif
isinstance
(
expr
,
sp
.
Indexed
):
typedSymbol
=
expr
.
base
.
label
return
typedSymbol
.
dtype
.
baseType
...
...
@@ -328,6 +334,9 @@ class Type(sp.Basic):
def
_sympystr
(
self
,
*
args
,
**
kwargs
):
return
str
(
self
)
def
_sympystr
(
self
,
*
args
,
**
kwargs
):
return
str
(
self
)
class
BasicType
(
Type
):
@
staticmethod
...
...
vectorization.py
View file @
fc198a17
...
...
@@ -70,6 +70,9 @@ def insertVectorCasts(astNode):
castedArgs
=
[
castFunc
(
a
,
targetType
)
if
t
!=
targetType
else
a
for
a
,
t
in
zip
(
newArgs
,
argTypes
)]
return
expr
.
func
(
*
castedArgs
)
elif
expr
.
func
is
sp
.
Pow
:
newArg
=
visitExpr
(
expr
.
args
[
0
])
return
sp
.
Pow
(
newArg
,
expr
.
args
[
1
])
elif
expr
.
func
==
sp
.
Piecewise
:
newResults
=
[
visitExpr
(
a
[
0
])
for
a
in
expr
.
args
]
newConditions
=
[
visitExpr
(
a
[
1
])
for
a
in
expr
.
args
]
...
...
@@ -77,10 +80,13 @@ def insertVectorCasts(astNode):
typesOfConditions
=
[
getTypeOfExpression
(
a
)
for
a
in
newConditions
]
resultTargetType
=
getTypeOfExpression
(
expr
)
conditionTargetType
=
collateTypes
(
typesOfConditions
)
if
type
(
conditionTargetType
)
is
VectorType
and
type
(
resultTargetType
)
is
not
VectorType
:
resultTargetType
=
VectorType
(
resultTargetType
,
width
=
conditionTargetType
.
width
)
castedResults
=
[
castFunc
(
a
,
resultTargetType
)
if
t
!=
resultTargetType
else
a
for
a
,
t
in
zip
(
newResults
,
typesOfResults
)]
conditionTargetType
=
collateTypes
(
typesOfConditions
)
castedConditions
=
[
castFunc
(
a
,
conditionTargetType
)
if
t
!=
conditionTargetType
and
a
!=
True
else
a
for
a
,
t
in
zip
(
newConditions
,
typesOfConditions
)]
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment