From 14986080c84e0c530588fb2e0b3b6ec49429749c Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Fri, 15 Mar 2024 16:27:34 +0100
Subject: [PATCH] Change FieldShapeSymbol to only store a single field name

---
 src/pystencils/field.py                       | 4 ++--
 src/pystencils/sympyextensions/typed_sympy.py | 6 +++---
 2 files changed, 5 insertions(+), 5 deletions(-)

diff --git a/src/pystencils/field.py b/src/pystencils/field.py
index 3f019f566..2a82b8a43 100644
--- a/src/pystencils/field.py
+++ b/src/pystencils/field.py
@@ -151,9 +151,9 @@ class Field:
 
         total_dimensions = spatial_dimensions + index_dimensions
         if index_shape is None or len(index_shape) == 0:
-            shape = tuple([FieldShapeSymbol([field_name], i) for i in range(total_dimensions)])
+            shape = tuple([FieldShapeSymbol(field_name, i) for i in range(total_dimensions)])
         else:
-            shape = tuple([FieldShapeSymbol([field_name], i) for i in range(spatial_dimensions)] + list(index_shape))
+            shape = tuple([FieldShapeSymbol(field_name, i) for i in range(spatial_dimensions)] + list(index_shape))
 
         strides = tuple([FieldStrideSymbol(field_name, i) for i in range(total_dimensions)])
 
diff --git a/src/pystencils/sympyextensions/typed_sympy.py b/src/pystencils/sympyextensions/typed_sympy.py
index c49c46c39..7e9edaab1 100644
--- a/src/pystencils/sympyextensions/typed_sympy.py
+++ b/src/pystencils/sympyextensions/typed_sympy.py
@@ -106,7 +106,7 @@ class FieldStrideSymbol(TypedSymbol):
         obj = FieldStrideSymbol.__xnew_cached_(cls, *args, **kwds)
         return obj
 
-    def __new_stage2__(cls, field_name, coordinate):
+    def __new_stage2__(cls, field_name: str, coordinate: int):
         from ..defaults import DEFAULTS
 
         name = f"_stride_{field_name}_{coordinate}"
@@ -139,7 +139,7 @@ class FieldShapeSymbol(TypedSymbol):
         obj = FieldShapeSymbol.__xnew_cached_(cls, *args, **kwds)
         return obj
 
-    def __new_stage2__(cls, field_names, coordinate):
+    def __new_stage2__(cls, field_name: str, coordinate: int):
         from ..defaults import DEFAULTS
 
         names = "_".join([field_name for field_name in field_names])
@@ -147,7 +147,7 @@ class FieldShapeSymbol(TypedSymbol):
         obj = super(FieldShapeSymbol, cls).__xnew__(
             cls, name, DEFAULTS.index_dtype, positive=True
         )
-        obj.field_names = tuple(field_names)
+        obj.field_name = field_name
         obj.coordinate = coordinate
         return obj
 
-- 
GitLab