Commit 82af488a authored by Jan Hönig's avatar Jan Hönig
Browse files

Merge branch 'Skip_interpolation_tests' into 'master'

Adapted test cases to Sympy Version 1.6

See merge request pycodegen/pystencils!158
parents 09de00cf 76c3727b
Pipeline #24583 passed with stage
in 6 minutes and 14 seconds
...@@ -3,4 +3,4 @@ max-line-length=120 ...@@ -3,4 +3,4 @@ max-line-length=120
exclude=pystencils/jupyter.py, exclude=pystencils/jupyter.py,
pystencils/plot.py pystencils/plot.py
pystencils/session.py pystencils/session.py
ignore = W293 W503 W291 C901 ignore = W293 W503 W291 C901 E741
...@@ -13,6 +13,7 @@ tests-and-coverage: ...@@ -13,6 +13,7 @@ tests-and-coverage:
- $ENABLE_NIGHTLY_BUILDS - $ENABLE_NIGHTLY_BUILDS
image: i10git.cs.fau.de:5005/pycodegen/pycodegen/full image: i10git.cs.fau.de:5005/pycodegen/pycodegen/full
script: script:
- env
- pip list - pip list
- export NUM_CORES=$(nproc --all) - export NUM_CORES=$(nproc --all)
- mkdir -p ~/.config/matplotlib - mkdir -p ~/.config/matplotlib
......
...@@ -152,6 +152,7 @@ class IPyNbFile(pytest.File): ...@@ -152,6 +152,7 @@ class IPyNbFile(pytest.File):
notebook = nbformat.read(notebook_contents, 4) notebook = nbformat.read(notebook_contents, 4)
code, _ = exporter.from_notebook_node(notebook) code, _ = exporter.from_notebook_node(notebook)
yield IPyNbTest(self.name, self, code) yield IPyNbTest(self.name, self, code)
# pytest v 2.4>: yield IPyNbTest.from_parent(name=self.name, parent=self, code=code)
def teardown(self): def teardown(self):
pass pass
...@@ -161,3 +162,4 @@ def pytest_collect_file(path, parent): ...@@ -161,3 +162,4 @@ def pytest_collect_file(path, parent):
glob_exprs = ["*demo*.ipynb", "*tutorial*.ipynb", "test_*.ipynb"] glob_exprs = ["*demo*.ipynb", "*tutorial*.ipynb", "test_*.ipynb"]
if any(path.fnmatch(g) for g in glob_exprs): if any(path.fnmatch(g) for g in glob_exprs):
return IPyNbFile(path, parent) return IPyNbFile(path, parent)
# pytest v 2.4 >: return IPyNbFile.from_parent(fspath=path, parent=parent)
...@@ -53,7 +53,7 @@ else: ...@@ -53,7 +53,7 @@ else:
# Tuple of things that can be on the lhs of an assignment # Tuple of things that can be on the lhs of an assignment
assignable = (sp.Symbol, MatrixSymbol, MatrixElement, sp.Indexed) assignable = (sp.Symbol, MatrixSymbol, MatrixElement, sp.Indexed)
if not isinstance(lhs, assignable): if not isinstance(lhs, assignable):
raise TypeError("Cannot assign to lhs of type %s." % type(lhs)) raise TypeError(f"Cannot assign to lhs of type {type(lhs)}.")
return sp.Rel.__new__(cls, lhs, rhs, **assumptions) return sp.Rel.__new__(cls, lhs, rhs, **assumptions)
__str__ = assignment_str __str__ = assignment_str
......
...@@ -113,14 +113,14 @@ class Conditional(Node): ...@@ -113,14 +113,14 @@ class Conditional(Node):
return self.__repr__() return self.__repr__()
def __repr__(self): def __repr__(self):
repr = 'if:({!r}) '.format(self.condition_expr) result = f'if:({self.condition_expr!r}) '
if self.true_block: if self.true_block:
repr += '\n\t{}) '.format(self.true_block) result += f'\n\t{self.true_block}) '
if self.false_block: if self.false_block:
repr = 'else: '.format(self.false_block) result = 'else: '
repr += '\n\t{} '.format(self.false_block) result += f'\n\t{self.false_block} '
return repr return result
def replace_by_true_block(self): def replace_by_true_block(self):
"""Replaces the conditional by its True block""" """Replaces the conditional by its True block"""
...@@ -264,7 +264,7 @@ class KernelFunction(Node): ...@@ -264,7 +264,7 @@ class KernelFunction(Node):
def __repr__(self): def __repr__(self):
params = [p.symbol for p in self.get_parameters()] params = [p.symbol for p in self.get_parameters()]
return '{0} {1}({2})'.format(type(self).__name__, self.function_name, params) return f'{type(self).__name__} {self.function_name}({params})'
def compile(self, *args, **kwargs): def compile(self, *args, **kwargs):
if self._compile_function is None: if self._compile_function is None:
...@@ -475,11 +475,11 @@ class LoopOverCoordinate(Node): ...@@ -475,11 +475,11 @@ class LoopOverCoordinate(Node):
@staticmethod @staticmethod
def get_loop_counter_name(coordinate_to_loop_over): def get_loop_counter_name(coordinate_to_loop_over):
return "%s_%s" % (LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX, coordinate_to_loop_over) return f"{LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX}_{coordinate_to_loop_over}"
@staticmethod @staticmethod
def get_block_loop_counter_name(coordinate_to_loop_over): def get_block_loop_counter_name(coordinate_to_loop_over):
return "%s_%s" % (LoopOverCoordinate.BlOCK_LOOP_COUNTER_NAME_PREFIX, coordinate_to_loop_over) return f"{LoopOverCoordinate.BlOCK_LOOP_COUNTER_NAME_PREFIX}_{coordinate_to_loop_over}"
@property @property
def loop_counter_name(self): def loop_counter_name(self):
...@@ -612,7 +612,7 @@ class SympyAssignment(Node): ...@@ -612,7 +612,7 @@ class SympyAssignment(Node):
replacement.parent = self replacement.parent = self
self.rhs = replacement self.rhs = replacement
else: else:
raise ValueError('%s is not in args of %s' % (replacement, self.__class__)) raise ValueError(f'{replacement} is not in args of {self.__class__}')
def __repr__(self): def __repr__(self):
return repr(self.lhs) + " ← " + repr(self.rhs) return repr(self.lhs) + " ← " + repr(self.rhs)
...@@ -620,7 +620,7 @@ class SympyAssignment(Node): ...@@ -620,7 +620,7 @@ class SympyAssignment(Node):
def _repr_html_(self): def _repr_html_(self):
printed_lhs = sp.latex(self.lhs) printed_lhs = sp.latex(self.lhs)
printed_rhs = sp.latex(self.rhs) printed_rhs = sp.latex(self.rhs)
return "${printed_lhs} \\leftarrow {printed_rhs}$".format(printed_lhs=printed_lhs, printed_rhs=printed_rhs) return f"${printed_lhs} \\leftarrow {printed_rhs}$"
def __hash__(self): def __hash__(self):
return hash((self.lhs, self.rhs)) return hash((self.lhs, self.rhs))
...@@ -663,7 +663,7 @@ class ResolvedFieldAccess(sp.Indexed): ...@@ -663,7 +663,7 @@ class ResolvedFieldAccess(sp.Indexed):
def __str__(self): def __str__(self):
top = super(ResolvedFieldAccess, self).__str__() top = super(ResolvedFieldAccess, self).__str__()
return "%s (%s)" % (top, self.typed_symbol.dtype) return f"{top} ({self.typed_symbol.dtype})"
def __getnewargs__(self): def __getnewargs__(self):
return self.base, self.indices[0], self.field, self.offsets, self.idx_coordinate_values return self.base, self.indices[0], self.field, self.offsets, self.idx_coordinate_values
...@@ -740,7 +740,7 @@ def early_out(condition): ...@@ -740,7 +740,7 @@ def early_out(condition):
def get_dummy_symbol(dtype='bool'): def get_dummy_symbol(dtype='bool'):
return TypedSymbol('dummy%s' % uuid.uuid4().hex, create_type(dtype)) return TypedSymbol(f'dummy{uuid.uuid4().hex}', create_type(dtype))
class SourceCodeComment(Node): class SourceCodeComment(Node):
......
...@@ -158,7 +158,7 @@ class CustomCodeNode(Node): ...@@ -158,7 +158,7 @@ class CustomCodeNode(Node):
class PrintNode(CustomCodeNode): class PrintNode(CustomCodeNode):
# noinspection SpellCheckingInspection # noinspection SpellCheckingInspection
def __init__(self, symbol_to_print): def __init__(self, symbol_to_print):
code = '\nstd::cout << "%s = " << %s << std::endl; \n' % (symbol_to_print.name, symbol_to_print.name) code = f'\nstd::cout << "{symbol_to_print.name} = " << {symbol_to_print.name} << std::endl; \n'
super(PrintNode, self).__init__(code, symbols_read=[symbol_to_print], symbols_defined=set()) super(PrintNode, self).__init__(code, symbols_read=[symbol_to_print], symbols_defined=set())
self.headers.append("<iostream>") self.headers.append("<iostream>")
...@@ -203,12 +203,12 @@ class CBackend: ...@@ -203,12 +203,12 @@ class CBackend:
return str(node) return str(node)
def _print_KernelFunction(self, node): def _print_KernelFunction(self, node):
function_arguments = ["%s %s" % (self._print(s.symbol.dtype), s.symbol.name) for s in node.get_parameters()] function_arguments = [f"{self._print(s.symbol.dtype)} {s.symbol.name}" for s in node.get_parameters()]
launch_bounds = "" launch_bounds = ""
if self._dialect == 'cuda': if self._dialect == 'cuda':
max_threads = node.indexing.max_threads_per_block() max_threads = node.indexing.max_threads_per_block()
if max_threads: if max_threads:
launch_bounds = "__launch_bounds__({}) ".format(max_threads) launch_bounds = f"__launch_bounds__({max_threads}) "
func_declaration = "FUNC_PREFIX %svoid %s(%s)" % (launch_bounds, node.function_name, func_declaration = "FUNC_PREFIX %svoid %s(%s)" % (launch_bounds, node.function_name,
", ".join(function_arguments)) ", ".join(function_arguments))
if self._signatureOnly: if self._signatureOnly:
...@@ -222,19 +222,19 @@ class CBackend: ...@@ -222,19 +222,19 @@ class CBackend:
return "{\n%s\n}" % (self._indent + self._indent.join(block_contents.splitlines(True))) return "{\n%s\n}" % (self._indent + self._indent.join(block_contents.splitlines(True)))
def _print_PragmaBlock(self, node): def _print_PragmaBlock(self, node):
return "%s\n%s" % (node.pragma_line, self._print_Block(node)) return f"{node.pragma_line}\n{self._print_Block(node)}"
def _print_LoopOverCoordinate(self, node): def _print_LoopOverCoordinate(self, node):
counter_symbol = node.loop_counter_name counter_symbol = node.loop_counter_name
start = "int %s = %s" % (counter_symbol, self.sympy_printer.doprint(node.start)) start = f"int {counter_symbol} = {self.sympy_printer.doprint(node.start)}"
condition = "%s < %s" % (counter_symbol, self.sympy_printer.doprint(node.stop)) condition = f"{counter_symbol} < {self.sympy_printer.doprint(node.stop)}"
update = "%s += %s" % (counter_symbol, self.sympy_printer.doprint(node.step),) update = f"{counter_symbol} += {self.sympy_printer.doprint(node.step)}"
loop_str = "for (%s; %s; %s)" % (start, condition, update) loop_str = f"for ({start}; {condition}; {update})"
prefix = "\n".join(node.prefix_lines) prefix = "\n".join(node.prefix_lines)
if prefix: if prefix:
prefix += "\n" prefix += "\n"
return "%s%s\n%s" % (prefix, loop_str, self._print(node.body)) return f"{prefix}{loop_str}\n{self._print(node.body)}"
def _print_SympyAssignment(self, node): def _print_SympyAssignment(self, node):
if node.is_declaration: if node.is_declaration:
...@@ -262,7 +262,7 @@ class CBackend: ...@@ -262,7 +262,7 @@ class CBackend:
instr = 'maskStore' if aligned else 'maskStoreU' instr = 'maskStore' if aligned else 'maskStoreU'
printed_mask = self.sympy_printer.doprint(mask) printed_mask = self.sympy_printer.doprint(mask)
if self._vector_instruction_set['dataTypePrefix']['double'] == '__mm256d': if self._vector_instruction_set['dataTypePrefix']['double'] == '__mm256d':
printed_mask = "_mm256_castpd_si256({})".format(printed_mask) printed_mask = f"_mm256_castpd_si256({printed_mask})"
rhs_type = get_type_of_expression(node.rhs) rhs_type = get_type_of_expression(node.rhs)
if type(rhs_type) is not VectorType: if type(rhs_type) is not VectorType:
...@@ -274,7 +274,7 @@ class CBackend: ...@@ -274,7 +274,7 @@ class CBackend:
self.sympy_printer.doprint(rhs), self.sympy_printer.doprint(rhs),
printed_mask) + ';' printed_mask) + ';'
else: else:
return "%s = %s;" % (self.sympy_printer.doprint(node.lhs), self.sympy_printer.doprint(node.rhs)) return f"{self.sympy_printer.doprint(node.lhs)} = {self.sympy_printer.doprint(node.rhs)};"
def _print_TemporaryMemoryAllocation(self, node): def _print_TemporaryMemoryAllocation(self, node):
align = 64 align = 64
...@@ -314,7 +314,7 @@ class CBackend: ...@@ -314,7 +314,7 @@ class CBackend:
raise ValueError("Problem with Conditional inside vectorized loop - use vec_any or vec_all") raise ValueError("Problem with Conditional inside vectorized loop - use vec_any or vec_all")
condition_expr = self.sympy_printer.doprint(node.condition_expr) condition_expr = self.sympy_printer.doprint(node.condition_expr)
true_block = self._print_Block(node.true_block) true_block = self._print_Block(node.true_block)
result = "if (%s)\n%s " % (condition_expr, true_block) result = f"if ({condition_expr})\n{true_block} "
if node.false_block: if node.false_block:
false_block = self._print_Block(node.false_block) false_block = self._print_Block(node.false_block)
result += "else " + false_block result += "else " + false_block
...@@ -343,7 +343,7 @@ class CustomSympyPrinter(CCodePrinter): ...@@ -343,7 +343,7 @@ class CustomSympyPrinter(CCodePrinter):
if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8: if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")" return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")"
elif expr.exp.is_integer and expr.exp.is_number and - 8 < expr.exp < 0: elif expr.exp.is_integer and expr.exp.is_number and - 8 < expr.exp < 0:
return "1 / ({})".format(self._print(sp.Mul(*[expr.base] * (-expr.exp), evaluate=False))) return f"1 / ({self._print(sp.Mul(*([expr.base] * -expr.exp), evaluate=False))})"
else: else:
return super(CustomSympyPrinter, self)._print_Pow(expr) return super(CustomSympyPrinter, self)._print_Pow(expr)
...@@ -362,10 +362,10 @@ class CustomSympyPrinter(CCodePrinter): ...@@ -362,10 +362,10 @@ class CustomSympyPrinter(CCodePrinter):
return result.replace("\n", "") return result.replace("\n", "")
def _print_Abs(self, expr): def _print_Abs(self, expr):
if expr.is_integer: if expr.args[0].is_integer:
return 'abs({0})'.format(self._print(expr.args[0])) return f'abs({self._print(expr.args[0])})'
else: else:
return 'fabs({0})'.format(self._print(expr.args[0])) return f'fabs({self._print(expr.args[0])})'
def _print_Type(self, node): def _print_Type(self, node):
return str(node) return str(node)
...@@ -382,37 +382,37 @@ class CustomSympyPrinter(CCodePrinter): ...@@ -382,37 +382,37 @@ class CustomSympyPrinter(CCodePrinter):
return expr.to_c(self._print) return expr.to_c(self._print)
if isinstance(expr, reinterpret_cast_func): if isinstance(expr, reinterpret_cast_func):
arg, data_type = expr.args arg, data_type = expr.args
return "*((%s)(& %s))" % (self._print(PointerType(data_type, restrict=False)), self._print(arg)) return f"*(({self._print(PointerType(data_type, restrict=False))})(& {self._print(arg)}))"
elif isinstance(expr, address_of): elif isinstance(expr, address_of):
assert len(expr.args) == 1, "address_of must only have one argument" assert len(expr.args) == 1, "address_of must only have one argument"
return "&(%s)" % self._print(expr.args[0]) return f"&({self._print(expr.args[0])})"
elif isinstance(expr, cast_func): elif isinstance(expr, cast_func):
arg, data_type = expr.args arg, data_type = expr.args
if isinstance(arg, sp.Number) and arg.is_finite: if isinstance(arg, sp.Number) and arg.is_finite:
return self._typed_number(arg, data_type) return self._typed_number(arg, data_type)
else: else:
return "((%s)(%s))" % (data_type, self._print(arg)) return f"(({data_type})({self._print(arg)}))"
elif isinstance(expr, fast_division): elif isinstance(expr, fast_division):
return "({})".format(self._print(expr.args[0] / expr.args[1])) return f"({self._print(expr.args[0] / expr.args[1])})"
elif isinstance(expr, fast_sqrt): elif isinstance(expr, fast_sqrt):
return "({})".format(self._print(sp.sqrt(expr.args[0]))) return f"({self._print(sp.sqrt(expr.args[0]))})"
elif isinstance(expr, vec_any) or isinstance(expr, vec_all): elif isinstance(expr, vec_any) or isinstance(expr, vec_all):
return self._print(expr.args[0]) return self._print(expr.args[0])
elif isinstance(expr, fast_inv_sqrt): elif isinstance(expr, fast_inv_sqrt):
return "({})".format(self._print(1 / sp.sqrt(expr.args[0]))) return f"({self._print(1 / sp.sqrt(expr.args[0]))})"
elif isinstance(expr, sp.Abs): elif isinstance(expr, sp.Abs):
return "abs({})".format(self._print(expr.args[0])) return f"abs({self._print(expr.args[0])})"
elif isinstance(expr, sp.Mod): elif isinstance(expr, sp.Mod):
if expr.args[0].is_integer and expr.args[1].is_integer: if expr.args[0].is_integer and expr.args[1].is_integer:
return "({} % {})".format(self._print(expr.args[0]), self._print(expr.args[1])) return f"({self._print(expr.args[0])} % {self._print(expr.args[1])})"
else: else:
return "fmod({}, {})".format(self._print(expr.args[0]), self._print(expr.args[1])) return f"fmod({self._print(expr.args[0])}, {self._print(expr.args[1])})"
elif expr.func in infix_functions: elif expr.func in infix_functions:
return "(%s %s %s)" % (self._print(expr.args[0]), infix_functions[expr.func], self._print(expr.args[1])) return f"({self._print(expr.args[0])} {infix_functions[expr.func]} {self._print(expr.args[1])})"
elif expr.func == int_power_of_2: elif expr.func == int_power_of_2:
return "(1 << (%s))" % (self._print(expr.args[0])) return f"(1 << ({self._print(expr.args[0])}))"
elif expr.func == int_div: elif expr.func == int_div:
return "((%s) / (%s))" % (self._print(expr.args[0]), self._print(expr.args[1])) return f"(({self._print(expr.args[0])}) / ({self._print(expr.args[1])}))"
else: else:
name = expr.name if hasattr(expr, 'name') else expr.__class__.__name__ name = expr.name if hasattr(expr, 'name') else expr.__class__.__name__
arg_str = ', '.join(self._print(a) for a in expr.args) arg_str = ', '.join(self._print(a) for a in expr.args)
...@@ -540,14 +540,14 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): ...@@ -540,14 +540,14 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
result = self.instruction_set['/'].format(self._print(expr.args[0]), self._print(expr.args[1])) result = self.instruction_set['/'].format(self._print(expr.args[0]), self._print(expr.args[1]))
return result return result
elif expr.func == fast_sqrt: elif expr.func == fast_sqrt:
return "({})".format(self._print(sp.sqrt(expr.args[0]))) return f"({self._print(sp.sqrt(expr.args[0]))})"
elif expr.func == fast_inv_sqrt: elif expr.func == fast_inv_sqrt:
result = self._scalarFallback('_print_Function', expr) result = self._scalarFallback('_print_Function', expr)
if not result: if not result:
if self.instruction_set['rsqrt']: if self.instruction_set['rsqrt']:
return self.instruction_set['rsqrt'].format(self._print(expr.args[0])) return self.instruction_set['rsqrt'].format(self._print(expr.args[0]))
else: else:
return "({})".format(self._print(1 / sp.sqrt(expr.args[0]))) return f"({self._print(1 / sp.sqrt(expr.args[0]))})"
elif isinstance(expr, vec_any): elif isinstance(expr, vec_any):
expr_type = get_type_of_expression(expr.args[0]) expr_type = get_type_of_expression(expr.args[0])
if type(expr_type) is not VectorType: if type(expr_type) is not VectorType:
......
...@@ -98,7 +98,7 @@ class CudaSympyPrinter(CustomSympyPrinter): ...@@ -98,7 +98,7 @@ class CudaSympyPrinter(CustomSympyPrinter):
if isinstance(expr, fast_division): if isinstance(expr, fast_division):
return "__fdividef(%s, %s)" % tuple(self._print(a) for a in expr.args) return "__fdividef(%s, %s)" % tuple(self._print(a) for a in expr.args)
elif isinstance(expr, fast_sqrt): elif isinstance(expr, fast_sqrt):
return "__fsqrt_rn(%s)" % tuple(self._print(a) for a in expr.args) return f"__fsqrt_rn({tuple(self._print(a) for a in expr.args)})"
elif isinstance(expr, fast_inv_sqrt): elif isinstance(expr, fast_inv_sqrt):
return "__frsqrt_rn(%s)" % tuple(self._print(a) for a in expr.args) return f"__frsqrt_rn({tuple(self._print(a) for a in expr.args)})"
return super()._print_Function(expr) return super()._print_Function(expr)
...@@ -57,7 +57,7 @@ def __shortened(node): ...@@ -57,7 +57,7 @@ def __shortened(node):
params = node.get_parameters() params = node.get_parameters()
param_names = [p.field_name for p in params if p.is_field_pointer] param_names = [p.field_name for p in params if p.is_field_pointer]
param_names += [p.symbol.name for p in params if not p.is_field_parameter] param_names += [p.symbol.name for p in params if not p.is_field_parameter]
return "Func: %s (%s)" % (node.function_name, ",".join(param_names)) return f"Func: {node.function_name} ({','.join(param_names)})"
elif isinstance(node, SympyAssignment): elif isinstance(node, SympyAssignment):
return repr(node.lhs) return repr(node.lhs)
elif isinstance(node, Block): elif isinstance(node, Block):
...@@ -65,7 +65,7 @@ def __shortened(node): ...@@ -65,7 +65,7 @@ def __shortened(node):
elif isinstance(node, Conditional): elif isinstance(node, Conditional):
return repr(node) return repr(node)
else: else:
raise NotImplementedError("Cannot handle node type %s" % (type(node),)) raise NotImplementedError(f"Cannot handle node type {type(node)}")
def print_dot(node, view=False, short=False, **kwargs): def print_dot(node, view=False, short=False, **kwargs):
......
...@@ -98,7 +98,7 @@ class OpenClSympyPrinter(CudaSympyPrinter): ...@@ -98,7 +98,7 @@ class OpenClSympyPrinter(CudaSympyPrinter):
if isinstance(expr, fast_division): if isinstance(expr, fast_division):
return "native_divide(%s, %s)" % tuple(self._print(a) for a in expr.args) return "native_divide(%s, %s)" % tuple(self._print(a) for a in expr.args)
elif isinstance(expr, fast_sqrt): elif isinstance(expr, fast_sqrt):
return "native_sqrt(%s)" % tuple(self._print(a) for a in expr.args) return f"native_sqrt({tuple(self._print(a) for a in expr.args)})"
elif isinstance(expr, fast_inv_sqrt): elif isinstance(expr, fast_inv_sqrt):
return "native_rsqrt(%s)" % tuple(self._print(a) for a in expr.args) return f"native_rsqrt({tuple(self._print(a) for a in expr.args)})"
return CustomSympyPrinter._print_Function(self, expr) return CustomSympyPrinter._print_Function(self, expr)
...@@ -51,7 +51,7 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'): ...@@ -51,7 +51,7 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'):
}) })
for comparison_op, constant in comparisons.items(): for comparison_op, constant in comparisons.items():
base_names[comparison_op] = 'cmp[0, 1, %s]' % (constant,) base_names[comparison_op] = f'cmp[0, 1, {constant}]'
headers = { headers = {
'avx512': ['<immintrin.h>'], 'avx512': ['<immintrin.h>'],
...@@ -89,16 +89,16 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'): ...@@ -89,16 +89,16 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'):
name = function_shortcut[:function_shortcut.index('[')] name = function_shortcut[:function_shortcut.index('[')]
if intrinsic_id == 'makeVecConst': if intrinsic_id == 'makeVecConst':
arg_string = "({})".format(",".join(["{0}"] * result['width'])) arg_string = f"({','.join(['{0}'] * result['width'])})"
elif intrinsic_id == 'makeVec': elif intrinsic_id == 'makeVec':
params = ["{" + str(i) + "}" for i in reversed(range(result['width']))] params = ["{" + str(i) + "}" for i in reversed(range(result['width']))]
arg_string = "({})".format(",".join(params)) arg_string = f"({','.join(params)})"
elif intrinsic_id == 'makeVecBool': elif intrinsic_id == 'makeVecBool':
params = ["(({{{i}}} ? -1.0 : 0.0)".format(i=i) for i in reversed(range(result['width']))] params = [f"(({{{i}}} ? -1.0 : 0.0)" for i in reversed(range(result['width']))]
arg_string = "({})".format(",".join(params)) arg_string = f"({','.join(params)})"
elif intrinsic_id == 'makeVecConstBool': elif intrinsic_id == 'makeVecConstBool':
params = ["(({0}) ? -1.0 : 0.0)" for _ in range(result['width'])] params = ["(({0}) ? -1.0 : 0.0)" for _ in range(result['width'])]
arg_string = "({})".format(",".join(params)) arg_string = f"({','.join(params)})"
else: else:
args = function_shortcut[function_shortcut.index('[') + 1: -1] args = function_shortcut[function_shortcut.index('[') + 1: -1]
arg_string = "(" arg_string = "("
...@@ -141,9 +141,9 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'): ...@@ -141,9 +141,9 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'):
result['bool'] = "__mmask%d" % (size,) result['bool'] = "__mmask%d" % (size,)
params = " | ".join(["({{{i}}} ? {power} : 0)".format(i=i, power=2 ** i) for i in range(8)]) params = " | ".join(["({{{i}}} ? {power} : 0)".format(i=i, power=2 ** i) for i in range(8)])
result['makeVecBool'] = "__mmask8(({}) )".format(params) result['makeVecBool'] = f"__mmask8(({params}) )"
params = " | ".join(["({{0}} ? {power} : 0)".format(power=2 ** i) for i in range(8)]) params = " | ".join(["({{0}} ? {power} : 0)".format(power=2 ** i) for i in range(8)])
result['makeVecConstBool'] = "__mmask8(({}) )".format(params) result['makeVecConstBool'] = f"__mmask8(({params}) )"
if instruction_set == 'avx' and data_type == 'float': if instruction_set == 'avx' and data_type == 'float':
result['rsqrt'] = "_mm256_rsqrt_ps({0})" result['rsqrt'] = "_mm256_rsqrt_ps({0})"
......
...@@ -66,13 +66,13 @@ class FlagInterface: ...@@ -66,13 +66,13 @@ class FlagInterface:
self._used_flags.add(flag) self._used_flags.add(flag)
assert self._is_power_of_2(flag) assert self._is_power_of_2(flag)
return flag return flag
raise ValueError("All available {} flags are reserved".format(self.max_bits)) raise ValueError(f"All available {self.max_bits} flags are reserved")
def reserve_flag(self, flag): def reserve_flag(self, flag):
assert self._is_power_of_2(flag) assert self._is_power_of_2(flag)
flag = self.dtype(flag) flag = self.dtype(flag)
if flag in self._used_flags: if flag in self._used_flags:
raise ValueError("The flag {flag} is already reserved".format(flag=flag)) raise ValueError(f"The flag {flag} is already reserved")
self._used_flags.add(flag) self._used_flags.add(flag)
return flag return flag
<