Commit e31f1062 authored by Martin Bauer's avatar Martin Bauer
Browse files

flake8 linter

- removed warnings
- added flake8 as CI target
parent afc933d9
"""Module to generate stencil kernels in C or CUDA using sympy expressions and call them as Python functions""" """Module to generate stencil kernels in C or CUDA using sympy expressions and call them as Python functions"""
from . import sympy_gmpy_bug_workaround from . import sympy_gmpy_bug_workaround # NOQA
from .field import Field, FieldType from .field import Field, FieldType
from .data_types import TypedSymbol from .data_types import TypedSymbol
from .slicing import make_slice from .slicing import make_slice
......
...@@ -98,4 +98,4 @@ def apply_on_all_subexpressions(operation: Callable[[sp.Expr], sp.Expr]) -> Call ...@@ -98,4 +98,4 @@ def apply_on_all_subexpressions(operation: Callable[[sp.Expr], sp.Expr]) -> Call
result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in ac.subexpressions] result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in ac.subexpressions]
return ac.copy(ac.main_assignments, result) return ac.copy(ac.main_assignments, result)
f.__name__ = operation.__name__ f.__name__ = operation.__name__
return f return f
\ No newline at end of file
...@@ -84,7 +84,7 @@ class SimplificationStrategy(object): ...@@ -84,7 +84,7 @@ class SimplificationStrategy(object):
report = Report() report = Report()
op = assignment_collection.operation_count op = assignment_collection.operation_count
total = op['adds'] + op['muls'] + op['divs'] total = op['adds'] + op['muls'] + op['divs']
report.add(ReportElement("OriginalTerm", '-', op['adds'], op['muls'], op['divs'], total)) report.add(ReportElement("OriginalTerm", '-', op['adds'], op['muls'], op['divs'], total))
for t in self._rules: for t in self._rules:
start_time = timeit.default_timer() start_time = timeit.default_timer()
assignment_collection = t(assignment_collection) assignment_collection = t(assignment_collection)
......
...@@ -60,7 +60,8 @@ class Conditional(Node): ...@@ -60,7 +60,8 @@ class Conditional(Node):
false_block: optional block which is run if conditional is false false_block: optional block which is run if conditional is false
""" """
def __init__(self, condition_expr: sp.Basic, true_block: Union['Block', 'SympyAssignment'], false_block: Optional['Block'] = None) -> None: def __init__(self, condition_expr: sp.Basic, true_block: Union['Block', 'SympyAssignment'],
false_block: Optional['Block'] = None) -> None:
super(Conditional, self).__init__(parent=None) super(Conditional, self).__init__(parent=None)
assert condition_expr.is_Boolean or condition_expr.is_Relational assert condition_expr.is_Boolean or condition_expr.is_Relational
...@@ -379,7 +380,7 @@ class LoopOverCoordinate(Node): ...@@ -379,7 +380,7 @@ class LoopOverCoordinate(Node):
return None return None
if symbol.dtype != create_type('int'): if symbol.dtype != create_type('int'):
return None return None
coordinate = int(symbol.name[len(prefix)+1:]) coordinate = int(symbol.name[len(prefix) + 1:])
return coordinate return coordinate
@staticmethod @staticmethod
......
from .cbackend import generate_c from .cbackend import generate_c
__all__ = ['generate_c']
try: try:
from .dot import print_dot from .dot import print_dot # NOQA
from .llvm import generate_llvm __all__.append('print_dot')
except ImportError:
pass
try:
from .llvm import generate_llvm # NOQA
__all__.append('generate_llvm')
except ImportError: except ImportError:
pass pass
...@@ -13,7 +13,7 @@ from pystencils.astnodes import Node, ResolvedFieldAccess, SympyAssignment ...@@ -13,7 +13,7 @@ from pystencils.astnodes import Node, ResolvedFieldAccess, SympyAssignment
from pystencils.data_types import create_type, PointerType, get_type_of_expression, VectorType, cast_func from pystencils.data_types import create_type, PointerType, get_type_of_expression, VectorType, cast_func
from pystencils.backends.simd_instruction_sets import selected_instruction_set from pystencils.backends.simd_instruction_sets import selected_instruction_set
__all__ = ['generate_c', 'CustomCppCode', 'PrintNode', 'get_headers'] __all__ = ['generate_c', 'CustomCppCode', 'PrintNode', 'get_headers', 'CustomSympyPrinter']
def generate_c(ast_node: Node, signature_only: bool = False, use_float_constants: Optional[bool] = None) -> str: def generate_c(ast_node: Node, signature_only: bool = False, use_float_constants: Optional[bool] = None) -> str:
...@@ -161,7 +161,8 @@ class CBackend: ...@@ -161,7 +161,8 @@ class CBackend:
def _print_SympyAssignment(self, node): def _print_SympyAssignment(self, node):
if node.is_declaration: if node.is_declaration:
data_type = "const " + str(node.lhs.dtype) + " " if node.is_const else str(node.lhs.dtype) + " " data_type = "const " + str(node.lhs.dtype) + " " if node.is_const else str(node.lhs.dtype) + " "
return "%s %s = %s;" % (data_type, self.sympy_printer.doprint(node.lhs), self.sympy_printer.doprint(node.rhs)) return "%s %s = %s;" % (data_type, self.sympy_printer.doprint(node.lhs),
self.sympy_printer.doprint(node.rhs))
else: else:
lhs_type = get_type_of_expression(node.lhs) lhs_type = get_type_of_expression(node.lhs)
if type(lhs_type) is VectorType and node.lhs.func == cast_func: if type(lhs_type) is VectorType and node.lhs.func == cast_func:
......
...@@ -104,4 +104,3 @@ def print_dot(node, view=False, short=False, full=False, **kwargs): ...@@ -104,4 +104,3 @@ def print_dot(node, view=False, short=False, full=False, **kwargs):
if view: if view:
return graphviz.Source(dot) return graphviz.Source(dot)
return dot return dot
...@@ -20,7 +20,7 @@ def x86_vector_instruction_set(data_type='double', instruction_set='avx'): ...@@ -20,7 +20,7 @@ def x86_vector_instruction_set(data_type='double', instruction_set='avx'):
'sqrt': 'sqrt[0]', 'sqrt': 'sqrt[0]',
'makeVec': 'set[0,0,0,0]', 'makeVec': 'set[0,0,0,0]',
'makeZero': 'setzero[]', 'makeZero': 'setzero[]',
'loadU': 'loadu[0]', 'loadU': 'loadu[0]',
......
from pystencils.boundaries.boundaryhandling import BoundaryHandling from pystencils.boundaries.boundaryhandling import BoundaryHandling
from pystencils.boundaries.boundaryconditions import Neumann from pystencils.boundaries.boundaryconditions import Neumann
from pystencils.boundaries.inkernel import add_neumann_boundary from pystencils.boundaries.inkernel import add_neumann_boundary
__all__ = ['BoundaryHandling', 'Neumann', 'add_neumann_boundary']
...@@ -20,7 +20,7 @@ class FlagInterface: ...@@ -20,7 +20,7 @@ class FlagInterface:
# Add flag field to data handling if it does not yet exist # Add flag field to data handling if it does not yet exist
if data_handling.has_data(self.flag_field_name): if data_handling.has_data(self.flag_field_name):
raise ValueError("There is already a boundary handling registered at the data handling." raise ValueError("There is already a boundary handling registered at the data handling."
"If you want to add multiple handlings, choose a different name.") "If you want to add multiple handling objects, choose a different name.")
data_handling.add_array(self.flag_field_name, dtype=self.FLAG_DTYPE, cpu=True, gpu=False) data_handling.add_array(self.flag_field_name, dtype=self.FLAG_DTYPE, cpu=True, gpu=False)
ff_ghost_layers = data_handling.ghost_layers_of_field(self.flag_field_name) ff_ghost_layers = data_handling.ghost_layers_of_field(self.flag_field_name)
...@@ -47,7 +47,8 @@ class BoundaryHandling: ...@@ -47,7 +47,8 @@ class BoundaryHandling:
self._boundary_object_to_boundary_info = {} self._boundary_object_to_boundary_info = {}
self.stencil = stencil self.stencil = stencil
self._dirty = True self._dirty = True
self.flag_interface = flag_interface if flag_interface is not None else FlagInterface(data_handling, name + "Flags") fi = flag_interface
self.flag_interface = fi if fi is not None else FlagInterface(data_handling, name + "Flags")
gpu = self._target == 'gpu' gpu = self._target == 'gpu'
data_handling.add_custom_class(self._index_array_name, self.IndexFieldBlockData, cpu=True, gpu=gpu) data_handling.add_custom_class(self._index_array_name, self.IndexFieldBlockData, cpu=True, gpu=gpu)
...@@ -121,7 +122,8 @@ class BoundaryHandling: ...@@ -121,7 +122,8 @@ class BoundaryHandling:
else: else:
flag = self._add_boundary(boundary_obj) flag = self._add_boundary(boundary_obj)
for b in self._data_handling.iterate(slice_obj, ghost_layers=ghost_layers, inner_ghost_layers=inner_ghost_layers): for b in self._data_handling.iterate(slice_obj, ghost_layers=ghost_layers,
inner_ghost_layers=inner_ghost_layers):
flag_arr = b[self.flag_interface.flag_field_name] flag_arr = b[self.flag_interface.flag_field_name]
if mask_callback is not None: if mask_callback is not None:
mask = mask_callback(*b.midpoint_arrays) mask = mask_callback(*b.midpoint_arrays)
...@@ -206,10 +208,10 @@ class BoundaryHandling: ...@@ -206,10 +208,10 @@ class BoundaryHandling:
def _add_boundary(self, boundary_obj, flag=None): def _add_boundary(self, boundary_obj, flag=None):
if boundary_obj not in self._boundary_object_to_boundary_info: if boundary_obj not in self._boundary_object_to_boundary_info:
symbolic_index_field = Field.create_generic('indexField', spatial_dimensions=1, sym_index_field = Field.create_generic('indexField', spatial_dimensions=1,
dtype=numpy_data_type_for_boundary_object(boundary_obj, self.dim)) dtype=numpy_data_type_for_boundary_object(boundary_obj, self.dim))
ast = self._create_boundary_kernel(self._data_handling.fields[self._field_name], ast = self._create_boundary_kernel(self._data_handling.fields[self._field_name],
symbolic_index_field, boundary_obj) sym_index_field, boundary_obj)
if flag is None: if flag is None:
flag = self.flag_interface.allocate_next_flag() flag = self.flag_interface.allocate_next_flag()
boundary_info = self.BoundaryInfo(boundary_obj, flag=flag, kernel=ast.compile()) boundary_info = self.BoundaryInfo(boundary_obj, flag=flag, kernel=ast.compile())
...@@ -253,7 +255,7 @@ class BoundaryHandling: ...@@ -253,7 +255,7 @@ class BoundaryHandling:
self.kernel = kernel self.kernel = kernel
class IndexFieldBlockData: class IndexFieldBlockData:
def __init__(self, *args, **kwargs): def __init__(self, *_1, **_2):
self.boundary_object_to_index_list = {} self.boundary_object_to_index_list = {}
self.boundary_objectToDataSetter = {} self.boundary_objectToDataSetter = {}
......
...@@ -3,7 +3,7 @@ import itertools ...@@ -3,7 +3,7 @@ import itertools
import warnings import warnings
try: try:
import pyximport; import pyximport
pyximport.install() pyximport.install()
from pystencils.boundaries.createindexlistcython import create_boundary_index_list_2d, create_boundary_index_list_3d from pystencils.boundaries.createindexlistcython import create_boundary_index_list_2d, create_boundary_index_list_3d
...@@ -31,7 +31,7 @@ def _create_boundary_index_list_python(flag_field_arr, nr_of_ghost_layers, bound ...@@ -31,7 +31,7 @@ def _create_boundary_index_list_python(flag_field_arr, nr_of_ghost_layers, bound
result = [] result = []
gl = nr_of_ghost_layers gl = nr_of_ghost_layers
for cell in itertools.product(*reversed([range(gl, i-gl) for i in flag_field_arr.shape])): for cell in itertools.product(*reversed([range(gl, i - gl) for i in flag_field_arr.shape])):
cell = cell[::-1] cell = cell[::-1]
if not flag_field_arr[cell] & fluid_mask: if not flag_field_arr[cell] & fluid_mask:
continue continue
......
from pystencils.cpu.kernelcreation import create_kernel, create_indexed_kernel, add_openmp from pystencils.cpu.kernelcreation import create_kernel, create_indexed_kernel, add_openmp
from pystencils.cpu.cpujit import make_python_function from pystencils.cpu.cpujit import make_python_function
from pystencils.backends.cbackend import generate_c
__all__ = ['create_kernel', 'create_indexed_kernel', 'add_openmp', 'make_python_function']
...@@ -247,7 +247,7 @@ def compile_object_cache_to_shared_library(): ...@@ -247,7 +247,7 @@ def compile_object_cache_to_shared_library():
try: try:
if compiler_config['os'] == 'windows': if compiler_config['os'] == 'windows':
all_object_files = glob.glob(os.path.join(cache_config['object_cache'], '*.obj')) all_object_files = glob.glob(os.path.join(cache_config['object_cache'], '*.obj'))
link_cmd = ['link.exe', '/DLL', '/out:' + shared_library] link_cmd = ['link.exe', '/DLL', '/out:' + shared_library]
else: else:
all_object_files = glob.glob(os.path.join(cache_config['object_cache'], '*.o')) all_object_files = glob.glob(os.path.join(cache_config['object_cache'], '*.o'))
link_cmd = [compiler_config['command'], '-shared', '-o', shared_library] link_cmd = [compiler_config['command'], '-shared', '-o', shared_library]
...@@ -318,7 +318,7 @@ def compile_windows(ast, code_hash_str, src_file, lib_file): ...@@ -318,7 +318,7 @@ def compile_windows(ast, code_hash_str, src_file, lib_file):
# Compilation # Compilation
if not os.path.exists(object_file): if not os.path.exists(object_file):
generate_code(ast, compiler_config['restrict_qualifier'], generate_code(ast, compiler_config['restrict_qualifier'],
'__declspec(dllexport)', src_file) '__declspec(dllexport)', src_file)
# /c compiles only, /EHsc turns of exception handling in c code # /c compiles only, /EHsc turns of exception handling in c code
compile_cmd = ['cl.exe', '/c', '/EHsc'] + compiler_config['flags'].split() compile_cmd = ['cl.exe', '/c', '/EHsc'] + compiler_config['flags'].split()
......
...@@ -2,8 +2,8 @@ import sympy as sp ...@@ -2,8 +2,8 @@ import sympy as sp
from functools import partial from functools import partial
from pystencils.astnodes import SympyAssignment, Block, LoopOverCoordinate, KernelFunction from pystencils.astnodes import SympyAssignment, Block, LoopOverCoordinate, KernelFunction
from pystencils.transformations import resolve_buffer_accesses, resolve_field_accesses, make_loop_over_domain, \ from pystencils.transformations import resolve_buffer_accesses, resolve_field_accesses, make_loop_over_domain, \
type_all_equations, get_optimal_loop_ordering, parse_base_pointer_info, move_constants_before_loop, split_inner_loop, \ type_all_equations, get_optimal_loop_ordering, parse_base_pointer_info, move_constants_before_loop, \
substitute_array_accesses_with_constants split_inner_loop, substitute_array_accesses_with_constants
from pystencils.data_types import TypedSymbol, BasicType, StructType, create_type from pystencils.data_types import TypedSymbol, BasicType, StructType, create_type
from pystencils.field import Field, FieldType from pystencils.field import Field, FieldType
import pystencils.astnodes as ast import pystencils.astnodes as ast
...@@ -175,7 +175,7 @@ def add_openmp(ast_node, schedule="static", num_threads=True): ...@@ -175,7 +175,7 @@ def add_openmp(ast_node, schedule="static", num_threads=True):
outer_loops = [l for l in body.atoms(ast.LoopOverCoordinate) if l.is_outermost_loop] outer_loops = [l for l in body.atoms(ast.LoopOverCoordinate) if l.is_outermost_loop]
assert outer_loops, "No outer loop found" assert outer_loops, "No outer loop found"
assert len(outer_loops) <= 1, "More than one outer loop found. Which one should be parallelized?" assert len(outer_loops) <= 1, "More than one outer loop found. Not clear where to put OpenMP pragma."
loop_to_parallelize = outer_loops[0] loop_to_parallelize = outer_loops[0]
try: try:
loop_range = int(loop_to_parallelize.stop - loop_to_parallelize.start) loop_range = int(loop_to_parallelize.stop - loop_to_parallelize.start)
......
...@@ -352,7 +352,7 @@ class Block: ...@@ -352,7 +352,7 @@ class Block:
@property @property
def global_slice(self): def global_slice(self):
"""Slice in global coordinates.""" """Slice in global coordinates."""
return tuple(slice(off, off+size) for off, size in zip(self._offset, self.shape)) return tuple(slice(off, off + size) for off, size in zip(self._offset, self.shape))
def __getitem__(self, data_name: str) -> np.ndarray: def __getitem__(self, data_name: str) -> np.ndarray:
raise NotImplementedError() raise NotImplementedError()
...@@ -10,7 +10,7 @@ from pystencils.utils import DotDict ...@@ -10,7 +10,7 @@ from pystencils.utils import DotDict
try: try:
import pycuda.gpuarray as gpuarray import pycuda.gpuarray as gpuarray
import pycuda.autoinit import pycuda.autoinit # NOQA
except ImportError: except ImportError:
gpuarray = None gpuarray = None
...@@ -276,13 +276,12 @@ class SerialDataHandling(DataHandling): ...@@ -276,13 +276,12 @@ class SerialDataHandling(DataHandling):
from pystencils.slicing import get_periodic_boundary_functor from pystencils.slicing import get_periodic_boundary_functor
result.append(get_periodic_boundary_functor(filtered_stencil, ghost_layers=gls)) result.append(get_periodic_boundary_functor(filtered_stencil, ghost_layers=gls))
else: else:
from pystencils.gpucuda.periodicity import get_periodic_boundary_functor from pystencils.gpucuda.periodicity import get_periodic_boundary_functor as boundary_func
result.append(get_periodic_boundary_functor(filtered_stencil, self._domainSize, result.append(boundary_func(filtered_stencil, self._domainSize,
index_dimensions=self.fields[name].index_dimensions, index_dimensions=self.fields[name].index_dimensions,
index_dim_shape=self._field_information[name][ index_dim_shape=self._field_information[name]['values_per_cell'],
'values_per_cell'], dtype=self.fields[name].dtype.numpy_dtype,
dtype=self.fields[name].dtype.numpy_dtype, ghost_layers=gls))
ghost_layers=gls))
if target == 'cpu': if target == 'cpu':
def result_functor(): def result_functor():
......
...@@ -149,6 +149,7 @@ class DiffOperator(sp.Expr): ...@@ -149,6 +149,7 @@ class DiffOperator(sp.Expr):
Multiplications of 'DiffOperator's are interpreted as nested application of differentiation: Multiplications of 'DiffOperator's are interpreted as nested application of differentiation:
i.e. DiffOperator('x')*DiffOperator('x') is a second derivative replaced by Diff(Diff(arg, x), t) i.e. DiffOperator('x')*DiffOperator('x') is a second derivative replaced by Diff(Diff(arg, x), t)
""" """
def handle_mul(mul): def handle_mul(mul):
args = normalize_product(mul) args = normalize_product(mul)
diffs = [a for a in args if isinstance(a, DiffOperator)] diffs = [a for a in args if isinstance(a, DiffOperator)]
...@@ -169,6 +170,7 @@ class DiffOperator(sp.Expr): ...@@ -169,6 +170,7 @@ class DiffOperator(sp.Expr):
else: else:
return expr * argument if apply_to_constants else expr return expr * argument if apply_to_constants else expr
# ---------------------------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------------------------
...@@ -186,6 +188,7 @@ def derivative_terms(expr): ...@@ -186,6 +188,7 @@ def derivative_terms(expr):
else: else:
for a in e.args: for a in e.args:
visit(a) visit(a)
visit(expr) visit(expr)
return result return result
...@@ -261,7 +264,7 @@ def full_diff_expand(expr, functions=None, constants=None): ...@@ -261,7 +264,7 @@ def full_diff_expand(expr, functions=None, constants=None):
independent_terms *= factor independent_terms *= factor
for i in range(len(dependent_terms)): for i in range(len(dependent_terms)):
dependent_term = dependent_terms[i] dependent_term = dependent_terms[i]
other_dependent_terms = dependent_terms[:i] + dependent_terms[i+1:] other_dependent_terms = dependent_terms[:i] + dependent_terms[i + 1:]
processed_diff = normalize_diff_order(Diff(dependent_term, **diff_args)) processed_diff = normalize_diff_order(Diff(dependent_term, **diff_args))
result += independent_terms * prod(other_dependent_terms) * processed_diff result += independent_terms * prod(other_dependent_terms) * processed_diff
return result return result
...@@ -278,6 +281,7 @@ def full_diff_expand(expr, functions=None, constants=None): ...@@ -278,6 +281,7 @@ def full_diff_expand(expr, functions=None, constants=None):
def normalize_diff_order(expression, functions=None, constants=None, sort_key=default_diff_sort_key): def normalize_diff_order(expression, functions=None, constants=None, sort_key=default_diff_sort_key):
"""Assumes order of differentiation can be exchanged. Changes the order of nested Diffs to a standard order defined """Assumes order of differentiation can be exchanged. Changes the order of nested Diffs to a standard order defined
by the sorting key 'sort_key' such that the derivative terms can be further simplified """ by the sorting key 'sort_key' such that the derivative terms can be further simplified """
def visit(expr): def visit(expr):
if isinstance(expr, Diff): if isinstance(expr, Diff):
nodes = [expr] nodes = [expr]
...@@ -425,12 +429,14 @@ def replace_diff(expr, replacement_dict): ...@@ -425,12 +429,14 @@ def replace_diff(expr, replacement_dict):
def zero_diffs(expr, label): def zero_diffs(expr, label):
"""Replaces all differentials with the given target by 0""" """Replaces all differentials with the given target by 0"""
def visit(e): def visit(e):
if isinstance(e, Diff): if isinstance(e, Diff):
if e.target == label: if e.target == label:
return 0 return 0
new_args = [visit(arg) for arg in e.args] new_args = [visit(arg) for arg in e.args]
return e.func(*new_args) if new_args else e return e.func(*new_args) if new_args else e
return visit(expr) return visit(expr)
......
...@@ -37,7 +37,7 @@ def show_code(ast: KernelFunction): ...@@ -37,7 +37,7 @@ def show_code(ast: KernelFunction):
Can either be displayed as HTML in Jupyter notebooks or printed as normal string. Can either be displayed as HTML in Jupyter notebooks or printed as normal string.
""" """
from pystencils.cpu import generate_c from pystencils.backends.cbackend import generate_c
class CodeDisplay: class CodeDisplay:
def __init__(self, ast_input): def __init__(self, ast_input):
......
...@@ -5,8 +5,6 @@ import numpy as np ...@@ -5,8 +5,6 @@ import numpy as np
import sympy as sp import sympy as sp
from sympy.core.cache import cacheit from sympy.core.cache import cacheit
from sympy.tensor import IndexedBase from sympy.tensor import IndexedBase
from pystencils.assignment import Assignment
from pystencils.alignedarray import aligned_empty from pystencils.alignedarray import aligned_empty
from pystencils.data_types import TypedSymbol, create_type, create_composite_type_from_string, StructType from pystencils.data_types import TypedSymbol, create_type, create_composite_type_from_string, StructType
from pystencils.sympyextensions import is_integer_sequence from pystencils.sympyextensions import is_integer_sequence
...@@ -69,6 +67,7 @@ class Field(object): ...@@ -69,6 +67,7 @@ class Field(object):
>>> jacobi = ( f[-1,0] + f[1,0] + f[0,-1] + f[0,1] ) / 4 >>> jacobi = ( f[-1,0] + f[1,0] + f[0,-1] + f[0,1] ) / 4
Example with index dimensions: LBM D2Q9 stream pull Example with index dimensions: LBM D2Q9 stream pull
>>> from pystencils import Assignment
>>> stencil = np.array([[0,0], [0,1], [0,-1]]) >>> stencil = np.array([[0,0], [0,1], [0,-1]])
>>> src = Field.create_generic("src", spatial_dimensions=2, index_dimensions=1) >>> src = Field.create_generic("src", spatial_dimensions=2, index_dimensions=1)
>>> dst = Field.create_generic("dst", spatial_dimensions=2, index_dimensions=1) >>> dst = Field.create_generic("dst", spatial_dimensions=2, index_dimensions=1)
...@@ -366,7 +365,7 @@ class Field(object): ...@@ -366,7 +365,7 @@ class Field(object):
__xnew_cached_ = staticmethod(cacheit(__new_stage2__)) __xnew_cached_ = staticmethod(cacheit(__new_stage2__))
def __call__(self, *idx): def __call__(self, *idx):
if self._index != tuple([0]*self.field.index_dimensions): if self._index != tuple([0] * self.field.index_dimensions):
raise ValueError("Indexing an already indexed Field.Access") raise ValueError("Indexing an already indexed Field.Access")
idx = tuple(idx) idx = tuple(idx)
...@@ -520,7 +519,7 @@ def layout_string_to_tuple(layout_str, dim): ...@@ -520,7 +519,7 @@ def layout_string_to_tuple(layout_str, dim):
return tuple(reversed(range(dim))) return tuple(reversed(range(dim)))
elif layout_str == 'zyxf' or layout_str == 'aos': elif layout_str == 'zyxf' or layout_str == 'aos':
assert dim <= 4 assert dim <= 4
return tuple(reversed(range(dim - 1))) + (dim-1,) return tuple(reversed(range(dim - 1))) + (dim - 1,)
elif layout_str == 'f' or layout_str == 'reverse_numpy': elif layout_str == 'f' or layout_str == 'reverse_numpy':
return tuple(reversed(range(dim))) return tuple(reversed(range(dim)))
elif layout_str == 'c' or layout_str == 'numpy': elif layout_str == 'c' or layout_str == 'numpy':
......
...@@ -103,7 +103,7 @@ def discretize_staggered(term, symbols_to_field_dict, coordinate, coordinate_off ...@@ -103,7 +103,7 @@ def discretize_staggered(term, symbols_to_field_dict, coordinate, coordinate_off
up, down = __up_down_offsets(d, dim) up, down = __up_down_offsets(d, dim)
for i, s in enumerate(symbols): for i, s in enumerate(symbols):
center_grad = (field[up](i) - field[down](i)) / (2 * dx) center_grad = (field[up](i) - field[down](i)) / (2 * dx)
neighbor_grad = (field[up+offset](i) - field[down+offset](i)) / (2 * dx) neighbor_grad = (field[up + offset](i) - field[down + offset](i)) / (2 * dx)
substitutions[grad(s)[d]] = (center_grad + neighbor_grad) / 2 substitutions[grad(s)[d]] = (center_grad + neighbor_grad) / 2
return fast_subs(term, substitutions) return fast_subs(term, substitutions)
...@@ -170,9 +170,9 @@ class Advection(sp.Function): ...@@ -170,9 +170,9 @@ class Advection(sp.Function):
name_suffix = "_%s" % self.scalar_index if self.scalar_index is not None else "" name_suffix = "_%s" % self.scalar_index if self.scalar_index is not None else ""
if isinstance(self.vector, Field): if isinstance(self.vector, Field):
return r"\nabla \cdot(%s %s)" % (printer.doprint(sp.Symbol(self.vector.name)), return r"\nabla \cdot(%s %s)" % (printer.doprint(sp.Symbol(self.vector.name)),
printer.doprint(sp.Symbol(self.scalar.name+name_suffix))) printer.doprint(sp.Symbol(self.scalar.name + name_suffix)))
else: else:
args = [r"\partial_%d(%s %s)" % (i, printer.doprint(sp.Symbol(self.scalar.name+name_suffix)), args = [r"\partial_%d(%s %s)" % (i, printer.doprint(sp.Symbol(self.scalar.name + name_suffix)),
printer.doprint(self.vector[i])) printer.doprint(self.vector[i]))
for i in range(self.dim)] for i in range(self.dim)]
return " + ".join(args) return " + ".join(args)
...@@ -233,7 +233,7 @@ class Diffusion(sp.Function): ...@@ -233,7 +233,7 @@ class Diffusion(sp.Function):
coeff = self.diffusion_coeff coeff = self.diffusion_coeff
diff_coeff = sp.Symbol(coeff.name) if isinstance(coeff, Field) else coeff diff_coeff = sp.Symbol(coeff.name) if isinstance(coeff, Field) else coeff
return r"div(%s \nabla %s)" % (printer.doprint(diff_coeff), return r"div(%s \nabla %s)" % (printer.doprint(diff_coeff),
printer.doprint(sp.Symbol(self.scalar.name+name_suffix))) printer.doprint(sp.Symbol(self.scalar.name + name_suffix)))
# --- Interface for discretization strategy # --- Interface for discretization strategy
...@@ -277,7 +277,7 @@ class Transient(sp.Function): ...@@ -277,7 +277,7 @@ class Transient(sp.Function):
def _latex(self, printer): def _latex(self, printer):
name_suffix = "_%s" % self.scalar_index if self.scalar_index is not None else "" name_suffix = "_%s" % self.scalar_index if self.scalar_index is not None else ""
return r"\partial_t %s" % (printer.doprint(sp.Symbol(self.scalar.name+name_suffix)),) return r"\partial_t %s" % (printer.doprint(sp.Symbol(self.scalar.name + name_suffix)),)
def transient(scalar, idx=None): def transient(scalar, idx=None):
...@@ -312,7 +312,7 @@ class Discretization2ndOrder: ...@@ -312,7 +312,7 @@ class Discretization2ndOrder:
- expr.diffusion_scalar_at_offset(0, 0) * expr.diffusion_coefficient_at_offset(0, 0)) - expr.diffusion_scalar_at_offset(0, 0) * expr.diffusion_coefficient_at_offset(0, 0))
for offset in [-1, 1]] for offset in [-1, 1]]
result += first_diffs[1] - first_diffs[0] result += first_diffs[1] - first_diffs[0]
return result / (self.dx**2) return result / (self.dx ** 2)