1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file storage_access.h
22 * \brief Common data structure for storage access analysis.
23 */
24#ifndef TVM_TIR_TRANSFORMS_STORAGE_ACCESS_H_
25#define TVM_TIR_TRANSFORMS_STORAGE_ACCESS_H_
26
27#include <tvm/arith/int_set.h>
28#include <tvm/ir/attrs.h>
29#include <tvm/tir/expr.h>
30#include <tvm/tir/stmt_functor.h>
31
32#include <unordered_map>
33#include <vector>
34
35#include "../../runtime/thread_storage_scope.h"
36
37namespace tvm {
38namespace tir {
39
40using runtime::StorageRank;
41using runtime::StorageScope;
42/*!
43 * \brief Base class of storage access analysis
44 */
45class StorageAccessVisitor : public StmtExprVisitor {
46 public:
47 /*! \brief Storage access type */
48 enum AccessType {
49 kRead,
50 kWrite,
51 kSync,
52 kAlloc,
53 // acquired version of read, only need to handle WAR dep.
54 kReadAcquire
55 };
56 /*! \brief An access entry */
57 struct AccessEntry {
58 /*! \brief The thread index that access this entry */
59 Array<IterVar> threads;
60 /*! \brief The buffer variable, if any */
61 Var buffer = NullValue<Var>();
62 /*! \brief The access data type */
63 DataType dtype;
64 /*! \brief The touched access range
65 *
66 * Has one IntSet for each index in the buffer being accessed.
67 */
68 Array<arith::IntSet> touched;
69 /*! \brief The type of access */
70 AccessType type;
71 /*! \brief The storage scope */
72 StorageScope scope;
73 /*! \brief Whether the access is double buffer write */
74 bool double_buffer_write = false;
75 };
76 /*! \brief Access pattern about a single statement */
77 struct StmtEntry {
78 /*! \brief The statement */
79 const Object* stmt;
80 /*! \brief access patterns in the statement */
81 std::vector<AccessEntry> access;
82 };
83 // override visitor pattern
84 void VisitExpr_(const LoadNode* op) final;
85 void VisitStmt_(const StoreNode* op) final;
86 void VisitExpr_(const BufferLoadNode* op) final;
87 void VisitStmt_(const BufferStoreNode* op) final;
88 void VisitStmt_(const EvaluateNode* op) final;
89 void VisitStmt_(const AttrStmtNode* op) final;
90 void VisitStmt_(const ForNode* op) final;
91 void VisitStmt_(const IfThenElseNode* op) final;
92 void VisitStmt_(const WhileNode* op) final;
93 void VisitExpr_(const CallNode* op) final;
94
95 protected:
96 StorageAccessVisitor() { scope_.push_back(std::vector<StmtEntry>()); }
97 /*! \return number of conditions in the current scope. */
98 int condition_counter() const { return condition_counter_; }
99 /*! \return whether we are in device environment. */
100 bool in_device_env() const { return in_device_env_; }
101 /*! \return environment threads */
102 const Array<IterVar>& env_threads() const { return env_threads_; }
103 /*!
104 * \brief Whether we need analyze the buffer in current scope.
105 * \param buffer The buffer to be checked
106 * \param scope The scope of the buffer.
107 * \return Whether the analysis of buffer is enabled.
108 */
109 virtual bool Enabled(const VarNode* buffer, const StorageScope& scope) const { return true; }
110 /*!
111 * \brief Summarize the sequence of operations into parent.
112 *
113 * Insert synchronization if necessary and remove un-necessary
114 * memory access which are already synced.
115 *
116 * \param seq The sequence of the access operations.
117 * \param loop Pass loop node if it is a loop, otherwise nullptr.
118 * \return The summarized sequence that represent access that
119 * the parent should taken care of to synchronize.
120 */
121 virtual std::vector<AccessEntry> Summarize(std::vector<StmtEntry> seq, const ForNode* loop) = 0;
122 /*!
123 * \brief Get the scope of the buffer array.
124 * \return The scope of the final buffer array.
125 */
126 StorageScope GetScope(Var buffer_var) const;
127 // access scope
128 std::vector<std::vector<StmtEntry>> scope_;
129
130 private:
131 // whether access appending is enabled.
132 bool allow_append_{false};
133 // Whether we are in device environment
134 bool in_device_env_{false};
135 // Whether we are inside condition.
136 int condition_counter_{0};
137 // The current double buffer write scope.
138 const VarNode* double_buffer_write_{nullptr};
139 // the current free stmt entry.
140 StmtEntry curr_stmt_;
141 // The involving threads
142 Array<IterVar> env_threads_;
143};
144
145} // namespace tir
146} // namespace tvm
147#endif // TVM_TIR_TRANSFORMS_STORAGE_ACCESS_H_
148