From bad8ff83ca3af3824f82b6e525271cfa7a5fd47c Mon Sep 17 00:00:00 2001 From: Markus Holzer <markus.holzer@fau.de> Date: Mon, 19 Feb 2024 10:27:28 +0100 Subject: [PATCH] More refactoring --- conftest.py | 3 +- src/pystencils/__init__.py | 7 +- .../backend/kernelcreation/__init__.py | 2 - .../backend/kernelcreation/analysis.py | 3 +- .../backend/kernelcreation/context.py | 4 +- .../backend/kernelcreation/defaults.py | 2 +- .../backend/kernelcreation/freeze.py | 5 +- .../backend/kernelcreation/iteration_space.py | 2 +- src/pystencils/config.py | 3 +- .../datahandling/serial_datahandling.py | 3 +- src/pystencils/display_utils.py | 5 +- src/pystencils/fd/derivation.py | 2 +- src/pystencils/fd/finitedifferences.py | 4 +- src/pystencils/fd/spatial.py | 6 +- src/pystencils/kernel_decorator.py | 4 +- src/pystencils/kernel_wrapper.py | 2 +- src/pystencils/kernelcreation.py | 2 +- src/pystencils/spatial_coordinates.py | 13 +- src/pystencils/sympyextensions/__init__.py | 5 +- src/pystencils/sympyextensions/assignment.py | 104 -- .../sympyextensions/assignment_collection.py | 474 -------- src/pystencils/sympyextensions/astnodes.py | 1056 ++++++++--------- src/pystencils/sympyextensions/bit_masks.py | 1 - src/pystencils/sympyextensions/math.py | 2 +- .../sympyextensions/simplifications.py | 15 +- .../sympyextensions/simplificationstrategy.py | 2 +- src/pystencils/sympyextensions/typed_sympy.py | 91 +- 27 files changed, 619 insertions(+), 1203 deletions(-) delete mode 100644 src/pystencils/sympyextensions/assignment.py delete mode 100644 src/pystencils/sympyextensions/assignment_collection.py diff --git a/conftest.py b/conftest.py index 2fd7f8e81..ca7c153b4 100644 --- a/conftest.py +++ b/conftest.py @@ -10,7 +10,8 @@ from nbconvert import PythonExporter # Trigger config file reading / creation once - to avoid race conditions when multiple instances are creating it # at the same time -from pystencils.cpu import cpujit +# TODO: replace with new backend +# from pystencils.cpu import cpujit # trigger cython imports - there seems to be a problem when multiple processes try to compile the same cython file # at the same time diff --git a/src/pystencils/__init__.py b/src/pystencils/__init__.py index 161530f88..51556f799 100644 --- a/src/pystencils/__init__.py +++ b/src/pystencils/__init__.py @@ -2,16 +2,17 @@ from .enums import Backend, Target from . import fd from . import stencil as stencil -from pystencils.sympyextensions.assignmentcollection.assignment import Assignment, AddAugmentedAssignment, assignment_from_stencil from .display_utils import get_code_obj, get_code_str, show_code, to_dot from .field import Field, FieldType, fields from .config import CreateKernelConfig from .cache import clear_cache from .kernel_decorator import kernel, kernel_config from .kernelcreation import create_kernel -from pystencils.sympyextensions.assignmentcollection import AssignmentCollection from .slicing import make_slice from .spatial_coordinates import x_, x_staggered, x_staggered_vector, x_vector, y_, y_staggered, z_, z_staggered +from .sympyextensions import Assignment, AssignmentCollection, AddAugmentedAssignment +from .sympyextensions.astnodes import assignment_from_stencil +from .sympyextensions.typed_sympy import TypedSymbol from .sympyextensions.math import SymbolCreator from .datahandling import create_data_handling @@ -19,7 +20,7 @@ __all__ = ['Field', 'FieldType', 'fields', 'TypedSymbol', 'make_slice', 'CreateKernelConfig', - 'create_kernel', 'create_staggered_kernel', + 'create_kernel', 'Target', 'Backend', 'show_code', 'to_dot', 'get_code_obj', 'get_code_str', 'AssignmentCollection', diff --git a/src/pystencils/backend/kernelcreation/__init__.py b/src/pystencils/backend/kernelcreation/__init__.py index 749493059..6ac30e9ee 100644 --- a/src/pystencils/backend/kernelcreation/__init__.py +++ b/src/pystencils/backend/kernelcreation/__init__.py @@ -99,7 +99,6 @@ It is furthermore annotated with constraints collected during the translation, a """ from .config import CreateKernelConfig -from .kernelcreation import create_kernel from .context import KernelCreationContext from .analysis import KernelAnalysis @@ -115,7 +114,6 @@ from .iteration_space import ( __all__ = [ "CreateKernelConfig", - "create_kernel", "KernelCreationContext", "KernelAnalysis", "FreezeExpressions", diff --git a/src/pystencils/backend/kernelcreation/analysis.py b/src/pystencils/backend/kernelcreation/analysis.py index 1b498427d..b0267a3d6 100644 --- a/src/pystencils/backend/kernelcreation/analysis.py +++ b/src/pystencils/backend/kernelcreation/analysis.py @@ -9,8 +9,7 @@ import sympy as sp from .context import KernelCreationContext from ...field import Field -from pystencils.sympyextensions.assignmentcollection.assignment import Assignment -from ...simp import AssignmentCollection +from ...sympyextensions import Assignment, AssignmentCollection from ..exceptions import PsInternalCompilerError, KernelConstraintsError diff --git a/src/pystencils/backend/kernelcreation/context.py b/src/pystencils/backend/kernelcreation/context.py index d2cddb142..2080efed3 100644 --- a/src/pystencils/backend/kernelcreation/context.py +++ b/src/pystencils/backend/kernelcreation/context.py @@ -1,14 +1,14 @@ from __future__ import annotations from ...field import Field, FieldType -from ...typing import TypedSymbol, BasicType, StructType - +from ...sympyextensions.typed_sympy import TypedSymbol, BasicType, StructType from ..arrays import PsLinearizedArray from ..types import PsIntegerType from ..types.quick import make_type from ..constraints import PsKernelConstraint from ..exceptions import PsInternalCompilerError, KernelConstraintsError + from .config import CreateKernelConfig from .iteration_space import IterationSpace, FullIterationSpace, SparseIterationSpace diff --git a/src/pystencils/backend/kernelcreation/defaults.py b/src/pystencils/backend/kernelcreation/defaults.py index fc0a602a1..b1822dc77 100644 --- a/src/pystencils/backend/kernelcreation/defaults.py +++ b/src/pystencils/backend/kernelcreation/defaults.py @@ -20,7 +20,7 @@ from typing import TypeVar, Generic, Callable from ..types import PsAbstractType, PsSignedIntegerType, PsStructType from ..typed_expressions import PsTypedVariable -from ...typing import TypedSymbol +from pystencils.sympyextensions.typed_sympy import TypedSymbol SymbolT = TypeVar("SymbolT") diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index 06d6f1adc..94b660fd9 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -4,10 +4,9 @@ import sympy as sp import pymbolic.primitives as pb from pymbolic.interop.sympy import SympyToPymbolicMapper -from pystencils.sympyextensions.assignmentcollection.assignment import Assignment -from ...simp import AssignmentCollection +from ...sympyextensions import Assignment, AssignmentCollection +from ...sympyextensions.typed_sympy import BasicType from ...field import Field, FieldType -from ...typing import BasicType from .context import KernelCreationContext diff --git a/src/pystencils/backend/kernelcreation/iteration_space.py b/src/pystencils/backend/kernelcreation/iteration_space.py index a739893c0..0e5828ef1 100644 --- a/src/pystencils/backend/kernelcreation/iteration_space.py +++ b/src/pystencils/backend/kernelcreation/iteration_space.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from functools import reduce from operator import mul -from ...simp import AssignmentCollection +from ...sympyextensions import AssignmentCollection from ...field import Field, FieldType from ..typed_expressions import ( diff --git a/src/pystencils/config.py b/src/pystencils/config.py index 5fff3b799..fdaed6665 100644 --- a/src/pystencils/config.py +++ b/src/pystencils/config.py @@ -5,8 +5,7 @@ from types import MappingProxyType from typing import Union, Tuple, List, Dict, Callable, Any, DefaultDict, Iterable from pystencils import Target, Backend, Field -from pystencils.typing.typed_sympy import BasicType -from pystencils.typing.utilities import collate_types +from .sympyextensions.typed_sympy import BasicType import numpy as np diff --git a/src/pystencils/datahandling/serial_datahandling.py b/src/pystencils/datahandling/serial_datahandling.py index 0f5ddb431..e01db8dbb 100644 --- a/src/pystencils/datahandling/serial_datahandling.py +++ b/src/pystencils/datahandling/serial_datahandling.py @@ -9,7 +9,8 @@ from pystencils.datahandling.datahandling_interface import DataHandling from pystencils.enums import Target from pystencils.field import (Field, FieldType, create_numpy_array_with_layout, layout_string_to_tuple, spatial_layout_string_to_tuple) -from pystencils.gpu.gpu_array_handler import GPUArrayHandler, GPUNotAvailableHandler +# TODO replace with platform +# from pystencils.gpu.gpu_array_handler import GPUArrayHandler, GPUNotAvailableHandler from pystencils.slicing import normalize_slice, remove_ghost_layers from pystencils.utils import DotDict diff --git a/src/pystencils/display_utils.py b/src/pystencils/display_utils.py index 2fba12da6..bce5d493c 100644 --- a/src/pystencils/display_utils.py +++ b/src/pystencils/display_utils.py @@ -2,7 +2,6 @@ from typing import Any, Dict, Optional, Union import sympy as sp -from pystencils.sympyextensions.astnodes import KernelFunction from pystencils.enums import Backend from pystencils.kernel_wrapper import KernelWrapper @@ -41,7 +40,7 @@ def highlight_cpp(code: str): return HTML(highlight(code, CppLexer(), HtmlFormatter())) -def get_code_obj(ast: Union[KernelFunction, KernelWrapper], custom_backend=None): +def get_code_obj(ast: Union[KernelWrapper], custom_backend=None): """Returns an object to display generated code (C/C++ or CUDA) Can either be displayed as HTML in Jupyter notebooks or printed as normal string. @@ -87,7 +86,7 @@ def _isnotebook(): return False -def show_code(ast: Union[KernelFunction, KernelWrapper], custom_backend=None): +def show_code(ast: Union[KernelWrapper], custom_backend=None): code = get_code_obj(ast, custom_backend) if _isnotebook(): diff --git a/src/pystencils/fd/derivation.py b/src/pystencils/fd/derivation.py index bea0a9674..f739c8f02 100644 --- a/src/pystencils/fd/derivation.py +++ b/src/pystencils/fd/derivation.py @@ -6,7 +6,7 @@ import sympy as sp from pystencils.field import Field from pystencils.stencil import direction_string_to_offset -from pystencils.sympyextensions import multidimensional_sum, prod +from pystencils.sympyextensions.math import multidimensional_sum, prod from pystencils.utils import LinearEquationSystem, fully_contains diff --git a/src/pystencils/fd/finitedifferences.py b/src/pystencils/fd/finitedifferences.py index e8ff2c80c..9c4116ee5 100644 --- a/src/pystencils/fd/finitedifferences.py +++ b/src/pystencils/fd/finitedifferences.py @@ -7,8 +7,8 @@ from pystencils.fd import Diff from pystencils.fd.derivative import diff_args from pystencils.fd.spatial import fd_stencils_standard from pystencils.field import Field -from pystencils.simp.assignment_collection import AssignmentCollection -from pystencils.sympyextensions import fast_subs +from pystencils.sympyextensions import AssignmentCollection +from pystencils.sympyextensions.math import fast_subs FieldOrFieldAccess = Union[Field, Field.Access] diff --git a/src/pystencils/fd/spatial.py b/src/pystencils/fd/spatial.py index 17f708ed8..1bec75726 100644 --- a/src/pystencils/fd/spatial.py +++ b/src/pystencils/fd/spatial.py @@ -3,10 +3,10 @@ from typing import Tuple import sympy as sp -from pystencils.sympyextensions.astnodes import LoopOverCoordinate from pystencils.fd import Diff from pystencils.field import Field -from pystencils.sympyextensions import generic_visit +from pystencils.sympyextensions.astnodes import generic_visit +from pystencils.sympyextensions.typed_sympy import is_loop_counter_symbol from .derivation import FiniteDifferenceStencilDerivation from .derivative import diff_args @@ -112,7 +112,7 @@ def discretize_spatial_staggered(expr, dx, stencil=fd_stencils_standard): elif isinstance(e, Field.Access): return (e.neighbor(coordinate, sign) + e) / 2 elif isinstance(e, sp.Symbol): - loop_idx = LoopOverCoordinate.is_loop_counter_symbol(e) + loop_idx = is_loop_counter_symbol(e) return e + sign / 2 if loop_idx == coordinate else e else: new_args = [staggered_visitor(a, coordinate, sign) for a in e.args] diff --git a/src/pystencils/kernel_decorator.py b/src/pystencils/kernel_decorator.py index a8246de07..deb94eec0 100644 --- a/src/pystencils/kernel_decorator.py +++ b/src/pystencils/kernel_decorator.py @@ -5,8 +5,8 @@ from typing import Callable, Union, List, Dict, Tuple import sympy as sp -from pystencils.sympyextensions.assignmentcollection.assignment import Assignment -from pystencils.sympyextensions import SymbolCreator +from .sympyextensions import Assignment +from .sympyextensions.math import SymbolCreator from pystencils.config import CreateKernelConfig __all__ = ['kernel', 'kernel_config'] diff --git a/src/pystencils/kernel_wrapper.py b/src/pystencils/kernel_wrapper.py index b76ec4e3d..d5dfbecca 100644 --- a/src/pystencils/kernel_wrapper.py +++ b/src/pystencils/kernel_wrapper.py @@ -8,7 +8,7 @@ class KernelWrapper: Can be called while still providing access to underlying AST. """ - def __init__(self, kernel, parameters, ast_node: pystencils.astnodes.KernelFunction): + def __init__(self, kernel, parameters, ast_node): self.kernel = kernel self.parameters = parameters self.ast = ast_node diff --git a/src/pystencils/kernelcreation.py b/src/pystencils/kernelcreation.py index 52c70fbfb..fb4d6abfb 100644 --- a/src/pystencils/kernelcreation.py +++ b/src/pystencils/kernelcreation.py @@ -8,7 +8,7 @@ from .backend.kernelcreation.transformations import EraseAnonymousStructTypes from .enums import Target from .config import CreateKernelConfig -from pystencils.sympyextensions.assignmentcollection import AssignmentCollection +from .sympyextensions import AssignmentCollection def create_kernel( diff --git a/src/pystencils/spatial_coordinates.py b/src/pystencils/spatial_coordinates.py index a8be92c94..cc244b11c 100644 --- a/src/pystencils/spatial_coordinates.py +++ b/src/pystencils/spatial_coordinates.py @@ -1,19 +1,14 @@ - import sympy +from pystencils.sympyextensions.typed_sympy import get_loop_counter_symbol -import pystencils -import pystencils.sympyextensions.astnodes -x_, y_, z_ = tuple(pystencils.sympyextensions.astnodes.LoopOverCoordinate.get_loop_counter_symbol(i) for i in range(3)) +x_, y_, z_ = tuple(get_loop_counter_symbol(i) for i in range(3)) x_staggered, y_staggered, z_staggered = x_ + 0.5, y_ + 0.5, z_ + 0.5 def x_vector(ndim): - return sympy.Matrix(tuple( - pystencils.sympyextensions.astnodes.LoopOverCoordinate.get_loop_counter_symbol(i) for i in range(ndim))) + return sympy.Matrix(tuple(get_loop_counter_symbol(i) for i in range(ndim))) def x_staggered_vector(ndim): - return sympy.Matrix(tuple( - pystencils.sympyextensions.astnodes.LoopOverCoordinate.get_loop_counter_symbol(i) + 0.5 for i in range(ndim) - )) + return sympy.Matrix(tuple(get_loop_counter_symbol(i) + 0.5 for i in range(ndim))) diff --git a/src/pystencils/sympyextensions/__init__.py b/src/pystencils/sympyextensions/__init__.py index d61e705d1..41d43d83c 100644 --- a/src/pystencils/sympyextensions/__init__.py +++ b/src/pystencils/sympyextensions/__init__.py @@ -1,5 +1,4 @@ -from .assignment import Assignment, AugmentedAssignment, AddAugmentedAssignment, assignment_from_stencil -from pystencils.sympyextensions.assignment_collection import AssignmentCollection +from .astnodes import Assignment, AugmentedAssignment, AddAugmentedAssignment, AssignmentCollection from .simplificationstrategy import SimplificationStrategy from .simplifications import (sympy_cse, sympy_cse_on_assignment_list, apply_to_all_assignments, apply_on_all_subexpressions, subexpression_substitution_in_existing_subexpressions, @@ -12,7 +11,7 @@ from .subexpression_insertion import ( insert_squares, insert_symbol_times_minus_one) -__all__ = ['Assignment', 'AugmentedAssignment', 'AddAugmentedAssignment', 'assignment_from_stencil', +__all__ = ['Assignment', 'AugmentedAssignment', 'AddAugmentedAssignment', 'AssignmentCollection', 'SimplificationStrategy', 'sympy_cse', 'sympy_cse_on_assignment_list', 'apply_to_all_assignments', 'apply_on_all_subexpressions', 'subexpression_substitution_in_existing_subexpressions', diff --git a/src/pystencils/sympyextensions/assignment.py b/src/pystencils/sympyextensions/assignment.py deleted file mode 100644 index 591bde9a5..000000000 --- a/src/pystencils/sympyextensions/assignment.py +++ /dev/null @@ -1,104 +0,0 @@ -import numpy as np -import sympy as sp -from sympy.codegen.ast import Assignment, AugmentedAssignment, AddAugmentedAssignment -from sympy.printing.latex import LatexPrinter - - -def print_assignment_latex(printer, expr): - binop = f"{expr.binop}=" if isinstance(expr, AugmentedAssignment) else '' - """sympy cannot print Assignments as Latex. Thus, this function is added to the sympy Latex printer""" - printed_lhs = printer.doprint(expr.lhs) - printed_rhs = printer.doprint(expr.rhs) - return fr"{printed_lhs} \leftarrow_{{{binop}}} {printed_rhs}" - - -def assignment_str(assignment): - op = f"{assignment.binop}=" if isinstance(assignment, AugmentedAssignment) else 'â†' - return fr"{assignment.lhs} {op} {assignment.rhs}" - - -_old_new = sp.codegen.ast.Assignment.__new__ - - -# TODO Typing Part2 add default type, defult_float_type, default_int_type and use sane defaults -def _Assignment__new__(cls, lhs, rhs, *args, **kwargs): - if isinstance(lhs, (list, tuple, sp.Matrix)) and isinstance(rhs, (list, tuple, sp.Matrix)): - assert len(lhs) == len(rhs), f'{lhs} and {rhs} must have same length when performing vector assignment!' - return tuple(_old_new(cls, a, b, *args, **kwargs) for a, b in zip(lhs, rhs)) - return _old_new(cls, lhs, rhs, *args, **kwargs) - - -Assignment.__str__ = assignment_str -Assignment.__new__ = _Assignment__new__ -LatexPrinter._print_Assignment = print_assignment_latex - -AugmentedAssignment.__str__ = assignment_str -LatexPrinter._print_AugmentedAssignment = print_assignment_latex - -sp.MutableDenseMatrix.__hash__ = lambda self: hash(tuple(self)) - - -def assignment_from_stencil(stencil_array, input_field, output_field, - normalization_factor=None, order='visual') -> Assignment: - """Creates an assignment - - Args: - stencil_array: nested list of numpy array defining the stencil weights - input_field: field or field access, defining where the stencil should be applied to - output_field: field or field access where the result is written to - normalization_factor: optional normalization factor for the stencil - order: defines how the stencil_array is interpreted. Possible values are 'visual' and 'numpy'. - For details see examples - - Returns: - Assignment that can be used to create a kernel - - Examples: - >>> import pystencils as ps - >>> f, g = ps.fields("f, g: [2D]") - >>> stencil = [[0, 2, 0], - ... [3, 4, 5], - ... [0, 6, 0]] - - By default 'visual ordering is used - i.e. the stencil is applied as the nested lists are written down - >>> expected_output = Assignment(g[0, 0], 3*f[-1, 0] + 6*f[0, -1] + 4*f[0, 0] + 2*f[0, 1] + 5*f[1, 0]) - >>> assignment_from_stencil(stencil, f, g, order='visual') == expected_output - True - - 'numpy' ordering uses the first coordinate of the stencil array for x offset, second for y offset etc. - >>> expected_output = Assignment(g[0, 0], 2*f[-1, 0] + 3*f[0, -1] + 4*f[0, 0] + 5*f[0, 1] + 6*f[1, 0]) - >>> assignment_from_stencil(stencil, f, g, order='numpy') == expected_output - True - - You can also pass field accesses to apply the stencil at an already shifted position: - >>> expected_output = Assignment(g[2, 0], 3*f[0, 0] + 6*f[1, -1] + 4*f[1, 0] + 2*f[1, 1] + 5*f[2, 0]) - >>> assignment_from_stencil(stencil, f[1, 0], g[2, 0]) == expected_output - True - """ - from pystencils.field import Field - - stencil_array = np.array(stencil_array) - if order == 'visual': - stencil_array = np.swapaxes(stencil_array, 0, 1) - stencil_array = np.flip(stencil_array, axis=1) - elif order == 'numpy': - pass - else: - raise ValueError("'order' has to be either 'visual' or 'numpy'") - - if isinstance(input_field, Field): - input_field = input_field.center - if isinstance(output_field, Field): - output_field = output_field.center - - rhs = 0 - offset = tuple(s // 2 for s in stencil_array.shape) - - for index, factor in np.ndenumerate(stencil_array): - shift = tuple(i - o for i, o in zip(index, offset)) - rhs += factor * input_field.get_shifted(*shift) - - if normalization_factor: - rhs *= normalization_factor - - return Assignment(output_field, rhs) diff --git a/src/pystencils/sympyextensions/assignment_collection.py b/src/pystencils/sympyextensions/assignment_collection.py deleted file mode 100644 index de8f7b4e5..000000000 --- a/src/pystencils/sympyextensions/assignment_collection.py +++ /dev/null @@ -1,474 +0,0 @@ -import itertools -from copy import copy -from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set, Union - -import sympy as sp - -import pystencils -from .assignment import Assignment -from .simplifications import (sort_assignments_topologically, transform_lhs_and_rhs, transform_rhs) -from pystencils.sympyextensions.math import count_operations, fast_subs - - -class AssignmentCollection: - """ - A collection of equations with subexpression definitions, also represented as assignments, - that are used in the main equations. AssignmentCollection can be passed to simplification methods. - These simplification methods can change the subexpressions, but the number and - left hand side of the main equations themselves is not altered. - Additionally a dictionary of simplification hints is stored, which are set by the functions that create - assignment collections to transport information to the simplification system. - - Args: - main_assignments: List of assignments. Main assignments are characterised, that the right hand side of each - assignment is a field access. Thus the generated equations write on arrays. - subexpressions: List of assignments defining subexpressions used in main equations - simplification_hints: Dict that is used to annotate the assignment collection with hints that are - used by the simplification system. See documentation of the simplification rules for - potentially required hints and their meaning. - subexpression_symbol_generator: Generator for new symbols that are used when new subexpressions are added - used to get new symbols that are unique for this AssignmentCollection - - """ - - __match_args__ = ("main_assignments", "subexpressions") - - # ------------------------------- Creation & Inplace Manipulation -------------------------------------------------- - - def __init__(self, main_assignments: Union[List[Assignment], Dict[sp.Expr, sp.Expr]], - subexpressions: Union[List[Assignment], Dict[sp.Expr, sp.Expr]] = None, - simplification_hints: Optional[Dict[str, Any]] = None, - subexpression_symbol_generator: Iterator[sp.Symbol] = None) -> None: - - if subexpressions is None: - subexpressions = {} - - if isinstance(main_assignments, Dict): - main_assignments = [Assignment(k, v) - for k, v in main_assignments.items()] - if isinstance(subexpressions, Dict): - subexpressions = [Assignment(k, v) - for k, v in subexpressions.items()] - - main_assignments = list(itertools.chain.from_iterable( - [(a if isinstance(a, Iterable) else [a]) for a in main_assignments])) - subexpressions = list(itertools.chain.from_iterable( - [(a if isinstance(a, Iterable) else [a]) for a in subexpressions])) - - self.main_assignments = main_assignments - self.subexpressions = subexpressions - - if simplification_hints is None: - simplification_hints = {} - - self.simplification_hints = simplification_hints - - ctrs = [int(n.name[3:])for n in self.rhs_symbols if "xi_" in n.name] - max_ctr = max(ctrs) + 1 if len(ctrs) > 0 else 0 - - if subexpression_symbol_generator is None: - self.subexpression_symbol_generator = SymbolGen(ctr=max_ctr) - else: - self.subexpression_symbol_generator = subexpression_symbol_generator - - def add_simplification_hint(self, key: str, value: Any) -> None: - """Adds an entry to the simplification_hints dictionary and checks that is does not exist yet.""" - assert key not in self.simplification_hints, "This hint already exists" - self.simplification_hints[key] = value - - def add_subexpression(self, rhs: sp.Expr, lhs: Optional[sp.Symbol] = None, topological_sort=True) -> sp.Symbol: - """Adds a subexpression to current collection. - - Args: - rhs: right hand side of new subexpression - lhs: optional left hand side of new subexpression. If None a new unique symbol is generated. - topological_sort: sort the subexpressions topologically after insertion, to make sure that - definition of a symbol comes before its usage. If False, subexpression is appended. - - Returns: - left hand side symbol (which could have been generated) - """ - if lhs is None: - lhs = next(self.subexpression_symbol_generator) - eq = Assignment(lhs, rhs) - self.subexpressions.append(eq) - if topological_sort: - self.topological_sort(sort_subexpressions=True, - sort_main_assignments=False) - return lhs - - def topological_sort(self, sort_subexpressions: bool = True, sort_main_assignments: bool = True) -> None: - """Sorts subexpressions and/or main_equations topologically to make sure symbol usage comes after definition.""" - if sort_subexpressions: - self.subexpressions = sort_assignments_topologically(self.subexpressions) - if sort_main_assignments: - self.main_assignments = sort_assignments_topologically(self.main_assignments) - - # ---------------------------------------------- Properties ------------------------------------------------------- - - @property - def all_assignments(self) -> List[Assignment]: - """Subexpression and main equations as a single list.""" - return self.subexpressions + self.main_assignments - - @property - def rhs_symbols(self) -> Set[sp.Symbol]: - """All symbols used in the assignment collection, which occur on the rhs of any assignment.""" - rhs_symbols = set() - for eq in self.all_assignments: - if isinstance(eq, Assignment): - rhs_symbols.update(eq.rhs.atoms(sp.Symbol)) - elif isinstance(eq, pystencils.astnodes.Node): - rhs_symbols.update(eq.undefined_symbols) - - return rhs_symbols - - @property - def free_symbols(self) -> Set[sp.Symbol]: - """All symbols used in the assignment collection, which do not occur as left hand sides in any assignment.""" - return self.rhs_symbols - self.bound_symbols - - @property - def bound_symbols(self) -> Set[sp.Symbol]: - """All symbols which occur on the left hand side of a main assignment or a subexpression.""" - bound_symbols_set = set( - [assignment.lhs for assignment in self.all_assignments if isinstance(assignment, Assignment)] - ) - - assert len(bound_symbols_set) == len(list(a for a in self.all_assignments if isinstance(a, Assignment))), \ - "Not in SSA form - same symbol assigned multiple times" - - bound_symbols_set = bound_symbols_set.union(*[ - assignment.symbols_defined for assignment in self.all_assignments - if isinstance(assignment, pystencils.astnodes.Node) - ]) - - return bound_symbols_set - - @property - def rhs_fields(self): - """All fields accessed in the assignment collection, which do not occur as left hand sides in any assignment.""" - return {s.field for s in self.rhs_symbols if hasattr(s, 'field')} - - @property - def free_fields(self): - """All fields accessed in the assignment collection, which do not occur as left hand sides in any assignment.""" - return {s.field for s in self.free_symbols if hasattr(s, 'field')} - - @property - def bound_fields(self): - """All field accessed on the left hand side of a main assignment or a subexpression.""" - return {s.field for s in self.bound_symbols if hasattr(s, 'field')} - - @property - def defined_symbols(self) -> Set[sp.Symbol]: - """All symbols which occur as left-hand-sides of one of the main equations""" - lhs_set = set([assignment.lhs for assignment in self.main_assignments if isinstance(assignment, Assignment)]) - return (lhs_set.union(*[assignment.symbols_defined for assignment in self.main_assignments - if isinstance(assignment, pystencils.astnodes.Node)])) - - @property - def operation_count(self): - """See :func:`count_operations` """ - return count_operations(self.all_assignments, only_type=None) - - def atoms(self, *args): - return set().union(*[a.atoms(*args) for a in self.all_assignments]) - - def dependent_symbols(self, symbols: Iterable[sp.Symbol]) -> Set[sp.Symbol]: - """Returns all symbols that depend on one of the passed symbols. - - A symbol 'a' depends on a symbol 'b', if there is an assignment 'a <- some_expression(b)' i.e. when - 'b' is required to compute 'a'. - """ - - queue = list(symbols) - - def add_symbols_from_expr(expr): - dependent_symbols = expr.atoms(sp.Symbol) - for ds in dependent_symbols: - queue.append(ds) - - handled_symbols = set() - assignment_dict = {e.lhs: e.rhs for e in self.all_assignments} - - while len(queue) > 0: - e = queue.pop(0) - if e in handled_symbols: - continue - if e in assignment_dict: - add_symbols_from_expr(assignment_dict[e]) - handled_symbols.add(e) - - return handled_symbols - - def lambdify(self, symbols: Sequence[sp.Symbol], fixed_symbols: Optional[Dict[sp.Symbol, Any]] = None, module=None): - """Returns a python function to evaluate this equation collection. - - Args: - symbols: symbol(s) which are the parameter for the created function - fixed_symbols: dictionary with substitutions, that are applied before sympy's lambdify - module: same as sympy.lambdify parameter. Defines which module to use e.g. 'numpy' - - Examples: - >>> a, b, c, d = sp.symbols("a b c d") - >>> ac = AssignmentCollection([Assignment(c, a + b), Assignment(d, a**2 + b)], - ... subexpressions=[Assignment(b, a + b / 2)]) - >>> python_function = ac.lambdify([a], fixed_symbols={b: 2}) - >>> python_function(4) - {c: 6, d: 18} - """ - assignments = self.new_with_substitutions(fixed_symbols, substitute_on_lhs=False) if fixed_symbols else self - assignments = assignments.new_without_subexpressions().main_assignments - lambdas = {assignment.lhs: sp.lambdify(symbols, assignment.rhs, module) for assignment in assignments} - - def f(*args, **kwargs): - return {s: func(*args, **kwargs) for s, func in lambdas.items()} - - return f - - # ---------------------------- Creating new modified collections --------------------------------------------------- - - def copy(self, - main_assignments: Optional[List[Assignment]] = None, - subexpressions: Optional[List[Assignment]] = None) -> 'AssignmentCollection': - """Returns a copy with optionally replaced main_assignments and/or subexpressions.""" - - res = copy(self) - res.simplification_hints = self.simplification_hints.copy() - res.subexpression_symbol_generator = copy(self.subexpression_symbol_generator) - - if main_assignments is not None: - res.main_assignments = main_assignments - else: - res.main_assignments = self.main_assignments.copy() - - if subexpressions is not None: - res.subexpressions = subexpressions - else: - res.subexpressions = self.subexpressions.copy() - - return res - - def new_with_substitutions(self, substitutions: Dict, add_substitutions_as_subexpressions: bool = False, - substitute_on_lhs: bool = True, - sort_topologically: bool = True) -> 'AssignmentCollection': - """Returns new object, where terms are substituted according to the passed substitution dict. - - Args: - substitutions: dict that is passed to sympy subs, substitutions are done main assignments and subexpressions - add_substitutions_as_subexpressions: if True, the substitutions are added as assignments to subexpressions - substitute_on_lhs: if False, the substitutions are done only on the right hand side of assignments - sort_topologically: if subexpressions are added as substitutions and this parameters is true, - the subexpressions are sorted topologically after insertion - Returns: - New AssignmentCollection where substitutions have been applied, self is not altered. - """ - transform = transform_lhs_and_rhs if substitute_on_lhs else transform_rhs - transformed_subexpressions = transform(self.subexpressions, fast_subs, substitutions) - transformed_assignments = transform(self.main_assignments, fast_subs, substitutions) - - if add_substitutions_as_subexpressions: - transformed_subexpressions = [Assignment(b, a) for a, b in - substitutions.items()] + transformed_subexpressions - if sort_topologically: - transformed_subexpressions = sort_assignments_topologically(transformed_subexpressions) - return self.copy(transformed_assignments, transformed_subexpressions) - - def new_merged(self, other: 'AssignmentCollection') -> 'AssignmentCollection': - """Returns a new collection which contains self and other. Subexpressions are renamed if they clash.""" - own_definitions = set([e.lhs for e in self.main_assignments]) - other_definitions = set([e.lhs for e in other.main_assignments]) - assert len(own_definitions.intersection(other_definitions)) == 0, \ - "Cannot merge collections, since both define the same symbols" - - own_subexpression_symbols = {e.lhs: e.rhs for e in self.subexpressions} - substitution_dict = {} - - processed_other_subexpression_equations = [] - for other_subexpression_eq in other.subexpressions: - if other_subexpression_eq.lhs in own_subexpression_symbols: - if other_subexpression_eq.rhs == own_subexpression_symbols[other_subexpression_eq.lhs]: - continue # exact the same subexpression equation exists already - else: - # different definition - a new name has to be introduced - new_lhs = next(self.subexpression_symbol_generator) - new_eq = Assignment(new_lhs, fast_subs(other_subexpression_eq.rhs, substitution_dict)) - processed_other_subexpression_equations.append(new_eq) - substitution_dict[other_subexpression_eq.lhs] = new_lhs - else: - processed_other_subexpression_equations.append(fast_subs(other_subexpression_eq, substitution_dict)) - - processed_other_main_assignments = [fast_subs(eq, substitution_dict) for eq in other.main_assignments] - return self.copy(self.main_assignments + processed_other_main_assignments, - self.subexpressions + processed_other_subexpression_equations) - - def new_filtered(self, symbols_to_extract: Iterable[sp.Symbol]) -> 'AssignmentCollection': - """Extracts equations that have symbols_to_extract as left hand side, together with necessary subexpressions. - - Returns: - new AssignmentCollection, self is not altered - """ - symbols_to_extract = set(symbols_to_extract) - dependent_symbols = self.dependent_symbols(symbols_to_extract) - new_assignments = [] - for eq in self.all_assignments: - if eq.lhs in symbols_to_extract: - new_assignments.append(eq) - - new_sub_expr = [eq for eq in self.all_assignments - if eq.lhs in dependent_symbols and eq.lhs not in symbols_to_extract] - return self.copy(new_assignments, new_sub_expr) - - def new_without_unused_subexpressions(self) -> 'AssignmentCollection': - """Returns new collection that only contains subexpressions required to compute the main assignments.""" - all_lhs = [eq.lhs for eq in self.main_assignments] - return self.new_filtered(all_lhs) - - def new_with_inserted_subexpression(self, symbol: sp.Symbol) -> 'AssignmentCollection': - """Eliminates the subexpression with the given symbol on its left hand side, by substituting it everywhere.""" - new_subexpressions = [] - subs_dict = None - for se in self.subexpressions: - if se.lhs == symbol: - subs_dict = {se.lhs: se.rhs} - else: - new_subexpressions.append(se) - if subs_dict is None: - return self - - new_subexpressions = [Assignment(eq.lhs, fast_subs(eq.rhs, subs_dict)) for eq in new_subexpressions] - new_eqs = [Assignment(eq.lhs, fast_subs(eq.rhs, subs_dict)) for eq in self.main_assignments] - return self.copy(new_eqs, new_subexpressions) - - def new_without_subexpressions(self, subexpressions_to_keep=None) -> 'AssignmentCollection': - """Returns a new collection where all subexpressions have been inserted.""" - if subexpressions_to_keep is None: - subexpressions_to_keep = set() - if len(self.subexpressions) == 0: - return self.copy() - - subexpressions_to_keep = set(subexpressions_to_keep) - - kept_subexpressions = [] - if self.subexpressions[0].lhs in subexpressions_to_keep: - substitution_dict = {} - kept_subexpressions.append(self.subexpressions[0]) - else: - substitution_dict = {self.subexpressions[0].lhs: self.subexpressions[0].rhs} - - subexpression = [e for e in self.subexpressions] - for i in range(1, len(subexpression)): - subexpression[i] = fast_subs(subexpression[i], substitution_dict) - if subexpression[i].lhs in subexpressions_to_keep: - kept_subexpressions.append(subexpression[i]) - else: - substitution_dict[subexpression[i].lhs] = subexpression[i].rhs - - new_assignment = [fast_subs(eq, substitution_dict) for eq in self.main_assignments] - return self.copy(new_assignment, kept_subexpressions) - - # ----------------------------------------- Display and Printing ------------------------------------------------- - - def _repr_html_(self): - """Interface to Jupyter notebook, to display as a nicely formatted HTML table""" - - def make_html_equation_table(equations): - no_border = 'style="border:none"' - html_table = '<table style="border:none; width: 100%; ">' - line = '<tr {nb}> <td {nb}>$${eq}$$</td> </tr> ' - for eq in equations: - format_dict = {'eq': sp.latex(eq), - 'nb': no_border, } - html_table += line.format(**format_dict) - html_table += "</table>" - return html_table - - result = "" - if len(self.subexpressions) > 0: - result += "<div>Subexpressions:</div>" - result += make_html_equation_table(self.subexpressions) - result += "<div>Main Assignments:</div>" - result += make_html_equation_table(self.main_assignments) - return result - - def __repr__(self): - return f"AssignmentCollection: {str(tuple(self.defined_symbols))[1:-1]} <- f{tuple(self.free_symbols)}" - - def __str__(self): - result = "Subexpressions:\n" - for eq in self.subexpressions: - result += f"\t{eq}\n" - result += "Main Assignments:\n" - for eq in self.main_assignments: - result += f"\t{eq}\n" - return result - - def __iter__(self): - return self.all_assignments.__iter__() - - @property - def main_assignments_dict(self): - return {a.lhs: a.rhs for a in self.main_assignments} - - @property - def subexpressions_dict(self): - return {a.lhs: a.rhs for a in self.subexpressions} - - def set_main_assignments_from_dict(self, main_assignments_dict): - self.main_assignments = [Assignment(k, v) - for k, v in main_assignments_dict.items()] - - def set_sub_expressions_from_dict(self, sub_expressions_dict): - self.subexpressions = [Assignment(k, v) - for k, v in sub_expressions_dict.items()] - - def find(self, *args, **kwargs): - return set.union( - *[a.find(*args, **kwargs) for a in self.all_assignments] - ) - - def match(self, *args, **kwargs): - rtn = {} - for a in self.all_assignments: - partial_result = a.match(*args, **kwargs) - if partial_result: - rtn.update(partial_result) - return rtn - - def subs(self, *args, **kwargs): - return AssignmentCollection( - main_assignments=[a.subs(*args, **kwargs) for a in self.main_assignments], - subexpressions=[a.subs(*args, **kwargs) for a in self.subexpressions] - ) - - def replace(self, *args, **kwargs): - return AssignmentCollection( - main_assignments=[a.replace(*args, **kwargs) for a in self.main_assignments], - subexpressions=[a.replace(*args, **kwargs) for a in self.subexpressions] - ) - - def __eq__(self, other): - return set(self.all_assignments) == set(other.all_assignments) - - def __bool__(self): - return bool(self.all_assignments) - - -class SymbolGen: - """Default symbol generator producing number symbols ζ_0, ζ_1, ...""" - - def __init__(self, symbol="xi", dtype=None, ctr=0): - self._ctr = ctr - self._symbol = symbol - self._dtype = dtype - - def __iter__(self): - return self - - def __next__(self): - name = f"{self._symbol}_{self._ctr}" - self._ctr += 1 - if self._dtype is not None: - return pystencils.TypedSymbol(name, self._dtype) - return sp.Symbol(name) diff --git a/src/pystencils/sympyextensions/astnodes.py b/src/pystencils/sympyextensions/astnodes.py index 0cd82322c..8483977d8 100644 --- a/src/pystencils/sympyextensions/astnodes.py +++ b/src/pystencils/sympyextensions/astnodes.py @@ -1,689 +1,588 @@ -import collections.abc +from copy import copy import itertools import uuid -from typing import Any, List, Optional, Sequence, Set, Union - -from .assignment import Assignment -from pystencils.enums import Target, Backend +from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set, Union import sympy as sp +from sympy.codegen.ast import Assignment, AugmentedAssignment, AddAugmentedAssignment +from sympy.printing.latex import LatexPrinter +import numpy as np -from .math import fast_subs -from .typed_sympy import (create_type, CastFunc, - FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol, TypedSymbol) +from .math import count_operations, fast_subs +from .simplifications import (sort_assignments_topologically, transform_lhs_and_rhs, transform_rhs) +from .typed_sympy import create_type, TypedSymbol -NodeOrExpr = Union['Node', sp.Expr] +def print_assignment_latex(printer, expr): + binop = f"{expr.binop}=" if isinstance(expr, AugmentedAssignment) else '' + """sympy cannot print Assignments as Latex. Thus, this function is added to the sympy Latex printer""" + printed_lhs = printer.doprint(expr.lhs) + printed_rhs = printer.doprint(expr.rhs) + return fr"{printed_lhs} \leftarrow_{{{binop}}} {printed_rhs}" -class Node: - """Base class for all AST nodes.""" +def assignment_str(assignment): + op = f"{assignment.binop}=" if isinstance(assignment, AugmentedAssignment) else 'â†' + return fr"{assignment.lhs} {op} {assignment.rhs}" - def __init__(self, parent: Optional['Node'] = None): - self.parent = parent - @property - def args(self) -> List[NodeOrExpr]: - """Returns all arguments/children of this node.""" - raise NotImplementedError() +_old_new = sp.codegen.ast.Assignment.__new__ - @property - def symbols_defined(self) -> Set[sp.Symbol]: - """Set of symbols which are defined by this node.""" - raise NotImplementedError() - @property - def undefined_symbols(self) -> Set[sp.Symbol]: - """Symbols which are used but are not defined inside this node.""" - raise NotImplementedError() - - def subs(self, subs_dict) -> None: - """Inplace! Substitute, similar to sympy's but modifies the AST inplace.""" - for i, a in enumerate(self.args): - result = a.subs(subs_dict) - if isinstance(a, sp.Expr): # sympy expressions' subs is out-of-place - self.args[i] = result - else: # all other should be in-place - assert result is None +# TODO Typing Part2 add default type, defult_float_type, default_int_type and use sane defaults +def _Assignment__new__(cls, lhs, rhs, *args, **kwargs): + if isinstance(lhs, (list, tuple, sp.Matrix)) and isinstance(rhs, (list, tuple, sp.Matrix)): + assert len(lhs) == len(rhs), f'{lhs} and {rhs} must have same length when performing vector assignment!' + return tuple(_old_new(cls, a, b, *args, **kwargs) for a, b in zip(lhs, rhs)) + return _old_new(cls, lhs, rhs, *args, **kwargs) - @property - def func(self): - return self.__class__ - - def atoms(self, arg_type) -> Set[Any]: - """Returns a set of all descendants recursively, which are an instance of the given type.""" - result = set() - for arg in self.args: - if isinstance(arg, arg_type): - result.add(arg) - result.update(arg.atoms(arg_type)) - return result +Assignment.__str__ = assignment_str +Assignment.__new__ = _Assignment__new__ +LatexPrinter._print_Assignment = print_assignment_latex + +AugmentedAssignment.__str__ = assignment_str +LatexPrinter._print_AugmentedAssignment = print_assignment_latex -class Conditional(Node): - """Conditional that maps to a 'if' statement in C/C++. +sp.MutableDenseMatrix.__hash__ = lambda self: hash(tuple(self)) - Try to avoid using this node inside of loops, since currently this construction can not be vectorized. - Consider using assignments with sympy.Piecewise in this case. + +def assignment_from_stencil(stencil_array, input_field, output_field, + normalization_factor=None, order='visual') -> Assignment: + """Creates an assignment Args: - condition_expr: sympy relational expression - true_block: block which is run if conditional is true - false_block: optional block which is run if conditional is false + stencil_array: nested list of numpy array defining the stencil weights + input_field: field or field access, defining where the stencil should be applied to + output_field: field or field access where the result is written to + normalization_factor: optional normalization factor for the stencil + order: defines how the stencil_array is interpreted. Possible values are 'visual' and 'numpy'. + For details see examples + + Returns: + Assignment that can be used to create a kernel + + Examples: + >>> import pystencils as ps + >>> f, g = ps.fields("f, g: [2D]") + >>> stencil = [[0, 2, 0], + ... [3, 4, 5], + ... [0, 6, 0]] + + By default 'visual ordering is used - i.e. the stencil is applied as the nested lists are written down + >>> expected_output = Assignment(g[0, 0], 3*f[-1, 0] + 6*f[0, -1] + 4*f[0, 0] + 2*f[0, 1] + 5*f[1, 0]) + >>> assignment_from_stencil(stencil, f, g, order='visual') == expected_output + True + + 'numpy' ordering uses the first coordinate of the stencil array for x offset, second for y offset etc. + >>> expected_output = Assignment(g[0, 0], 2*f[-1, 0] + 3*f[0, -1] + 4*f[0, 0] + 5*f[0, 1] + 6*f[1, 0]) + >>> assignment_from_stencil(stencil, f, g, order='numpy') == expected_output + True + + You can also pass field accesses to apply the stencil at an already shifted position: + >>> expected_output = Assignment(g[2, 0], 3*f[0, 0] + 6*f[1, -1] + 4*f[1, 0] + 2*f[1, 1] + 5*f[2, 0]) + >>> assignment_from_stencil(stencil, f[1, 0], g[2, 0]) == expected_output + True """ + from pystencils.field import Field - def __init__(self, condition_expr: sp.Basic, true_block: Union['Block', 'SympyAssignment'], - false_block: Optional['Block'] = None) -> None: - super(Conditional, self).__init__(parent=None) + stencil_array = np.array(stencil_array) + if order == 'visual': + stencil_array = np.swapaxes(stencil_array, 0, 1) + stencil_array = np.flip(stencil_array, axis=1) + elif order == 'numpy': + pass + else: + raise ValueError("'order' has to be either 'visual' or 'numpy'") - self.condition_expr = condition_expr + if isinstance(input_field, Field): + input_field = input_field.center + if isinstance(output_field, Field): + output_field = output_field.center - def handle_child(c): - if c is None: - return None - if not isinstance(c, Block): - c = Block([c]) - c.parent = self - return c + rhs = 0 + offset = tuple(s // 2 for s in stencil_array.shape) - self.true_block = handle_child(true_block) - self.false_block = handle_child(false_block) + for index, factor in np.ndenumerate(stencil_array): + shift = tuple(i - o for i, o in zip(index, offset)) + rhs += factor * input_field.get_shifted(*shift) - def subs(self, subs_dict): - self.true_block.subs(subs_dict) - if self.false_block: - self.false_block.subs(subs_dict) - self.condition_expr = self.condition_expr.subs(subs_dict) + if normalization_factor: + rhs *= normalization_factor - @property - def args(self): - result = [self.condition_expr, self.true_block] - if self.false_block: - result.append(self.false_block) - return result + return Assignment(output_field, rhs) - @property - def symbols_defined(self): - return set() - - @property - def undefined_symbols(self): - result = self.true_block.undefined_symbols - if self.false_block: - result.update(self.false_block.undefined_symbols) - if hasattr(self.condition_expr, 'atoms'): - result.update(self.condition_expr.atoms(sp.Symbol)) - return result - def __str__(self): - return self.__repr__() +class AssignmentCollection: + """ + A collection of equations with subexpression definitions, also represented as assignments, + that are used in the main equations. AssignmentCollection can be passed to simplification methods. + These simplification methods can change the subexpressions, but the number and + left hand side of the main equations themselves is not altered. + Additionally a dictionary of simplification hints is stored, which are set by the functions that create + assignment collections to transport information to the simplification system. - def __repr__(self): - result = f'if:({self.condition_expr!r}) ' - if self.true_block: - result += f'\n\t{self.true_block}) ' - if self.false_block: - result = 'else: ' - result += f'\n\t{self.false_block} ' + Args: + main_assignments: List of assignments. Main assignments are characterised, that the right hand side of each + assignment is a field access. Thus the generated equations write on arrays. + subexpressions: List of assignments defining subexpressions used in main equations + simplification_hints: Dict that is used to annotate the assignment collection with hints that are + used by the simplification system. See documentation of the simplification rules for + potentially required hints and their meaning. + subexpression_symbol_generator: Generator for new symbols that are used when new subexpressions are added + used to get new symbols that are unique for this AssignmentCollection - return result + """ - def replace_by_true_block(self): - """Replaces the conditional by its True block""" - self.parent.replace(self, [self.true_block]) + __match_args__ = ("main_assignments", "subexpressions") - def replace_by_false_block(self): - """Replaces the conditional by its False block""" - self.parent.replace(self, [self.false_block] if self.false_block else []) + # ------------------------------- Creation & Inplace Manipulation -------------------------------------------------- + def __init__(self, main_assignments: Union[List[Assignment], Dict[sp.Expr, sp.Expr]], + subexpressions: Union[List[Assignment], Dict[sp.Expr, sp.Expr]] = None, + simplification_hints: Optional[Dict[str, Any]] = None, + subexpression_symbol_generator: Iterator[sp.Symbol] = None) -> None: -class SkipIteration(Node): - @property - def args(self): - return [] + if subexpressions is None: + subexpressions = {} - @property - def symbols_defined(self): - return set() + if isinstance(main_assignments, Dict): + main_assignments = [Assignment(k, v) + for k, v in main_assignments.items()] + if isinstance(subexpressions, Dict): + subexpressions = [Assignment(k, v) + for k, v in subexpressions.items()] - @property - def undefined_symbols(self): - return set() - - -class Block(Node): - def __init__(self, nodes: Union[Node, List[Node]]): - super(Block, self).__init__() - if not isinstance(nodes, list): - nodes = [nodes] - self._nodes = nodes - self.parent = None - for n in self._nodes: - try: - n.parent = self - except AttributeError: - pass + main_assignments = list(itertools.chain.from_iterable( + [(a if isinstance(a, Iterable) else [a]) for a in main_assignments])) + subexpressions = list(itertools.chain.from_iterable( + [(a if isinstance(a, Iterable) else [a]) for a in subexpressions])) - @property - def args(self): - return self._nodes + self.main_assignments = main_assignments + self.subexpressions = subexpressions - def subs(self, subs_dict) -> None: - for a in self.args: - a.subs(subs_dict) + if simplification_hints is None: + simplification_hints = {} - def fast_subs(self, subs_dict, skip=None): - self._nodes = [fast_subs(a, subs_dict, skip) for a in self._nodes] - return self + self.simplification_hints = simplification_hints - def insert_front(self, node, if_not_exists=False): - if if_not_exists and len(self._nodes) > 0 and self._nodes[0] == node: - return - if isinstance(node, collections.abc.Iterable): - node = list(node) - for n in node: - n.parent = self + ctrs = [int(n.name[3:])for n in self.rhs_symbols if "xi_" in n.name] + max_ctr = max(ctrs) + 1 if len(ctrs) > 0 else 0 - self._nodes = node + self._nodes - else: - node.parent = self - self._nodes.insert(0, node) - - def insert_before(self, new_node, insert_before, if_not_exists=False): - new_node.parent = self - assert self._nodes.count(insert_before) == 1 - idx = self._nodes.index(insert_before) - - if not if_not_exists or self._nodes[idx] != new_node: - self._nodes.insert(idx, new_node) - - def insert_after(self, new_node, insert_after, if_not_exists=False): - new_node.parent = self - assert self._nodes.count(insert_after) == 1 - idx = self._nodes.index(insert_after) + 1 - - if not if_not_exists or not (self._nodes[idx - 1] == new_node - or (idx < len(self._nodes) and self._nodes[idx] == new_node)): - self._nodes.insert(idx, new_node) - - def append(self, node): - if isinstance(node, list) or isinstance(node, tuple): - for n in node: - n.parent = self - self._nodes.append(n) - else: - node.parent = self - self._nodes.append(node) - - def take_child_nodes(self): - tmp = self._nodes - self._nodes = [] - return tmp - - def replace(self, child, replacements): - assert self._nodes.count(child) == 1 - idx = self._nodes.index(child) - del self._nodes[idx] - if type(replacements) is list: - for e in replacements: - e.parent = self - self._nodes = self._nodes[:idx] + replacements + self._nodes[idx:] + if subexpression_symbol_generator is None: + self.subexpression_symbol_generator = SymbolGen(ctr=max_ctr) else: - replacements.parent = self - self._nodes.insert(idx, replacements) + self.subexpression_symbol_generator = subexpression_symbol_generator - @property - def symbols_defined(self): - result = set() - for a in self.args: - if isinstance(a, Assignment): - result.update(a.free_symbols) - else: - result.update(a.symbols_defined) - return result - - @property - def undefined_symbols(self): - result = set() - defined_symbols = set() - for a in self.args: - if isinstance(a, Assignment): - result.update(a.free_symbols) - defined_symbols.update({a.lhs}) - else: - result.update(a.undefined_symbols) - defined_symbols.update(a.symbols_defined) - return result - defined_symbols - - def __str__(self): - return "Block " + ''.join('{!s}\n'.format(node) for node in self._nodes) - - def __repr__(self): - return "Block" + def add_simplification_hint(self, key: str, value: Any) -> None: + """Adds an entry to the simplification_hints dictionary and checks that is does not exist yet.""" + assert key not in self.simplification_hints, "This hint already exists" + self.simplification_hints[key] = value + def add_subexpression(self, rhs: sp.Expr, lhs: Optional[sp.Symbol] = None, topological_sort=True) -> sp.Symbol: + """Adds a subexpression to current collection. -class PragmaBlock(Block): - def __init__(self, pragma_line, nodes): - super(PragmaBlock, self).__init__(nodes) - self.pragma_line = pragma_line - for n in nodes: - n.parent = self + Args: + rhs: right hand side of new subexpression + lhs: optional left hand side of new subexpression. If None a new unique symbol is generated. + topological_sort: sort the subexpressions topologically after insertion, to make sure that + definition of a symbol comes before its usage. If False, subexpression is appended. - def __repr__(self): - return self.pragma_line - - -class LoopOverCoordinate(Node): - LOOP_COUNTER_NAME_PREFIX = "ctr" - BLOCK_LOOP_COUNTER_NAME_PREFIX = "_blockctr" - - def __init__(self, body, coordinate_to_loop_over, start, stop, step=1, is_block_loop=False, custom_loop_ctr=None): - super(LoopOverCoordinate, self).__init__(parent=None) - self.body = body - body.parent = self - self.coordinate_to_loop_over = coordinate_to_loop_over - self.start = start - self.stop = stop - self.step = step - self.body.parent = self - self.prefix_lines = [] - self.is_block_loop = is_block_loop - self.custom_loop_ctr = custom_loop_ctr - - def new_loop_with_different_body(self, new_body): - result = LoopOverCoordinate(new_body, self.coordinate_to_loop_over, self.start, self.stop, - self.step, self.is_block_loop, self.custom_loop_ctr) - result.prefix_lines = [prefix_line for prefix_line in self.prefix_lines] - return result - - def subs(self, subs_dict): - self.body.subs(subs_dict) - if hasattr(self.start, "subs"): - self.start = self.start.subs(subs_dict) - if hasattr(self.stop, "subs"): - self.stop = self.stop.subs(subs_dict) - if hasattr(self.step, "subs"): - self.step = self.step.subs(subs_dict) - - def fast_subs(self, subs_dict, skip=None): - self.body = fast_subs(self.body, subs_dict, skip) - if isinstance(self.start, sp.Basic): - self.start = fast_subs(self.start, subs_dict, skip) - if isinstance(self.stop, sp.Basic): - self.stop = fast_subs(self.stop, subs_dict, skip) - if isinstance(self.step, sp.Basic): - self.step = fast_subs(self.step, subs_dict, skip) - return self + Returns: + left hand side symbol (which could have been generated) + """ + if lhs is None: + lhs = next(self.subexpression_symbol_generator) + eq = Assignment(lhs, rhs) + self.subexpressions.append(eq) + if topological_sort: + self.topological_sort(sort_subexpressions=True, + sort_main_assignments=False) + return lhs - @property - def args(self): - result = [self.body] - for e in [self.start, self.stop, self.step]: - if hasattr(e, "args"): - result.append(e) - return result + def topological_sort(self, sort_subexpressions: bool = True, sort_main_assignments: bool = True) -> None: + """Sorts subexpressions and/or main_equations topologically to make sure symbol usage comes after definition.""" + if sort_subexpressions: + self.subexpressions = sort_assignments_topologically(self.subexpressions) + if sort_main_assignments: + self.main_assignments = sort_assignments_topologically(self.main_assignments) - def replace(self, child, replacement): - if child == self.body: - self.body = replacement - elif child == self.start: - self.start = replacement - elif child == self.step: - self.step = replacement - elif child == self.stop: - self.stop = replacement + # ---------------------------------------------- Properties ------------------------------------------------------- @property - def symbols_defined(self): - return {self.loop_counter_symbol} + def all_assignments(self) -> List[Assignment]: + """Subexpression and main equations as a single list.""" + return self.subexpressions + self.main_assignments @property - def undefined_symbols(self): - result = self.body.undefined_symbols - for possible_symbol in [self.start, self.stop, self.step]: - if isinstance(possible_symbol, Node) or isinstance(possible_symbol, sp.Basic): - result.update(possible_symbol.atoms(sp.Symbol)) - return result - {self.loop_counter_symbol} - - @staticmethod - def get_loop_counter_name(coordinate_to_loop_over): - return f"{LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX}_{coordinate_to_loop_over}" - - @staticmethod - def get_block_loop_counter_name(coordinate_to_loop_over): - return f"{LoopOverCoordinate.BLOCK_LOOP_COUNTER_NAME_PREFIX}_{coordinate_to_loop_over}" + def rhs_symbols(self) -> Set[sp.Symbol]: + """All symbols used in the assignment collection, which occur on the rhs of any assignment.""" + rhs_symbols = set() + for eq in self.all_assignments: + if isinstance(eq, Assignment): + rhs_symbols.update(eq.rhs.atoms(sp.Symbol)) + # TODO rewrite with SymPy AST + # elif isinstance(eq, pystencils.astnodes.Node): + # rhs_symbols.update(eq.undefined_symbols) - @property - def loop_counter_name(self): - if self.custom_loop_ctr: - return self.custom_loop_ctr.name - else: - if self.is_block_loop: - return LoopOverCoordinate.get_block_loop_counter_name(self.coordinate_to_loop_over) - else: - return LoopOverCoordinate.get_loop_counter_name(self.coordinate_to_loop_over) - - @staticmethod - def is_loop_counter_symbol(symbol): - prefix = LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX - if not symbol.name.startswith(prefix): - return None - if symbol.dtype != create_type('int'): - return None - coordinate = int(symbol.name[len(prefix) + 1:]) - return coordinate - - @staticmethod - def get_loop_counter_symbol(coordinate_to_loop_over): - return TypedSymbol(LoopOverCoordinate.get_loop_counter_name(coordinate_to_loop_over), 'int', nonnegative=True) - - @staticmethod - def get_block_loop_counter_symbol(coordinate_to_loop_over): - return TypedSymbol(LoopOverCoordinate.get_block_loop_counter_name(coordinate_to_loop_over), - 'int', - nonnegative=True) + return rhs_symbols @property - def loop_counter_symbol(self): - if self.custom_loop_ctr: - return self.custom_loop_ctr - else: - if self.is_block_loop: - return self.get_block_loop_counter_symbol(self.coordinate_to_loop_over) - else: - return self.get_loop_counter_symbol(self.coordinate_to_loop_over) + def free_symbols(self) -> Set[sp.Symbol]: + """All symbols used in the assignment collection, which do not occur as left hand sides in any assignment.""" + return self.rhs_symbols - self.bound_symbols @property - def is_outermost_loop(self): - return get_next_parent_of_type(self, LoopOverCoordinate) is None + def bound_symbols(self) -> Set[sp.Symbol]: + """All symbols which occur on the left hand side of a main assignment or a subexpression.""" + bound_symbols_set = set( + [assignment.lhs for assignment in self.all_assignments if isinstance(assignment, Assignment)] + ) - @property - def is_innermost_loop(self): - return len(self.atoms(LoopOverCoordinate)) == 0 + assert len(bound_symbols_set) == len(list(a for a in self.all_assignments if isinstance(a, Assignment))), \ + "Not in SSA form - same symbol assigned multiple times" - def __str__(self): - return 'for({!s}={!s}; {!s}<{!s}; {!s}+={!s})\n{!s}'.format(self.loop_counter_name, self.start, - self.loop_counter_name, self.stop, - self.loop_counter_name, self.step, - ("\t" + "\t".join(str(self.body).splitlines(True)))) + # TODO rewrite with SymPy AST + # bound_symbols_set = bound_symbols_set.union(*[ + # assignment.symbols_defined for assignment in self.all_assignments + # if isinstance(assignment, pystencils.astnodes.Node) + # ]) - def __repr__(self): - return 'for({!s}={!s}; {!s}<{!s}; {!s}+={!s})'.format(self.loop_counter_name, self.start, - self.loop_counter_name, self.stop, - self.loop_counter_name, self.step) - - -class SympyAssignment(Node): - def __init__(self, lhs_symbol, rhs_expr, is_const=True, use_auto=False): - super(SympyAssignment, self).__init__(parent=None) - self._lhs_symbol = sp.sympify(lhs_symbol) - self._rhs = sp.sympify(rhs_expr) - self._is_const = is_const - self._is_declaration = self.__is_declaration() - self._use_auto = use_auto - - def __is_declaration(self): - if isinstance(self._lhs_symbol, CastFunc): - return False - if any(isinstance(self._lhs_symbol, c) for c in (Field.Access, sp.Indexed, TemporaryMemoryAllocation)): - return False - return True + return bound_symbols_set @property - def lhs(self): - return self._lhs_symbol + def rhs_fields(self): + """All fields accessed in the assignment collection, which do not occur as left hand sides in any assignment.""" + return {s.field for s in self.rhs_symbols if hasattr(s, 'field')} @property - def rhs(self): - return self._rhs - - @lhs.setter - def lhs(self, new_value): - self._lhs_symbol = new_value - self._is_declaration = self.__is_declaration() - - @rhs.setter - def rhs(self, new_rhs_expr): - self._rhs = new_rhs_expr - - def subs(self, subs_dict): - self.lhs = fast_subs(self.lhs, subs_dict) - self.rhs = fast_subs(self.rhs, subs_dict) - - def fast_subs(self, subs_dict, skip=None): - self.lhs = fast_subs(self.lhs, subs_dict, skip) - self.rhs = fast_subs(self.rhs, subs_dict, skip) - return self - - def optimize(self, optimizations): - try: - from sympy.codegen.rewriting import optimize - self.rhs = optimize(self.rhs, optimizations) - except Exception: - pass + def free_fields(self): + """All fields accessed in the assignment collection, which do not occur as left hand sides in any assignment.""" + return {s.field for s in self.free_symbols if hasattr(s, 'field')} @property - def args(self): - return [self._lhs_symbol, self.rhs] + def bound_fields(self): + """All field accessed on the left hand side of a main assignment or a subexpression.""" + return {s.field for s in self.bound_symbols if hasattr(s, 'field')} @property - def symbols_defined(self): - if not self._is_declaration: - return set() - return {self._lhs_symbol} + def defined_symbols(self) -> Set[sp.Symbol]: + """All symbols which occur as left-hand-sides of one of the main equations""" + lhs_set = set([assignment.lhs for assignment in self.main_assignments if isinstance(assignment, Assignment)]) + return lhs_set + # TODO rewrite with SymPy AST + # return (lhs_set.union(*[assignment.symbols_defined for assignment in self.main_assignments + # if isinstance(assignment, pystencils.astnodes.Node)])) @property - def undefined_symbols(self): - result = {s for s in self.rhs.free_symbols if not isinstance(s, sp.Indexed)} - # Add loop counters if there a field accesses - loop_counters = set() - for symbol in result: - if isinstance(symbol, Field.Access): - for i in range(len(symbol.offsets)): - loop_counters.add(LoopOverCoordinate.get_loop_counter_symbol(i)) - result.update(loop_counters) - - result.update(self._lhs_symbol.atoms(sp.Symbol)) - - return result + def operation_count(self): + """See :func:`count_operations` """ + return count_operations(self.all_assignments, only_type=None) - @property - def is_declaration(self): - return self._is_declaration + def atoms(self, *args): + return set().union(*[a.atoms(*args) for a in self.all_assignments]) - @property - def is_const(self): - return self._is_const + def dependent_symbols(self, symbols: Iterable[sp.Symbol]) -> Set[sp.Symbol]: + """Returns all symbols that depend on one of the passed symbols. - @property - def use_auto(self): - return self._use_auto - - def replace(self, child, replacement): - if child == self.lhs: - replacement.parent = self - self.lhs = replacement - elif child == self.rhs: - replacement.parent = self - self.rhs = replacement - else: - raise ValueError(f'{replacement} is not in args of {self.__class__}') + A symbol 'a' depends on a symbol 'b', if there is an assignment 'a <- some_expression(b)' i.e. when + 'b' is required to compute 'a'. + """ - def __repr__(self): - return repr(self.lhs) + " ↠" + repr(self.rhs) + queue = list(symbols) - def _repr_html_(self): - printed_lhs = sp.latex(self.lhs) - printed_rhs = sp.latex(self.rhs) - return f"${printed_lhs} \\leftarrow {printed_rhs}$" + def add_symbols_from_expr(expr): + dependent_symbols = expr.atoms(sp.Symbol) + for ds in dependent_symbols: + queue.append(ds) - def __hash__(self): - return hash((self.lhs, self.rhs)) + handled_symbols = set() + assignment_dict = {e.lhs: e.rhs for e in self.all_assignments} - def __eq__(self, other): - return type(self) is type(other) and (self.lhs, self.rhs) == (other.lhs, other.rhs) - - -class ResolvedFieldAccess(sp.Indexed): - def __new__(cls, base, linearized_index, field, offsets, idx_coordinate_values): - if not isinstance(base, sp.IndexedBase): - assert isinstance(base, TypedSymbol) - base = sp.IndexedBase(base, shape=(1,)) - assert isinstance(base.label, TypedSymbol) - obj = super(ResolvedFieldAccess, cls).__new__(cls, base, linearized_index) - obj.field = field - obj.offsets = offsets - obj.idx_coordinate_values = idx_coordinate_values - return obj - - def _eval_subs(self, old, new): - return ResolvedFieldAccess(self.args[0], - self.args[1].subs(old, new), - self.field, self.offsets, self.idx_coordinate_values) - - def fast_subs(self, substitutions, skip=None): - if self in substitutions: - return substitutions[self] - return ResolvedFieldAccess(self.args[0].subs(substitutions), - self.args[1].subs(substitutions), - self.field, self.offsets, self.idx_coordinate_values) - - def _hashable_content(self): - super_class_contents = super(ResolvedFieldAccess, self)._hashable_content() - return super_class_contents + tuple(self.offsets) + (repr(self.idx_coordinate_values), hash(self.field)) + while len(queue) > 0: + e = queue.pop(0) + if e in handled_symbols: + continue + if e in assignment_dict: + add_symbols_from_expr(assignment_dict[e]) + handled_symbols.add(e) - @property - def typed_symbol(self): - return self.base.label + return handled_symbols - def __str__(self): - top = super(ResolvedFieldAccess, self).__str__() - return f"{top} ({self.typed_symbol.dtype})" + def lambdify(self, symbols: Sequence[sp.Symbol], fixed_symbols: Optional[Dict[sp.Symbol, Any]] = None, module=None): + """Returns a python function to evaluate this equation collection. - def __getnewargs__(self): - return self.base, self.indices[0], self.field, self.offsets, self.idx_coordinate_values + Args: + symbols: symbol(s) which are the parameter for the created function + fixed_symbols: dictionary with substitutions, that are applied before sympy's lambdify + module: same as sympy.lambdify parameter. Defines which module to use e.g. 'numpy' - def __getnewargs_ex__(self): - return (self.base, self.indices[0], self.field, self.offsets, self.idx_coordinate_values), {} + Examples: + >>> a, b, c, d = sp.symbols("a b c d") + >>> ac = AssignmentCollection([Assignment(c, a + b), Assignment(d, a**2 + b)], + ... subexpressions=[Assignment(b, a + b / 2)]) + >>> python_function = ac.lambdify([a], fixed_symbols={b: 2}) + >>> python_function(4) + {c: 6, d: 18} + """ + assignments = self.new_with_substitutions(fixed_symbols, substitute_on_lhs=False) if fixed_symbols else self + assignments = assignments.new_without_subexpressions().main_assignments + lambdas = {assignment.lhs: sp.lambdify(symbols, assignment.rhs, module) for assignment in assignments} + def f(*args, **kwargs): + return {s: func(*args, **kwargs) for s, func in lambdas.items()} -class TemporaryMemoryAllocation(Node): - """Node for temporary memory buffer allocation. + return f - Always allocates aligned memory. + # ---------------------------- Creating new modified collections --------------------------------------------------- - Args: - typed_symbol: symbol used as pointer (has to be typed) - size: number of elements to allocate - align_offset: the align_offset's element is aligned - """ + def copy(self, + main_assignments: Optional[List[Assignment]] = None, + subexpressions: Optional[List[Assignment]] = None) -> 'AssignmentCollection': + """Returns a copy with optionally replaced main_assignments and/or subexpressions.""" - def __init__(self, typed_symbol: TypedSymbol, size, align_offset): - super(TemporaryMemoryAllocation, self).__init__(parent=None) - self.symbol = typed_symbol - self.size = size - self.headers = ['<stdlib.h>'] - self._align_offset = align_offset + res = copy(self) + res.simplification_hints = self.simplification_hints.copy() + res.subexpression_symbol_generator = copy(self.subexpression_symbol_generator) - @property - def symbols_defined(self): - return {self.symbol} + if main_assignments is not None: + res.main_assignments = main_assignments + else: + res.main_assignments = self.main_assignments.copy() - @property - def undefined_symbols(self): - if isinstance(self.size, sp.Basic): - return self.size.atoms(sp.Symbol) + if subexpressions is not None: + res.subexpressions = subexpressions else: - return set() + res.subexpressions = self.subexpressions.copy() + + return res + + def new_with_substitutions(self, substitutions: Dict, add_substitutions_as_subexpressions: bool = False, + substitute_on_lhs: bool = True, + sort_topologically: bool = True) -> 'AssignmentCollection': + """Returns new object, where terms are substituted according to the passed substitution dict. + + Args: + substitutions: dict that is passed to sympy subs, substitutions are done main assignments and subexpressions + add_substitutions_as_subexpressions: if True, the substitutions are added as assignments to subexpressions + substitute_on_lhs: if False, the substitutions are done only on the right hand side of assignments + sort_topologically: if subexpressions are added as substitutions and this parameters is true, + the subexpressions are sorted topologically after insertion + Returns: + New AssignmentCollection where substitutions have been applied, self is not altered. + """ + transform = transform_lhs_and_rhs if substitute_on_lhs else transform_rhs + transformed_subexpressions = transform(self.subexpressions, fast_subs, substitutions) + transformed_assignments = transform(self.main_assignments, fast_subs, substitutions) + + if add_substitutions_as_subexpressions: + transformed_subexpressions = [Assignment(b, a) for a, b in + substitutions.items()] + transformed_subexpressions + if sort_topologically: + transformed_subexpressions = sort_assignments_topologically(transformed_subexpressions) + return self.copy(transformed_assignments, transformed_subexpressions) + + def new_merged(self, other: 'AssignmentCollection') -> 'AssignmentCollection': + """Returns a new collection which contains self and other. Subexpressions are renamed if they clash.""" + own_definitions = set([e.lhs for e in self.main_assignments]) + other_definitions = set([e.lhs for e in other.main_assignments]) + assert len(own_definitions.intersection(other_definitions)) == 0, \ + "Cannot merge collections, since both define the same symbols" + + own_subexpression_symbols = {e.lhs: e.rhs for e in self.subexpressions} + substitution_dict = {} + + processed_other_subexpression_equations = [] + for other_subexpression_eq in other.subexpressions: + if other_subexpression_eq.lhs in own_subexpression_symbols: + if other_subexpression_eq.rhs == own_subexpression_symbols[other_subexpression_eq.lhs]: + continue # exact the same subexpression equation exists already + else: + # different definition - a new name has to be introduced + new_lhs = next(self.subexpression_symbol_generator) + new_eq = Assignment(new_lhs, fast_subs(other_subexpression_eq.rhs, substitution_dict)) + processed_other_subexpression_equations.append(new_eq) + substitution_dict[other_subexpression_eq.lhs] = new_lhs + else: + processed_other_subexpression_equations.append(fast_subs(other_subexpression_eq, substitution_dict)) + + processed_other_main_assignments = [fast_subs(eq, substitution_dict) for eq in other.main_assignments] + return self.copy(self.main_assignments + processed_other_main_assignments, + self.subexpressions + processed_other_subexpression_equations) + + def new_filtered(self, symbols_to_extract: Iterable[sp.Symbol]) -> 'AssignmentCollection': + """Extracts equations that have symbols_to_extract as left hand side, together with necessary subexpressions. + + Returns: + new AssignmentCollection, self is not altered + """ + symbols_to_extract = set(symbols_to_extract) + dependent_symbols = self.dependent_symbols(symbols_to_extract) + new_assignments = [] + for eq in self.all_assignments: + if eq.lhs in symbols_to_extract: + new_assignments.append(eq) + + new_sub_expr = [eq for eq in self.all_assignments + if eq.lhs in dependent_symbols and eq.lhs not in symbols_to_extract] + return self.copy(new_assignments, new_sub_expr) + + def new_without_unused_subexpressions(self) -> 'AssignmentCollection': + """Returns new collection that only contains subexpressions required to compute the main assignments.""" + all_lhs = [eq.lhs for eq in self.main_assignments] + return self.new_filtered(all_lhs) + + def new_with_inserted_subexpression(self, symbol: sp.Symbol) -> 'AssignmentCollection': + """Eliminates the subexpression with the given symbol on its left hand side, by substituting it everywhere.""" + new_subexpressions = [] + subs_dict = None + for se in self.subexpressions: + if se.lhs == symbol: + subs_dict = {se.lhs: se.rhs} + else: + new_subexpressions.append(se) + if subs_dict is None: + return self + + new_subexpressions = [Assignment(eq.lhs, fast_subs(eq.rhs, subs_dict)) for eq in new_subexpressions] + new_eqs = [Assignment(eq.lhs, fast_subs(eq.rhs, subs_dict)) for eq in self.main_assignments] + return self.copy(new_eqs, new_subexpressions) + + def new_without_subexpressions(self, subexpressions_to_keep=None) -> 'AssignmentCollection': + """Returns a new collection where all subexpressions have been inserted.""" + if subexpressions_to_keep is None: + subexpressions_to_keep = set() + if len(self.subexpressions) == 0: + return self.copy() + + subexpressions_to_keep = set(subexpressions_to_keep) + + kept_subexpressions = [] + if self.subexpressions[0].lhs in subexpressions_to_keep: + substitution_dict = {} + kept_subexpressions.append(self.subexpressions[0]) + else: + substitution_dict = {self.subexpressions[0].lhs: self.subexpressions[0].rhs} - @property - def args(self): - return [self.symbol] + subexpression = [e for e in self.subexpressions] + for i in range(1, len(subexpression)): + subexpression[i] = fast_subs(subexpression[i], substitution_dict) + if subexpression[i].lhs in subexpressions_to_keep: + kept_subexpressions.append(subexpression[i]) + else: + substitution_dict[subexpression[i].lhs] = subexpression[i].rhs - def offset(self, byte_alignment): - """Number of ELEMENTS to skip for a pointer that is aligned to byte_alignment.""" - np_dtype = self.symbol.dtype.base_type.numpy_dtype - assert byte_alignment % np_dtype.itemsize == 0 - return -self._align_offset % (byte_alignment / np_dtype.itemsize) + new_assignment = [fast_subs(eq, substitution_dict) for eq in self.main_assignments] + return self.copy(new_assignment, kept_subexpressions) + # ----------------------------------------- Display and Printing ------------------------------------------------- -class TemporaryMemoryFree(Node): - def __init__(self, alloc_node): - super(TemporaryMemoryFree, self).__init__(parent=None) - self.alloc_node = alloc_node + def _repr_html_(self): + """Interface to Jupyter notebook, to display as a nicely formatted HTML table""" + + def make_html_equation_table(equations): + no_border = 'style="border:none"' + html_table = '<table style="border:none; width: 100%; ">' + line = '<tr {nb}> <td {nb}>$${eq}$$</td> </tr> ' + for eq in equations: + format_dict = {'eq': sp.latex(eq), + 'nb': no_border, } + html_table += line.format(**format_dict) + html_table += "</table>" + return html_table + + result = "" + if len(self.subexpressions) > 0: + result += "<div>Subexpressions:</div>" + result += make_html_equation_table(self.subexpressions) + result += "<div>Main Assignments:</div>" + result += make_html_equation_table(self.main_assignments) + return result - @property - def symbol(self): - return self.alloc_node.symbol + def __repr__(self): + return f"AssignmentCollection: {str(tuple(self.defined_symbols))[1:-1]} <- f{tuple(self.free_symbols)}" - def offset(self, byte_alignment): - return self.alloc_node.offset(byte_alignment) + def __str__(self): + result = "Subexpressions:\n" + for eq in self.subexpressions: + result += f"\t{eq}\n" + result += "Main Assignments:\n" + for eq in self.main_assignments: + result += f"\t{eq}\n" + return result - @property - def symbols_defined(self): - return set() + def __iter__(self): + return self.all_assignments.__iter__() @property - def undefined_symbols(self): - return set() + def main_assignments_dict(self): + return {a.lhs: a.rhs for a in self.main_assignments} @property - def args(self): - return [] - - -def early_out(condition): - return Conditional(vec_all(condition), Block([SkipIteration()])) + def subexpressions_dict(self): + return {a.lhs: a.rhs for a in self.subexpressions} + def set_main_assignments_from_dict(self, main_assignments_dict): + self.main_assignments = [Assignment(k, v) + for k, v in main_assignments_dict.items()] -def get_dummy_symbol(dtype='bool'): - return TypedSymbol(f'dummy{uuid.uuid4().hex}', create_type(dtype)) + def set_sub_expressions_from_dict(self, sub_expressions_dict): + self.subexpressions = [Assignment(k, v) + for k, v in sub_expressions_dict.items()] + def find(self, *args, **kwargs): + return set.union( + *[a.find(*args, **kwargs) for a in self.all_assignments] + ) -class SourceCodeComment(Node): - def __init__(self, text): - self.text = text + def match(self, *args, **kwargs): + rtn = {} + for a in self.all_assignments: + partial_result = a.match(*args, **kwargs) + if partial_result: + rtn.update(partial_result) + return rtn - @property - def args(self): - return [] + def subs(self, *args, **kwargs): + return AssignmentCollection( + main_assignments=[a.subs(*args, **kwargs) for a in self.main_assignments], + subexpressions=[a.subs(*args, **kwargs) for a in self.subexpressions] + ) - @property - def symbols_defined(self): - return set() + def replace(self, *args, **kwargs): + return AssignmentCollection( + main_assignments=[a.replace(*args, **kwargs) for a in self.main_assignments], + subexpressions=[a.replace(*args, **kwargs) for a in self.subexpressions] + ) - @property - def undefined_symbols(self): - return set() - - def __str__(self): - return "/* " + self.text + " */" + def __eq__(self, other): + return set(self.all_assignments) == set(other.all_assignments) - def __repr__(self): - return self.__str__() + def __bool__(self): + return bool(self.all_assignments) -class EmptyLine(Node): - def __init__(self): - pass +class SymbolGen: + """Default symbol generator producing number symbols ζ_0, ζ_1, ...""" - @property - def args(self): - return [] + def __init__(self, symbol="xi", dtype=None, ctr=0): + self._ctr = ctr + self._symbol = symbol + self._dtype = dtype - @property - def symbols_defined(self): - return set() + def __iter__(self): + return self - @property - def undefined_symbols(self): - return set() + def __next__(self): + name = f"{self._symbol}_{self._ctr}" + self._ctr += 1 + if self._dtype is not None: + return TypedSymbol(name, self._dtype) + return sp.Symbol(name) - def __str__(self): - return "" - def __repr__(self): - return self.__str__() +def get_dummy_symbol(dtype='bool'): + return TypedSymbol(f'dummy{uuid.uuid4().hex}', create_type(dtype)) class ConditionalFieldAccess(sp.Function): @@ -712,3 +611,18 @@ class ConditionalFieldAccess(sp.Function): def __getnewargs_ex__(self): return (self.access, self.outofbounds_condition, self.outofbounds_value), {} + + +def generic_visit(term, visitor): + if isinstance(term, AssignmentCollection): + new_main_assignments = generic_visit(term.main_assignments, visitor) + new_subexpressions = generic_visit(term.subexpressions, visitor) + return term.copy(new_main_assignments, new_subexpressions) + elif isinstance(term, list): + return [generic_visit(e, visitor) for e in term] + elif isinstance(term, Assignment): + return Assignment(term.lhs, generic_visit(term.rhs, visitor)) + elif isinstance(term, sp.Matrix): + return term.applyfunc(lambda e: generic_visit(e, visitor)) + else: + return visitor(term) diff --git a/src/pystencils/sympyextensions/bit_masks.py b/src/pystencils/sympyextensions/bit_masks.py index f8b6b7ef0..57f2ab5fb 100644 --- a/src/pystencils/sympyextensions/bit_masks.py +++ b/src/pystencils/sympyextensions/bit_masks.py @@ -1,5 +1,4 @@ import sympy as sp -# from pystencils.typing import get_type_of_expression # noinspection PyPep8Naming diff --git a/src/pystencils/sympyextensions/math.py b/src/pystencils/sympyextensions/math.py index 4a6f01a5c..00974809f 100644 --- a/src/pystencils/sympyextensions/math.py +++ b/src/pystencils/sympyextensions/math.py @@ -10,7 +10,7 @@ from sympy import PolynomialError from sympy.functions import Abs from sympy.core.numbers import Zero -from .assignment import Assignment +from .astnodes import Assignment from pystencils.functions import DivFunc from .typed_sympy import CastFunc, PointerType, VectorType, FieldPointerSymbol diff --git a/src/pystencils/sympyextensions/simplifications.py b/src/pystencils/sympyextensions/simplifications.py index c16e42f85..cdcad81e7 100644 --- a/src/pystencils/sympyextensions/simplifications.py +++ b/src/pystencils/sympyextensions/simplifications.py @@ -4,20 +4,21 @@ from collections import defaultdict import sympy as sp -from .assignment import Assignment -from pystencils.sympyextensions.astnodes import Node +from .astnodes import Assignment from .math import subs_additive, is_constant, recursive_collect from .typed_sympy import TypedSymbol -def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node]]) -> List[Union[Assignment, Node]]: +# TODO rewrite with SymPy AST +# def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node]]) -> List[Union[Assignment, Node]]: +def sort_assignments_topologically(assignments: Sequence[Union[Assignment]]) -> List[Union[Assignment]]: """Sorts assignments in topological order, such that symbols used on rhs occur first on a lhs""" edges = [] for c1, e1 in enumerate(assignments): if hasattr(e1, 'lhs') and hasattr(e1, 'rhs'): symbols = [e1.lhs] - elif isinstance(e1, Node): - symbols = e1.symbols_defined + # elif isinstance(e1, Node): + # symbols = e1.symbols_defined else: raise NotImplementedError(f"Cannot sort topologically. Object of type {type(e1)} cannot be handled.") @@ -25,8 +26,8 @@ def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node] for c2, e2 in enumerate(assignments): if isinstance(e2, Assignment) and lhs in e2.rhs.free_symbols: edges.append((c1, c2)) - elif isinstance(e2, Node) and lhs in e2.undefined_symbols: - edges.append((c1, c2)) + # elif isinstance(e2, Node) and lhs in e2.undefined_symbols: + # edges.append((c1, c2)) return [assignments[i] for i in sp.topological_sort((range(len(assignments)), edges))] diff --git a/src/pystencils/sympyextensions/simplificationstrategy.py b/src/pystencils/sympyextensions/simplificationstrategy.py index cd2fa8faa..b76d12711 100644 --- a/src/pystencils/sympyextensions/simplificationstrategy.py +++ b/src/pystencils/sympyextensions/simplificationstrategy.py @@ -3,7 +3,7 @@ from typing import Any, Callable, Optional, Sequence import sympy as sp -from .assignment_collection import AssignmentCollection +from .astnodes import AssignmentCollection class SimplificationStrategy: diff --git a/src/pystencils/sympyextensions/typed_sympy.py b/src/pystencils/sympyextensions/typed_sympy.py index 458fbae20..1f20624ce 100644 --- a/src/pystencils/sympyextensions/typed_sympy.py +++ b/src/pystencils/sympyextensions/typed_sympy.py @@ -1,5 +1,6 @@ from abc import abstractmethod -from typing import Union +from itertools import groupby +from typing import Sequence, Union import numpy as np import sympy as sp @@ -330,6 +331,72 @@ class StructType(AbstractType): return hash((self.numpy_dtype, self.const)) +def result_type(*args: np.dtype): + """Returns the type of the result if the np.dtype arguments would be collated. + We can't use numpy functionality, because numpy casts don't behave exactly like C casts""" + s = sorted(args, key=lambda x: x.itemsize) + + def kind_to_value(kind: str) -> int: + if kind == 'f': + return 3 + elif kind == 'i': + return 2 + elif kind == 'u': + return 1 + elif kind == 'b': + return 0 + else: + raise NotImplementedError(f'{kind=} is not a supported kind of a type. See "numpy.dtype.kind" for options') + s = sorted(s, key=lambda x: kind_to_value(x.kind)) + return s[-1] + + +def all_equal(iterable): + """ + Returns ``True`` if all the elements are equal to each other. + Copied from: more-itertools 8.12.0 + """ + g = groupby(iterable) + return next(g, True) and not next(g, False) + + +def collate_types(types: Sequence[Union[BasicType, VectorType]]): + """ + Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double + Uses the collation rules from numpy. + """ + # Pointer arithmetic case i.e. pointer + [int, uint] is allowed + if any(isinstance(t, PointerType) for t in types): + pointer_type = None + for t in types: + if isinstance(t, PointerType): + if pointer_type is not None: + raise ValueError(f'Cannot collate the combination of two pointer types "{pointer_type}" and "{t}"') + pointer_type = t + elif isinstance(t, BasicType): + if not (t.is_int() or t.is_uint()): + raise ValueError("Invalid pointer arithmetic") + else: + raise ValueError("Invalid pointer arithmetic") + return pointer_type + + # # peel of vector types, if at least one vector type occurred the result will also be the vector type + vector_type = [t for t in types if isinstance(t, VectorType)] + if not all_equal(t.width for t in vector_type): + raise ValueError("Collation failed because of vector types with different width") + + types = [t.base_type if isinstance(t, VectorType) else t for t in types] + + # now we should have a list of basic types - struct types are not yet supported + assert all(type(t) is BasicType for t in types) + + result_numpy_type = result_type(*(t.numpy_dtype for t in types)) + result = BasicType(result_numpy_type) + if vector_type: + result = VectorType(result, vector_type[0].width) + return result + + def assumptions_from_dtype(dtype: Union[BasicType, np.dtype]): """Derives SymPy assumptions from :class:`BasicType` or a Numpy dtype @@ -419,6 +486,9 @@ class TypedSymbol(sp.Symbol): SHAPE_DTYPE = BasicType('int64', const=True) STRIDE_DTYPE = BasicType('int64', const=True) +LOOP_COUNTER_DTYPE = BasicType('int64', const=True) +LOOP_COUNTER_NAME_PREFIX = "ctr" +BLOCK_LOOP_COUNTER_NAME_PREFIX = "_blockctr" class FieldStrideSymbol(TypedSymbol): @@ -501,6 +571,25 @@ class FieldPointerSymbol(TypedSymbol): __xnew_cached_ = staticmethod(sp.core.cacheit(__new_stage2__)) +def get_loop_counter_symbol(coordinate_to_loop_over): + return TypedSymbol(f"{LOOP_COUNTER_NAME_PREFIX}_{coordinate_to_loop_over}", + LOOP_COUNTER_DTYPE, nonnegative=True) + + +def get_block_loop_counter_symbol(coordinate_to_loop_over): + return TypedSymbol(f"{BLOCK_LOOP_COUNTER_NAME_PREFIX}_{coordinate_to_loop_over}", + LOOP_COUNTER_DTYPE, nonnegative=True) + + +def is_loop_counter_symbol(symbol): + if not symbol.name.startswith(LOOP_COUNTER_NAME_PREFIX): + return None + if symbol.dtype != LOOP_COUNTER_DTYPE: + return None + coordinate = int(symbol.name[len(LOOP_COUNTER_NAME_PREFIX) + 1:]) + return coordinate + + class CastFunc(sp.Function): """ CastFunc is used in order to introduce static casts. They are especially useful as a way to signal what type -- GitLab