diff --git a/backends/cbackend.py b/backends/cbackend.py index 9becd107d649c5a01b215d52a5f6dc53ecd3a22a..e5d9730ff83aeb46de8b0ecc840b063119946275 100644 --- a/backends/cbackend.py +++ b/backends/cbackend.py @@ -220,6 +220,8 @@ class CustomSympyPrinter(CCodePrinter): """Don't use std::pow function, for small integer exponents, write as multiplication""" if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8: return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")" + elif expr.exp.is_integer and expr.exp.is_number and - 8 < expr.exp < 0: + return "1 / ({})".format(self._print(sp.Mul(*[expr.base] * (-expr.exp), evaluate=False))) else: return super(CustomSympyPrinter, self)._print_Pow(expr) @@ -359,14 +361,17 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8: return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")" + elif expr.exp == -1: + one = self.instruction_set['makeVec'].format(1.0) + return self.instruction_set['/'].format(one, self._print(expr.base)) + elif expr.exp == 0.5: + return self.instruction_set['sqrt'].format(self._print(expr.base)) + elif expr.exp.is_integer and expr.exp.is_number and - 8 < expr.exp < 0: + one = self.instruction_set['makeVec'].format(1.0) + return self.instruction_set['/'].format(one, + self._print(sp.Mul(*[expr.base] * (-expr.exp), evaluate=False))) else: - if expr.exp == -1: - one = self.instruction_set['makeVec'].format(1.0) - return self.instruction_set['/'].format(one, self._print(expr.base)) - elif expr.exp == 0.5: - return self.instruction_set['sqrt'].format(self._print(expr.base)) - else: - raise ValueError("Generic exponential not supported") + raise ValueError("Generic exponential not supported: " + str(expr)) def _print_Mul(self, expr, inside_add=False): # noinspection PyProtectedMember