CUDA_GatherVariableAccesses.scala 5.67 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
//=============================================================================
//
//  This file is part of the ExaStencils code generation framework. ExaStencils
//  is free software: you can redistribute it and/or modify it under the terms
//  of the GNU General Public License as published by the Free Software
//  Foundation, either version 3 of the License, or (at your option) any later
//  version.
//
//  ExaStencils is distributed in the hope that it will be useful, but WITHOUT
//  ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
//  FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for
//  more details.
//
//  You should have received a copy of the GNU General Public License along
//  with ExaStencils. If not, see <http://www.gnu.org/licenses/>.
//
//=============================================================================

19
package exastencils.parallelization.api.cuda
Sebastian Kuckuk's avatar
Sebastian Kuckuk committed
20
21
22
23

import scala.collection.mutable

import exastencils.base.ir._
24
25
import exastencils.baseExt.ir.IR_LoopOverFragments
import exastencils.config.Knowledge
Sebastian Kuckuk's avatar
Sebastian Kuckuk committed
26
import exastencils.datastructures._
27
import exastencils.logger.Logger
28
29
import exastencils.optimization.ir.EvaluationException
import exastencils.optimization.ir.IR_SimplifyExpression
30
import exastencils.parallelization.api.cuda.CUDA_Util._
Sebastian Kuckuk's avatar
Sebastian Kuckuk committed
31

32
object CUDA_GatherVariableAccesses extends QuietDefaultStrategy("Gather local VariableAccess nodes") {
33
  var reductionTarget : Option[IR_Expression] = None
34
  var kernelCount : Int = 0
35

36
37
  var evaluableAccesses = mutable.HashMap[String, (IR_Access, IR_Datatype)]()
  var nonEvaluableAccesses = mutable.HashMap[String, (IR_VariableAccess, IR_Datatype)]()
Sebastian Kuckuk's avatar
Sebastian Kuckuk committed
38
  var ignoredAccesses = mutable.SortedSet[String]()
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
  var ignoredArrayVariableAccesses = mutable.SortedSet[String]()

  def basePrefix(base : IR_VariableAccess) = base.name
  // regular, evaluable indexed array accesses
  def arrayAccessAsString(base : IR_VariableAccess, idx : IR_Expression) = basePrefix(base) + idx.prettyprint()
  def containsArrayAccess(base : IR_VariableAccess, idx : IR_Expression) = evaluableAccesses.contains(arrayAccessAsString(base, idx))
  // array variable accesses in case that a kernel is passed whole array as argument (for non-evaluable indices)
  def arrayVariableAccessAsString(base : IR_VariableAccess) = s"${basePrefix(base)}_deviceCopy_$kernelCount"
  def containsArrayVariableAccess(base : IR_VariableAccess) = nonEvaluableAccesses.contains(arrayVariableAccessAsString(base))

  def isReplaceable(base : IR_VariableAccess, idx : IR_Expression) =
    containsArrayAccess(base, idx) || containsArrayVariableAccess(base)

  def replaceAccess(base : IR_VariableAccess, idx : IR_Expression) : Option[IR_Expression] = {
    if (isReplaceable(base, idx)) {
      if (containsArrayAccess(base, idx)) {
        val name = arrayAccessAsString(base, idx)
        Some(IR_VariableAccess(name, evaluableAccesses(name)._2))
      } else if (containsArrayVariableAccess(base)) {
        val name = arrayVariableAccessAsString(base)
        Some(IR_ArrayAccess(IR_VariableAccess(name, base.datatype), idx))
      } else {
        Logger.error("Error while gathering variables for CUDA kernels")
      }
    } else {
      None
    }
  }
Sebastian Kuckuk's avatar
Sebastian Kuckuk committed
67

68
69
70
71
72
73
74
75
76
77
78
  def isEvaluable(idx : IR_Expression) = {
    var ret = true
    try {
      IR_SimplifyExpression.evalIntegral(idx)
    } catch {
      case _ : EvaluationException => ret = false
      case _ : MatchError          => ret = false
    }
    ret
  }

79
80
  val fragIdx = IR_LoopOverFragments.defIt

Sebastian Kuckuk's avatar
Sebastian Kuckuk committed
81
  def clear() = {
82
    reductionTarget = None
83
84
85
    evaluableAccesses = mutable.HashMap[String, (IR_Access, IR_Datatype)]()
    nonEvaluableAccesses = mutable.HashMap[String, (IR_VariableAccess, IR_Datatype)]()
    ignoredArrayVariableAccesses = mutable.SortedSet[String]()
86
    ignoredAccesses = mutable.SortedSet[String]()
87
88
89
    ignoredAccesses += "std::cout"
    ignoredAccesses += "std::cerr"
    ignoredAccesses += "std::endl"
Sebastian Kuckuk's avatar
Sebastian Kuckuk committed
90
91
92
  }

  this += new Transformation("Searching", {
93
    case decl : IR_VariableDeclaration =>
Sebastian Kuckuk's avatar
Sebastian Kuckuk committed
94
95
      ignoredAccesses += decl.name
      decl
96

97
    case arrAcc @ IR_ArrayAccess(base : IR_VariableAccess, idx, _) if !ignoredAccesses.contains(base.name) =>
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
      ignoredArrayVariableAccesses += base.name

      if (isEvaluable(idx)) {
        // single, evaluable array accesses -> count "base[idx]" as variable access
        evaluableAccesses.put(arrayAccessAsString(base, idx), (arrAcc, base.datatype.resolveBaseDatatype))
      } else {
        // we found a non-evaluable index -> remove previous evaluable accesses
        evaluableAccesses.foreach {
          case (k, _) if k.startsWith(basePrefix(base)) && k.length > basePrefix(base).length => evaluableAccesses.remove(k)
          case _ =>
        }

        // copy "base" to device data and pass device pointer to the kernel -> count as single variable access to "base"
        nonEvaluableAccesses.put(arrayVariableAccessAsString(base), (base, base.datatype))
      }
113

114
115
      // it can happen that no fragmentIdx is accessed in a loop, but the resulting CudaReductionBuffer requires it
      if (Knowledge.domain_numFragmentsPerBlock > 1 && isReductionVariableAccess(reductionTarget, arrAcc))
116
        evaluableAccesses.put(fragIdx.name, (fragIdx, fragIdx.datatype))
117

118
119
      arrAcc

120
121
    case vAcc : IR_VariableAccess if !ignoredAccesses.contains(vAcc.name) && !ignoredArrayVariableAccesses.contains(vAcc.name) =>
      evaluableAccesses.put(vAcc.name, (vAcc, vAcc.datatype))
122
      vAcc
123
124
125

    // same phenomenon: fragmentIdx is required by CudaReductionBuffer, but not present in loop body
    case expr : IR_Expression if Knowledge.domain_numFragmentsPerBlock > 1 && isReductionTarget(reductionTarget, expr) =>
126
      evaluableAccesses.put(fragIdx.name, (fragIdx, fragIdx.datatype))
127
      expr
128
  })
Sebastian Kuckuk's avatar
Sebastian Kuckuk committed
129
}