diff --git a/src/pystencils/field.py b/src/pystencils/field.py index 3f019f5660d67bda97ebf5ec74cd586d4a4ddbed..2a82b8a43ebf0f0acb78d267b1aeced6b06ce3cd 100644 --- a/src/pystencils/field.py +++ b/src/pystencils/field.py @@ -151,9 +151,9 @@ class Field: total_dimensions = spatial_dimensions + index_dimensions if index_shape is None or len(index_shape) == 0: - shape = tuple([FieldShapeSymbol([field_name], i) for i in range(total_dimensions)]) + shape = tuple([FieldShapeSymbol(field_name, i) for i in range(total_dimensions)]) else: - shape = tuple([FieldShapeSymbol([field_name], i) for i in range(spatial_dimensions)] + list(index_shape)) + shape = tuple([FieldShapeSymbol(field_name, i) for i in range(spatial_dimensions)] + list(index_shape)) strides = tuple([FieldStrideSymbol(field_name, i) for i in range(total_dimensions)]) diff --git a/src/pystencils/sympyextensions/typed_sympy.py b/src/pystencils/sympyextensions/typed_sympy.py index c49c46c399bd991fa58b799de905cad1b3987b99..7e9edaab123cf9dbb837a6588c4084fe58bda9de 100644 --- a/src/pystencils/sympyextensions/typed_sympy.py +++ b/src/pystencils/sympyextensions/typed_sympy.py @@ -106,7 +106,7 @@ class FieldStrideSymbol(TypedSymbol): obj = FieldStrideSymbol.__xnew_cached_(cls, *args, **kwds) return obj - def __new_stage2__(cls, field_name, coordinate): + def __new_stage2__(cls, field_name: str, coordinate: int): from ..defaults import DEFAULTS name = f"_stride_{field_name}_{coordinate}" @@ -139,7 +139,7 @@ class FieldShapeSymbol(TypedSymbol): obj = FieldShapeSymbol.__xnew_cached_(cls, *args, **kwds) return obj - def __new_stage2__(cls, field_names, coordinate): + def __new_stage2__(cls, field_name: str, coordinate: int): from ..defaults import DEFAULTS names = "_".join([field_name for field_name in field_names]) @@ -147,7 +147,7 @@ class FieldShapeSymbol(TypedSymbol): obj = super(FieldShapeSymbol, cls).__xnew__( cls, name, DEFAULTS.index_dtype, positive=True ) - obj.field_names = tuple(field_names) + obj.field_name = field_name obj.coordinate = coordinate return obj