diff --git a/doc/notebooks/02_tutorial_basic_kernels.ipynb b/doc/notebooks/02_tutorial_basic_kernels.ipynb index 413572375e075500898e146cc4c46a5dbb058f8c..c19bd011e0e1abe4700c1b5d875339d13ea5bc06 100644 --- a/doc/notebooks/02_tutorial_basic_kernels.ipynb +++ b/doc/notebooks/02_tutorial_basic_kernels.ipynb @@ -207,7 +207,7 @@ "<div>Subexpressions:</div><table style=\"border:none; width: 100%; \"><tr style=\"border:none\"> <td style=\"border:none\">$$a \\leftarrow {src}_{(0,1)} + {src}_{(-1,0)}$$</td> </tr> <tr style=\"border:none\"> <td style=\"border:none\">$$b \\leftarrow 2 {src}_{(1,0)} + {src}_{(0,-1)}$$</td> </tr> <tr style=\"border:none\"> <td style=\"border:none\">$$c \\leftarrow - {src}_{(0,0)} + 2 {src}_{(1,0)} + {src}_{(0,1)} + {src}_{(0,-1)} + {src}_{(-1,0)}$$</td> </tr> </table><div>Main Assignments:</div><table style=\"border:none; width: 100%; \"><tr style=\"border:none\"> <td style=\"border:none\">$${dst}_{(0,0)} \\leftarrow a + b + c$$</td> </tr> </table>" ], "text/plain": [ - "AssignmentCollection: dst_C, <- f(src_W, src_S, src_N, src_C, src_E)" + "AssignmentCollection: dst_C, <- f(src_E, src_C, src_N, src_W, src_S)" ] }, "execution_count": 7, @@ -274,7 +274,7 @@ "<div>Subexpressions:</div><table style=\"border:none; width: 100%; \"><tr style=\"border:none\"> <td style=\"border:none\">$$a \\leftarrow {src}_{(0,1)} + {src}_{(-1,0)}$$</td> </tr> <tr style=\"border:none\"> <td style=\"border:none\">$$b \\leftarrow 2 {src}_{(1,0)} + {src}_{(0,-1)}$$</td> </tr> <tr style=\"border:none\"> <td style=\"border:none\">$$c \\leftarrow - {src}_{(0,0)} + a + b$$</td> </tr> </table><div>Main Assignments:</div><table style=\"border:none; width: 100%; \"><tr style=\"border:none\"> <td style=\"border:none\">$${dst}_{(0,0)} \\leftarrow a + b + c$$</td> </tr> </table>" ], "text/plain": [ - "AssignmentCollection: dst_C, <- f(src_W, src_S, src_N, src_C, src_E)" + "AssignmentCollection: dst_C, <- f(src_E, src_C, src_N, src_W, src_S)" ] }, "execution_count": 9, @@ -415,11 +415,11 @@ { "data": { "text/html": [ - "<div class=\"highlight\"><pre><span></span><span class=\"n\">FUNC_PREFIX</span><span class=\"w\"> </span><span class=\"kt\">void</span><span class=\"w\"> </span><span class=\"n\">kernel</span><span class=\"p\">(</span><span class=\"kt\">double</span><span class=\"w\"> </span><span class=\"o\">*</span><span class=\"w\"> </span><span class=\"n\">RESTRICT</span><span class=\"w\"> </span><span class=\"n\">_data_dst</span><span class=\"p\">,</span><span class=\"w\"> </span><span class=\"kt\">double</span><span class=\"w\"> </span><span class=\"o\">*</span><span class=\"w\"> </span><span class=\"n\">RESTRICT</span><span class=\"w\"> </span><span class=\"k\">const</span><span class=\"w\"> </span><span class=\"n\">_data_src</span><span class=\"p\">)</span><span class=\"w\"></span>\n", + "<div class=\"highlight\"><pre><span></span><span class=\"n\">FUNC_PREFIX</span><span class=\"w\"> </span><span class=\"kt\">void</span><span class=\"w\"> </span><span class=\"n\">kernel</span><span class=\"p\">(</span><span class=\"kt\">double</span><span class=\"w\"> </span><span class=\"o\">*</span><span class=\"w\"> </span><span class=\"n\">RESTRICT</span><span class=\"w\"> </span><span class=\"n\">_data_dst</span><span class=\"p\">,</span><span class=\"w\"> </span><span class=\"kt\">double</span><span class=\"w\"> </span><span class=\"o\">*</span><span class=\"w\"> </span><span class=\"n\">RESTRICT</span><span class=\"w\"> </span><span class=\"k\">const</span><span class=\"w\"> </span><span class=\"n\">_data_src</span><span class=\"p\">)</span><span class=\"w\"></span>\n", "<span class=\"p\">{</span><span class=\"w\"></span>\n", "<span class=\"w\"> </span><span class=\"k\">for</span><span class=\"w\"> </span><span class=\"p\">(</span><span class=\"kt\">int64_t</span><span class=\"w\"> </span><span class=\"n\">ctr_0</span><span class=\"w\"> </span><span class=\"o\">=</span><span class=\"w\"> </span><span class=\"mi\">2</span><span class=\"p\">;</span><span class=\"w\"> </span><span class=\"n\">ctr_0</span><span class=\"w\"> </span><span class=\"o\"><</span><span class=\"w\"> </span><span class=\"mi\">18</span><span class=\"p\">;</span><span class=\"w\"> </span><span class=\"n\">ctr_0</span><span class=\"w\"> </span><span class=\"o\">+=</span><span class=\"w\"> </span><span class=\"mi\">1</span><span class=\"p\">)</span><span class=\"w\"></span>\n", "<span class=\"w\"> </span><span class=\"p\">{</span><span class=\"w\"></span>\n", - "<span class=\"w\"> </span><span class=\"kt\">double</span><span class=\"w\"> </span><span class=\"o\">*</span><span class=\"w\"> </span><span class=\"n\">RESTRICT</span><span class=\"w\"> </span><span class=\"n\">_data_dst_00</span><span class=\"w\"> </span><span class=\"o\">=</span><span class=\"w\"> </span><span class=\"n\">_data_dst</span><span class=\"w\"> </span><span class=\"o\">+</span><span class=\"w\"> </span><span class=\"mi\">30</span><span class=\"o\">*</span><span class=\"n\">ctr_0</span><span class=\"p\">;</span><span class=\"w\"></span>\n", + "<span class=\"w\"> </span><span class=\"kt\">double</span><span class=\"w\"> </span><span class=\"o\">*</span><span class=\"w\"> </span><span class=\"n\">RESTRICT</span><span class=\"w\"> </span><span class=\"n\">_data_dst_00</span><span class=\"w\"> </span><span class=\"o\">=</span><span class=\"w\"> </span><span class=\"n\">_data_dst</span><span class=\"w\"> </span><span class=\"o\">+</span><span class=\"w\"> </span><span class=\"mi\">30</span><span class=\"o\">*</span><span class=\"n\">ctr_0</span><span class=\"p\">;</span><span class=\"w\"></span>\n", "<span class=\"w\"> </span><span class=\"kt\">double</span><span class=\"w\"> </span><span class=\"o\">*</span><span class=\"w\"> </span><span class=\"n\">RESTRICT</span><span class=\"w\"> </span><span class=\"n\">_data_src_02</span><span class=\"w\"> </span><span class=\"o\">=</span><span class=\"w\"> </span><span class=\"n\">_data_src</span><span class=\"w\"> </span><span class=\"o\">+</span><span class=\"w\"> </span><span class=\"mi\">30</span><span class=\"o\">*</span><span class=\"n\">ctr_0</span><span class=\"w\"> </span><span class=\"o\">+</span><span class=\"w\"> </span><span class=\"mi\">60</span><span class=\"p\">;</span><span class=\"w\"></span>\n", "<span class=\"w\"> </span><span class=\"kt\">double</span><span class=\"w\"> </span><span class=\"o\">*</span><span class=\"w\"> </span><span class=\"n\">RESTRICT</span><span class=\"w\"> </span><span class=\"n\">_data_src_0m1</span><span class=\"w\"> </span><span class=\"o\">=</span><span class=\"w\"> </span><span class=\"n\">_data_src</span><span class=\"w\"> </span><span class=\"o\">+</span><span class=\"w\"> </span><span class=\"mi\">30</span><span class=\"o\">*</span><span class=\"n\">ctr_0</span><span class=\"w\"> </span><span class=\"o\">-</span><span class=\"w\"> </span><span class=\"mi\">30</span><span class=\"p\">;</span><span class=\"w\"></span>\n", "<span class=\"w\"> </span><span class=\"k\">for</span><span class=\"w\"> </span><span class=\"p\">(</span><span class=\"kt\">int64_t</span><span class=\"w\"> </span><span class=\"n\">ctr_1</span><span class=\"w\"> </span><span class=\"o\">=</span><span class=\"w\"> </span><span class=\"mi\">2</span><span class=\"p\">;</span><span class=\"w\"> </span><span class=\"n\">ctr_1</span><span class=\"w\"> </span><span class=\"o\"><</span><span class=\"w\"> </span><span class=\"mi\">28</span><span class=\"p\">;</span><span class=\"w\"> </span><span class=\"n\">ctr_1</span><span class=\"w\"> </span><span class=\"o\">+=</span><span class=\"w\"> </span><span class=\"mi\">1</span><span class=\"p\">)</span><span class=\"w\"></span>\n", @@ -431,11 +431,11 @@ "</pre></div>\n" ], "text/plain": [ - "FUNC_PREFIX void kernel(double * RESTRICT _data_dst, double * RESTRICT const _data_src)\n", + "FUNC_PREFIX void kernel(double * RESTRICT _data_dst, double * RESTRICT const _data_src)\n", "{\n", " for (int64_t ctr_0 = 2; ctr_0 < 18; ctr_0 += 1)\n", " {\n", - " double * RESTRICT _data_dst_00 = _data_dst + 30*ctr_0;\n", + " double * RESTRICT _data_dst_00 = _data_dst + 30*ctr_0;\n", " double * RESTRICT _data_src_02 = _data_src + 30*ctr_0 + 60;\n", " double * RESTRICT _data_src_0m1 = _data_src + 30*ctr_0 - 30;\n", " for (int64_t ctr_1 = 2; ctr_1 < 28; ctr_1 += 1)\n", @@ -556,11 +556,11 @@ { "data": { "text/html": [ - "<div class=\"highlight\"><pre><span></span><span class=\"n\">FUNC_PREFIX</span><span class=\"w\"> </span><span class=\"kt\">void</span><span class=\"w\"> </span><span class=\"n\">kernel</span><span class=\"p\">(</span><span class=\"kt\">double</span><span class=\"w\"> </span><span class=\"o\">*</span><span class=\"w\"> </span><span class=\"n\">RESTRICT</span><span class=\"w\"> </span><span class=\"n\">_data_dst</span><span class=\"p\">,</span><span class=\"w\"> </span><span class=\"kt\">double</span><span class=\"w\"> </span><span class=\"o\">*</span><span class=\"w\"> </span><span class=\"n\">RESTRICT</span><span class=\"w\"> </span><span class=\"k\">const</span><span class=\"w\"> </span><span class=\"n\">_data_src</span><span class=\"p\">)</span><span class=\"w\"></span>\n", + "<div class=\"highlight\"><pre><span></span><span class=\"n\">FUNC_PREFIX</span><span class=\"w\"> </span><span class=\"kt\">void</span><span class=\"w\"> </span><span class=\"n\">kernel</span><span class=\"p\">(</span><span class=\"kt\">double</span><span class=\"w\"> </span><span class=\"o\">*</span><span class=\"w\"> </span><span class=\"n\">RESTRICT</span><span class=\"w\"> </span><span class=\"n\">_data_dst</span><span class=\"p\">,</span><span class=\"w\"> </span><span class=\"kt\">double</span><span class=\"w\"> </span><span class=\"o\">*</span><span class=\"w\"> </span><span class=\"n\">RESTRICT</span><span class=\"w\"> </span><span class=\"k\">const</span><span class=\"w\"> </span><span class=\"n\">_data_src</span><span class=\"p\">)</span><span class=\"w\"></span>\n", "<span class=\"p\">{</span><span class=\"w\"></span>\n", "<span class=\"w\"> </span><span class=\"k\">for</span><span class=\"w\"> </span><span class=\"p\">(</span><span class=\"kt\">int64_t</span><span class=\"w\"> </span><span class=\"n\">ctr_0</span><span class=\"w\"> </span><span class=\"o\">=</span><span class=\"w\"> </span><span class=\"mi\">0</span><span class=\"p\">;</span><span class=\"w\"> </span><span class=\"n\">ctr_0</span><span class=\"w\"> </span><span class=\"o\"><</span><span class=\"w\"> </span><span class=\"mi\">18</span><span class=\"p\">;</span><span class=\"w\"> </span><span class=\"n\">ctr_0</span><span class=\"w\"> </span><span class=\"o\">+=</span><span class=\"w\"> </span><span class=\"mi\">1</span><span class=\"p\">)</span><span class=\"w\"></span>\n", "<span class=\"w\"> </span><span class=\"p\">{</span><span class=\"w\"></span>\n", - "<span class=\"w\"> </span><span class=\"kt\">double</span><span class=\"w\"> </span><span class=\"o\">*</span><span class=\"w\"> </span><span class=\"n\">RESTRICT</span><span class=\"w\"> </span><span class=\"n\">_data_dst_00</span><span class=\"w\"> </span><span class=\"o\">=</span><span class=\"w\"> </span><span class=\"n\">_data_dst</span><span class=\"w\"> </span><span class=\"o\">+</span><span class=\"w\"> </span><span class=\"mi\">30</span><span class=\"o\">*</span><span class=\"n\">ctr_0</span><span class=\"p\">;</span><span class=\"w\"></span>\n", + "<span class=\"w\"> </span><span class=\"kt\">double</span><span class=\"w\"> </span><span class=\"o\">*</span><span class=\"w\"> </span><span class=\"n\">RESTRICT</span><span class=\"w\"> </span><span class=\"n\">_data_dst_00</span><span class=\"w\"> </span><span class=\"o\">=</span><span class=\"w\"> </span><span class=\"n\">_data_dst</span><span class=\"w\"> </span><span class=\"o\">+</span><span class=\"w\"> </span><span class=\"mi\">30</span><span class=\"o\">*</span><span class=\"n\">ctr_0</span><span class=\"p\">;</span><span class=\"w\"></span>\n", "<span class=\"w\"> </span><span class=\"kt\">double</span><span class=\"w\"> </span><span class=\"o\">*</span><span class=\"w\"> </span><span class=\"n\">RESTRICT</span><span class=\"w\"> </span><span class=\"n\">_data_src_02</span><span class=\"w\"> </span><span class=\"o\">=</span><span class=\"w\"> </span><span class=\"n\">_data_src</span><span class=\"w\"> </span><span class=\"o\">+</span><span class=\"w\"> </span><span class=\"mi\">30</span><span class=\"o\">*</span><span class=\"n\">ctr_0</span><span class=\"w\"> </span><span class=\"o\">+</span><span class=\"w\"> </span><span class=\"mi\">60</span><span class=\"p\">;</span><span class=\"w\"></span>\n", "<span class=\"w\"> </span><span class=\"kt\">double</span><span class=\"w\"> </span><span class=\"o\">*</span><span class=\"w\"> </span><span class=\"n\">RESTRICT</span><span class=\"w\"> </span><span class=\"n\">_data_src_0m1</span><span class=\"w\"> </span><span class=\"o\">=</span><span class=\"w\"> </span><span class=\"n\">_data_src</span><span class=\"w\"> </span><span class=\"o\">+</span><span class=\"w\"> </span><span class=\"mi\">30</span><span class=\"o\">*</span><span class=\"n\">ctr_0</span><span class=\"w\"> </span><span class=\"o\">-</span><span class=\"w\"> </span><span class=\"mi\">30</span><span class=\"p\">;</span><span class=\"w\"></span>\n", "<span class=\"w\"> </span><span class=\"k\">for</span><span class=\"w\"> </span><span class=\"p\">(</span><span class=\"kt\">int64_t</span><span class=\"w\"> </span><span class=\"n\">ctr_1</span><span class=\"w\"> </span><span class=\"o\">=</span><span class=\"w\"> </span><span class=\"mi\">1</span><span class=\"p\">;</span><span class=\"w\"> </span><span class=\"n\">ctr_1</span><span class=\"w\"> </span><span class=\"o\"><</span><span class=\"w\"> </span><span class=\"mi\">30</span><span class=\"p\">;</span><span class=\"w\"> </span><span class=\"n\">ctr_1</span><span class=\"w\"> </span><span class=\"o\">+=</span><span class=\"w\"> </span><span class=\"mi\">1</span><span class=\"p\">)</span><span class=\"w\"></span>\n", @@ -572,11 +572,11 @@ "</pre></div>\n" ], "text/plain": [ - "FUNC_PREFIX void kernel(double * RESTRICT _data_dst, double * RESTRICT const _data_src)\n", + "FUNC_PREFIX void kernel(double * RESTRICT _data_dst, double * RESTRICT const _data_src)\n", "{\n", " for (int64_t ctr_0 = 0; ctr_0 < 18; ctr_0 += 1)\n", " {\n", - " double * RESTRICT _data_dst_00 = _data_dst + 30*ctr_0;\n", + " double * RESTRICT _data_dst_00 = _data_dst + 30*ctr_0;\n", " double * RESTRICT _data_src_02 = _data_src + 30*ctr_0 + 60;\n", " double * RESTRICT _data_src_0m1 = _data_src + 30*ctr_0 - 30;\n", " for (int64_t ctr_1 = 1; ctr_1 < 30; ctr_1 += 1)\n", @@ -716,31 +716,183 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "However, for right hand sides that are Field.Accesses this is allowed:" + "Also it is not allowed to write a field at the same location" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Field dst is written twice at the same location\n" + ] + } + ], + "source": [ + "@ps.kernel\n", + "def not_allowed():\n", + " dst[0, 0] @= src[0, 1] + src[1, 0]\n", + " dst[0, 0] @= 2 * dst[0, 0]\n", + "\n", + "try:\n", + " ps.create_kernel(not_allowed)\n", + " assert False\n", + "except ValueError as e:\n", + " print(e)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This situation should be resolved by introducing temporary variables" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, "outputs": [ { "data": { + "text/html": [ + "<style>pre { line-height: 125%; }\n", + "td.linenos .normal { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; }\n", + "span.linenos { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; }\n", + "td.linenos .special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; }\n", + "span.linenos.special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; }\n", + ".highlight .hll { background-color: #ffffcc }\n", + ".highlight { background: #f8f8f8; }\n", + ".highlight .c { color: #408080; font-style: italic } /* Comment */\n", + ".highlight .err { border: 1px solid #FF0000 } /* Error */\n", + ".highlight .k { color: #008000; font-weight: bold } /* Keyword */\n", + ".highlight .o { color: #666666 } /* Operator */\n", + ".highlight .ch { color: #408080; font-style: italic } /* Comment.Hashbang */\n", + ".highlight .cm { color: #408080; font-style: italic } /* Comment.Multiline */\n", + ".highlight .cp { color: #BC7A00 } /* Comment.Preproc */\n", + ".highlight .cpf { color: #408080; font-style: italic } /* Comment.PreprocFile */\n", + ".highlight .c1 { color: #408080; font-style: italic } /* Comment.Single */\n", + ".highlight .cs { color: #408080; font-style: italic } /* Comment.Special */\n", + ".highlight .gd { color: #A00000 } /* Generic.Deleted */\n", + ".highlight .ge { font-style: italic } /* Generic.Emph */\n", + ".highlight .gr { color: #FF0000 } /* Generic.Error */\n", + ".highlight .gh { color: #000080; font-weight: bold } /* Generic.Heading */\n", + ".highlight .gi { color: #00A000 } /* Generic.Inserted */\n", + ".highlight .go { color: #888888 } /* Generic.Output */\n", + ".highlight .gp { color: #000080; font-weight: bold } /* Generic.Prompt */\n", + ".highlight .gs { font-weight: bold } /* Generic.Strong */\n", + ".highlight .gu { color: #800080; font-weight: bold } /* Generic.Subheading */\n", + ".highlight .gt { color: #0044DD } /* Generic.Traceback */\n", + ".highlight .kc { color: #008000; font-weight: bold } /* Keyword.Constant */\n", + ".highlight .kd { color: #008000; font-weight: bold } /* Keyword.Declaration */\n", + ".highlight .kn { color: #008000; font-weight: bold } /* Keyword.Namespace */\n", + ".highlight .kp { color: #008000 } /* Keyword.Pseudo */\n", + ".highlight .kr { color: #008000; font-weight: bold } /* Keyword.Reserved */\n", + ".highlight .kt { color: #B00040 } /* Keyword.Type */\n", + ".highlight .m { color: #666666 } /* Literal.Number */\n", + ".highlight .s { color: #BA2121 } /* Literal.String */\n", + ".highlight .na { color: #7D9029 } /* Name.Attribute */\n", + ".highlight .nb { color: #008000 } /* Name.Builtin */\n", + ".highlight .nc { color: #0000FF; font-weight: bold } /* Name.Class */\n", + ".highlight .no { color: #880000 } /* Name.Constant */\n", + ".highlight .nd { color: #AA22FF } /* Name.Decorator */\n", + ".highlight .ni { color: #999999; font-weight: bold } /* Name.Entity */\n", + ".highlight .ne { color: #D2413A; font-weight: bold } /* Name.Exception */\n", + ".highlight .nf { color: #0000FF } /* Name.Function */\n", + ".highlight .nl { color: #A0A000 } /* Name.Label */\n", + ".highlight .nn { color: #0000FF; font-weight: bold } /* Name.Namespace */\n", + ".highlight .nt { color: #008000; font-weight: bold } /* Name.Tag */\n", + ".highlight .nv { color: #19177C } /* Name.Variable */\n", + ".highlight .ow { color: #AA22FF; font-weight: bold } /* Operator.Word */\n", + ".highlight .w { color: #bbbbbb } /* Text.Whitespace */\n", + ".highlight .mb { color: #666666 } /* Literal.Number.Bin */\n", + ".highlight .mf { color: #666666 } /* Literal.Number.Float */\n", + ".highlight .mh { color: #666666 } /* Literal.Number.Hex */\n", + ".highlight .mi { color: #666666 } /* Literal.Number.Integer */\n", + ".highlight .mo { color: #666666 } /* Literal.Number.Oct */\n", + ".highlight .sa { color: #BA2121 } /* Literal.String.Affix */\n", + ".highlight .sb { color: #BA2121 } /* Literal.String.Backtick */\n", + ".highlight .sc { color: #BA2121 } /* Literal.String.Char */\n", + ".highlight .dl { color: #BA2121 } /* Literal.String.Delimiter */\n", + ".highlight .sd { color: #BA2121; font-style: italic } /* Literal.String.Doc */\n", + ".highlight .s2 { color: #BA2121 } /* Literal.String.Double */\n", + ".highlight .se { color: #BB6622; font-weight: bold } /* Literal.String.Escape */\n", + ".highlight .sh { color: #BA2121 } /* Literal.String.Heredoc */\n", + ".highlight .si { color: #BB6688; font-weight: bold } /* Literal.String.Interpol */\n", + ".highlight .sx { color: #008000 } /* Literal.String.Other */\n", + ".highlight .sr { color: #BB6688 } /* Literal.String.Regex */\n", + ".highlight .s1 { color: #BA2121 } /* Literal.String.Single */\n", + ".highlight .ss { color: #19177C } /* Literal.String.Symbol */\n", + ".highlight .bp { color: #008000 } /* Name.Builtin.Pseudo */\n", + ".highlight .fm { color: #0000FF } /* Name.Function.Magic */\n", + ".highlight .vc { color: #19177C } /* Name.Variable.Class */\n", + ".highlight .vg { color: #19177C } /* Name.Variable.Global */\n", + ".highlight .vi { color: #19177C } /* Name.Variable.Instance */\n", + ".highlight .vm { color: #19177C } /* Name.Variable.Magic */\n", + ".highlight .il { color: #666666 } /* Literal.Number.Integer.Long */</style>" + ], "text/plain": [ - "KernelFunction kernel([_data_dst, _data_src])" + "<IPython.core.display.HTML object>" ] }, - "execution_count": 17, "metadata": {}, - "output_type": "execute_result" + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "<div class=\"highlight\"><pre><span></span><span class=\"n\">FUNC_PREFIX</span><span class=\"w\"> </span><span class=\"kt\">void</span><span class=\"w\"> </span><span class=\"n\">kernel</span><span class=\"p\">(</span><span class=\"kt\">double</span><span class=\"w\"> </span><span class=\"o\">*</span><span class=\"w\"> </span><span class=\"n\">RESTRICT</span><span class=\"w\"> </span><span class=\"n\">_data_dst</span><span class=\"p\">,</span><span class=\"w\"> </span><span class=\"kt\">double</span><span class=\"w\"> </span><span class=\"o\">*</span><span class=\"w\"> </span><span class=\"n\">RESTRICT</span><span class=\"w\"> </span><span class=\"k\">const</span><span class=\"w\"> </span><span class=\"n\">_data_src</span><span class=\"p\">)</span><span class=\"w\"></span>\n", + "<span class=\"p\">{</span><span class=\"w\"></span>\n", + "<span class=\"w\"> </span><span class=\"k\">for</span><span class=\"w\"> </span><span class=\"p\">(</span><span class=\"kt\">int64_t</span><span class=\"w\"> </span><span class=\"n\">ctr_0</span><span class=\"w\"> </span><span class=\"o\">=</span><span class=\"w\"> </span><span class=\"mi\">1</span><span class=\"p\">;</span><span class=\"w\"> </span><span class=\"n\">ctr_0</span><span class=\"w\"> </span><span class=\"o\"><</span><span class=\"w\"> </span><span class=\"mi\">19</span><span class=\"p\">;</span><span class=\"w\"> </span><span class=\"n\">ctr_0</span><span class=\"w\"> </span><span class=\"o\">+=</span><span class=\"w\"> </span><span class=\"mi\">1</span><span class=\"p\">)</span><span class=\"w\"></span>\n", + "<span class=\"w\"> </span><span class=\"p\">{</span><span class=\"w\"></span>\n", + "<span class=\"w\"> </span><span class=\"kt\">double</span><span class=\"w\"> </span><span class=\"o\">*</span><span class=\"w\"> </span><span class=\"n\">RESTRICT</span><span class=\"w\"> </span><span class=\"n\">_data_src_01</span><span class=\"w\"> </span><span class=\"o\">=</span><span class=\"w\"> </span><span class=\"n\">_data_src</span><span class=\"w\"> </span><span class=\"o\">+</span><span class=\"w\"> </span><span class=\"mi\">30</span><span class=\"o\">*</span><span class=\"n\">ctr_0</span><span class=\"w\"> </span><span class=\"o\">+</span><span class=\"w\"> </span><span class=\"mi\">30</span><span class=\"p\">;</span><span class=\"w\"></span>\n", + "<span class=\"w\"> </span><span class=\"kt\">double</span><span class=\"w\"> </span><span class=\"o\">*</span><span class=\"w\"> </span><span class=\"n\">RESTRICT</span><span class=\"w\"> </span><span class=\"n\">_data_src_00</span><span class=\"w\"> </span><span class=\"o\">=</span><span class=\"w\"> </span><span class=\"n\">_data_src</span><span class=\"w\"> </span><span class=\"o\">+</span><span class=\"w\"> </span><span class=\"mi\">30</span><span class=\"o\">*</span><span class=\"n\">ctr_0</span><span class=\"p\">;</span><span class=\"w\"></span>\n", + "<span class=\"w\"> </span><span class=\"kt\">double</span><span class=\"w\"> </span><span class=\"o\">*</span><span class=\"w\"> </span><span class=\"n\">RESTRICT</span><span class=\"w\"> </span><span class=\"n\">_data_dst_00</span><span class=\"w\"> </span><span class=\"o\">=</span><span class=\"w\"> </span><span class=\"n\">_data_dst</span><span class=\"w\"> </span><span class=\"o\">+</span><span class=\"w\"> </span><span class=\"mi\">30</span><span class=\"o\">*</span><span class=\"n\">ctr_0</span><span class=\"p\">;</span><span class=\"w\"></span>\n", + "<span class=\"w\"> </span><span class=\"k\">for</span><span class=\"w\"> </span><span class=\"p\">(</span><span class=\"kt\">int64_t</span><span class=\"w\"> </span><span class=\"n\">ctr_1</span><span class=\"w\"> </span><span class=\"o\">=</span><span class=\"w\"> </span><span class=\"mi\">1</span><span class=\"p\">;</span><span class=\"w\"> </span><span class=\"n\">ctr_1</span><span class=\"w\"> </span><span class=\"o\"><</span><span class=\"w\"> </span><span class=\"mi\">29</span><span class=\"p\">;</span><span class=\"w\"> </span><span class=\"n\">ctr_1</span><span class=\"w\"> </span><span class=\"o\">+=</span><span class=\"w\"> </span><span class=\"mi\">1</span><span class=\"p\">)</span><span class=\"w\"></span>\n", + "<span class=\"w\"> </span><span class=\"p\">{</span><span class=\"w\"></span>\n", + "<span class=\"w\"> </span><span class=\"k\">const</span><span class=\"w\"> </span><span class=\"kt\">double</span><span class=\"w\"> </span><span class=\"n\">a</span><span class=\"w\"> </span><span class=\"o\">=</span><span class=\"w\"> </span><span class=\"n\">_data_src_00</span><span class=\"p\">[</span><span class=\"n\">ctr_1</span><span class=\"w\"> </span><span class=\"o\">+</span><span class=\"w\"> </span><span class=\"mi\">1</span><span class=\"p\">]</span><span class=\"w\"> </span><span class=\"o\">+</span><span class=\"w\"> </span><span class=\"n\">_data_src_01</span><span class=\"p\">[</span><span class=\"n\">ctr_1</span><span class=\"p\">];</span><span class=\"w\"></span>\n", + "<span class=\"w\"> </span><span class=\"n\">_data_dst_00</span><span class=\"p\">[</span><span class=\"n\">ctr_1</span><span class=\"p\">]</span><span class=\"w\"> </span><span class=\"o\">=</span><span class=\"w\"> </span><span class=\"n\">a</span><span class=\"o\">*</span><span class=\"mf\">2.0</span><span class=\"p\">;</span><span class=\"w\"></span>\n", + "<span class=\"w\"> </span><span class=\"p\">}</span><span class=\"w\"></span>\n", + "<span class=\"w\"> </span><span class=\"p\">}</span><span class=\"w\"></span>\n", + "<span class=\"p\">}</span><span class=\"w\"></span>\n", + "</pre></div>\n" + ], + "text/plain": [ + "FUNC_PREFIX void kernel(double * RESTRICT _data_dst, double * RESTRICT const _data_src)\n", + "{\n", + " for (int64_t ctr_0 = 1; ctr_0 < 19; ctr_0 += 1)\n", + " {\n", + " double * RESTRICT _data_src_01 = _data_src + 30*ctr_0 + 30;\n", + " double * RESTRICT _data_src_00 = _data_src + 30*ctr_0;\n", + " double * RESTRICT _data_dst_00 = _data_dst + 30*ctr_0;\n", + " for (int64_t ctr_1 = 1; ctr_1 < 29; ctr_1 += 1)\n", + " {\n", + " const double a = _data_src_00[ctr_1 + 1] + _data_src_01[ctr_1];\n", + " _data_dst_00[ctr_1] = a*2.0;\n", + " }\n", + " }\n", + "}" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ + "tmp_var = sp.Symbol(\"a\")\n", + "\n", "@ps.kernel\n", "def allowed():\n", - " dst[0, 0] @= src[0, 1] + src[1, 0]\n", - " dst[0, 0] @= 2 * dst[0, 0]\n", - "ps.create_kernel(allowed)" + " tmp_var @= src[0, 1] + src[1, 0]\n", + " dst[0, 0] @= 2 * tmp_var\n", + "\n", + "\n", + "ast = ps.create_kernel(allowed)\n", + "ps.show_code(ast)" ] } ], @@ -760,7 +912,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.7" + "version": "3.9.9" } }, "nbformat": 4, diff --git a/pystencils/kernel_contrains_check.py b/pystencils/kernel_contrains_check.py index 7456e3f618269a8e8da5fca5a151e84810c700f6..a39c7d0ca12d06acfcf413821b467e9a4270b5a5 100644 --- a/pystencils/kernel_contrains_check.py +++ b/pystencils/kernel_contrains_check.py @@ -35,7 +35,7 @@ class KernelConstraintsCheck: """ FieldAndIndex = namedtuple('FieldAndIndex', ['field', 'index']) - def __init__(self, check_independence_condition, check_double_write_condition=True): + def __init__(self, check_independence_condition=True, check_double_write_condition=True): self.scopes = NestedScopes() self.field_writes = defaultdict(set) self.fields_read = set() @@ -98,7 +98,11 @@ class KernelConstraintsCheck: def update_accesses_lhs(self, lhs): if isinstance(lhs, Field.Access): fai = self.FieldAndIndex(lhs.field, lhs.index) + if self.check_double_write_condition and lhs.offsets in self.field_writes[fai]: + raise ValueError(f"Field {lhs.field.name} is written twice at the same location") + self.field_writes[fai].add(lhs.offsets) + if self.check_double_write_condition and len(self.field_writes[fai]) > 1: raise ValueError( f"Field {lhs.field.name} is written at two different locations") diff --git a/pystencils/kernelcreation.py b/pystencils/kernelcreation.py index 8bfee1b21a70e2c660a869f44ba876fcc2aae5d0..0f943c5eab0ba937b0bfac6b0776cf7b4a108426 100644 --- a/pystencils/kernelcreation.py +++ b/pystencils/kernelcreation.py @@ -128,8 +128,8 @@ def create_domain_kernel(assignments: Union[AssignmentCollection, NodeCollection # FUTURE WORK from here we shouldn't NEED sympy # --- check constrains - check = KernelConstraintsCheck(check_independence_condition=config.skip_independence_check, - check_double_write_condition=config.allow_double_writes) + check = KernelConstraintsCheck(check_independence_condition=not config.skip_independence_check, + check_double_write_condition=not config.allow_double_writes) check.visit(assignments) if isinstance(assignments, AssignmentCollection):