Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Stephan Seitz
pystencils
Commits
edb6fdcd
Commit
edb6fdcd
authored
Dec 15, 2016
by
Jan Hoenig
Browse files
More fixes on DataType transition
Move my llvm demo notebook in the correct folder
parent
815edd12
Changes
8
Hide whitespace changes
Inline
Side-by-side
ast.py
View file @
edb6fdcd
...
...
@@ -2,7 +2,7 @@ import sympy as sp
import
textwrap
as
textwrap
from
sympy.tensor
import
IndexedBase
,
Indexed
from
pystencils.field
import
Field
from
pystencils.types
import
TypedSymbol
from
pystencils.types
import
TypedSymbol
,
DataType
class
Node
(
object
):
...
...
@@ -266,7 +266,7 @@ class LoopOverCoordinate(Node):
@
staticmethod
def
getLoopCounterSymbol
(
coordinateToLoopOver
):
return
TypedSymbol
(
LoopOverCoordinate
.
getLoopCounterName
(
coordinateToLoopOver
),
"
int
"
)
return
TypedSymbol
(
LoopOverCoordinate
.
getLoopCounterName
(
coordinateToLoopOver
),
DataType
(
'
int
'
)
)
@
property
def
loopCounterSymbol
(
self
):
...
...
backends/__init__.py
View file @
edb6fdcd
from
.llvm
import
generateLLVM
from
.cbackend
import
generateC
,
generateCUDA
backends/llvm.py
View file @
edb6fdcd
...
...
@@ -78,6 +78,7 @@ class LLVMPrinter(Printer):
def
_print_Mul
(
self
,
expr
):
nodes
=
[
self
.
_print
(
a
)
for
a
in
expr
.
args
]
e
=
nodes
[
0
]
print
(
nodes
)
for
node
in
nodes
[
1
:]:
e
=
self
.
builder
.
fmul
(
e
,
node
)
return
e
...
...
@@ -120,10 +121,12 @@ class LLVMPrinter(Printer):
def
_print_LoopOverCoordinate
(
self
,
loop
):
with
Loop
(
self
.
builder
,
self
.
_print
(
loop
.
start
),
self
.
_print
(
loop
.
stop
),
self
.
_print
(
loop
.
step
),
loop
.
loopCounterName
,
loop
.
loopCounterSymbol
.
name
)
as
i
:
self
.
_add_tmp_var
(
loop
.
loopCounterSymbol
,
i
)
self
.
_print
(
loop
.
body
)
def
_print_SympyAssignment
(
self
,
loop
):
pass
def
_print_SympyAssignment
(
self
,
assignment
):
expr
=
self
.
_print
(
assignment
.
rhs
)
# Should have a list of math library functions to validate this.
...
...
cpu/kernelcreation.py
View file @
edb6fdcd
import
sympy
as
sp
from
pystencils.transformations
import
resolveFieldAccesses
,
makeLoopOverDomain
,
typingFromSympyInspection
,
\
typeAllEquations
,
getOptimalLoopOrdering
,
parseBasePointerInfo
,
moveConstantsBeforeLoop
,
splitInnerLoop
from
pystencils.types
import
TypedSymbol
from
pystencils.types
import
TypedSymbol
,
DataType
from
pystencils.field
import
Field
import
pystencils.ast
as
ast
...
...
@@ -37,7 +37,7 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol=None, spl
if
isinstance
(
term
,
Field
.
Access
)
or
isinstance
(
term
,
TypedSymbol
):
return
term
elif
isinstance
(
term
,
sp
.
Symbol
):
return
TypedSymbol
(
term
.
name
,
typeForSymbol
[
term
.
name
])
return
TypedSymbol
(
term
.
name
,
DataType
(
typeForSymbol
[
term
.
name
])
)
else
:
raise
ValueError
(
"Term has to be field access or symbol"
)
...
...
llvm/__init__.py
View file @
edb6fdcd
from
pystencils.cpu
.kernelcreation
import
createKernel
from
.kernelcreation
import
createKernel
llvm/kernelcreation.py
View file @
edb6fdcd
import
sympy
as
sp
from
pystencils.transformations
import
resolveFieldAccesses
,
makeLoopOverDomain
,
typingFromSympyInspection
,
\
typeAllEquations
,
getOptimalLoopOrdering
,
parseBasePointerInfo
,
moveConstantsBeforeLoop
,
splitInnerLoop
from
pystencils.types
import
TypedSymbol
from
pystencils.types
import
TypedSymbol
,
DataType
from
pystencils.field
import
Field
import
pystencils.ast
as
ast
...
...
@@ -35,7 +35,7 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol=None, spl
if
isinstance
(
term
,
Field
.
Access
)
or
isinstance
(
term
,
TypedSymbol
):
return
term
elif
isinstance
(
term
,
sp
.
Symbol
):
return
TypedSymbol
(
term
.
name
,
typeForSymbol
[
term
.
name
])
return
TypedSymbol
(
term
.
name
,
DataType
(
typeForSymbol
[
term
.
name
])
)
else
:
raise
ValueError
(
"Term has to be field access or symbol"
)
...
...
transformations.py
View file @
edb6fdcd
...
...
@@ -98,7 +98,7 @@ def createIntermediateBasePointer(fieldAccess, coordinates, previousPtr):
Example:
>>> field = Field.createGeneric('myfield', spatialDimensions=2, indexDimensions=1)
>>> x, y = sp.symbols("x y")
>>> prevPointer = TypedSymbol("ptr", "double")
>>> prevPointer = TypedSymbol("ptr",
DataType(
"double")
)
>>> createIntermediateBasePointer(field[1,-2](5), {0: x}, prevPointer)
(ptr_E, x*fstride_myfield[0] + fstride_myfield[0])
>>> createIntermediateBasePointer(field[1,-2](5), {0: x, 1 : y }, prevPointer)
...
...
@@ -129,7 +129,7 @@ def createIntermediateBasePointer(fieldAccess, coordinates, previousPtr):
if
len
(
listToHash
)
>
0
:
name
+=
"%0.6X"
%
(
abs
(
hash
(
tuple
(
listToHash
))))
newPtr
=
TypedSymbol
(
previousPtr
.
name
+
name
,
previousPtr
.
dtype
)
newPtr
=
TypedSymbol
(
previousPtr
.
name
+
name
,
DataType
(
previousPtr
.
dtype
)
)
return
newPtr
,
offset
...
...
@@ -238,7 +238,7 @@ def resolveFieldAccesses(astNode, readOnlyFieldNames=set(), fieldToBasePointerIn
coordDict
[
e
]
=
fieldToFixedCoordinates
[
field
.
name
][
e
]
else
:
ctrName
=
ast
.
LoopOverCoordinate
.
LOOP_COUNTER_NAME_PREFIX
coordDict
[
e
]
=
TypedSymbol
(
"%s_%d"
%
(
ctrName
,
e
),
"
int
"
)
coordDict
[
e
]
=
TypedSymbol
(
"%s_%d"
%
(
ctrName
,
e
),
DataType
(
'
int
'
)
)
else
:
coordDict
[
e
]
=
fieldAccess
.
index
[
e
-
field
.
spatialDimensions
]
return
coordDict
...
...
@@ -420,7 +420,7 @@ def typeAllEquations(eqs, typeForSymbol):
elif
isinstance
(
term
,
TypedSymbol
):
return
term
elif
isinstance
(
term
,
sp
.
Symbol
):
return
TypedSymbol
(
symbolNameToVariableName
(
term
.
name
),
typeForSymbol
[
term
.
name
])
return
TypedSymbol
(
symbolNameToVariableName
(
term
.
name
),
DataType
(
typeForSymbol
[
term
.
name
])
)
else
:
newArgs
=
[
processRhs
(
arg
)
for
arg
in
term
.
args
]
return
term
.
func
(
*
newArgs
)
if
newArgs
else
term
...
...
@@ -433,7 +433,7 @@ def typeAllEquations(eqs, typeForSymbol):
elif
isinstance
(
term
,
TypedSymbol
):
return
term
elif
isinstance
(
term
,
sp
.
Symbol
):
return
TypedSymbol
(
term
.
name
,
typeForSymbol
[
term
.
name
])
return
TypedSymbol
(
term
.
name
,
DataType
(
typeForSymbol
[
term
.
name
])
)
else
:
assert
False
,
"Expected a symbol as left-hand-side"
...
...
types.py
View file @
edb6fdcd
...
...
@@ -10,7 +10,7 @@ class TypedSymbol(sp.Symbol):
def
__new_stage2__
(
cls
,
name
,
dtype
):
obj
=
super
(
TypedSymbol
,
cls
).
__xnew__
(
cls
,
name
)
obj
.
_dtype
=
dtype
obj
.
_dtype
=
DataType
(
dtype
)
if
isinstance
(
dtype
,
str
)
else
dtype
return
obj
__xnew__
=
staticmethod
(
__new_stage2__
)
...
...
@@ -29,8 +29,8 @@ class TypedSymbol(sp.Symbol):
return
self
.
name
,
self
.
dtype
_c_dtype_dict
=
{
0
:
'int'
,
1
:
'double'
,
2
:
'float'
}
_dtype_dict
=
{
'int'
:
0
,
'double'
:
1
,
'float'
:
2
}
_c_dtype_dict
=
{
0
:
'int'
,
1
:
'double'
,
2
:
'float'
,
3
:
'bool'
}
_dtype_dict
=
{
'int'
:
0
,
'double'
:
1
,
'float'
:
2
,
'bool'
:
3
}
class
DataType
(
object
):
...
...
@@ -38,11 +38,28 @@ class DataType(object):
self
.
alias
=
True
self
.
const
=
False
self
.
ptr
=
False
self
.
dtype
=
0
if
isinstance
(
dtype
,
str
):
self
.
dtype
=
_dtype_dict
[
dtype
]
for
s
in
dtype
.
split
():
if
s
==
'const'
:
self
.
const
=
True
elif
s
==
'*'
:
self
.
ptr
=
True
elif
s
==
'__restrict__'
:
self
.
alias
=
False
else
:
self
.
dtype
=
_dtype_dict
[
s
]
elif
isinstance
(
dtype
,
DataType
):
self
.
__dict__
.
update
(
dtype
.
__dict__
)
else
:
self
.
dtype
=
dtype
def
__repr__
(
self
):
return
"{!s} {!s}{!s} {!s}"
.
format
(
"const"
if
self
.
const
else
""
,
_c_dtype_dict
[
self
.
dtype
],
"*"
if
self
.
ptr
else
""
,
"__restrict__"
if
not
self
.
alias
else
""
)
def
__eq__
(
self
,
other
):
if
self
.
alias
==
other
.
alias
and
self
.
const
==
other
.
const
and
self
.
ptr
==
other
.
ptr
and
self
.
dtype
==
other
.
dtype
:
return
True
else
:
return
False
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment