Skip to content
Snippets Groups Projects
Commit b962a099 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Add assertion that headers follow the pattern /"..."/ or /<...>/

parent de3489c4
No related merge requests found
import re
from collections import namedtuple
from typing import Set
......@@ -24,6 +25,9 @@ except ImportError:
__all__ = ['generate_c', 'CustomCodeNode', 'PrintNode', 'get_headers', 'CustomSympyPrinter']
HEADER_REGEX = re.compile(r'^[<"].*[">]$')
KERNCRAFT_NO_TERNARY_MODE = False
......@@ -112,6 +116,9 @@ def get_headers(ast_node: Node) -> Set[str]:
if isinstance(g, Node):
headers.update(get_headers(g))
for h in headers:
assert HEADER_REGEX.match(h), f'header /{h}/ does not follow the pattern /"..."/ or /<...>/'
return sorted(headers)
......
"""
"""
import pytest
from pystencils.astnodes import Block
from pystencils.backends.cbackend import CustomCodeNode, get_headers
def test_headers_have_quotes_or_brackets():
class ErrorNode1(CustomCodeNode):
def __init__(self):
super().__init__("", [], [])
self.headers = ["iostream"]
class ErrorNode2(CustomCodeNode):
headers = ["<iostream>", "foo"]
def __init__(self):
super().__init__("", [], [])
self.headers = ["<iostream>", "foo"]
class OkNode3(CustomCodeNode):
def __init__(self):
super().__init__("", [], [])
self.headers = ["<iostream>", '"foo"']
with pytest.raises(AssertionError, match='.* does not follow the pattern .*'):
get_headers(Block([ErrorNode1()]))
with pytest.raises(AssertionError, match='.* does not follow the pattern .*'):
get_headers(ErrorNode2())
get_headers(OkNode3())
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment