Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Jonas Plewinski
pystencils
Commits
e31f1062
Commit
e31f1062
authored
Apr 13, 2018
by
Martin Bauer
Browse files
flake8 linter
- removed warnings - added flake8 as CI target
parent
afc933d9
Changes
41
Hide whitespace changes
Inline
Side-by-side
__init__.py
View file @
e31f1062
"""Module to generate stencil kernels in C or CUDA using sympy expressions and call them as Python functions"""
from
.
import
sympy_gmpy_bug_workaround
from
.
import
sympy_gmpy_bug_workaround
# NOQA
from
.field
import
Field
,
FieldType
from
.data_types
import
TypedSymbol
from
.slicing
import
make_slice
...
...
assignment_collection/simplifications.py
View file @
e31f1062
...
...
@@ -98,4 +98,4 @@ def apply_on_all_subexpressions(operation: Callable[[sp.Expr], sp.Expr]) -> Call
result
=
[
Assignment
(
eq
.
lhs
,
operation
(
eq
.
rhs
))
for
eq
in
ac
.
subexpressions
]
return
ac
.
copy
(
ac
.
main_assignments
,
result
)
f
.
__name__
=
operation
.
__name__
return
f
\ No newline at end of file
return
f
assignment_collection/simplificationstrategy.py
View file @
e31f1062
...
...
@@ -84,7 +84,7 @@ class SimplificationStrategy(object):
report
=
Report
()
op
=
assignment_collection
.
operation_count
total
=
op
[
'adds'
]
+
op
[
'muls'
]
+
op
[
'divs'
]
report
.
add
(
ReportElement
(
"OriginalTerm"
,
'-'
,
op
[
'adds'
],
op
[
'muls'
],
op
[
'divs'
],
total
))
report
.
add
(
ReportElement
(
"OriginalTerm"
,
'-'
,
op
[
'adds'
],
op
[
'muls'
],
op
[
'divs'
],
total
))
for
t
in
self
.
_rules
:
start_time
=
timeit
.
default_timer
()
assignment_collection
=
t
(
assignment_collection
)
...
...
astnodes.py
View file @
e31f1062
...
...
@@ -60,7 +60,8 @@ class Conditional(Node):
false_block: optional block which is run if conditional is false
"""
def
__init__
(
self
,
condition_expr
:
sp
.
Basic
,
true_block
:
Union
[
'Block'
,
'SympyAssignment'
],
false_block
:
Optional
[
'Block'
]
=
None
)
->
None
:
def
__init__
(
self
,
condition_expr
:
sp
.
Basic
,
true_block
:
Union
[
'Block'
,
'SympyAssignment'
],
false_block
:
Optional
[
'Block'
]
=
None
)
->
None
:
super
(
Conditional
,
self
).
__init__
(
parent
=
None
)
assert
condition_expr
.
is_Boolean
or
condition_expr
.
is_Relational
...
...
@@ -379,7 +380,7 @@ class LoopOverCoordinate(Node):
return
None
if
symbol
.
dtype
!=
create_type
(
'int'
):
return
None
coordinate
=
int
(
symbol
.
name
[
len
(
prefix
)
+
1
:])
coordinate
=
int
(
symbol
.
name
[
len
(
prefix
)
+
1
:])
return
coordinate
@
staticmethod
...
...
backends/__init__.py
View file @
e31f1062
from
.cbackend
import
generate_c
__all__
=
[
'generate_c'
]
try
:
from
.dot
import
print_dot
from
.llvm
import
generate_llvm
from
.dot
import
print_dot
# NOQA
__all__
.
append
(
'print_dot'
)
except
ImportError
:
pass
try
:
from
.llvm
import
generate_llvm
# NOQA
__all__
.
append
(
'generate_llvm'
)
except
ImportError
:
pass
backends/cbackend.py
View file @
e31f1062
...
...
@@ -13,7 +13,7 @@ from pystencils.astnodes import Node, ResolvedFieldAccess, SympyAssignment
from
pystencils.data_types
import
create_type
,
PointerType
,
get_type_of_expression
,
VectorType
,
cast_func
from
pystencils.backends.simd_instruction_sets
import
selected_instruction_set
__all__
=
[
'generate_c'
,
'CustomCppCode'
,
'PrintNode'
,
'get_headers'
]
__all__
=
[
'generate_c'
,
'CustomCppCode'
,
'PrintNode'
,
'get_headers'
,
'CustomSympyPrinter'
]
def
generate_c
(
ast_node
:
Node
,
signature_only
:
bool
=
False
,
use_float_constants
:
Optional
[
bool
]
=
None
)
->
str
:
...
...
@@ -161,7 +161,8 @@ class CBackend:
def
_print_SympyAssignment
(
self
,
node
):
if
node
.
is_declaration
:
data_type
=
"const "
+
str
(
node
.
lhs
.
dtype
)
+
" "
if
node
.
is_const
else
str
(
node
.
lhs
.
dtype
)
+
" "
return
"%s %s = %s;"
%
(
data_type
,
self
.
sympy_printer
.
doprint
(
node
.
lhs
),
self
.
sympy_printer
.
doprint
(
node
.
rhs
))
return
"%s %s = %s;"
%
(
data_type
,
self
.
sympy_printer
.
doprint
(
node
.
lhs
),
self
.
sympy_printer
.
doprint
(
node
.
rhs
))
else
:
lhs_type
=
get_type_of_expression
(
node
.
lhs
)
if
type
(
lhs_type
)
is
VectorType
and
node
.
lhs
.
func
==
cast_func
:
...
...
backends/dot.py
View file @
e31f1062
...
...
@@ -104,4 +104,3 @@ def print_dot(node, view=False, short=False, full=False, **kwargs):
if
view
:
return
graphviz
.
Source
(
dot
)
return
dot
backends/simd_instruction_sets.py
View file @
e31f1062
...
...
@@ -20,7 +20,7 @@ def x86_vector_instruction_set(data_type='double', instruction_set='avx'):
'sqrt'
:
'sqrt[0]'
,
'makeVec'
:
'set[0,0,0,0]'
,
'makeVec'
:
'set[0,0,0,0]'
,
'makeZero'
:
'setzero[]'
,
'loadU'
:
'loadu[0]'
,
...
...
boundaries/__init__.py
View file @
e31f1062
from
pystencils.boundaries.boundaryhandling
import
BoundaryHandling
from
pystencils.boundaries.boundaryconditions
import
Neumann
from
pystencils.boundaries.inkernel
import
add_neumann_boundary
__all__
=
[
'BoundaryHandling'
,
'Neumann'
,
'add_neumann_boundary'
]
boundaries/boundaryhandling.py
View file @
e31f1062
...
...
@@ -20,7 +20,7 @@ class FlagInterface:
# Add flag field to data handling if it does not yet exist
if
data_handling
.
has_data
(
self
.
flag_field_name
):
raise
ValueError
(
"There is already a boundary handling registered at the data handling."
"If you want to add multiple handlings, choose a different name."
)
"If you want to add multiple handling
object
s, choose a different name."
)
data_handling
.
add_array
(
self
.
flag_field_name
,
dtype
=
self
.
FLAG_DTYPE
,
cpu
=
True
,
gpu
=
False
)
ff_ghost_layers
=
data_handling
.
ghost_layers_of_field
(
self
.
flag_field_name
)
...
...
@@ -47,7 +47,8 @@ class BoundaryHandling:
self
.
_boundary_object_to_boundary_info
=
{}
self
.
stencil
=
stencil
self
.
_dirty
=
True
self
.
flag_interface
=
flag_interface
if
flag_interface
is
not
None
else
FlagInterface
(
data_handling
,
name
+
"Flags"
)
fi
=
flag_interface
self
.
flag_interface
=
fi
if
fi
is
not
None
else
FlagInterface
(
data_handling
,
name
+
"Flags"
)
gpu
=
self
.
_target
==
'gpu'
data_handling
.
add_custom_class
(
self
.
_index_array_name
,
self
.
IndexFieldBlockData
,
cpu
=
True
,
gpu
=
gpu
)
...
...
@@ -121,7 +122,8 @@ class BoundaryHandling:
else
:
flag
=
self
.
_add_boundary
(
boundary_obj
)
for
b
in
self
.
_data_handling
.
iterate
(
slice_obj
,
ghost_layers
=
ghost_layers
,
inner_ghost_layers
=
inner_ghost_layers
):
for
b
in
self
.
_data_handling
.
iterate
(
slice_obj
,
ghost_layers
=
ghost_layers
,
inner_ghost_layers
=
inner_ghost_layers
):
flag_arr
=
b
[
self
.
flag_interface
.
flag_field_name
]
if
mask_callback
is
not
None
:
mask
=
mask_callback
(
*
b
.
midpoint_arrays
)
...
...
@@ -206,10 +208,10 @@ class BoundaryHandling:
def
_add_boundary
(
self
,
boundary_obj
,
flag
=
None
):
if
boundary_obj
not
in
self
.
_boundary_object_to_boundary_info
:
sym
bolic
_index_field
=
Field
.
create_generic
(
'indexField'
,
spatial_dimensions
=
1
,
dtype
=
numpy_data_type_for_boundary_object
(
boundary_obj
,
self
.
dim
))
sym_index_field
=
Field
.
create_generic
(
'indexField'
,
spatial_dimensions
=
1
,
dtype
=
numpy_data_type_for_boundary_object
(
boundary_obj
,
self
.
dim
))
ast
=
self
.
_create_boundary_kernel
(
self
.
_data_handling
.
fields
[
self
.
_field_name
],
sym
bolic
_index_field
,
boundary_obj
)
sym_index_field
,
boundary_obj
)
if
flag
is
None
:
flag
=
self
.
flag_interface
.
allocate_next_flag
()
boundary_info
=
self
.
BoundaryInfo
(
boundary_obj
,
flag
=
flag
,
kernel
=
ast
.
compile
())
...
...
@@ -253,7 +255,7 @@ class BoundaryHandling:
self
.
kernel
=
kernel
class
IndexFieldBlockData
:
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
_1
,
**
_2
):
self
.
boundary_object_to_index_list
=
{}
self
.
boundary_objectToDataSetter
=
{}
...
...
boundaries/createindexlist.py
View file @
e31f1062
...
...
@@ -3,7 +3,7 @@ import itertools
import
warnings
try
:
import
pyximport
;
import
pyximport
pyximport
.
install
()
from
pystencils.boundaries.createindexlistcython
import
create_boundary_index_list_2d
,
create_boundary_index_list_3d
...
...
@@ -31,7 +31,7 @@ def _create_boundary_index_list_python(flag_field_arr, nr_of_ghost_layers, bound
result
=
[]
gl
=
nr_of_ghost_layers
for
cell
in
itertools
.
product
(
*
reversed
([
range
(
gl
,
i
-
gl
)
for
i
in
flag_field_arr
.
shape
])):
for
cell
in
itertools
.
product
(
*
reversed
([
range
(
gl
,
i
-
gl
)
for
i
in
flag_field_arr
.
shape
])):
cell
=
cell
[::
-
1
]
if
not
flag_field_arr
[
cell
]
&
fluid_mask
:
continue
...
...
cpu/__init__.py
View file @
e31f1062
from
pystencils.cpu.kernelcreation
import
create_kernel
,
create_indexed_kernel
,
add_openmp
from
pystencils.cpu.cpujit
import
make_python_function
from
pystencils.backends.cbackend
import
generate_c
__all__
=
[
'create_kernel'
,
'create_indexed_kernel'
,
'add_openmp'
,
'make_python_function'
]
cpu/cpujit.py
View file @
e31f1062
...
...
@@ -247,7 +247,7 @@ def compile_object_cache_to_shared_library():
try
:
if
compiler_config
[
'os'
]
==
'windows'
:
all_object_files
=
glob
.
glob
(
os
.
path
.
join
(
cache_config
[
'object_cache'
],
'*.obj'
))
link_cmd
=
[
'link.exe'
,
'/DLL'
,
'/out:'
+
shared_library
]
link_cmd
=
[
'link.exe'
,
'/DLL'
,
'/out:'
+
shared_library
]
else
:
all_object_files
=
glob
.
glob
(
os
.
path
.
join
(
cache_config
[
'object_cache'
],
'*.o'
))
link_cmd
=
[
compiler_config
[
'command'
],
'-shared'
,
'-o'
,
shared_library
]
...
...
@@ -318,7 +318,7 @@ def compile_windows(ast, code_hash_str, src_file, lib_file):
# Compilation
if
not
os
.
path
.
exists
(
object_file
):
generate_code
(
ast
,
compiler_config
[
'restrict_qualifier'
],
'__declspec(dllexport)'
,
src_file
)
'__declspec(dllexport)'
,
src_file
)
# /c compiles only, /EHsc turns of exception handling in c code
compile_cmd
=
[
'cl.exe'
,
'/c'
,
'/EHsc'
]
+
compiler_config
[
'flags'
].
split
()
...
...
cpu/kernelcreation.py
View file @
e31f1062
...
...
@@ -2,8 +2,8 @@ import sympy as sp
from
functools
import
partial
from
pystencils.astnodes
import
SympyAssignment
,
Block
,
LoopOverCoordinate
,
KernelFunction
from
pystencils.transformations
import
resolve_buffer_accesses
,
resolve_field_accesses
,
make_loop_over_domain
,
\
type_all_equations
,
get_optimal_loop_ordering
,
parse_base_pointer_info
,
move_constants_before_loop
,
split_inner_loop
,
\
substitute_array_accesses_with_constants
type_all_equations
,
get_optimal_loop_ordering
,
parse_base_pointer_info
,
move_constants_before_loop
,
\
split_inner_loop
,
substitute_array_accesses_with_constants
from
pystencils.data_types
import
TypedSymbol
,
BasicType
,
StructType
,
create_type
from
pystencils.field
import
Field
,
FieldType
import
pystencils.astnodes
as
ast
...
...
@@ -175,7 +175,7 @@ def add_openmp(ast_node, schedule="static", num_threads=True):
outer_loops
=
[
l
for
l
in
body
.
atoms
(
ast
.
LoopOverCoordinate
)
if
l
.
is_outermost_loop
]
assert
outer_loops
,
"No outer loop found"
assert
len
(
outer_loops
)
<=
1
,
"More than one outer loop found.
Which one should be parallelized?
"
assert
len
(
outer_loops
)
<=
1
,
"More than one outer loop found.
Not clear where to put OpenMP pragma.
"
loop_to_parallelize
=
outer_loops
[
0
]
try
:
loop_range
=
int
(
loop_to_parallelize
.
stop
-
loop_to_parallelize
.
start
)
...
...
datahandling/datahandling_interface.py
View file @
e31f1062
...
...
@@ -352,7 +352,7 @@ class Block:
@
property
def
global_slice
(
self
):
"""Slice in global coordinates."""
return
tuple
(
slice
(
off
,
off
+
size
)
for
off
,
size
in
zip
(
self
.
_offset
,
self
.
shape
))
return
tuple
(
slice
(
off
,
off
+
size
)
for
off
,
size
in
zip
(
self
.
_offset
,
self
.
shape
))
def
__getitem__
(
self
,
data_name
:
str
)
->
np
.
ndarray
:
raise
NotImplementedError
()
datahandling/serial_datahandling.py
View file @
e31f1062
...
...
@@ -10,7 +10,7 @@ from pystencils.utils import DotDict
try
:
import
pycuda.gpuarray
as
gpuarray
import
pycuda.autoinit
import
pycuda.autoinit
# NOQA
except
ImportError
:
gpuarray
=
None
...
...
@@ -276,13 +276,12 @@ class SerialDataHandling(DataHandling):
from
pystencils.slicing
import
get_periodic_boundary_functor
result
.
append
(
get_periodic_boundary_functor
(
filtered_stencil
,
ghost_layers
=
gls
))
else
:
from
pystencils.gpucuda.periodicity
import
get_periodic_boundary_functor
result
.
append
(
get_periodic_boundary_functor
(
filtered_stencil
,
self
.
_domainSize
,
index_dimensions
=
self
.
fields
[
name
].
index_dimensions
,
index_dim_shape
=
self
.
_field_information
[
name
][
'values_per_cell'
],
dtype
=
self
.
fields
[
name
].
dtype
.
numpy_dtype
,
ghost_layers
=
gls
))
from
pystencils.gpucuda.periodicity
import
get_periodic_boundary_functor
as
boundary_func
result
.
append
(
boundary_func
(
filtered_stencil
,
self
.
_domainSize
,
index_dimensions
=
self
.
fields
[
name
].
index_dimensions
,
index_dim_shape
=
self
.
_field_information
[
name
][
'values_per_cell'
],
dtype
=
self
.
fields
[
name
].
dtype
.
numpy_dtype
,
ghost_layers
=
gls
))
if
target
==
'cpu'
:
def
result_functor
():
...
...
derivative.py
View file @
e31f1062
...
...
@@ -149,6 +149,7 @@ class DiffOperator(sp.Expr):
Multiplications of 'DiffOperator's are interpreted as nested application of differentiation:
i.e. DiffOperator('x')*DiffOperator('x') is a second derivative replaced by Diff(Diff(arg, x), t)
"""
def
handle_mul
(
mul
):
args
=
normalize_product
(
mul
)
diffs
=
[
a
for
a
in
args
if
isinstance
(
a
,
DiffOperator
)]
...
...
@@ -169,6 +170,7 @@ class DiffOperator(sp.Expr):
else
:
return
expr
*
argument
if
apply_to_constants
else
expr
# ----------------------------------------------------------------------------------------------------------------------
...
...
@@ -186,6 +188,7 @@ def derivative_terms(expr):
else
:
for
a
in
e
.
args
:
visit
(
a
)
visit
(
expr
)
return
result
...
...
@@ -261,7 +264,7 @@ def full_diff_expand(expr, functions=None, constants=None):
independent_terms
*=
factor
for
i
in
range
(
len
(
dependent_terms
)):
dependent_term
=
dependent_terms
[
i
]
other_dependent_terms
=
dependent_terms
[:
i
]
+
dependent_terms
[
i
+
1
:]
other_dependent_terms
=
dependent_terms
[:
i
]
+
dependent_terms
[
i
+
1
:]
processed_diff
=
normalize_diff_order
(
Diff
(
dependent_term
,
**
diff_args
))
result
+=
independent_terms
*
prod
(
other_dependent_terms
)
*
processed_diff
return
result
...
...
@@ -278,6 +281,7 @@ def full_diff_expand(expr, functions=None, constants=None):
def
normalize_diff_order
(
expression
,
functions
=
None
,
constants
=
None
,
sort_key
=
default_diff_sort_key
):
"""Assumes order of differentiation can be exchanged. Changes the order of nested Diffs to a standard order defined
by the sorting key 'sort_key' such that the derivative terms can be further simplified """
def
visit
(
expr
):
if
isinstance
(
expr
,
Diff
):
nodes
=
[
expr
]
...
...
@@ -425,12 +429,14 @@ def replace_diff(expr, replacement_dict):
def
zero_diffs
(
expr
,
label
):
"""Replaces all differentials with the given target by 0"""
def
visit
(
e
):
if
isinstance
(
e
,
Diff
):
if
e
.
target
==
label
:
return
0
new_args
=
[
visit
(
arg
)
for
arg
in
e
.
args
]
return
e
.
func
(
*
new_args
)
if
new_args
else
e
return
visit
(
expr
)
...
...
display_utils.py
View file @
e31f1062
...
...
@@ -37,7 +37,7 @@ def show_code(ast: KernelFunction):
Can either be displayed as HTML in Jupyter notebooks or printed as normal string.
"""
from
pystencils.
cpu
import
generate_c
from
pystencils.
backends.cbackend
import
generate_c
class
CodeDisplay
:
def
__init__
(
self
,
ast_input
):
...
...
field.py
View file @
e31f1062
...
...
@@ -5,8 +5,6 @@ import numpy as np
import
sympy
as
sp
from
sympy.core.cache
import
cacheit
from
sympy.tensor
import
IndexedBase
from
pystencils.assignment
import
Assignment
from
pystencils.alignedarray
import
aligned_empty
from
pystencils.data_types
import
TypedSymbol
,
create_type
,
create_composite_type_from_string
,
StructType
from
pystencils.sympyextensions
import
is_integer_sequence
...
...
@@ -69,6 +67,7 @@ class Field(object):
>>> jacobi = ( f[-1,0] + f[1,0] + f[0,-1] + f[0,1] ) / 4
Example with index dimensions: LBM D2Q9 stream pull
>>> from pystencils import Assignment
>>> stencil = np.array([[0,0], [0,1], [0,-1]])
>>> src = Field.create_generic("src", spatial_dimensions=2, index_dimensions=1)
>>> dst = Field.create_generic("dst", spatial_dimensions=2, index_dimensions=1)
...
...
@@ -366,7 +365,7 @@ class Field(object):
__xnew_cached_
=
staticmethod
(
cacheit
(
__new_stage2__
))
def
__call__
(
self
,
*
idx
):
if
self
.
_index
!=
tuple
([
0
]
*
self
.
field
.
index_dimensions
):
if
self
.
_index
!=
tuple
([
0
]
*
self
.
field
.
index_dimensions
):
raise
ValueError
(
"Indexing an already indexed Field.Access"
)
idx
=
tuple
(
idx
)
...
...
@@ -520,7 +519,7 @@ def layout_string_to_tuple(layout_str, dim):
return
tuple
(
reversed
(
range
(
dim
)))
elif
layout_str
==
'zyxf'
or
layout_str
==
'aos'
:
assert
dim
<=
4
return
tuple
(
reversed
(
range
(
dim
-
1
)))
+
(
dim
-
1
,)
return
tuple
(
reversed
(
range
(
dim
-
1
)))
+
(
dim
-
1
,)
elif
layout_str
==
'f'
or
layout_str
==
'reverse_numpy'
:
return
tuple
(
reversed
(
range
(
dim
)))
elif
layout_str
==
'c'
or
layout_str
==
'numpy'
:
...
...
finitedifferences.py
View file @
e31f1062
...
...
@@ -103,7 +103,7 @@ def discretize_staggered(term, symbols_to_field_dict, coordinate, coordinate_off
up
,
down
=
__up_down_offsets
(
d
,
dim
)
for
i
,
s
in
enumerate
(
symbols
):
center_grad
=
(
field
[
up
](
i
)
-
field
[
down
](
i
))
/
(
2
*
dx
)
neighbor_grad
=
(
field
[
up
+
offset
](
i
)
-
field
[
down
+
offset
](
i
))
/
(
2
*
dx
)
neighbor_grad
=
(
field
[
up
+
offset
](
i
)
-
field
[
down
+
offset
](
i
))
/
(
2
*
dx
)
substitutions
[
grad
(
s
)[
d
]]
=
(
center_grad
+
neighbor_grad
)
/
2
return
fast_subs
(
term
,
substitutions
)
...
...
@@ -170,9 +170,9 @@ class Advection(sp.Function):
name_suffix
=
"_%s"
%
self
.
scalar_index
if
self
.
scalar_index
is
not
None
else
""
if
isinstance
(
self
.
vector
,
Field
):
return
r
"\nabla \cdot(%s %s)"
%
(
printer
.
doprint
(
sp
.
Symbol
(
self
.
vector
.
name
)),
printer
.
doprint
(
sp
.
Symbol
(
self
.
scalar
.
name
+
name_suffix
)))
printer
.
doprint
(
sp
.
Symbol
(
self
.
scalar
.
name
+
name_suffix
)))
else
:
args
=
[
r
"\partial_%d(%s %s)"
%
(
i
,
printer
.
doprint
(
sp
.
Symbol
(
self
.
scalar
.
name
+
name_suffix
)),
args
=
[
r
"\partial_%d(%s %s)"
%
(
i
,
printer
.
doprint
(
sp
.
Symbol
(
self
.
scalar
.
name
+
name_suffix
)),
printer
.
doprint
(
self
.
vector
[
i
]))
for
i
in
range
(
self
.
dim
)]
return
" + "
.
join
(
args
)
...
...
@@ -233,7 +233,7 @@ class Diffusion(sp.Function):
coeff
=
self
.
diffusion_coeff
diff_coeff
=
sp
.
Symbol
(
coeff
.
name
)
if
isinstance
(
coeff
,
Field
)
else
coeff
return
r
"div(%s \nabla %s)"
%
(
printer
.
doprint
(
diff_coeff
),
printer
.
doprint
(
sp
.
Symbol
(
self
.
scalar
.
name
+
name_suffix
)))
printer
.
doprint
(
sp
.
Symbol
(
self
.
scalar
.
name
+
name_suffix
)))
# --- Interface for discretization strategy
...
...
@@ -277,7 +277,7 @@ class Transient(sp.Function):
def
_latex
(
self
,
printer
):
name_suffix
=
"_%s"
%
self
.
scalar_index
if
self
.
scalar_index
is
not
None
else
""
return
r
"\partial_t %s"
%
(
printer
.
doprint
(
sp
.
Symbol
(
self
.
scalar
.
name
+
name_suffix
)),)
return
r
"\partial_t %s"
%
(
printer
.
doprint
(
sp
.
Symbol
(
self
.
scalar
.
name
+
name_suffix
)),)
def
transient
(
scalar
,
idx
=
None
):
...
...
@@ -312,7 +312,7 @@ class Discretization2ndOrder:
-
expr
.
diffusion_scalar_at_offset
(
0
,
0
)
*
expr
.
diffusion_coefficient_at_offset
(
0
,
0
))
for
offset
in
[
-
1
,
1
]]
result
+=
first_diffs
[
1
]
-
first_diffs
[
0
]
return
result
/
(
self
.
dx
**
2
)
return
result
/
(
self
.
dx
**
2
)
def
_discretize_advection
(
self
,
expr
):
result
=
0
...
...
@@ -352,8 +352,8 @@ class Discretization2ndOrder:
else
:
assert
all
(
i
>=
0
for
i
in
indices
)
offsets
=
[(
1
,
1
),
[
-
1
,
1
],
[
1
,
-
1
],
[
-
1
,
-
1
]]
result
=
sum
(
o1
*
o2
*
fa
.
neighbor
(
indices
[
0
],
o1
).
neighbor
(
indices
[
1
],
o2
)
for
o1
,
o2
in
offsets
)
/
4
return
result
/
(
self
.
dx
**
2
)
result
=
sum
(
o1
*
o2
*
fa
.
neighbor
(
indices
[
0
],
o1
).
neighbor
(
indices
[
1
],
o2
)
for
o1
,
o2
in
offsets
)
/
4
return
result
/
(
self
.
dx
**
2
)
else
:
raise
NotImplementedError
(
"Term contains derivatives of order > 2"
)
...
...
@@ -380,4 +380,3 @@ class Discretization2ndOrder:
else
:
print
(
transient_terms
)
raise
NotImplementedError
(
"Cannot discretize expression with more than one transient term"
)
Prev
1
2
3
Next
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment