diff --git a/pystencils/boundaries/boundaryconditions.py b/pystencils/boundaries/boundaryconditions.py index 39338634a51cae5bfc782efaae2ac69ca523152e..dc01224d02a04fd466c4dda6000acb87326a7706 100644 --- a/pystencils/boundaries/boundaryconditions.py +++ b/pystencils/boundaries/boundaryconditions.py @@ -84,7 +84,7 @@ class Dirichlet(Boundary): inner_or_boundary = False single_link = True - def __init__(self, value, name="Dirchlet"): + def __init__(self, value, name=None): super().__init__(name) self._value = value @@ -103,7 +103,7 @@ class Dirichlet(Boundary): def __call__(self, field, direction_symbol, index_field, **kwargs): if field.index_dimensions == 0: - return [Assignment(field, index_field("value") if self.additional_data else self._value)] + return [Assignment(field.center, index_field("value") if self.additional_data else self._value)] elif field.index_dimensions == 1: assert not self.additional_data if not field.has_fixed_index_shape: diff --git a/pystencils_tests/test_boundary.py b/pystencils_tests/test_boundary.py index 23770c8ef6d61c15110f94875b0000c3e7d11fac..421cc2565094687e6acde6de580bcaff9cbc375b 100644 --- a/pystencils_tests/test_boundary.py +++ b/pystencils_tests/test_boundary.py @@ -2,11 +2,10 @@ import os from tempfile import TemporaryDirectory import numpy as np - import pytest from pystencils import Assignment, create_kernel -from pystencils.boundaries import BoundaryHandling, Neumann, add_neumann_boundary +from pystencils.boundaries import BoundaryHandling, Dirichlet, Neumann, add_neumann_boundary from pystencils.datahandling import SerialDataHandling from pystencils.slicing import slice_from_direction @@ -88,3 +87,27 @@ def test_kernel_vs_copy_boundary(): pytest.importorskip('pyevtk') boundary_handling.geometry_to_vtk(file_name=os.path.join(tmp_dir, 'test_output1'), ghost_layers=False) boundary_handling.geometry_to_vtk(file_name=os.path.join(tmp_dir, 'test_output2'), ghost_layers=True) + + +@pytest.mark.parametrize('with_indices', ('with_indices', False)) +def test_dirichlet(with_indices): + value = (1, 20, 3) if with_indices else 1 + + dh = SerialDataHandling(domain_size=(7, 7)) + src = dh.add_array('src', values_per_cell=3 if with_indices else 1) + dh.cpu_arrays.src[...] = np.random.rand(*src.shape) + boundary_stencil = [(1, 0), (-1, 0), (0, 1), (0, -1)] + boundary_handling = BoundaryHandling(dh, src.name, boundary_stencil) + dirichlet = Dirichlet(value) + assert dirichlet.name == 'Dirichlet' + dirichlet.name = "wall" + assert dirichlet.name == 'wall' + + for d in ('N', 'S', 'W', 'E'): + boundary_handling.set_boundary(dirichlet, slice_from_direction(d, dim=2)) + boundary_handling() + + assert all([np.allclose(a, np.array(value)) for a in dh.cpu_arrays.src[1:-2, 0]]) + assert all([np.allclose(a, np.array(value)) for a in dh.cpu_arrays.src[1:-2, -1]]) + assert all([np.allclose(a, np.array(value)) for a in dh.cpu_arrays.src[0, 1:-2]]) + assert all([np.allclose(a, np.array(value)) for a in dh.cpu_arrays.src[-1, 1:-2]])