1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file storage_access.h |
22 | * \brief Common data structure for storage access analysis. |
23 | */ |
24 | #ifndef TVM_TIR_TRANSFORMS_STORAGE_ACCESS_H_ |
25 | #define TVM_TIR_TRANSFORMS_STORAGE_ACCESS_H_ |
26 | |
27 | #include <tvm/arith/int_set.h> |
28 | #include <tvm/ir/attrs.h> |
29 | #include <tvm/tir/expr.h> |
30 | #include <tvm/tir/stmt_functor.h> |
31 | |
32 | #include <unordered_map> |
33 | #include <vector> |
34 | |
35 | #include "../../runtime/thread_storage_scope.h" |
36 | |
37 | namespace tvm { |
38 | namespace tir { |
39 | |
40 | using runtime::StorageRank; |
41 | using runtime::StorageScope; |
42 | /*! |
43 | * \brief Base class of storage access analysis |
44 | */ |
45 | class StorageAccessVisitor : public StmtExprVisitor { |
46 | public: |
47 | /*! \brief Storage access type */ |
48 | enum AccessType { |
49 | kRead, |
50 | kWrite, |
51 | kSync, |
52 | kAlloc, |
53 | // acquired version of read, only need to handle WAR dep. |
54 | kReadAcquire |
55 | }; |
56 | /*! \brief An access entry */ |
57 | struct AccessEntry { |
58 | /*! \brief The thread index that access this entry */ |
59 | Array<IterVar> threads; |
60 | /*! \brief The buffer variable, if any */ |
61 | Var buffer = NullValue<Var>(); |
62 | /*! \brief The access data type */ |
63 | DataType dtype; |
64 | /*! \brief The touched access range |
65 | * |
66 | * Has one IntSet for each index in the buffer being accessed. |
67 | */ |
68 | Array<arith::IntSet> touched; |
69 | /*! \brief The type of access */ |
70 | AccessType type; |
71 | /*! \brief The storage scope */ |
72 | StorageScope scope; |
73 | /*! \brief Whether the access is double buffer write */ |
74 | bool double_buffer_write = false; |
75 | }; |
76 | /*! \brief Access pattern about a single statement */ |
77 | struct StmtEntry { |
78 | /*! \brief The statement */ |
79 | const Object* stmt; |
80 | /*! \brief access patterns in the statement */ |
81 | std::vector<AccessEntry> access; |
82 | }; |
83 | // override visitor pattern |
84 | void VisitExpr_(const LoadNode* op) final; |
85 | void VisitStmt_(const StoreNode* op) final; |
86 | void VisitExpr_(const BufferLoadNode* op) final; |
87 | void VisitStmt_(const BufferStoreNode* op) final; |
88 | void VisitStmt_(const EvaluateNode* op) final; |
89 | void VisitStmt_(const AttrStmtNode* op) final; |
90 | void VisitStmt_(const ForNode* op) final; |
91 | void VisitStmt_(const IfThenElseNode* op) final; |
92 | void VisitStmt_(const WhileNode* op) final; |
93 | void VisitExpr_(const CallNode* op) final; |
94 | |
95 | protected: |
96 | StorageAccessVisitor() { scope_.push_back(std::vector<StmtEntry>()); } |
97 | /*! \return number of conditions in the current scope. */ |
98 | int condition_counter() const { return condition_counter_; } |
99 | /*! \return whether we are in device environment. */ |
100 | bool in_device_env() const { return in_device_env_; } |
101 | /*! \return environment threads */ |
102 | const Array<IterVar>& env_threads() const { return env_threads_; } |
103 | /*! |
104 | * \brief Whether we need analyze the buffer in current scope. |
105 | * \param buffer The buffer to be checked |
106 | * \param scope The scope of the buffer. |
107 | * \return Whether the analysis of buffer is enabled. |
108 | */ |
109 | virtual bool Enabled(const VarNode* buffer, const StorageScope& scope) const { return true; } |
110 | /*! |
111 | * \brief Summarize the sequence of operations into parent. |
112 | * |
113 | * Insert synchronization if necessary and remove un-necessary |
114 | * memory access which are already synced. |
115 | * |
116 | * \param seq The sequence of the access operations. |
117 | * \param loop Pass loop node if it is a loop, otherwise nullptr. |
118 | * \return The summarized sequence that represent access that |
119 | * the parent should taken care of to synchronize. |
120 | */ |
121 | virtual std::vector<AccessEntry> Summarize(std::vector<StmtEntry> seq, const ForNode* loop) = 0; |
122 | /*! |
123 | * \brief Get the scope of the buffer array. |
124 | * \return The scope of the final buffer array. |
125 | */ |
126 | StorageScope GetScope(Var buffer_var) const; |
127 | // access scope |
128 | std::vector<std::vector<StmtEntry>> scope_; |
129 | |
130 | private: |
131 | // whether access appending is enabled. |
132 | bool allow_append_{false}; |
133 | // Whether we are in device environment |
134 | bool in_device_env_{false}; |
135 | // Whether we are inside condition. |
136 | int condition_counter_{0}; |
137 | // The current double buffer write scope. |
138 | const VarNode* double_buffer_write_{nullptr}; |
139 | // the current free stmt entry. |
140 | StmtEntry curr_stmt_; |
141 | // The involving threads |
142 | Array<IterVar> env_threads_; |
143 | }; |
144 | |
145 | } // namespace tir |
146 | } // namespace tvm |
147 | #endif // TVM_TIR_TRANSFORMS_STORAGE_ACCESS_H_ |
148 | |