CUDA_GatherFieldAccess.scala 3.76 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
//=============================================================================
//
//  This file is part of the ExaStencils code generation framework. ExaStencils
//  is free software: you can redistribute it and/or modify it under the terms
//  of the GNU General Public License as published by the Free Software
//  Foundation, either version 3 of the License, or (at your option) any later
//  version.
//
//  ExaStencils is distributed in the hope that it will be useful, but WITHOUT
//  ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
//  FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for
//  more details.
//
//  You should have received a copy of the GNU General Public License along
//  with ExaStencils. If not, see <http://www.gnu.org/licenses/>.
//
//=============================================================================

19
package exastencils.parallelization.api.cuda
Sebastian Kuckuk's avatar
Sebastian Kuckuk committed
20
21
22
23

import scala.collection.mutable._

import exastencils.base.ir._
24
import exastencils.core.collectors.Collector
Sebastian Kuckuk's avatar
Sebastian Kuckuk committed
25
import exastencils.datastructures._
26
import exastencils.domain.ir.IR_IV_NeighborFragmentIdx
Sebastian Kuckuk's avatar
Sebastian Kuckuk committed
27
import exastencils.field.ir._
28
import exastencils.logger.Logger
29
import exastencils.util.ir.IR_Read
Sebastian Kuckuk's avatar
Sebastian Kuckuk committed
30

31
class CUDA_GatherFieldAccess extends Collector {
Sebastian Kuckuk's avatar
Sebastian Kuckuk committed
32

33
34
35
36
37
38
39
  /** constants for read/write annotations */
  private final object Access extends Enumeration {
    type Access = Value
    val ANNOT : String = "CUDAAcc"
    val READ, WRITE, UPDATE = Value

    exastencils.core.Duplicate.registerConstant(this)
40
41
  }

42
  val fieldAccesses = HashMap[String, IR_MultiDimFieldAccess]()
43
  private var isRead : Boolean = true
44
45
46
47
  private var isWrite : Boolean = false

  override def reset() : Unit = {
    fieldAccesses.clear()
48
    isRead = true
49
50
    isWrite = false
  }
Sebastian Kuckuk's avatar
Sebastian Kuckuk committed
51

52
  override def enter(node : Node) : Unit = {
Sebastian Kuckuk's avatar
Sebastian Kuckuk committed
53

54
    node.getAnnotation(Access.ANNOT) match {
55
56
57
58
59
60
      case Some(Access.READ)   =>
        isRead = true
        isWrite = false
      case Some(Access.WRITE)  =>
        isRead = false
        isWrite = true
61
62
63
64
      case Some(Access.UPDATE) =>
        isRead = true
        isWrite = true
      case None                =>
65
66
      case _                   => Logger.error("Invalid annotation")

Sebastian Kuckuk's avatar
Sebastian Kuckuk committed
67
68
    }

69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
    def getFieldIdentifier(access : IR_MultiDimFieldAccess) = {
      val field = access.field
      var identifier = field.codeName

      // TODO: array fields
      if (field.numSlots > 1) {
        access.slot match {
          case IR_SlotAccess(_, offset) => identifier += s"_o$offset"
          case IR_IntegerConstant(slot) => identifier += s"_s$slot"
          case other                    => identifier += s"_s${ other.prettyprint }"
        }
      }

      // also consider neighbor fragment accesses
      access.fragIdx match {
        case neigh : IR_IV_NeighborFragmentIdx => identifier += s"_n${ neigh.neighIdx }"
        case _                                 =>
      }

      identifier
    }

91
92
93
94
95
96
97
98
    node match {
      case assign : IR_Assignment =>
        assign.op match {
          case "=" => assign.dest.annotate(Access.ANNOT, Access.WRITE)
          case _   => assign.dest.annotate(Access.ANNOT, Access.UPDATE)
        }
        assign.src.annotate(Access.ANNOT, Access.READ)

99
100
101
102
103
      case read : IR_Read =>
        read.toRead foreach {
          case expr : IR_Expression =>
            expr.annotate(Access.ANNOT, Access.WRITE)
          case _ =>
104
105
        }

106
107
      case access : IR_MultiDimFieldAccess =>
        val identifier = getFieldIdentifier(access)
108

109
110
111
112
113
114
115
        if (isRead)
          fieldAccesses.put("read_" + identifier, access)
        if (isWrite)
          fieldAccesses.put("write_" + identifier, access)

      case _ =>
    }
Sebastian Kuckuk's avatar
Sebastian Kuckuk committed
116
117
  }

118
119
  override def leave(node : Node) : Unit = {
    if (node.removeAnnotation(Access.ANNOT).isDefined) {
120
      isRead = true
121
122
123
      isWrite = false
    }
  }
Sebastian Kuckuk's avatar
Sebastian Kuckuk committed
124
}