From e4874494c096bd36f8d8bca1cd1c444b2acb208b Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Thu, 29 Oct 2020 17:06:48 +0100
Subject: [PATCH] Fix Dirichlet boundary condition for scalar case

---
 pystencils/boundaries/boundaryconditions.py |  4 +--
 pystencils_tests/test_boundary.py           | 27 +++++++++++++++++++--
 2 files changed, 27 insertions(+), 4 deletions(-)

diff --git a/pystencils/boundaries/boundaryconditions.py b/pystencils/boundaries/boundaryconditions.py
index 39338634a..dc01224d0 100644
--- a/pystencils/boundaries/boundaryconditions.py
+++ b/pystencils/boundaries/boundaryconditions.py
@@ -84,7 +84,7 @@ class Dirichlet(Boundary):
     inner_or_boundary = False
     single_link = True
 
-    def __init__(self, value, name="Dirchlet"):
+    def __init__(self, value, name=None):
         super().__init__(name)
         self._value = value
 
@@ -103,7 +103,7 @@ class Dirichlet(Boundary):
     def __call__(self, field, direction_symbol, index_field, **kwargs):
 
         if field.index_dimensions == 0:
-            return [Assignment(field, index_field("value") if self.additional_data else self._value)]
+            return [Assignment(field.center, index_field("value") if self.additional_data else self._value)]
         elif field.index_dimensions == 1:
             assert not self.additional_data
             if not field.has_fixed_index_shape:
diff --git a/pystencils_tests/test_boundary.py b/pystencils_tests/test_boundary.py
index 23770c8ef..421cc2565 100644
--- a/pystencils_tests/test_boundary.py
+++ b/pystencils_tests/test_boundary.py
@@ -2,11 +2,10 @@ import os
 from tempfile import TemporaryDirectory
 
 import numpy as np
-
 import pytest
 
 from pystencils import Assignment, create_kernel
-from pystencils.boundaries import BoundaryHandling, Neumann, add_neumann_boundary
+from pystencils.boundaries import BoundaryHandling, Dirichlet, Neumann, add_neumann_boundary
 from pystencils.datahandling import SerialDataHandling
 from pystencils.slicing import slice_from_direction
 
@@ -88,3 +87,27 @@ def test_kernel_vs_copy_boundary():
         pytest.importorskip('pyevtk')
         boundary_handling.geometry_to_vtk(file_name=os.path.join(tmp_dir, 'test_output1'), ghost_layers=False)
         boundary_handling.geometry_to_vtk(file_name=os.path.join(tmp_dir, 'test_output2'), ghost_layers=True)
+
+
+@pytest.mark.parametrize('with_indices', ('with_indices', False))
+def test_dirichlet(with_indices):
+    value = (1, 20, 3) if with_indices else 1
+
+    dh = SerialDataHandling(domain_size=(7, 7))
+    src = dh.add_array('src', values_per_cell=3 if with_indices else 1)
+    dh.cpu_arrays.src[...] = np.random.rand(*src.shape)
+    boundary_stencil = [(1, 0), (-1, 0), (0, 1), (0, -1)]
+    boundary_handling = BoundaryHandling(dh, src.name, boundary_stencil)
+    dirichlet = Dirichlet(value)
+    assert dirichlet.name == 'Dirichlet'
+    dirichlet.name = "wall"
+    assert dirichlet.name == 'wall'
+
+    for d in ('N', 'S', 'W', 'E'):
+        boundary_handling.set_boundary(dirichlet, slice_from_direction(d, dim=2))
+    boundary_handling()
+
+    assert all([np.allclose(a, np.array(value)) for a in dh.cpu_arrays.src[1:-2, 0]])
+    assert all([np.allclose(a, np.array(value)) for a in dh.cpu_arrays.src[1:-2, -1]])
+    assert all([np.allclose(a, np.array(value)) for a in dh.cpu_arrays.src[0, 1:-2]])
+    assert all([np.allclose(a, np.array(value)) for a in dh.cpu_arrays.src[-1, 1:-2]])
-- 
GitLab