From 843ea751bd228a54357896d08a9efa70660400d5 Mon Sep 17 00:00:00 2001
From: Markus Holzer <markus.holzer@fau.de>
Date: Tue, 13 Nov 2018 13:53:28 +0100
Subject: [PATCH] Dirichlet Boundary Condition for pystencils

---
 boundaries/__init__.py           |  4 ++--
 boundaries/boundaryconditions.py | 27 ++++++++++++++++++++++++++-
 2 files changed, 28 insertions(+), 3 deletions(-)

diff --git a/boundaries/__init__.py b/boundaries/__init__.py
index afce6f99c..cf94f330f 100644
--- a/boundaries/__init__.py
+++ b/boundaries/__init__.py
@@ -1,5 +1,5 @@
 from pystencils.boundaries.boundaryhandling import BoundaryHandling
-from pystencils.boundaries.boundaryconditions import Neumann
+from pystencils.boundaries.boundaryconditions import Neumann, Dirichlet
 from pystencils.boundaries.inkernel import add_neumann_boundary
 
-__all__ = ['BoundaryHandling', 'Neumann', 'add_neumann_boundary']
+__all__ = ['BoundaryHandling', 'Neumann', 'Dirichlet', 'add_neumann_boundary']
diff --git a/boundaries/boundaryconditions.py b/boundaries/boundaryconditions.py
index db898d982..f1e99e04f 100644
--- a/boundaries/boundaryconditions.py
+++ b/boundaries/boundaryconditions.py
@@ -1,6 +1,7 @@
+from typing import List, Tuple, Any
 from pystencils import Assignment
 from pystencils.boundaries.boundaryhandling import BoundaryOffsetInfo
-from typing import List, Tuple, Any
+from pystencils.data_types import create_type
 
 
 class Boundary:
@@ -68,3 +69,27 @@ class Neumann(Boundary):
 
     def __eq__(self, other):
         return type(other) == Neumann
+
+
+class Dirichlet(Boundary):
+    def __init__(self, value, name="Dirchlet"):
+        super().__init__(name)
+        self._value = value
+
+    @property
+    def additional_data(self):
+        if callable(self._value):
+            return [('value', create_type("double"))]
+        else:
+            return []
+
+    @property
+    def additional_data_init_callback(self):
+        if callable(self._value):
+            return self._value
+
+    def __call__(self, field, direction_symbol, index_field, **kwargs):
+        if self.additional_data:
+            return [Assignment(field.center, index_field("value"))]
+        if field.index_dimensions == 0:
+            return [Assignment(field.center, self._value)]
-- 
GitLab