1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file bound_deducer.cc |
22 | * \brief Utility to deduce bound of expression |
23 | */ |
24 | #include <tvm/runtime/registry.h> |
25 | #include <tvm/te/tensor.h> |
26 | #include <tvm/tir/expr.h> |
27 | #include <tvm/tir/stmt_functor.h> |
28 | |
29 | #include <tuple> |
30 | #include <unordered_map> |
31 | #include <unordered_set> |
32 | |
33 | #include "ir_visitor_with_analyzer.h" |
34 | |
35 | namespace tvm { |
36 | namespace arith { |
37 | |
38 | using namespace tir; |
39 | |
40 | namespace { |
41 | |
42 | using BufferTouches = std::vector<std::vector<IntSet>>; |
43 | |
44 | struct LoadAccess { |
45 | BufferTouches set; |
46 | }; |
47 | |
48 | struct StoreAccess { |
49 | BufferTouches set; |
50 | }; |
51 | |
52 | struct CombinedAccess { |
53 | BufferTouches set; |
54 | }; |
55 | |
56 | using BufferDomainAccess = std::tuple<LoadAccess, StoreAccess, CombinedAccess>; |
57 | |
58 | } // namespace |
59 | |
60 | // Find Read region of the tensor in the stmt. |
61 | class BufferTouchedDomain final : public IRVisitorWithAnalyzer { |
62 | public: |
63 | BufferTouchedDomain(const Stmt& stmt) { operator()(stmt); } |
64 | |
65 | std::unordered_map<const BufferNode*, BufferDomainAccess>& GetAccessedBufferRegions() { |
66 | return buffer_access_map_; |
67 | } |
68 | |
69 | Region FindUnion(const Buffer& buffer, bool consider_loads, bool consider_stores) { |
70 | Region ret; |
71 | auto kv = buffer_access_map_.find(buffer.get()); |
72 | if (kv == buffer_access_map_.end()) { |
73 | LOG(WARNING) << "[arith::BufferDomainTouched] " |
74 | << "The requested buffer is not contained in the provided stmt body: " << buffer; |
75 | return ret; |
76 | } |
77 | |
78 | Range none; |
79 | BufferTouches bounds; |
80 | if (consider_loads && consider_stores) { |
81 | bounds = std::get<CombinedAccess>(kv->second).set; |
82 | } else if (consider_loads) { |
83 | bounds = std::get<LoadAccess>(kv->second).set; |
84 | } else if (consider_stores) { |
85 | bounds = std::get<StoreAccess>(kv->second).set; |
86 | } else { |
87 | CHECK(false) << "Must consider at least on of either loads and stores, but both are false" ; |
88 | } |
89 | for (size_t i = 0; i < bounds.size(); ++i) { |
90 | ret.push_back(arith::Union(bounds[i]).CoverRange(none)); |
91 | } |
92 | return ret; |
93 | } |
94 | |
95 | private: |
96 | using Parent = IRVisitorWithAnalyzer; |
97 | using Parent::VisitExpr_; |
98 | using Parent::VisitStmt_; |
99 | |
100 | void VisitExpr_(const BufferLoadNode* op) final { |
101 | // Record load-exclusive buffer access |
102 | Touch(&std::get<LoadAccess>(buffer_access_map_[op->buffer.get()]).set, op->indices); |
103 | // Record load-store inclusive buffer access |
104 | Touch(&std::get<CombinedAccess>(buffer_access_map_[op->buffer.get()]).set, op->indices); |
105 | Parent::VisitExpr_(op); |
106 | } |
107 | |
108 | void VisitStmt_(const BufferStoreNode* op) final { |
109 | // Record store-exclusive buffer access |
110 | Touch(&std::get<StoreAccess>(buffer_access_map_[op->buffer.get()]).set, op->indices); |
111 | // Record load-store inclusive buffer access |
112 | Touch(&std::get<CombinedAccess>(buffer_access_map_[op->buffer.get()]).set, op->indices); |
113 | Parent::VisitStmt_(op); |
114 | } |
115 | |
116 | private: |
117 | void Touch(BufferTouches* bounds, const Array<PrimExpr>& args) { |
118 | if (args.size() > bounds->size()) { |
119 | bounds->resize(args.size()); |
120 | } |
121 | for (size_t i = 0; i < args.size(); ++i) { |
122 | if (args[i].as<RampNode>()) { |
123 | (*bounds)[i].emplace_back(IntSet::Vector(args[i])); |
124 | } else { |
125 | (*bounds)[i].emplace_back(analyzer_.int_set(args[i])); |
126 | } |
127 | } |
128 | } |
129 | |
130 | std::unordered_map<const BufferNode*, BufferDomainAccess> buffer_access_map_; |
131 | }; |
132 | |
133 | Region DomainTouched(const Stmt& stmt, const Buffer& buffer, bool consider_loads, |
134 | bool consider_stores) { |
135 | return BufferTouchedDomain(stmt).FindUnion(buffer, consider_loads, consider_stores); |
136 | } |
137 | |
138 | Map<Buffer, runtime::ADT> DomainTouchedAccessMap(const PrimFunc& func) { |
139 | auto buffer_access_map = BufferTouchedDomain(func->body).GetAccessedBufferRegions(); |
140 | Map<Buffer, runtime::ADT> ret; |
141 | auto& buffer_map = func->buffer_map; |
142 | for (auto& var : func->params) { |
143 | auto& buffer = buffer_map[var]; |
144 | auto& access = buffer_access_map[buffer.get()]; |
145 | Array<Array<IntSet>> loads, stores, combined; |
146 | for (std::vector<IntSet>& touch : std::get<LoadAccess>(access).set) { |
147 | loads.push_back(Array<IntSet>(touch)); |
148 | } |
149 | for (std::vector<IntSet>& touch : std::get<StoreAccess>(access).set) { |
150 | stores.push_back(Array<IntSet>(touch)); |
151 | } |
152 | for (std::vector<IntSet>& touch : std::get<CombinedAccess>(access).set) { |
153 | combined.push_back(Array<IntSet>(touch)); |
154 | } |
155 | |
156 | std::vector<ObjectRef> fields; |
157 | fields.push_back(loads); |
158 | fields.push_back(stores); |
159 | fields.push_back(combined); |
160 | ret.Set(buffer, runtime::ADT::Tuple(fields)); |
161 | } |
162 | return ret; |
163 | } |
164 | |
165 | TVM_REGISTER_GLOBAL("arith.DomainTouched" ).set_body_typed(DomainTouched); |
166 | TVM_REGISTER_GLOBAL("arith.DomainTouchedAccessMap" ).set_body_typed(DomainTouchedAccessMap); |
167 | |
168 | } // namespace arith |
169 | } // namespace tvm |
170 | |