From 0178dcc25f04d5a796e2ce9f89f08088b87d986d Mon Sep 17 00:00:00 2001 From: Martin Bauer <martin.bauer@fau.de> Date: Fri, 6 Jul 2018 11:13:51 +0200 Subject: [PATCH] CBackend: support for negative integer exponentials --- backends/cbackend.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/backends/cbackend.py b/backends/cbackend.py index 9becd107d..e5d9730ff 100644 --- a/backends/cbackend.py +++ b/backends/cbackend.py @@ -220,6 +220,8 @@ class CustomSympyPrinter(CCodePrinter): """Don't use std::pow function, for small integer exponents, write as multiplication""" if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8: return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")" + elif expr.exp.is_integer and expr.exp.is_number and - 8 < expr.exp < 0: + return "1 / ({})".format(self._print(sp.Mul(*[expr.base] * (-expr.exp), evaluate=False))) else: return super(CustomSympyPrinter, self)._print_Pow(expr) @@ -359,14 +361,17 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8: return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")" + elif expr.exp == -1: + one = self.instruction_set['makeVec'].format(1.0) + return self.instruction_set['/'].format(one, self._print(expr.base)) + elif expr.exp == 0.5: + return self.instruction_set['sqrt'].format(self._print(expr.base)) + elif expr.exp.is_integer and expr.exp.is_number and - 8 < expr.exp < 0: + one = self.instruction_set['makeVec'].format(1.0) + return self.instruction_set['/'].format(one, + self._print(sp.Mul(*[expr.base] * (-expr.exp), evaluate=False))) else: - if expr.exp == -1: - one = self.instruction_set['makeVec'].format(1.0) - return self.instruction_set['/'].format(one, self._print(expr.base)) - elif expr.exp == 0.5: - return self.instruction_set['sqrt'].format(self._print(expr.base)) - else: - raise ValueError("Generic exponential not supported") + raise ValueError("Generic exponential not supported: " + str(expr)) def _print_Mul(self, expr, inside_add=False): # noinspection PyProtectedMember -- GitLab