Commit 9f6bbd4c authored by Martin Bauer's avatar Martin Bauer
Browse files

Boolean Function: import from sympy.logic to work with recent sympy versions

parent e2b61455
...@@ -4,16 +4,17 @@ from typing import Set ...@@ -4,16 +4,17 @@ from typing import Set
import numpy as np import numpy as np
import sympy as sp import sympy as sp
from sympy.core import S from sympy.core import S
from sympy.logic.boolalg import BooleanFalse, BooleanTrue
from sympy.printing.ccode import C89CodePrinter from sympy.printing.ccode import C89CodePrinter
from pystencils.astnodes import KernelFunction, Node from pystencils.astnodes import KernelFunction, Node
from pystencils.cpu.vectorization import vec_all, vec_any from pystencils.cpu.vectorization import vec_all, vec_any
from pystencils.data_types import ( from pystencils.data_types import (
PointerType, VectorType, address_of, cast_func, create_type, get_type_of_expression, PointerType, VectorType, address_of, cast_func, create_type, get_type_of_expression, reinterpret_cast_func,
reinterpret_cast_func, vector_memory_access) vector_memory_access)
from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
from pystencils.integer_functions import ( from pystencils.integer_functions import (
bit_shift_left, bit_shift_right, bitwise_and, bitwise_or, bitwise_xor, bit_shift_left, bit_shift_right, bitwise_and, bitwise_or, bitwise_xor, int_div, int_power_of_2, modulo_ceil)
int_div, int_power_of_2, modulo_ceil)
try: try:
from sympy.printing.ccode import C99CodePrinter as CCodePrinter from sympy.printing.ccode import C99CodePrinter as CCodePrinter
...@@ -292,9 +293,9 @@ class CBackend: ...@@ -292,9 +293,9 @@ class CBackend:
return "" return ""
def _print_Conditional(self, node): def _print_Conditional(self, node):
if type(node.condition_expr) is sp.boolalg.BooleanTrue: if type(node.condition_expr) is BooleanTrue:
return self._print_Block(node.true_block) return self._print_Block(node.true_block)
elif type(node.condition_expr) is sp.boolalg.BooleanFalse: elif type(node.condition_expr) is BooleanFalse:
return self._print_Block(node.false_block) return self._print_Block(node.false_block)
cond_type = get_type_of_expression(node.condition_expr) cond_type = get_type_of_expression(node.condition_expr)
if isinstance(cond_type, VectorType): if isinstance(cond_type, VectorType):
......
...@@ -2,18 +2,17 @@ import warnings ...@@ -2,18 +2,17 @@ import warnings
from typing import Container, Union from typing import Container, Union
import sympy as sp import sympy as sp
from sympy.logic.boolalg import BooleanFunction
import pystencils.astnodes as ast import pystencils.astnodes as ast
from pystencils.backends.simd_instruction_sets import get_vector_instruction_set from pystencils.backends.simd_instruction_sets import get_vector_instruction_set
from pystencils.data_types import ( from pystencils.data_types import (
PointerType, TypedSymbol, VectorType, cast_func, collate_types, get_type_of_expression, PointerType, TypedSymbol, VectorType, cast_func, collate_types, get_type_of_expression, vector_memory_access)
vector_memory_access)
from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
from pystencils.field import Field from pystencils.field import Field
from pystencils.integer_functions import modulo_ceil, modulo_floor from pystencils.integer_functions import modulo_ceil, modulo_floor
from pystencils.sympyextensions import fast_subs from pystencils.sympyextensions import fast_subs
from pystencils.transformations import ( from pystencils.transformations import cut_loop, filtered_tree_iteration, replace_inner_stride_with_one
cut_loop, filtered_tree_iteration, replace_inner_stride_with_one)
# noinspection PyPep8Naming # noinspection PyPep8Naming
...@@ -177,7 +176,7 @@ def insert_vector_casts(ast_node): ...@@ -177,7 +176,7 @@ def insert_vector_casts(ast_node):
visit_expr(expr.args[4])) visit_expr(expr.args[4]))
elif isinstance(expr, cast_func): elif isinstance(expr, cast_func):
return expr return expr
elif expr.func in handled_functions or isinstance(expr, sp.Rel) or isinstance(expr, sp.boolalg.BooleanFunction): elif expr.func in handled_functions or isinstance(expr, sp.Rel) or isinstance(expr, BooleanFunction):
new_args = [visit_expr(a) for a in expr.args] new_args = [visit_expr(a) for a in expr.args]
arg_types = [get_type_of_expression(a) for a in new_args] arg_types = [get_type_of_expression(a) for a in new_args]
if not any(type(t) is VectorType for t in arg_types): if not any(type(t) is VectorType for t in arg_types):
......
...@@ -4,14 +4,14 @@ from functools import partial ...@@ -4,14 +4,14 @@ from functools import partial
from typing import Tuple from typing import Tuple
import numpy as np import numpy as np
import pystencils
import sympy as sp import sympy as sp
import sympy.codegen.ast import sympy.codegen.ast
from sympy.core.cache import cacheit
from sympy.logic.boolalg import Boolean, BooleanFunction
import pystencils
from pystencils.cache import memorycache, memorycache_if_hashable from pystencils.cache import memorycache, memorycache_if_hashable
from pystencils.utils import all_equal from pystencils.utils import all_equal
from sympy.core.cache import cacheit
from sympy.logic.boolalg import Boolean
try: try:
import llvmlite.ir as ir import llvmlite.ir as ir
...@@ -541,7 +541,7 @@ def get_type_of_expression(expr, ...@@ -541,7 +541,7 @@ def get_type_of_expression(expr,
elif isinstance(expr, sp.Indexed): elif isinstance(expr, sp.Indexed):
typed_symbol = expr.base.label typed_symbol = expr.base.label
return typed_symbol.dtype.base_type return typed_symbol.dtype.base_type
elif isinstance(expr, (sp.boolalg.Boolean, sp.boolalg.BooleanFunction)): elif isinstance(expr, (Boolean, BooleanFunction)):
# if any arg is of vector type return a vector boolean, else return a normal scalar boolean # if any arg is of vector type return a vector boolean, else return a normal scalar boolean
result = create_type("bool") result = create_type("bool")
vec_args = [get_type(a) for a in expr.args if isinstance(get_type(a), VectorType)] vec_args = [get_type(a) for a in expr.args if isinstance(get_type(a), VectorType)]
......
...@@ -8,14 +8,14 @@ from types import MappingProxyType ...@@ -8,14 +8,14 @@ from types import MappingProxyType
import numpy as np import numpy as np
import sympy as sp import sympy as sp
from sympy.core.numbers import ImaginaryUnit from sympy.core.numbers import ImaginaryUnit
from sympy.logic.boolalg import Boolean from sympy.logic.boolalg import Boolean, BooleanFunction
import pystencils.astnodes as ast import pystencils.astnodes as ast
import pystencils.integer_functions import pystencils.integer_functions
from pystencils.assignment import Assignment from pystencils.assignment import Assignment
from pystencils.data_types import ( from pystencils.data_types import (
PointerType, StructType, TypedImaginaryUnit, TypedSymbol, cast_func, collate_types, create_type, PointerType, StructType, TypedImaginaryUnit, TypedSymbol, cast_func, collate_types, create_type, get_base_type,
get_base_type, get_type_of_expression, pointer_arithmetic_func, reinterpret_cast_func) get_type_of_expression, pointer_arithmetic_func, reinterpret_cast_func)
from pystencils.field import AbstractField, Field, FieldType from pystencils.field import AbstractField, Field, FieldType
from pystencils.kernelparameters import FieldPointerSymbol from pystencils.kernelparameters import FieldPointerSymbol
from pystencils.simp.assignment_collection import AssignmentCollection from pystencils.simp.assignment_collection import AssignmentCollection
...@@ -851,7 +851,7 @@ class KernelConstraintsCheck: ...@@ -851,7 +851,7 @@ class KernelConstraintsCheck:
return cast_func( return cast_func(
self.process_expression(rhs.args[0], type_constants=False), self.process_expression(rhs.args[0], type_constants=False),
rhs.dtype) rhs.dtype)
elif isinstance(rhs, sp.boolalg.BooleanFunction) or \ elif isinstance(rhs, BooleanFunction) or \
type(rhs) in pystencils.integer_functions.__dict__.values(): type(rhs) in pystencils.integer_functions.__dict__.values():
new_args = [self.process_expression(a, type_constants) for a in rhs.args] new_args = [self.process_expression(a, type_constants) for a in rhs.args]
types_of_expressions = [get_type_of_expression(a) for a in new_args] types_of_expressions = [get_type_of_expression(a) for a in new_args]
...@@ -1030,7 +1030,7 @@ def insert_casts(node): ...@@ -1030,7 +1030,7 @@ def insert_casts(node):
types = [get_type_of_expression(arg) for arg in args] types = [get_type_of_expression(arg) for arg in args]
assert len(types) > 0 assert len(types) > 0
# Never ever, ever collate to float type for boolean functions! # Never ever, ever collate to float type for boolean functions!
target = collate_types(types, forbid_collation_to_float=isinstance(node.func, sp.boolalg.BooleanFunction)) target = collate_types(types, forbid_collation_to_float=isinstance(node.func, BooleanFunction))
zipped = list(zip(args, types)) zipped = list(zip(args, types))
if target.func is PointerType: if target.func is PointerType:
assert node.func is sp.Add assert node.func is sp.Add
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment