Commit 4a7299f1 authored by Martin Bauer's avatar Martin Bauer
Browse files

Rest of PEP8 renaming

parent 7acdc31c
......@@ -5,9 +5,9 @@ def aligned_empty(shape, byte_alignment=32, dtype=np.float64, byte_offset=0, ord
"""
Creates an aligned empty numpy array
:param shape: size of the array
:param byte_alignment: alignment in bytes, for the start address of the array holds (a % byteAlignment) == 0
:param byte_alignment: alignment in bytes, for the start address of the array holds (a % byte_alignment) == 0
:param dtype: numpy data type
:param byte_offset: offset in bytes for position that should be aligned i.e. (a+byte_offset) % byteAlignment == 0
:param byte_offset: offset in bytes for position that should be aligned i.e. (a+byte_offset) % byte_alignment == 0
typically used to align first inner cell instead of ghost layer
:param order: storage linearization order
:param align_inner_coordinate: if True, the start of the innermost coordinate lines are aligned as well
......
......@@ -111,7 +111,7 @@ class AssignmentCollection:
def dependent_symbols(self, symbols: Iterable[sp.Symbol]) -> Set[sp.Symbol]:
"""Returns all symbols that depend on one of the passed symbols.
A symbol 'a' depends on a symbol 'b', if there is an assignment 'a <- someExpression(b)' i.e. when
A symbol 'a' depends on a symbol 'b', if there is an assignment 'a <- some_expression(b)' i.e. when
'b' is required to compute 'a'.
"""
......@@ -217,18 +217,18 @@ class AssignmentCollection:
substitution_dict = {}
processed_other_subexpression_equations = []
for otherSubexpressionEq in other.subexpressions:
if otherSubexpressionEq.lhs in own_subexpression_symbols:
if otherSubexpressionEq.rhs == own_subexpression_symbols[otherSubexpressionEq.lhs]:
for other_subexpression_eq in other.subexpressions:
if other_subexpression_eq.lhs in own_subexpression_symbols:
if other_subexpression_eq.rhs == own_subexpression_symbols[other_subexpression_eq.lhs]:
continue # exact the same subexpression equation exists already
else:
# different definition - a new name has to be introduced
new_lhs = next(self.subexpression_symbol_generator)
new_eq = Assignment(new_lhs, fast_subs(otherSubexpressionEq.rhs, substitution_dict))
new_eq = Assignment(new_lhs, fast_subs(other_subexpression_eq.rhs, substitution_dict))
processed_other_subexpression_equations.append(new_eq)
substitution_dict[otherSubexpressionEq.lhs] = new_lhs
substitution_dict[other_subexpression_eq.lhs] = new_lhs
else:
processed_other_subexpression_equations.append(fast_subs(otherSubexpressionEq, substitution_dict))
processed_other_subexpression_equations.append(fast_subs(other_subexpression_eq, substitution_dict))
processed_other_main_assignments = [fast_subs(eq, substitution_dict) for eq in other.main_assignments]
return self.copy(self.main_assignments + processed_other_main_assignments,
......
......@@ -50,10 +50,10 @@ def apply_on_all_subexpressions(ac: AssignmentCollection,
def subexpression_substitution_in_existing_subexpressions(ac: AssignmentCollection) -> AssignmentCollection:
"""Goes through the subexpressions list and replaces the term in the following subexpressions."""
result = []
for outerCtr, s in enumerate(ac.subexpressions):
for outer_ctr, s in enumerate(ac.subexpressions):
new_rhs = s.rhs
for innerCtr in range(outerCtr):
sub_expr = ac.subexpressions[innerCtr]
for inner_ctr in range(outer_ctr):
sub_expr = ac.subexpressions[inner_ctr]
new_rhs = subs_additive(new_rhs, sub_expr.lhs, sub_expr.rhs, required_match_replacement=1.0)
new_rhs = new_rhs.subs(sub_expr.rhs, sub_expr.lhs)
result.append(Assignment(s.lhs, new_rhs))
......@@ -66,8 +66,8 @@ def subexpression_substitution_in_main_assignments(ac: AssignmentCollection) ->
result = []
for s in ac.main_assignments:
new_rhs = s.rhs
for subExpr in ac.subexpressions:
new_rhs = subs_additive(new_rhs, subExpr.lhs, subExpr.rhs, required_match_replacement=1.0)
for sub_expr in ac.subexpressions:
new_rhs = subs_additive(new_rhs, sub_expr.lhs, sub_expr.rhs, required_match_replacement=1.0)
result.append(Assignment(s.lhs, new_rhs))
return ac.copy(result)
......@@ -91,5 +91,5 @@ def add_subexpressions_for_divisions(ac: AssignmentCollection) -> AssignmentColl
search_divisors(eq.rhs)
new_symbol_gen = ac.subexpression_symbol_generator
substitutions = {divisor: newSymbol for newSymbol, divisor in zip(new_symbol_gen, divisors)}
substitutions = {divisor: new_symbol for new_symbol, divisor in zip(new_symbol_gen, divisors)}
return ac.new_with_substitutions(substitutions, True)
......@@ -64,7 +64,7 @@ class Conditional(Node):
super(Conditional, self).__init__(parent=None)
assert condition_expr.is_Boolean or condition_expr.is_Relational
self.conditionExpr = condition_expr
self.condition_expr = condition_expr
def handle_child(c):
if c is None:
......@@ -74,20 +74,20 @@ class Conditional(Node):
c.parent = self
return c
self.trueBlock = handle_child(true_block)
self.falseBlock = handle_child(false_block)
self.true_block = handle_child(true_block)
self.false_block = handle_child(false_block)
def subs(self, *args, **kwargs):
self.trueBlock.subs(*args, **kwargs)
if self.falseBlock:
self.falseBlock.subs(*args, **kwargs)
self.conditionExpr = self.conditionExpr.subs(*args, **kwargs)
self.true_block.subs(*args, **kwargs)
if self.false_block:
self.false_block.subs(*args, **kwargs)
self.condition_expr = self.condition_expr.subs(*args, **kwargs)
@property
def args(self):
result = [self.conditionExpr, self.trueBlock]
if self.falseBlock:
result.append(self.falseBlock)
result = [self.condition_expr, self.true_block]
if self.false_block:
result.append(self.false_block)
return result
@property
......@@ -96,17 +96,17 @@ class Conditional(Node):
@property
def undefined_symbols(self):
result = self.trueBlock.undefined_symbols
if self.falseBlock:
result.update(self.falseBlock.undefined_symbols)
result.update(self.conditionExpr.atoms(sp.Symbol))
result = self.true_block.undefined_symbols
if self.false_block:
result.update(self.false_block.undefined_symbols)
result.update(self.condition_expr.atoms(sp.Symbol))
return result
def __str__(self):
return 'if:({!s}) '.format(self.conditionExpr)
return 'if:({!s}) '.format(self.condition_expr)
def __repr__(self):
return 'if:({!r}) '.format(self.conditionExpr)
return 'if:({!r}) '.format(self.condition_expr)
class KernelFunction(Node):
......@@ -116,39 +116,39 @@ class KernelFunction(Node):
from pystencils.transformations import symbol_name_to_variable_name
self.name = name
self.dtype = dtype
self.isFieldPtrArgument = False
self.isFieldShapeArgument = False
self.isFieldStrideArgument = False
self.isFieldArgument = False
self.is_field_ptr_argument = False
self.is_field_shape_argument = False
self.is_field_stride_argument = False
self.is_field_argument = False
self.field_name = ""
self.coordinate = None
self.symbol = symbol
if name.startswith(Field.DATA_PREFIX):
self.isFieldPtrArgument = True
self.isFieldArgument = True
self.is_field_ptr_argument = True
self.is_field_argument = True
self.field_name = name[len(Field.DATA_PREFIX):]
elif name.startswith(Field.SHAPE_PREFIX):
self.isFieldShapeArgument = True
self.isFieldArgument = True
self.is_field_shape_argument = True
self.is_field_argument = True
self.field_name = name[len(Field.SHAPE_PREFIX):]
elif name.startswith(Field.STRIDE_PREFIX):
self.isFieldStrideArgument = True
self.isFieldArgument = True
self.is_field_stride_argument = True
self.is_field_argument = True
self.field_name = name[len(Field.STRIDE_PREFIX):]
self.field = None
if self.isFieldArgument:
if self.is_field_argument:
field_map = {symbol_name_to_variable_name(f.name): f for f in kernel_function_node.fields_accessed}
self.field = field_map[self.field_name]
def __lt__(self, other):
def score(l):
if l.isFieldPtrArgument:
if l.is_field_ptr_argument:
return -4
elif l.isFieldShapeArgument:
elif l.is_field_shape_argument:
return -3
elif l.isFieldStrideArgument:
elif l.is_field_stride_argument:
return -2
return 0
......@@ -298,12 +298,12 @@ class Block(Node):
class PragmaBlock(Block):
def __init__(self, pragma_line, nodes):
super(PragmaBlock, self).__init__(nodes)
self.pragmaLine = pragma_line
self.pragma_line = pragma_line
for n in nodes:
n.parent = self
def __repr__(self):
return self.pragmaLine
return self.pragma_line
class LoopOverCoordinate(Node):
......@@ -313,16 +313,16 @@ class LoopOverCoordinate(Node):
super(LoopOverCoordinate, self).__init__(parent=None)
self.body = body
body.parent = self
self.coordinateToLoopOver = coordinate_to_loop_over
self.coordinate_to_loop_over = coordinate_to_loop_over
self.start = start
self.stop = stop
self.step = step
self.body.parent = self
self.prefixLines = []
self.prefix_lines = []
def new_loop_with_different_body(self, new_body):
result = LoopOverCoordinate(new_body, self.coordinateToLoopOver, self.start, self.stop, self.step)
result.prefixLines = [l for l in self.prefixLines]
result = LoopOverCoordinate(new_body, self.coordinate_to_loop_over, self.start, self.stop, self.step)
result.prefix_lines = [l for l in self.prefix_lines]
return result
def subs(self, *args, **kwargs):
......@@ -359,9 +359,9 @@ class LoopOverCoordinate(Node):
@property
def undefined_symbols(self):
result = self.body.undefined_symbols
for possibleSymbol in [self.start, self.stop, self.step]:
if isinstance(possibleSymbol, Node) or isinstance(possibleSymbol, sp.Basic):
result.update(possibleSymbol.atoms(sp.Symbol))
for possible_symbol in [self.start, self.stop, self.step]:
if isinstance(possible_symbol, Node) or isinstance(possible_symbol, sp.Basic):
result.update(possible_symbol.atoms(sp.Symbol))
return result - {self.loop_counter_symbol}
@staticmethod
......@@ -370,7 +370,7 @@ class LoopOverCoordinate(Node):
@property
def loop_counter_name(self):
return LoopOverCoordinate.get_loop_counter_name(self.coordinateToLoopOver)
return LoopOverCoordinate.get_loop_counter_name(self.coordinate_to_loop_over)
@staticmethod
def is_loop_counter_symbol(symbol):
......@@ -388,7 +388,7 @@ class LoopOverCoordinate(Node):
@property
def loop_counter_symbol(self):
return LoopOverCoordinate.get_loop_counter_symbol(self.coordinateToLoopOver)
return LoopOverCoordinate.get_loop_counter_symbol(self.coordinate_to_loop_over)
@property
def is_outermost_loop(self):
......@@ -414,25 +414,25 @@ class LoopOverCoordinate(Node):
class SympyAssignment(Node):
def __init__(self, lhs_symbol, rhs_expr, is_const=True):
super(SympyAssignment, self).__init__(parent=None)
self._lhsSymbol = lhs_symbol
self._lhs_symbol = lhs_symbol
self.rhs = rhs_expr
self._isDeclaration = True
is_cast = self._lhsSymbol.func == cast_func
if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, ResolvedFieldAccess) or is_cast:
self._isDeclaration = False
self._isConst = is_const
self._is_declaration = True
is_cast = self._lhs_symbol.func == cast_func
if isinstance(self._lhs_symbol, Field.Access) or isinstance(self._lhs_symbol, ResolvedFieldAccess) or is_cast:
self._is_declaration = False
self._is_const = is_const
@property
def lhs(self):
return self._lhsSymbol
return self._lhs_symbol
@lhs.setter
def lhs(self, new_value):
self._lhsSymbol = new_value
self._isDeclaration = True
is_cast = self._lhsSymbol.func == cast_func
if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, sp.Indexed) or is_cast:
self._isDeclaration = False
self._lhs_symbol = new_value
self._is_declaration = True
is_cast = self._lhs_symbol.func == cast_func
if isinstance(self._lhs_symbol, Field.Access) or isinstance(self._lhs_symbol, sp.Indexed) or is_cast:
self._is_declaration = False
def subs(self, *args, **kwargs):
self.lhs = fast_subs(self.lhs, *args, **kwargs)
......@@ -440,13 +440,13 @@ class SympyAssignment(Node):
@property
def args(self):
return [self._lhsSymbol, self.rhs]
return [self._lhs_symbol, self.rhs]
@property
def symbols_defined(self):
if not self._isDeclaration:
if not self._is_declaration:
return set()
return {self._lhsSymbol}
return {self._lhs_symbol}
@property
def undefined_symbols(self):
......@@ -458,16 +458,16 @@ class SympyAssignment(Node):
for i in range(len(symbol.offsets)):
loop_counters.add(LoopOverCoordinate.get_loop_counter_symbol(i))
result.update(loop_counters)
result.update(self._lhsSymbol.atoms(sp.Symbol))
result.update(self._lhs_symbol.atoms(sp.Symbol))
return result
@property
def is_declaration(self):
return self._isDeclaration
return self._is_declaration
@property
def is_const(self):
return self._isConst
return self._is_const
def replace(self, child, replacement):
if child == self.lhs:
......@@ -495,24 +495,24 @@ class ResolvedFieldAccess(sp.Indexed):
obj = super(ResolvedFieldAccess, cls).__new__(cls, base, linearized_index)
obj.field = field
obj.offsets = offsets
obj.idxCoordinateValues = idx_coordinate_values
obj.idx_coordinate_values = idx_coordinate_values
return obj
def _eval_subs(self, old, new):
return ResolvedFieldAccess(self.args[0],
self.args[1].subs(old, new),
self.field, self.offsets, self.idxCoordinateValues)
self.field, self.offsets, self.idx_coordinate_values)
def fast_subs(self, substitutions):
if self in substitutions:
return substitutions[self]
return ResolvedFieldAccess(self.args[0].subs(substitutions),
self.args[1].subs(substitutions),
self.field, self.offsets, self.idxCoordinateValues)
self.field, self.offsets, self.idx_coordinate_values)
def _hashable_content(self):
super_class_contents = super(ResolvedFieldAccess, self)._hashable_content()
return super_class_contents + tuple(self.offsets) + (repr(self.idxCoordinateValues), hash(self.field))
return super_class_contents + tuple(self.offsets) + (repr(self.idx_coordinate_values), hash(self.field))
@property
def typed_symbol(self):
......@@ -523,7 +523,7 @@ class ResolvedFieldAccess(sp.Indexed):
return "%s (%s)" % (top, self.typed_symbol.dtype)
def __getnewargs__(self):
return self.base, self.indices[0], self.field, self.offsets, self.idxCoordinateValues
return self.base, self.indices[0], self.field, self.offsets, self.idx_coordinate_values
class TemporaryMemoryAllocation(Node):
......
......@@ -2,6 +2,6 @@ from .cbackend import generate_c
try:
from .dot import print_dot
from .llvm import generateLLVM
from .llvm import generate_llvm
except ImportError:
pass
......@@ -11,7 +11,7 @@ except ImportError:
from pystencils.bitoperations import bitwise_xor, bit_shift_right, bit_shift_left, bitwise_and, bitwise_or
from pystencils.astnodes import Node, ResolvedFieldAccess, SympyAssignment
from pystencils.data_types import create_type, PointerType, get_type_of_expression, VectorType, cast_func
from pystencils.backends.simd_instruction_sets import selectedInstructionSet
from pystencils.backends.simd_instruction_sets import selected_instruction_set
__all__ = ['generate_c', 'CustomCppCode', 'PrintNode', 'get_headers']
......@@ -36,7 +36,7 @@ def generate_c(ast_node: Node, signature_only: bool = False, use_float_constants
double = create_type('double')
use_float_constants = double not in field_types
vector_is = selectedInstructionSet['double']
vector_is = selected_instruction_set['double']
printer = CBackend(constants_as_floats=use_float_constants, signature_only=signature_only,
vector_instruction_set=vector_is)
return printer(ast_node)
......@@ -50,7 +50,7 @@ def get_headers(ast_node: Node) -> Set[str]:
headers.update(ast_node.headers)
elif isinstance(ast_node, SympyAssignment):
if type(get_type_of_expression(ast_node.rhs)) is VectorType:
headers.update(selectedInstructionSet['double']['headers'])
headers.update(selected_instruction_set['double']['headers'])
for a in ast_node.args:
if isinstance(a, Node):
......@@ -104,23 +104,23 @@ class CBackend:
def __init__(self, constants_as_floats=False, sympy_printer=None,
signature_only=False, vector_instruction_set=None):
if sympy_printer is None:
self.sympyPrinter = CustomSympyPrinter(constants_as_floats)
self.sympy_printer = CustomSympyPrinter(constants_as_floats)
if vector_instruction_set is not None:
self.sympyPrinter = VectorizedCustomSympyPrinter(vector_instruction_set, constants_as_floats)
self.sympy_printer = VectorizedCustomSympyPrinter(vector_instruction_set, constants_as_floats)
else:
self.sympyPrinter = CustomSympyPrinter(constants_as_floats)
self.sympy_printer = CustomSympyPrinter(constants_as_floats)
else:
self.sympyPrinter = sympy_printer
self.sympy_printer = sympy_printer
self._vectorInstructionSet = vector_instruction_set
self._indent = " "
self._signatureOnly = signature_only
def __call__(self, node):
prev_is = VectorType.instructionSet
VectorType.instructionSet = self._vectorInstructionSet
prev_is = VectorType.instruction_set
VectorType.instruction_set = self._vectorInstructionSet
result = str(self._print(node))
VectorType.instructionSet = prev_is
VectorType.instruction_set = prev_is
return result
def _print(self, node):
......@@ -144,49 +144,49 @@ class CBackend:
return "{\n%s\n}" % (self._indent + self._indent.join(block_contents.splitlines(True)))
def _print_PragmaBlock(self, node):
return "%s\n%s" % (node.pragmaLine, self._print_Block(node))
return "%s\n%s" % (node.pragma_line, self._print_Block(node))
def _print_LoopOverCoordinate(self, node):
counter_symbol = node.loop_counter_name
start = "int %s = %s" % (counter_symbol, self.sympyPrinter.doprint(node.start))
condition = "%s < %s" % (counter_symbol, self.sympyPrinter.doprint(node.stop))
update = "%s += %s" % (counter_symbol, self.sympyPrinter.doprint(node.step),)
loopStr = "for (%s; %s; %s)" % (start, condition, update)
start = "int %s = %s" % (counter_symbol, self.sympy_printer.doprint(node.start))
condition = "%s < %s" % (counter_symbol, self.sympy_printer.doprint(node.stop))
update = "%s += %s" % (counter_symbol, self.sympy_printer.doprint(node.step),)
loop_str = "for (%s; %s; %s)" % (start, condition, update)
prefix = "\n".join(node.prefixLines)
prefix = "\n".join(node.prefix_lines)
if prefix:
prefix += "\n"
return "%s%s\n%s" % (prefix, loopStr, self._print(node.body))
return "%s%s\n%s" % (prefix, loop_str, self._print(node.body))
def _print_SympyAssignment(self, node):
if node.is_declaration:
data_type = "const " + str(node.lhs.dtype) + " " if node.is_const else str(node.lhs.dtype) + " "
return "%s %s = %s;" % (data_type, self.sympyPrinter.doprint(node.lhs), self.sympyPrinter.doprint(node.rhs))
return "%s %s = %s;" % (data_type, self.sympy_printer.doprint(node.lhs), self.sympy_printer.doprint(node.rhs))
else:
lhs_type = get_type_of_expression(node.lhs)
if type(lhs_type) is VectorType and node.lhs.func == cast_func:
return self._vectorInstructionSet['storeU'].format("&" + self.sympyPrinter.doprint(node.lhs.args[0]),
self.sympyPrinter.doprint(node.rhs)) + ';'
return self._vectorInstructionSet['storeU'].format("&" + self.sympy_printer.doprint(node.lhs.args[0]),
self.sympy_printer.doprint(node.rhs)) + ';'
else:
return "%s = %s;" % (self.sympyPrinter.doprint(node.lhs), self.sympyPrinter.doprint(node.rhs))
return "%s = %s;" % (self.sympy_printer.doprint(node.lhs), self.sympy_printer.doprint(node.rhs))
def _print_TemporaryMemoryAllocation(self, node):
return "%s %s = new %s[%s];" % (node.symbol.dtype, self.sympyPrinter.doprint(node.symbol.name),
node.symbol.dtype.base_type, self.sympyPrinter.doprint(node.size))
return "%s %s = new %s[%s];" % (node.symbol.dtype, self.sympy_printer.doprint(node.symbol.name),
node.symbol.dtype.base_type, self.sympy_printer.doprint(node.size))
def _print_TemporaryMemoryFree(self, node):
return "delete [] %s;" % (self.sympyPrinter.doprint(node.symbol.name),)
return "delete [] %s;" % (self.sympy_printer.doprint(node.symbol.name),)
@staticmethod
def _print_CustomCppCode(node):
return node.code
def _print_Conditional(self, node):
condition_expr = self.sympyPrinter.doprint(node.conditionExpr)
true_block = self._print_Block(node.trueBlock)
condition_expr = self.sympy_printer.doprint(node.condition_expr)
true_block = self._print_Block(node.true_block)
result = "if (%s)\n%s " % (condition_expr, true_block)
if node.falseBlock:
false_block = self._print_Block(node.falseBlock)
if node.false_block:
false_block = self._print_Block(node.false_block)
result += "else " + false_block
return result
......@@ -253,14 +253,14 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
def __init__(self, instruction_set, constants_as_floats=False):
super(VectorizedCustomSympyPrinter, self).__init__(constants_as_floats)
self.instructionSet = instruction_set
self.instruction_set = instruction_set
def _scalarFallback(self, func_name, expr, *args, **kwargs):
expr_type = get_type_of_expression(expr)
if type(expr_type) is not VectorType:
return getattr(super(VectorizedCustomSympyPrinter, self), func_name)(expr, *args, **kwargs)
else:
assert self.instructionSet['width'] == expr_type.width
assert self.instruction_set['width'] == expr_type.width
return None
def _print_Function(self, expr):
......@@ -268,9 +268,9 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
arg, data_type = expr.args
if type(data_type) is VectorType:
if type(arg) is ResolvedFieldAccess:
return self.instructionSet['loadU'].format("& " + self._print(arg))
return self.instruction_set['loadU'].format("& " + self._print(arg))
else:
return self.instructionSet['makeVec'].format(self._print(arg))
return self.instruction_set['makeVec'].format(self._print(arg))
return super(VectorizedCustomSympyPrinter, self)._print_Function(expr)
......@@ -283,7 +283,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
assert len(arg_strings) > 0
result = arg_strings[0]
for item in arg_strings[1:]:
result = self.instructionSet['&'].format(result, item)
result = self.instruction_set['&'].format(result, item)
return result
def _print_Or(self, expr):
......@@ -295,7 +295,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
assert len(arg_strings) > 0
result = arg_strings[0]
for item in arg_strings[1:]:
result = self.instructionSet['|'].format(result, item)
result = self.instruction_set['|'].format(result, item)
return result
def _print_Add(self, expr, order=None):
......@@ -320,7 +320,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
assert len(summands) >= 2
processed = summands[0].term
for summand in summands[1:]:
func = self.instructionSet['-'] if summand.sign == -1 else self.instructionSet['+']
func = self.instruction_set['-'] if summand.sign == -1 else self.instruction_set['+']
processed = func.format(processed, summand.term)
return processed
......@@ -333,10 +333,10 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")"
else:
if expr.exp == -1:
one = self.instructionSet['makeVec'].format(1.0)
return self.instructionSet['/'].format(one, self._print(expr.base))
one = self.instruction_set['makeVec'].format(1.0)
return self.instruction_set['/'].format(one, self._print(expr.base))
elif expr.exp == 0.5:
return self.instructionSet['sqrt'].format(self._print(expr.base))
return self.instruction_set['sqrt'].format(self._print(expr.base))
else:
raise ValueError("Generic exponential not supported")
......@@ -369,26 +369,26 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
a.append(item)
a = a or [S.One]
# a = a or [castFunc(S.One, VectorType(createTypeFromString("double"), exprType.width))]
# a = a or [cast_func(S.One, VectorType(create_type_from_string("double"), expr_type.width))]
a_str = [self._print(x) for x in a]
b_str = [self._print(x) for x in b]
result = a_str[0]
for item in a_str[1:]:
result = self.instructionSet['*'].format(result, item)
result = self.instruction_set['*'].format(result, item)
if len(b) > 0:
denominator_str = b_str[0]
for item in b_str[1:]:
denominator_str = self.instructionSet['*'].format(denominator_str, item)
result = self.instructionSet['/'].format(result, denominator_str)