From 0178dcc25f04d5a796e2ce9f89f08088b87d986d Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Fri, 6 Jul 2018 11:13:51 +0200
Subject: [PATCH] CBackend: support for negative integer exponentials

---
 backends/cbackend.py | 19 ++++++++++++-------
 1 file changed, 12 insertions(+), 7 deletions(-)

diff --git a/backends/cbackend.py b/backends/cbackend.py
index 9becd107d..e5d9730ff 100644
--- a/backends/cbackend.py
+++ b/backends/cbackend.py
@@ -220,6 +220,8 @@ class CustomSympyPrinter(CCodePrinter):
         """Don't use std::pow function, for small integer exponents, write as multiplication"""
         if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
             return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")"
+        elif expr.exp.is_integer and expr.exp.is_number and - 8 < expr.exp < 0:
+            return "1 / ({})".format(self._print(sp.Mul(*[expr.base] * (-expr.exp), evaluate=False)))
         else:
             return super(CustomSympyPrinter, self)._print_Pow(expr)
 
@@ -359,14 +361,17 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
 
         if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
             return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")"
+        elif expr.exp == -1:
+            one = self.instruction_set['makeVec'].format(1.0)
+            return self.instruction_set['/'].format(one, self._print(expr.base))
+        elif expr.exp == 0.5:
+            return self.instruction_set['sqrt'].format(self._print(expr.base))
+        elif expr.exp.is_integer and expr.exp.is_number and - 8 < expr.exp < 0:
+            one = self.instruction_set['makeVec'].format(1.0)
+            return self.instruction_set['/'].format(one,
+                                                    self._print(sp.Mul(*[expr.base] * (-expr.exp), evaluate=False)))
         else:
-            if expr.exp == -1:
-                one = self.instruction_set['makeVec'].format(1.0)
-                return self.instruction_set['/'].format(one, self._print(expr.base))
-            elif expr.exp == 0.5:
-                return self.instruction_set['sqrt'].format(self._print(expr.base))
-            else:
-                raise ValueError("Generic exponential not supported")
+            raise ValueError("Generic exponential not supported: " + str(expr))
 
     def _print_Mul(self, expr, inside_add=False):
         # noinspection PyProtectedMember
-- 
GitLab