From b962a099b28fcd106a9b7dc1ee87665e2f03b079 Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Fri, 24 Jan 2020 10:29:45 +0100
Subject: [PATCH] Add assertion that headers follow the pattern /"..."/ or
 /<...>/

---
 pystencils/backends/cbackend.py         |  7 +++++
 pystencils_tests/test_helpful_errors.py | 37 +++++++++++++++++++++++++
 2 files changed, 44 insertions(+)
 create mode 100644 pystencils_tests/test_helpful_errors.py

diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py
index ad339a4bd..dd75f42d5 100644
--- a/pystencils/backends/cbackend.py
+++ b/pystencils/backends/cbackend.py
@@ -1,3 +1,4 @@
+import re
 from collections import namedtuple
 from typing import Set
 
@@ -24,6 +25,9 @@ except ImportError:
 
 __all__ = ['generate_c', 'CustomCodeNode', 'PrintNode', 'get_headers', 'CustomSympyPrinter']
 
+
+HEADER_REGEX = re.compile(r'^[<"].*[">]$')
+
 KERNCRAFT_NO_TERNARY_MODE = False
 
 
@@ -112,6 +116,9 @@ def get_headers(ast_node: Node) -> Set[str]:
         if isinstance(g, Node):
             headers.update(get_headers(g))
 
+    for h in headers:
+        assert HEADER_REGEX.match(h), f'header /{h}/ does not follow the pattern /"..."/ or /<...>/'
+
     return sorted(headers)
 
 
diff --git a/pystencils_tests/test_helpful_errors.py b/pystencils_tests/test_helpful_errors.py
new file mode 100644
index 000000000..a13178afe
--- /dev/null
+++ b/pystencils_tests/test_helpful_errors.py
@@ -0,0 +1,37 @@
+"""
+
+"""
+
+import pytest
+
+from pystencils.astnodes import Block
+from pystencils.backends.cbackend import CustomCodeNode, get_headers
+
+
+def test_headers_have_quotes_or_brackets():
+    class ErrorNode1(CustomCodeNode):
+
+        def __init__(self):
+            super().__init__("", [], [])
+            self.headers = ["iostream"]
+
+    class ErrorNode2(CustomCodeNode):
+        headers = ["<iostream>", "foo"]
+
+        def __init__(self):
+            super().__init__("", [], [])
+            self.headers = ["<iostream>", "foo"]
+
+    class OkNode3(CustomCodeNode):
+
+        def __init__(self):
+            super().__init__("", [], [])
+            self.headers = ["<iostream>", '"foo"']
+
+    with pytest.raises(AssertionError, match='.* does not follow the pattern .*'):
+        get_headers(Block([ErrorNode1()]))
+
+    with pytest.raises(AssertionError, match='.* does not follow the pattern .*'):
+        get_headers(ErrorNode2())
+
+    get_headers(OkNode3())
-- 
GitLab