Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
No results found
Show changes
Showing
with 252 additions and 30 deletions
...@@ -5,7 +5,7 @@ from typing import Union, List ...@@ -5,7 +5,7 @@ from typing import Union, List
import sympy as sp import sympy as sp
from pystencils.config import CreateKernelConfig from pystencils.config import CreateKernelConfig
from pystencils.assignment import Assignment from pystencils.assignment import Assignment, AddAugmentedAssignment
from pystencils.astnodes import Node, Block, Conditional, LoopOverCoordinate, SympyAssignment from pystencils.astnodes import Node, Block, Conditional, LoopOverCoordinate, SympyAssignment
from pystencils.cpu.vectorization import vectorize from pystencils.cpu.vectorization import vectorize
from pystencils.enums import Target, Backend from pystencils.enums import Target, Backend
...@@ -19,7 +19,10 @@ from pystencils.transformations import ( ...@@ -19,7 +19,10 @@ from pystencils.transformations import (
loop_blocking, move_constants_before_loop, remove_conditionals_in_staggered_kernel) loop_blocking, move_constants_before_loop, remove_conditionals_in_staggered_kernel)
def create_kernel(assignments: Union[Assignment, List[Assignment], AssignmentCollection, List[Node], NodeCollection], *, def create_kernel(assignments: Union[Assignment, List[Assignment],
AddAugmentedAssignment, List[AddAugmentedAssignment],
AssignmentCollection, List[Node], NodeCollection],
*,
config: CreateKernelConfig = None, **kwargs): config: CreateKernelConfig = None, **kwargs):
""" """
Creates abstract syntax tree (AST) of kernel, using a list of update equations. Creates abstract syntax tree (AST) of kernel, using a list of update equations.
...@@ -59,7 +62,7 @@ def create_kernel(assignments: Union[Assignment, List[Assignment], AssignmentCol ...@@ -59,7 +62,7 @@ def create_kernel(assignments: Union[Assignment, List[Assignment], AssignmentCol
setattr(config, k, v) setattr(config, k, v)
# ---- Normalizing parameters # ---- Normalizing parameters
if isinstance(assignments, Assignment): if isinstance(assignments, (Assignment, AddAugmentedAssignment)):
assignments = [assignments] assignments = [assignments]
assert assignments, "Assignments must not be empty!" assert assignments, "Assignments must not be empty!"
if isinstance(assignments, list): if isinstance(assignments, list):
...@@ -86,13 +89,13 @@ def create_kernel(assignments: Union[Assignment, List[Assignment], AssignmentCol ...@@ -86,13 +89,13 @@ def create_kernel(assignments: Union[Assignment, List[Assignment], AssignmentCol
def create_domain_kernel(assignments: NodeCollection, *, config: CreateKernelConfig): def create_domain_kernel(assignments: NodeCollection, *, config: CreateKernelConfig):
""" """
Creates abstract syntax tree (AST) of kernel, using a list of update equations. Creates abstract syntax tree (AST) of kernel, using a NodeCollection.
Note that `create_domain_kernel` is a lower level function which shoul be accessed by not providing `index_fields` Note that `create_domain_kernel` is a lower level function which shoul be accessed by not providing `index_fields`
to create_kernel to create_kernel
Args: Args:
assignments: can be a single assignment, sequence of assignments or an `AssignmentCollection` assignments: `pystencils.node_collection.NodeCollection` containing all assignements and nodes to be processed
config: CreateKernelConfig which includes the needed configuration config: CreateKernelConfig which includes the needed configuration
Returns: Returns:
...@@ -125,6 +128,7 @@ def create_domain_kernel(assignments: NodeCollection, *, config: CreateKernelCon ...@@ -125,6 +128,7 @@ def create_domain_kernel(assignments: NodeCollection, *, config: CreateKernelCon
# --- check constrains # --- check constrains
check = KernelConstraintsCheck(check_independence_condition=not config.skip_independence_check, check = KernelConstraintsCheck(check_independence_condition=not config.skip_independence_check,
check_double_write_condition=not config.allow_double_writes) check_double_write_condition=not config.allow_double_writes)
check.visit(assignments) check.visit(assignments)
assignments.bound_fields = check.fields_written assignments.bound_fields = check.fields_written
...@@ -159,7 +163,7 @@ def create_domain_kernel(assignments: NodeCollection, *, config: CreateKernelCon ...@@ -159,7 +163,7 @@ def create_domain_kernel(assignments: NodeCollection, *, config: CreateKernelCon
raise ValueError("Invalid value for cpu_vectorize_info") raise ValueError("Invalid value for cpu_vectorize_info")
elif config.target == Target.GPU: elif config.target == Target.GPU:
if config.backend == Backend.CUDA: if config.backend == Backend.CUDA:
from pystencils.gpucuda import create_cuda_kernel from pystencils.gpu import create_cuda_kernel
ast = create_cuda_kernel(assignments, config=config) ast = create_cuda_kernel(assignments, config=config)
if not ast: if not ast:
...@@ -187,7 +191,7 @@ def create_indexed_kernel(assignments: NodeCollection, *, config: CreateKernelCo ...@@ -187,7 +191,7 @@ def create_indexed_kernel(assignments: NodeCollection, *, config: CreateKernelCo
to create_kernel to create_kernel
Args: Args:
assignments: can be a single assignment, sequence of assignments or an `AssignmentCollection` assignments: `pystencils.node_collection.NodeCollection` containing all assignements and nodes to be processed
config: CreateKernelConfig which includes the needed configuration config: CreateKernelConfig which includes the needed configuration
Returns: Returns:
...@@ -241,7 +245,7 @@ def create_indexed_kernel(assignments: NodeCollection, *, config: CreateKernelCo ...@@ -241,7 +245,7 @@ def create_indexed_kernel(assignments: NodeCollection, *, config: CreateKernelCo
add_openmp(ast, num_threads=config.cpu_openmp) add_openmp(ast, num_threads=config.cpu_openmp)
elif config.target == Target.GPU: elif config.target == Target.GPU:
if config.backend == Backend.CUDA: if config.backend == Backend.CUDA:
from pystencils.gpucuda import created_indexed_cuda_kernel from pystencils.gpu import created_indexed_cuda_kernel
ast = created_indexed_cuda_kernel(assignments, config=config) ast = created_indexed_cuda_kernel(assignments, config=config)
if not ast: if not ast:
......
from typing import Any, Dict, List, Union, Optional, Set
import sympy
import sympy as sp
from sympy.codegen.rewriting import ReplaceOptim, optimize
from pystencils.assignment import Assignment, AddAugmentedAssignment
import pystencils.astnodes as ast
from pystencils.backends.cbackend import CustomCodeNode
from pystencils.functions import DivFunc
from pystencils.simp import AssignmentCollection
from pystencils.typing import FieldPointerSymbol
class NodeCollection:
def __init__(self, assignments: List[Union[ast.Node, Assignment]],
simplification_hints: Optional[Dict[str, Any]] = None,
bound_fields: Set[sp.Symbol] = None, rhs_fields: Set[sp.Symbol] = None):
def visit(obj):
if isinstance(obj, (list, tuple)):
return [visit(e) for e in obj]
if isinstance(obj, Assignment):
if isinstance(obj.lhs, FieldPointerSymbol):
return ast.SympyAssignment(obj.lhs, obj.rhs, is_const=obj.lhs.dtype.const)
return ast.SympyAssignment(obj.lhs, obj.rhs)
elif isinstance(obj, AddAugmentedAssignment):
return ast.SympyAssignment(obj.lhs, obj.lhs + obj.rhs)
elif isinstance(obj, ast.SympyAssignment):
return obj
elif isinstance(obj, ast.Conditional):
true_block = visit(obj.true_block)
false_block = None if obj.false_block is None else visit(obj.false_block)
return ast.Conditional(obj.condition_expr, true_block=true_block, false_block=false_block)
elif isinstance(obj, ast.Block):
return ast.Block([visit(e) for e in obj.args])
elif isinstance(obj, ast.Node) and not isinstance(obj, ast.LoopOverCoordinate):
return obj
else:
raise ValueError("Invalid object in the List of Assignments " + str(type(obj)))
self.all_assignments = visit(assignments)
self.simplification_hints = simplification_hints if simplification_hints else {}
self.bound_fields = bound_fields if bound_fields else {}
self.rhs_fields = rhs_fields if rhs_fields else {}
@staticmethod
def from_assignment_collection(assignment_collection: AssignmentCollection):
return NodeCollection(assignments=assignment_collection.all_assignments,
simplification_hints=assignment_collection.simplification_hints,
bound_fields=assignment_collection.bound_fields,
rhs_fields=assignment_collection.rhs_fields)
def evaluate_terms(self):
evaluate_constant_terms = ReplaceOptim(
lambda e: hasattr(e, 'is_constant') and e.is_constant and not e.is_integer,
lambda p: p.evalf()
)
evaluate_pow = ReplaceOptim(
lambda e: e.is_Pow and e.exp.is_Integer and abs(e.exp) <= 8,
lambda p: sp.UnevaluatedExpr(sp.Mul(*([p.base] * +p.exp), evaluate=False)) if p.exp > 0 else
(DivFunc(sp.Integer(1), p.base) if p.exp == -1 else
DivFunc(sp.Integer(1), sp.UnevaluatedExpr(sp.Mul(*([p.base] * -p.exp), evaluate=False))))
)
sympy_optimisations = [evaluate_constant_terms, evaluate_pow]
def visitor(node):
if isinstance(node, CustomCodeNode):
return node
elif isinstance(node, ast.Block):
return node.func([visitor(child) for child in node.args])
elif isinstance(node, ast.SympyAssignment):
new_lhs = visitor(node.lhs)
new_rhs = visitor(node.rhs)
return node.func(new_lhs, new_rhs, node.is_const, node.use_auto)
elif isinstance(node, ast.Node):
return node.func(*[visitor(child) for child in node.args])
elif isinstance(node, sympy.Basic):
return optimize(node, sympy_optimisations)
else:
raise NotImplementedError(f'{node} {type(node)} has no valid visitor')
self.all_assignments = [visitor(assignment) for assignment in self.all_assignments]
File moved
...@@ -61,6 +61,15 @@ class RNGBase(CustomCodeNode): ...@@ -61,6 +61,15 @@ class RNGBase(CustomCodeNode):
return ", ".join([str(s) for s in self.result_symbols]) + " \\leftarrow " + \ return ", ".join([str(s) for s in self.result_symbols]) + " \\leftarrow " + \
self._name.capitalize() + "_RNG(" + ", ".join([str(a) for a in self.args]) + ")" self._name.capitalize() + "_RNG(" + ", ".join([str(a) for a in self.args]) + ")"
def _hashable_content(self):
return (self._name, *self.result_symbols, *self.args)
def __eq__(self, other):
return type(self) is type(other) and self._hashable_content() == other._hashable_content()
def __hash__(self):
return hash(self._hashable_content())
class PhiloxTwoDoubles(RNGBase): class PhiloxTwoDoubles(RNGBase):
_name = "philox_double2" _name = "philox_double2"
......
import socket import socket
import time import time
from types import MappingProxyType
from typing import Dict, Iterator, Sequence from typing import Dict, Iterator, Sequence
import blitzdb import blitzdb
import six
from blitzdb.backends.file.backend import serializer_classes
from blitzdb.backends.file.utils import JsonEncoder
from pystencils.cpu.cpujit import get_compiler_config from pystencils.cpu.cpujit import get_compiler_config
from pystencils import CreateKernelConfig, Target, Backend, Field
import json
import sympy as sp
from pystencils.typing import BasicType
class PystencilsJsonEncoder(JsonEncoder):
def default(self, obj):
if isinstance(obj, CreateKernelConfig):
return obj.__dict__
if isinstance(obj, (sp.Float, sp.Rational)):
return float(obj)
if isinstance(obj, sp.Integer):
return int(obj)
if isinstance(obj, (BasicType, MappingProxyType)):
return str(obj)
if isinstance(obj, (Target, Backend, sp.Symbol)):
return obj.name
if isinstance(obj, Field):
return f"pystencils.Field(name = {obj.name}, field_type = {obj.field_type.name}, " \
f"dtype = {str(obj.dtype)}, layout = {obj.layout}, shape = {obj.shape}, " \
f"strides = {obj.strides})"
return JsonEncoder.default(self, obj)
class PystencilsJsonSerializer(object):
@classmethod
def serialize(cls, data):
if six.PY3:
if isinstance(data, bytes):
return json.dumps(data.decode('utf-8'), cls=PystencilsJsonEncoder, ensure_ascii=False).encode('utf-8')
else:
return json.dumps(data, cls=PystencilsJsonEncoder, ensure_ascii=False).encode('utf-8')
else:
return json.dumps(data, cls=PystencilsJsonEncoder, ensure_ascii=False).encode('utf-8')
@classmethod
def deserialize(cls, data):
if six.PY3:
return json.loads(data.decode('utf-8'))
else:
return json.loads(data.decode('utf-8'))
class Database: class Database:
...@@ -46,7 +96,7 @@ class Database: ...@@ -46,7 +96,7 @@ class Database:
class SimulationResult(blitzdb.Document): class SimulationResult(blitzdb.Document):
pass pass
def __init__(self, file: str) -> None: def __init__(self, file: str, serializer_info: tuple = None) -> None:
if file.startswith("mongo://"): if file.startswith("mongo://"):
from pymongo import MongoClient from pymongo import MongoClient
db_name = file[len("mongo://"):] db_name = file[len("mongo://"):]
...@@ -57,6 +107,10 @@ class Database: ...@@ -57,6 +107,10 @@ class Database:
self.backend.autocommit = True self.backend.autocommit = True
if serializer_info:
serializer_classes.update({serializer_info[0]: serializer_info[1]})
self.backend.load_config({'serializer_class': serializer_info[0]}, True)
def save(self, params: Dict, result: Dict, env: Dict = None, **kwargs) -> None: def save(self, params: Dict, result: Dict, env: Dict = None, **kwargs) -> None:
"""Stores a simulation result in the database. """Stores a simulation result in the database.
...@@ -146,10 +200,15 @@ class Database: ...@@ -146,10 +200,15 @@ class Database:
'cpuCompilerConfig': get_compiler_config(), 'cpuCompilerConfig': get_compiler_config(),
} }
try: try:
from git import Repo, InvalidGitRepositoryError from git import Repo
except ImportError:
return result
try:
from git import InvalidGitRepositoryError
repo = Repo(search_parent_directories=True) repo = Repo(search_parent_directories=True)
result['git_hash'] = str(repo.head.commit) result['git_hash'] = str(repo.head.commit)
except (ImportError, InvalidGitRepositoryError): except InvalidGitRepositoryError:
pass pass
return result return result
......
...@@ -9,6 +9,7 @@ from time import sleep ...@@ -9,6 +9,7 @@ from time import sleep
from typing import Any, Callable, Dict, Optional, Sequence, Tuple from typing import Any, Callable, Dict, Optional, Sequence, Tuple
from pystencils.runhelper import Database from pystencils.runhelper import Database
from pystencils.runhelper.db import PystencilsJsonSerializer
from pystencils.utils import DotDict from pystencils.utils import DotDict
ParameterDict = Dict[str, Any] ParameterDict = Dict[str, Any]
...@@ -54,10 +55,11 @@ class ParameterStudy: ...@@ -54,10 +55,11 @@ class ParameterStudy:
Run = namedtuple("Run", ['parameter_dict', 'weight']) Run = namedtuple("Run", ['parameter_dict', 'weight'])
def __init__(self, run_function: Callable[..., Dict], runs: Sequence = (), def __init__(self, run_function: Callable[..., Dict], runs: Sequence = (),
database_connector: str = './db') -> None: database_connector: str = './db',
serializer_info: tuple = ('pystencils_serializer', PystencilsJsonSerializer)) -> None:
self.runs = list(runs) self.runs = list(runs)
self.run_function = run_function self.run_function = run_function
self.db = Database(database_connector) self.db = Database(database_connector, serializer_info)
def add_run(self, parameter_dict: ParameterDict, weight: int = 1) -> None: def add_run(self, parameter_dict: ParameterDict, weight: int = 1) -> None:
"""Schedule a dictionary of parameters to run in this parameter study. """Schedule a dictionary of parameters to run in this parameter study.
......
File moved
...@@ -61,8 +61,11 @@ class AssignmentCollection: ...@@ -61,8 +61,11 @@ class AssignmentCollection:
self.simplification_hints = simplification_hints self.simplification_hints = simplification_hints
ctrs = [int(n.name[3:])for n in self.rhs_symbols if "xi_" in n.name]
max_ctr = max(ctrs) + 1 if len(ctrs) > 0 else 0
if subexpression_symbol_generator is None: if subexpression_symbol_generator is None:
self.subexpression_symbol_generator = SymbolGen() self.subexpression_symbol_generator = SymbolGen(ctr=max_ctr)
else: else:
self.subexpression_symbol_generator = subexpression_symbol_generator self.subexpression_symbol_generator = subexpression_symbol_generator
...@@ -283,12 +286,13 @@ class AssignmentCollection: ...@@ -283,12 +286,13 @@ class AssignmentCollection:
processed_other_subexpression_equations = [] processed_other_subexpression_equations = []
for other_subexpression_eq in other.subexpressions: for other_subexpression_eq in other.subexpressions:
if other_subexpression_eq.lhs in own_subexpression_symbols: if other_subexpression_eq.lhs in own_subexpression_symbols:
if other_subexpression_eq.rhs == own_subexpression_symbols[other_subexpression_eq.lhs]: new_rhs = fast_subs(other_subexpression_eq.rhs, substitution_dict)
if new_rhs == own_subexpression_symbols[other_subexpression_eq.lhs]:
continue # exact the same subexpression equation exists already continue # exact the same subexpression equation exists already
else: else:
# different definition - a new name has to be introduced # different definition - a new name has to be introduced
new_lhs = next(self.subexpression_symbol_generator) new_lhs = next(self.subexpression_symbol_generator)
new_eq = Assignment(new_lhs, fast_subs(other_subexpression_eq.rhs, substitution_dict)) new_eq = Assignment(new_lhs, new_rhs)
processed_other_subexpression_equations.append(new_eq) processed_other_subexpression_equations.append(new_eq)
substitution_dict[other_subexpression_eq.lhs] = new_lhs substitution_dict[other_subexpression_eq.lhs] = new_lhs
else: else:
...@@ -453,8 +457,8 @@ class AssignmentCollection: ...@@ -453,8 +457,8 @@ class AssignmentCollection:
class SymbolGen: class SymbolGen:
"""Default symbol generator producing number symbols ζ_0, ζ_1, ...""" """Default symbol generator producing number symbols ζ_0, ζ_1, ..."""
def __init__(self, symbol="xi", dtype=None): def __init__(self, symbol="xi", dtype=None, ctr=0):
self._ctr = 0 self._ctr = ctr
self._symbol = symbol self._symbol = symbol
self._dtype = dtype self._dtype = dtype
......
...@@ -8,6 +8,7 @@ from pystencils.assignment import Assignment ...@@ -8,6 +8,7 @@ from pystencils.assignment import Assignment
from pystencils.astnodes import Node from pystencils.astnodes import Node
from pystencils.field import Field from pystencils.field import Field
from pystencils.sympyextensions import subs_additive, is_constant, recursive_collect from pystencils.sympyextensions import subs_additive, is_constant, recursive_collect
from pystencils.typing import TypedSymbol
def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node]]) -> List[Union[Assignment, Node]]: def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node]]) -> List[Union[Assignment, Node]]:
...@@ -168,12 +169,14 @@ def add_subexpressions_for_sums(ac): ...@@ -168,12 +169,14 @@ def add_subexpressions_for_sums(ac):
return ac.new_with_substitutions(substitutions, True, substitute_on_lhs=False) return ac.new_with_substitutions(substitutions, True, substitute_on_lhs=False)
def add_subexpressions_for_field_reads(ac, subexpressions=True, main_assignments=True): def add_subexpressions_for_field_reads(ac, subexpressions=True, main_assignments=True, data_type=None):
r"""Substitutes field accesses on rhs of assignments with subexpressions r"""Substitutes field accesses on rhs of assignments with subexpressions
Can change semantics of the update rule (which is the goal of this transformation) Can change semantics of the update rule (which is the goal of this transformation)
This is useful if a field should be update in place - all values are loaded before into subexpression variables, This is useful if a field should be update in place - all values are loaded before into subexpression variables,
then the new values are computed and written to the same field in-place. then the new values are computed and written to the same field in-place.
Additionally, if a datatype is given to the function the rhs symbol of the new isolated field read will have
this data type. This is useful for mixed precision kernels
""" """
field_reads = set() field_reads = set()
to_iterate = [] to_iterate = []
...@@ -185,7 +188,17 @@ def add_subexpressions_for_field_reads(ac, subexpressions=True, main_assignments ...@@ -185,7 +188,17 @@ def add_subexpressions_for_field_reads(ac, subexpressions=True, main_assignments
for assignment in to_iterate: for assignment in to_iterate:
if hasattr(assignment, 'lhs') and hasattr(assignment, 'rhs'): if hasattr(assignment, 'lhs') and hasattr(assignment, 'rhs'):
field_reads.update(assignment.rhs.atoms(Field.Access)) field_reads.update(assignment.rhs.atoms(Field.Access))
substitutions = {fa: next(ac.subexpression_symbol_generator) for fa in field_reads}
if not field_reads:
return ac
substitutions = dict()
for fa in field_reads:
lhs = next(ac.subexpression_symbol_generator)
if data_type is not None:
substitutions.update({fa: TypedSymbol(lhs.name, data_type)})
else:
substitutions.update({fa: lhs})
return ac.new_with_substitutions(substitutions, add_substitutions_as_subexpressions=True, return ac.new_with_substitutions(substitutions, add_substitutions_as_subexpressions=True,
substitute_on_lhs=False, sort_topologically=False) substitute_on_lhs=False, sort_topologically=False)
......
File moved
...@@ -5,6 +5,8 @@ from typing import Sequence ...@@ -5,6 +5,8 @@ from typing import Sequence
import numpy as np import numpy as np
import sympy as sp import sympy as sp
from pystencils.utils import binary_numbers
def inverse_direction(direction): def inverse_direction(direction):
"""Returns inverse i.e. negative of given direction tuple """Returns inverse i.e. negative of given direction tuple
...@@ -293,6 +295,38 @@ def direction_string_to_offset(direction: str, dim: int = 3): ...@@ -293,6 +295,38 @@ def direction_string_to_offset(direction: str, dim: int = 3):
return offset[:dim] return offset[:dim]
def adjacent_directions(direction):
"""
Returns all adjacent directions for a direction as tuple of tuples. This is useful for exmple to find all directions
relevant for neighbour communication.
Args:
direction: tuple representing a direction. For example (0, 1, 0) for the northern side
Examples:
>>> adjacent_directions((0, 0, 0))
((0, 0, 0),)
>>> adjacent_directions((0, 1, 0))
((0, 1, 0),)
>>> adjacent_directions((0, 1, 1))
((0, 0, 1), (0, 1, 0), (0, 1, 1))
>>> adjacent_directions((-1, -1))
((-1, -1), (-1, 0), (0, -1))
"""
result = set()
if all(e == 0 for e in direction):
result.add(direction)
return tuple(result)
binary_numbers_list = binary_numbers(len(direction))
for adjacent_direction in binary_numbers_list:
for i, entry in enumerate(direction):
if entry == 0:
adjacent_direction[i] = 0
if entry == -1 and adjacent_direction[i] == 1:
adjacent_direction[i] = -1
if not all(e == 0 for e in adjacent_direction):
result.add(tuple(adjacent_direction))
return tuple(sorted(result))
# -------------------------------------- Visualization ----------------------------------------------------------------- # -------------------------------------- Visualization -----------------------------------------------------------------
...@@ -341,7 +375,7 @@ def plot_2d(stencil, axes=None, figure=None, data=None, textsize='12', **kwargs) ...@@ -341,7 +375,7 @@ def plot_2d(stencil, axes=None, figure=None, data=None, textsize='12', **kwargs)
for direction, annotation in zip(stencil, data): for direction, annotation in zip(stencil, data):
assert len(direction) == 2, "Works only for 2D stencils" assert len(direction) == 2, "Works only for 2D stencils"
direction = tuple(int(i) for i in direction) direction = tuple(int(i) for i in direction)
if not(direction[0] == 0 and direction[1] == 0): if not (direction[0] == 0 and direction[1] == 0):
axes.arrow(0, 0, direction[0], direction[1], head_width=0.08, head_length=head_length, color='k') axes.arrow(0, 0, direction[0], direction[1], head_width=0.08, head_length=head_length, color='k')
if isinstance(annotation, sp.Basic): if isinstance(annotation, sp.Basic):
...@@ -421,11 +455,12 @@ def plot_3d(stencil, figure=None, axes=None, data=None, textsize='8'): ...@@ -421,11 +455,12 @@ def plot_3d(stencil, figure=None, axes=None, data=None, textsize='8'):
FancyArrowPatch.__init__(self, (0, 0), (0, 0), *args, **kwargs) FancyArrowPatch.__init__(self, (0, 0), (0, 0), *args, **kwargs)
self._verts3d = xs, ys, zs self._verts3d = xs, ys, zs
def draw(self, renderer): def do_3d_projection(self, *_):
xs3d, ys3d, zs3d = self._verts3d xs3d, ys3d, zs3d = self._verts3d
xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, self.axes.M) xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, self.axes.M)
self.set_positions((xs[0], ys[0]), (xs[1], ys[1])) self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))
FancyArrowPatch.draw(self, renderer)
return np.min(zs)
if axes is None: if axes is None:
if figure is None: if figure is None:
......
...@@ -6,6 +6,7 @@ from functools import partial, reduce ...@@ -6,6 +6,7 @@ from functools import partial, reduce
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Tuple, TypeVar, Union from typing import Callable, Dict, Iterable, List, Optional, Sequence, Tuple, TypeVar, Union
import sympy as sp import sympy as sp
from sympy import PolynomialError
from sympy.functions import Abs from sympy.functions import Abs
from sympy.core.numbers import Zero from sympy.core.numbers import Zero
...@@ -355,7 +356,7 @@ def remove_higher_order_terms(expr: sp.Expr, symbols: Sequence[sp.Symbol], order ...@@ -355,7 +356,7 @@ def remove_higher_order_terms(expr: sp.Expr, symbols: Sequence[sp.Symbol], order
factor_count = 0 factor_count = 0
if type(product) is Mul: if type(product) is Mul:
for factor in product.args: for factor in product.args:
if type(factor) == Pow: if type(factor) is Pow:
if factor.args[0] in symbols: if factor.args[0] in symbols:
factor_count += factor.args[1] factor_count += factor.args[1]
if factor in symbols: if factor in symbols:
...@@ -365,13 +366,13 @@ def remove_higher_order_terms(expr: sp.Expr, symbols: Sequence[sp.Symbol], order ...@@ -365,13 +366,13 @@ def remove_higher_order_terms(expr: sp.Expr, symbols: Sequence[sp.Symbol], order
factor_count += product.args[1] factor_count += product.args[1]
return factor_count return factor_count
if type(expr) == Mul or type(expr) == Pow: if type(expr) is Mul or type(expr) is Pow:
if velocity_factors_in_product(expr) <= order: if velocity_factors_in_product(expr) <= order:
return expr return expr
else: else:
return Zero() return Zero()
if type(expr) != Add: if type(expr) is not Add:
return expr return expr
for sum_term in expr.args: for sum_term in expr.args:
...@@ -442,11 +443,14 @@ def extract_most_common_factor(term): ...@@ -442,11 +443,14 @@ def extract_most_common_factor(term):
def recursive_collect(expr, symbols, order_by_occurences=False): def recursive_collect(expr, symbols, order_by_occurences=False):
"""Applies sympy.collect recursively for a list of symbols, collecting symbol 2 in the coefficients of symbol 1, """Applies sympy.collect recursively for a list of symbols, collecting symbol 2 in the coefficients of symbol 1,
and so on. and so on.
``expr`` must be rewritable as a polynomial in the given ``symbols``.
It it is not, ``recursive_collect`` will fail quietly, returning the original expression.
Args: Args:
expr: A sympy expression expr: A sympy expression.
symbols: A sequence of symbols symbols: A sequence of symbols
order_by_occurences: If True, during recursive descent, always collect the symbol occuring order_by_occurences: If True, during recursive descent, always collect the symbol occuring
most often in the expression. most often in the expression.
...@@ -457,7 +461,13 @@ def recursive_collect(expr, symbols, order_by_occurences=False): ...@@ -457,7 +461,13 @@ def recursive_collect(expr, symbols, order_by_occurences=False):
if len(symbols) == 0: if len(symbols) == 0:
return expr return expr
symbol = symbols[0] symbol = symbols[0]
collected_poly = sp.Poly(expr.collect(symbol), symbol) collected = expr.collect(symbol)
try:
collected_poly = sp.Poly(collected, symbol)
except PolynomialError:
return expr
coeffs = collected_poly.all_coeffs()[::-1] coeffs = collected_poly.all_coeffs()[::-1]
rec_sum = sum(symbol**i * recursive_collect(c, symbols[1:], order_by_occurences) for i, c in enumerate(coeffs)) rec_sum = sum(symbol**i * recursive_collect(c, symbols[1:], order_by_occurences) for i, c in enumerate(coeffs))
return rec_sum return rec_sum
...@@ -629,8 +639,10 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr], List[Assignment]], ...@@ -629,8 +639,10 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr], List[Assignment]],
for child_term, condition in t.args: for child_term, condition in t.args:
visit(child_term) visit(child_term)
visit_children = False visit_children = False
elif isinstance(t, sp.Rel): elif isinstance(t, (sp.Rel, sp.UnevaluatedExpr)):
pass pass
elif isinstance(t, DivFunc):
result["divs"] += 1
else: else:
warnings.warn(f"Unknown sympy node of type {str(t.func)} counting will be inaccurate") warnings.warn(f"Unknown sympy node of type {str(t.func)} counting will be inaccurate")
......