Commit 0d6780b8 authored by Jan Hönig's avatar Jan Hönig
Browse files

Merge branch 'Extend_testsuite' into 'master'

Extend testsuite

See merge request pycodegen/pystencils!168
parents 4f41d979 20118400
...@@ -298,7 +298,7 @@ class CBackend: ...@@ -298,7 +298,7 @@ class CBackend:
return node.get_code(self._dialect, self._vector_instruction_set) return node.get_code(self._dialect, self._vector_instruction_set)
def _print_SourceCodeComment(self, node): def _print_SourceCodeComment(self, node):
return "/* " + node.text + " */" return f"/* {node.text } */"
def _print_EmptyLine(self, node): def _print_EmptyLine(self, node):
return "" return ""
...@@ -316,7 +316,7 @@ class CBackend: ...@@ -316,7 +316,7 @@ class CBackend:
result = f"if ({condition_expr})\n{true_block} " result = f"if ({condition_expr})\n{true_block} "
if node.false_block: if node.false_block:
false_block = self._print_Block(node.false_block) false_block = self._print_Block(node.false_block)
result += "else " + false_block result += f"else {false_block}"
return result return result
...@@ -336,7 +336,7 @@ class CustomSympyPrinter(CCodePrinter): ...@@ -336,7 +336,7 @@ class CustomSympyPrinter(CCodePrinter):
return self._typed_number(expr.evalf(), get_type_of_expression(expr)) return self._typed_number(expr.evalf(), get_type_of_expression(expr))
if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8: if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")" return f"({self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False))})"
elif expr.exp.is_integer and expr.exp.is_number and - 8 < expr.exp < 0: elif expr.exp.is_integer and expr.exp.is_number and - 8 < expr.exp < 0:
return f"1 / ({self._print(sp.Mul(*([expr.base] * -expr.exp), evaluate=False))})" return f"1 / ({self._print(sp.Mul(*([expr.base] * -expr.exp), evaluate=False))})"
else: else:
...@@ -589,9 +589,6 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): ...@@ -589,9 +589,6 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
result = self.instruction_set['&'].format(result, item) result = self.instruction_set['&'].format(result, item)
return result return result
def _print_Max(self, expr):
return "test"
def _print_Or(self, expr): def _print_Or(self, expr):
result = self._scalarFallback('_print_Or', expr) result = self._scalarFallback('_print_Or', expr)
if result: if result:
......
...@@ -104,7 +104,6 @@ class CudaSympyPrinter(CustomSympyPrinter): ...@@ -104,7 +104,6 @@ class CudaSympyPrinter(CustomSympyPrinter):
assert len(expr.args) == 1, f"__fsqrt_rn has one argument, but {len(expr.args)} where given" assert len(expr.args) == 1, f"__fsqrt_rn has one argument, but {len(expr.args)} where given"
return f"__fsqrt_rn({self._print(expr.args[0])})" return f"__fsqrt_rn({self._print(expr.args[0])})"
elif isinstance(expr, fast_inv_sqrt): elif isinstance(expr, fast_inv_sqrt):
print(len(expr.args) == 1)
assert len(expr.args) == 1, f"__frsqrt_rn has one argument, but {len(expr.args)} where given" assert len(expr.args) == 1, f"__frsqrt_rn has one argument, but {len(expr.args)} where given"
return f"__frsqrt_rn({self._print(expr.args[0])})" return f"__frsqrt_rn({self._print(expr.args[0])})"
return super()._print_Function(expr) return super()._print_Function(expr)
...@@ -86,6 +86,13 @@ class DataHandling(ABC): ...@@ -86,6 +86,13 @@ class DataHandling(ABC):
Args: Args:
description (str): String description of the fields to add description (str): String description of the fields to add
dtype: data type of the array as numpy data type dtype: data type of the array as numpy data type
ghost_layers: number of ghost layers - if not specified a default value specified in the constructor
is used
layout: memory layout of array, either structure of arrays 'SoA' or array of structures 'AoS'.
this is only important if values_per_cell > 1
cpu: allocate field on the CPU
gpu: allocate field on the GPU, if None, a GPU field is allocated if default_target is 'gpu'
alignment: either False for no alignment, or the number of bytes to align to
Returns: Returns:
Fields representing the just created arrays Fields representing the just created arrays
""" """
...@@ -200,6 +207,10 @@ class DataHandling(ABC): ...@@ -200,6 +207,10 @@ class DataHandling(ABC):
directly passed to the kernel function and override possible parameters from the DataHandling directly passed to the kernel function and override possible parameters from the DataHandling
""" """
@abstractmethod
def get_kernel_kwargs(self, kernel_function, **kwargs):
"""Returns the input arguments of a kernel"""
@abstractmethod @abstractmethod
def swap(self, name1, name2, gpu=False): def swap(self, name1, name2, gpu=False):
"""Swaps data of two arrays""" """Swaps data of two arrays"""
......
...@@ -266,10 +266,10 @@ class SerialDataHandling(DataHandling): ...@@ -266,10 +266,10 @@ class SerialDataHandling(DataHandling):
return name in self.gpu_arrays return name in self.gpu_arrays
def synchronization_function_cpu(self, names, stencil_name=None, **_): def synchronization_function_cpu(self, names, stencil_name=None, **_):
return self.synchronization_function(names, stencil_name, 'cpu') return self.synchronization_function(names, stencil_name, target='cpu')
def synchronization_function_gpu(self, names, stencil_name=None, **_): def synchronization_function_gpu(self, names, stencil_name=None, **_):
return self.synchronization_function(names, stencil_name, 'gpu') return self.synchronization_function(names, stencil_name, target='gpu')
def synchronization_function(self, names, stencil=None, target=None, **_): def synchronization_function(self, names, stencil=None, target=None, **_):
if target is None: if target is None:
...@@ -425,14 +425,15 @@ class SerialDataHandling(DataHandling): ...@@ -425,14 +425,15 @@ class SerialDataHandling(DataHandling):
np.savez_compressed(file, **self.cpu_arrays) np.savez_compressed(file, **self.cpu_arrays)
def load_all(self, file): def load_all(self, file):
if '.npz' not in file:
file += '.npz'
file_contents = np.load(file) file_contents = np.load(file)
for arr_name, arr_contents in self.cpu_arrays.items(): for arr_name, arr_contents in self.cpu_arrays.items():
if arr_name not in file_contents: if arr_name not in file_contents:
print(f"Skipping read data {arr_name} because there is no data with this name in data handling") print(f"Skipping read data {arr_name} because there is no data with this name in data handling")
continue continue
if file_contents[arr_name].shape != arr_contents.shape: if file_contents[arr_name].shape != arr_contents.shape:
print("Skipping read data {} because shapes don't match. " print(f"Skipping read data {arr_name} because shapes don't match. "
"Read array shape {}, existing array shape {}".format(arr_name, file_contents[arr_name].shape, f"Read array shape {file_contents[arr_name].shape}, existing array shape {arr_contents.shape}")
arr_contents.shape))
continue continue
np.copyto(arr_contents, file_contents[arr_name]) np.copyto(arr_contents, file_contents[arr_name])
...@@ -228,7 +228,9 @@ def diff_terms(expr): ...@@ -228,7 +228,9 @@ def diff_terms(expr):
Example: Example:
>>> x, y = sp.symbols("x, y") >>> x, y = sp.symbols("x, y")
>>> diff_terms( diff(x, 0, 0) ) >>> diff_terms( diff(x, 0, 0) )
{Diff(Diff(x, 0, -1), 0, -1)}
>>> diff_terms( diff(x, 0, 0) + y )
{Diff(Diff(x, 0, -1), 0, -1)} {Diff(Diff(x, 0, -1), 0, -1)}
""" """
result = set() result = set()
......
from .assignment_collection import AssignmentCollection from .assignment_collection import AssignmentCollection
from .simplifications import ( from .simplifications import (
add_subexpressions_for_divisions, add_subexpressions_for_field_reads, add_subexpressions_for_divisions, add_subexpressions_for_field_reads,
apply_on_all_subexpressions, apply_to_all_assignments, add_subexpressions_for_sums, apply_on_all_subexpressions, apply_to_all_assignments,
subexpression_substitution_in_existing_subexpressions, subexpression_substitution_in_existing_subexpressions,
subexpression_substitution_in_main_assignments, sympy_cse, sympy_cse_on_assignment_list) subexpression_substitution_in_main_assignments, sympy_cse, sympy_cse_on_assignment_list)
from .simplificationstrategy import SimplificationStrategy from .simplificationstrategy import SimplificationStrategy
...@@ -10,4 +10,4 @@ __all__ = ['AssignmentCollection', 'SimplificationStrategy', ...@@ -10,4 +10,4 @@ __all__ = ['AssignmentCollection', 'SimplificationStrategy',
'sympy_cse', 'sympy_cse_on_assignment_list', 'apply_to_all_assignments', 'sympy_cse', 'sympy_cse_on_assignment_list', 'apply_to_all_assignments',
'apply_on_all_subexpressions', 'subexpression_substitution_in_existing_subexpressions', 'apply_on_all_subexpressions', 'subexpression_substitution_in_existing_subexpressions',
'subexpression_substitution_in_main_assignments', 'add_subexpressions_for_divisions', 'subexpression_substitution_in_main_assignments', 'add_subexpressions_for_divisions',
'add_subexpressions_for_field_reads'] 'add_subexpressions_for_sums', 'add_subexpressions_for_field_reads']
...@@ -6,8 +6,7 @@ import sympy as sp ...@@ -6,8 +6,7 @@ import sympy as sp
import pystencils import pystencils
from pystencils.assignment import Assignment from pystencils.assignment import Assignment
from pystencils.simp.simplifications import ( from pystencils.simp.simplifications import (sort_assignments_topologically, transform_lhs_and_rhs, transform_rhs)
sort_assignments_topologically, transform_lhs_and_rhs, transform_rhs)
from pystencils.sympyextensions import count_operations, fast_subs from pystencils.sympyextensions import count_operations, fast_subs
...@@ -263,7 +262,7 @@ class AssignmentCollection: ...@@ -263,7 +262,7 @@ class AssignmentCollection:
own_definitions = set([e.lhs for e in self.main_assignments]) own_definitions = set([e.lhs for e in self.main_assignments])
other_definitions = set([e.lhs for e in other.main_assignments]) other_definitions = set([e.lhs for e in other.main_assignments])
assert len(own_definitions.intersection(other_definitions)) == 0, \ assert len(own_definitions.intersection(other_definitions)) == 0, \
"Cannot new_merged, since both collection define the same symbols" "Cannot merge collections, since both define the same symbols"
own_subexpression_symbols = {e.lhs: e.rhs for e in self.subexpressions} own_subexpression_symbols = {e.lhs: e.rhs for e in self.subexpressions}
substitution_dict = {} substitution_dict = {}
...@@ -334,7 +333,7 @@ class AssignmentCollection: ...@@ -334,7 +333,7 @@ class AssignmentCollection:
kept_subexpressions = [] kept_subexpressions = []
if self.subexpressions[0].lhs in subexpressions_to_keep: if self.subexpressions[0].lhs in subexpressions_to_keep:
substitution_dict = {} substitution_dict = {}
kept_subexpressions = self.subexpressions[0] kept_subexpressions.append(self.subexpressions[0])
else: else:
substitution_dict = {self.subexpressions[0].lhs: self.subexpressions[0].rhs} substitution_dict = {self.subexpressions[0].lhs: self.subexpressions[0].rhs}
......
...@@ -18,7 +18,7 @@ def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node] ...@@ -18,7 +18,7 @@ def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node]
elif isinstance(e1, Node): elif isinstance(e1, Node):
symbols = e1.symbols_defined symbols = e1.symbols_defined
else: else:
raise NotImplementedError("Cannot sort topologically. Object of type " + type(e1) + " cannot be handled.") raise NotImplementedError(f"Cannot sort topologically. Object of type {type(e1)} cannot be handled.")
for lhs in symbols: for lhs in symbols:
for c2, e2 in enumerate(assignments): for c2, e2 in enumerate(assignments):
...@@ -112,14 +112,14 @@ def add_subexpressions_for_sums(ac): ...@@ -112,14 +112,14 @@ def add_subexpressions_for_sums(ac):
addends = [] addends = []
def contains_sum(term): def contains_sum(term):
if term.func == sp.add.Add: if term.func == sp.Add:
return True return True
if term.is_Atom: if term.is_Atom:
return False return False
return any([contains_sum(a) for a in term.args]) return any([contains_sum(a) for a in term.args])
def search_addends(term): def search_addends(term):
if term.func == sp.add.Add: if term.func == sp.Add:
if all([not contains_sum(a) for a in term.args]): if all([not contains_sum(a) for a in term.args]):
addends.extend(term.args) addends.extend(term.args)
for a in term.args: for a in term.args:
......
...@@ -34,6 +34,8 @@ def is_valid(stencil, max_neighborhood=None): ...@@ -34,6 +34,8 @@ def is_valid(stencil, max_neighborhood=None):
True True
>>> is_valid([(2, 0), (1, 0)], max_neighborhood=1) >>> is_valid([(2, 0), (1, 0)], max_neighborhood=1)
False False
>>> is_valid([(2, 0), (1, 0)], max_neighborhood=2)
True
""" """
expected_dim = len(stencil[0]) expected_dim = len(stencil[0])
for d in stencil: for d in stencil:
...@@ -67,8 +69,11 @@ def have_same_entries(s1, s2): ...@@ -67,8 +69,11 @@ def have_same_entries(s1, s2):
Examples: Examples:
>>> stencil1 = [(1, 0), (-1, 0), (0, 1), (0, -1)] >>> stencil1 = [(1, 0), (-1, 0), (0, 1), (0, -1)]
>>> stencil2 = [(-1, 0), (0, -1), (1, 0), (0, 1)] >>> stencil2 = [(-1, 0), (0, -1), (1, 0), (0, 1)]
>>> stencil3 = [(-1, 0), (0, -1), (1, 0)]
>>> have_same_entries(stencil1, stencil2) >>> have_same_entries(stencil1, stencil2)
True True
>>> have_same_entries(stencil1, stencil3)
False
""" """
if len(s1) != len(s2): if len(s1) != len(s2):
return False return False
......
...@@ -272,7 +272,7 @@ def subs_additive(expr: sp.Expr, replacement: sp.Expr, subexpression: sp.Expr, ...@@ -272,7 +272,7 @@ def subs_additive(expr: sp.Expr, replacement: sp.Expr, subexpression: sp.Expr,
def replace_second_order_products(expr: sp.Expr, search_symbols: Iterable[sp.Symbol], def replace_second_order_products(expr: sp.Expr, search_symbols: Iterable[sp.Symbol],
positive: Optional[bool] = None, positive: Optional[bool] = None,
replace_mixed: Optional[List[Assignment]] = None) -> sp.Expr: replace_mixed: Optional[List[Assignment]] = None) -> sp.Expr:
"""Replaces second order mixed terms like x*y by 2*( (x+y)**2 - x**2 - y**2 ). """Replaces second order mixed terms like 4*x*y by 2*( (x+y)**2 - x**2 - y**2 ).
This makes the term longer - simplify usually is undoing these - however this This makes the term longer - simplify usually is undoing these - however this
transformation can be done to find more common sub-expressions transformation can be done to find more common sub-expressions
...@@ -293,7 +293,7 @@ def replace_second_order_products(expr: sp.Expr, search_symbols: Iterable[sp.Sym ...@@ -293,7 +293,7 @@ def replace_second_order_products(expr: sp.Expr, search_symbols: Iterable[sp.Sym
if expr.is_Mul: if expr.is_Mul:
distinct_search_symbols = set() distinct_search_symbols = set()
nr_of_search_terms = 0 nr_of_search_terms = 0
other_factors = 1 other_factors = sp.Integer(1)
for t in expr.args: for t in expr.args:
if t in search_symbols: if t in search_symbols:
nr_of_search_terms += 1 nr_of_search_terms += 1
...@@ -509,13 +509,14 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]], ...@@ -509,13 +509,14 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
if t.exp >= 0: if t.exp >= 0:
result['muls'] += int(t.exp) - 1 result['muls'] += int(t.exp) - 1
else: else:
result['muls'] -= 1 if result['muls'] > 0:
result['muls'] -= 1
result['divs'] += 1 result['divs'] += 1
result['muls'] += (-int(t.exp)) - 1 result['muls'] += (-int(t.exp)) - 1
elif sp.nsimplify(t.exp) == sp.Rational(1, 2): elif sp.nsimplify(t.exp) == sp.Rational(1, 2):
result['sqrts'] += 1 result['sqrts'] += 1
else: else:
warnings.warn("Cannot handle exponent", t.exp, " of sp.Pow node") warnings.warn(f"Cannot handle exponent {t.exp} of sp.Pow node")
else: else:
warnings.warn("Counting operations: only integer exponents are supported in Pow, " warnings.warn("Counting operations: only integer exponents are supported in Pow, "
"counting will be inaccurate") "counting will be inaccurate")
...@@ -526,7 +527,7 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]], ...@@ -526,7 +527,7 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
elif isinstance(t, sp.Rel): elif isinstance(t, sp.Rel):
pass pass
else: else:
warnings.warn("Unknown sympy node of type " + str(t.func) + " counting will be inaccurate") warnings.warn(f"Unknown sympy node of type {str(t.func)} counting will be inaccurate")
if visit_children: if visit_children:
for a in t.args: for a in t.args:
......
...@@ -1206,13 +1206,13 @@ def get_loop_hierarchy(ast_node): ...@@ -1206,13 +1206,13 @@ def get_loop_hierarchy(ast_node):
return reversed(result) return reversed(result)
def get_loop_counter_symbol_hierarchy(astNode): def get_loop_counter_symbol_hierarchy(ast_node):
"""Determines the loop counter symbols around a given AST node. """Determines the loop counter symbols around a given AST node.
:param astNode: the AST node :param ast_node: the AST node
:return: list of loop counter symbols, where the first list entry is the symbol of the innermost loop :return: list of loop counter symbols, where the first list entry is the symbol of the innermost loop
""" """
result = [] result = []
node = astNode node = ast_node
while node is not None: while node is not None:
node = get_next_parent_of_type(node, ast.LoopOverCoordinate) node = get_next_parent_of_type(node, ast.LoopOverCoordinate)
if node: if node:
......
import os import os
import itertools
from collections import Counter from collections import Counter
from contextlib import contextmanager from contextlib import contextmanager
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
...@@ -96,16 +97,21 @@ def fully_contains(l1, l2): ...@@ -96,16 +97,21 @@ def fully_contains(l1, l2):
def boolean_array_bounding_box(boolean_array): def boolean_array_bounding_box(boolean_array):
"""Returns bounding box around "true" area of boolean array""" """Returns bounding box around "true" area of boolean array
dim = len(boolean_array.shape)
>>> a = np.zeros((4, 4), dtype=bool)
>>> a[1:-1, 1:-1] = True
>>> boolean_array_bounding_box(a)
[(1, 3), (1, 3)]
"""
dim = boolean_array.ndim
shape = boolean_array.shape
assert 0 not in shape, "Shape must not contain zero"
bounds = [] bounds = []
for i in range(dim): for ax in itertools.combinations(reversed(range(dim)), dim - 1):
for j in range(dim): nonzero = np.any(boolean_array, axis=ax)
if i != j: t = np.where(nonzero)[0][[0, -1]]
arr_1d = np.any(boolean_array, axis=j) bounds.append((t[0], t[1] + 1))
begin = np.argmax(arr_1d)
end = begin + np.argmin(arr_1d[begin:])
bounds.append((begin, end))
return bounds return bounds
...@@ -217,7 +223,8 @@ class LinearEquationSystem: ...@@ -217,7 +223,8 @@ class LinearEquationSystem:
return 'multiple' return 'multiple'
def solution(self): def solution(self):
"""Solves the system if it has a single solution. Returns a dictionary mapping symbol to solution value.""" """Solves the system. Under- and overdetermined systems are supported.
Returns a dictionary mapping symbol to solution value."""
return sp.solve_linear_system(self._matrix, *self.unknowns) return sp.solve_linear_system(self._matrix, *self.unknowns)
def _resize_if_necessary(self, new_rows=1): def _resize_if_necessary(self, new_rows=1):
...@@ -233,8 +240,3 @@ class LinearEquationSystem: ...@@ -233,8 +240,3 @@ class LinearEquationSystem:
break break
result -= 1 result -= 1
self.next_zero_row = result self.next_zero_row = result
def find_unique_solutions_with_zeros(system: LinearEquationSystem):
if not system.solution_structure() != 'multiple':
raise ValueError("Function works only for underdetermined systems")
import pytest import pytest
import sympy as sp import sympy as sp
import pystencils as ps
from pystencils import Assignment, AssignmentCollection from pystencils import Assignment, AssignmentCollection
from pystencils.astnodes import Conditional from pystencils.astnodes import Conditional
from pystencils.simp.assignment_collection import SymbolGen from pystencils.simp.assignment_collection import SymbolGen
a, b, c = sp.symbols("a b c")
x, y, z, t = sp.symbols("x y z t")
symbol_gen = SymbolGen("a")
f = ps.fields("f(2) : [2D]")
d = ps.fields("d(2) : [2D]")
def test_assignment_collection():
x, y, z, t = sp.symbols("x y z t")
symbol_gen = SymbolGen("a")
def test_assignment_collection():
ac = AssignmentCollection([Assignment(z, x + y)], ac = AssignmentCollection([Assignment(z, x + y)],
[], subexpression_symbol_generator=symbol_gen) [], subexpression_symbol_generator=symbol_gen)
...@@ -32,10 +36,6 @@ def test_assignment_collection(): ...@@ -32,10 +36,6 @@ def test_assignment_collection():
def test_free_and_defined_symbols(): def test_free_and_defined_symbols():
x, y, z, t = sp.symbols("x y z t")
a, b = sp.symbols("a b")
symbol_gen = SymbolGen("a")
ac = AssignmentCollection([Assignment(z, x + y), Conditional(t > 0, Assignment(a, b+1), Assignment(a, b+2))], ac = AssignmentCollection([Assignment(z, x + y), Conditional(t > 0, Assignment(a, b+1), Assignment(a, b+2))],
[], subexpression_symbol_generator=symbol_gen) [], subexpression_symbol_generator=symbol_gen)
...@@ -45,35 +45,128 @@ def test_free_and_defined_symbols(): ...@@ -45,35 +45,128 @@ def test_free_and_defined_symbols():
def test_vector_assignments(): def test_vector_assignments():
"""From #17 (https://i10git.cs.fau.de/pycodegen/pystencils/issues/17)""" """From #17 (https://i10git.cs.fau.de/pycodegen/pystencils/issues/17)"""
assignments = ps.Assignment(sp.Matrix([a, b, c]), sp.Matrix([1, 2, 3]))
import pystencils as ps
import sympy as sp
a, b, c = sp.symbols("a b c")
assignments = ps.Assignment(sp.Matrix([a,b,c]), sp.Matrix([1,2,3]))
print(assignments) print(assignments)
def test_wrong_vector_assignments(): def test_wrong_vector_assignments():
"""From #17 (https://i10git.cs.fau.de/pycodegen/pystencils/issues/17)""" """From #17 (https://i10git.cs.fau.de/pycodegen/pystencils/issues/17)"""
import pystencils as ps
import sympy as sp
a, b = sp.symbols("a b")
with pytest.raises(AssertionError, with pytest.raises(AssertionError,
match=r'Matrix(.*) and Matrix(.*) must have same length when performing vector assignment!'): match=r'Matrix(.*) and Matrix(.*) must have same length when performing vector assignment!'):
ps.Assignment(sp.Matrix([a,b]), sp.Matrix([1,2,3])) ps.Assignment(sp.Matrix([a, b]), sp.Matrix([1, 2, 3]))
def test_vector_assignment_collection(): def test_vector_assignment_collection():
"""From #17 (https://i10git.cs.fau.de/pycodegen/pystencils/issues/17)""" """From #17 (https://i10git.cs.fau.de/pycodegen/pystencils/issues/17)"""
import pystencils as ps y_m, x_m = sp.Matrix([a, b, c]), sp.Matrix([1, 2, 3])
import sympy as sp assignments = ps.AssignmentCollection({y_m: x_m})
a, b, c = sp.symbols("a b c")
y, x = sp.Matrix([a,b,c]), sp.Matrix([1,2,3])
assignments = ps.AssignmentCollection({y: x})
print(assignments) print(assignments)
assignments = ps.AssignmentCollection([ps.Assignment(y,x)]) assignments = ps.AssignmentCollection([ps.Assignment(y_m, x_m)])
print(assignments) print(assignments)
def test_new_with_substitutions():
a1 = ps.Assignment(f[0, 0](0), a * b)
a2 = ps.Assignment(f[0, 0](1), b * c)
ac = ps.AssignmentCollection([a1, a2], subexpressions=[])
subs_dict = {f[0, 0](0): d[0, 0](0), f[0, 0](1): d[0, 0](1)}
subs_ac = ac.new_with_substitutions(subs_dict,
add_substitutions_as_subexpressions=False,
substitute_on_lhs=True,
sort_topologically=True)
assert subs_ac.main_assignments[0].lhs == d[0, 0](0)
assert subs_ac.main_assignments[1].lhs == d[0, 0](1)
subs_ac = ac.new_with_substitutions(subs_dict,