Commit 9f6bbd4c authored by Martin Bauer's avatar Martin Bauer
Browse files

Boolean Function: import from sympy.logic to work with recent sympy versions

parent e2b61455
......@@ -4,16 +4,17 @@ from typing import Set
import numpy as np
import sympy as sp
from sympy.core import S
from sympy.logic.boolalg import BooleanFalse, BooleanTrue
from sympy.printing.ccode import C89CodePrinter
from pystencils.astnodes import KernelFunction, Node
from pystencils.cpu.vectorization import vec_all, vec_any
from pystencils.data_types import (
PointerType, VectorType, address_of, cast_func, create_type, get_type_of_expression,
reinterpret_cast_func, vector_memory_access)
PointerType, VectorType, address_of, cast_func, create_type, get_type_of_expression, reinterpret_cast_func,
vector_memory_access)
from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
from pystencils.integer_functions import (
bit_shift_left, bit_shift_right, bitwise_and, bitwise_or, bitwise_xor,
int_div, int_power_of_2, modulo_ceil)
bit_shift_left, bit_shift_right, bitwise_and, bitwise_or, bitwise_xor, int_div, int_power_of_2, modulo_ceil)
try:
from sympy.printing.ccode import C99CodePrinter as CCodePrinter
......@@ -292,9 +293,9 @@ class CBackend:
return ""
def _print_Conditional(self, node):
if type(node.condition_expr) is sp.boolalg.BooleanTrue:
if type(node.condition_expr) is BooleanTrue:
return self._print_Block(node.true_block)
elif type(node.condition_expr) is sp.boolalg.BooleanFalse:
elif type(node.condition_expr) is BooleanFalse:
return self._print_Block(node.false_block)
cond_type = get_type_of_expression(node.condition_expr)
if isinstance(cond_type, VectorType):
......
......@@ -2,18 +2,17 @@ import warnings
from typing import Container, Union
import sympy as sp
from sympy.logic.boolalg import BooleanFunction
import pystencils.astnodes as ast
from pystencils.backends.simd_instruction_sets import get_vector_instruction_set
from pystencils.data_types import (
PointerType, TypedSymbol, VectorType, cast_func, collate_types, get_type_of_expression,
vector_memory_access)
PointerType, TypedSymbol, VectorType, cast_func, collate_types, get_type_of_expression, vector_memory_access)
from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
from pystencils.field import Field
from pystencils.integer_functions import modulo_ceil, modulo_floor
from pystencils.sympyextensions import fast_subs
from pystencils.transformations import (
cut_loop, filtered_tree_iteration, replace_inner_stride_with_one)
from pystencils.transformations import cut_loop, filtered_tree_iteration, replace_inner_stride_with_one
# noinspection PyPep8Naming
......@@ -177,7 +176,7 @@ def insert_vector_casts(ast_node):
visit_expr(expr.args[4]))
elif isinstance(expr, cast_func):
return expr
elif expr.func in handled_functions or isinstance(expr, sp.Rel) or isinstance(expr, sp.boolalg.BooleanFunction):
elif expr.func in handled_functions or isinstance(expr, sp.Rel) or isinstance(expr, BooleanFunction):
new_args = [visit_expr(a) for a in expr.args]
arg_types = [get_type_of_expression(a) for a in new_args]
if not any(type(t) is VectorType for t in arg_types):
......
......@@ -4,14 +4,14 @@ from functools import partial
from typing import Tuple
import numpy as np
import pystencils
import sympy as sp
import sympy.codegen.ast
from sympy.core.cache import cacheit
from sympy.logic.boolalg import Boolean, BooleanFunction
import pystencils
from pystencils.cache import memorycache, memorycache_if_hashable
from pystencils.utils import all_equal
from sympy.core.cache import cacheit
from sympy.logic.boolalg import Boolean
try:
import llvmlite.ir as ir
......@@ -541,7 +541,7 @@ def get_type_of_expression(expr,
elif isinstance(expr, sp.Indexed):
typed_symbol = expr.base.label
return typed_symbol.dtype.base_type
elif isinstance(expr, (sp.boolalg.Boolean, sp.boolalg.BooleanFunction)):
elif isinstance(expr, (Boolean, BooleanFunction)):
# if any arg is of vector type return a vector boolean, else return a normal scalar boolean
result = create_type("bool")
vec_args = [get_type(a) for a in expr.args if isinstance(get_type(a), VectorType)]
......
......@@ -8,14 +8,14 @@ from types import MappingProxyType
import numpy as np
import sympy as sp
from sympy.core.numbers import ImaginaryUnit
from sympy.logic.boolalg import Boolean
from sympy.logic.boolalg import Boolean, BooleanFunction
import pystencils.astnodes as ast
import pystencils.integer_functions
from pystencils.assignment import Assignment
from pystencils.data_types import (
PointerType, StructType, TypedImaginaryUnit, TypedSymbol, cast_func, collate_types, create_type,
get_base_type, get_type_of_expression, pointer_arithmetic_func, reinterpret_cast_func)
PointerType, StructType, TypedImaginaryUnit, TypedSymbol, cast_func, collate_types, create_type, get_base_type,
get_type_of_expression, pointer_arithmetic_func, reinterpret_cast_func)
from pystencils.field import AbstractField, Field, FieldType
from pystencils.kernelparameters import FieldPointerSymbol
from pystencils.simp.assignment_collection import AssignmentCollection
......@@ -851,7 +851,7 @@ class KernelConstraintsCheck:
return cast_func(
self.process_expression(rhs.args[0], type_constants=False),
rhs.dtype)
elif isinstance(rhs, sp.boolalg.BooleanFunction) or \
elif isinstance(rhs, BooleanFunction) or \
type(rhs) in pystencils.integer_functions.__dict__.values():
new_args = [self.process_expression(a, type_constants) for a in rhs.args]
types_of_expressions = [get_type_of_expression(a) for a in new_args]
......@@ -1030,7 +1030,7 @@ def insert_casts(node):
types = [get_type_of_expression(arg) for arg in args]
assert len(types) > 0
# Never ever, ever collate to float type for boolean functions!
target = collate_types(types, forbid_collation_to_float=isinstance(node.func, sp.boolalg.BooleanFunction))
target = collate_types(types, forbid_collation_to_float=isinstance(node.func, BooleanFunction))
zipped = list(zip(args, types))
if target.func is PointerType:
assert node.func is sp.Add
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment