1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tir/analysis/calculate_allocated_memory.cc |
22 | * \brief Calculate allocated memory per memory scope required by PrimFuncs. |
23 | */ |
24 | #include <tvm/arith/analyzer.h> |
25 | #include <tvm/runtime/container/map.h> |
26 | #include <tvm/runtime/device_api.h> |
27 | #include <tvm/tir/analysis.h> |
28 | #include <tvm/tir/function.h> |
29 | #include <tvm/tir/stmt_functor.h> |
30 | #include <tvm/tir/usmp/utils.h> |
31 | |
32 | #include <algorithm> |
33 | #include <map> |
34 | #include <unordered_map> |
35 | |
36 | namespace tvm { |
37 | namespace tir { |
38 | |
39 | template <typename T> |
40 | class AllocationCalculator : public StmtExprVisitor { |
41 | public: |
42 | AllocationCalculator() = default; |
43 | tvm::Map<String, Integer> operator()(const PrimFunc& func); |
44 | |
45 | private: |
46 | void VisitStmt_(const T* op) override; |
47 | std::unordered_map<std::string, int64_t> _max_size; |
48 | std::unordered_map<std::string, int64_t> _current_size; |
49 | }; |
50 | |
51 | template <typename T> |
52 | tvm::Map<String, Integer> AllocationCalculator<T>::operator()(const PrimFunc& func) { |
53 | this->VisitStmt(func->body); |
54 | tvm::Map<String, Integer> res; |
55 | for (auto [k, v] : _max_size) { |
56 | res.Set(String(k), Integer(v)); |
57 | } |
58 | return res; |
59 | } |
60 | |
61 | std::string GetStorageScope(const Var& var) { |
62 | auto* ptr = var->type_annotation.as<PointerTypeNode>(); |
63 | ICHECK(ptr) << "Buffer Var's type annotation must be of PointerType" ; |
64 | return ptr->storage_scope; |
65 | } |
66 | |
67 | template <typename T> |
68 | void AllocationCalculator<T>::VisitStmt_(const T* op) { |
69 | std::string storage_scope = GetStorageScope(op->buffer_var); |
70 | auto search = _current_size.find(storage_scope); |
71 | if (search == _current_size.end()) { |
72 | _current_size[storage_scope] = 0; |
73 | _max_size[storage_scope] = 0; |
74 | } |
75 | auto size = op->ConstantAllocationSize() * op->dtype.bytes() * op->dtype.lanes(); |
76 | _current_size[storage_scope] += size; |
77 | _max_size[storage_scope] = std::max(_current_size[storage_scope], _max_size[storage_scope]); |
78 | StmtExprVisitor::VisitStmt(op->body); |
79 | _current_size[storage_scope] -= size; |
80 | } |
81 | |
82 | tvm::Map<String, Integer> CalculateAllocatedBytes(const PrimFunc& func) { |
83 | return AllocationCalculator<AllocateNode>()(func); |
84 | } |
85 | |
86 | TVM_REGISTER_GLOBAL("tir.analysis.calculate_allocated_bytes" ).set_body_typed([](PrimFunc func) { |
87 | return CalculateAllocatedBytes(func); |
88 | }); |
89 | |
90 | bool VerifyVTCMLimit(const PrimFunc& func, Integer limit) { |
91 | auto sizes = CalculateAllocatedBytes(func); |
92 | const auto vtcm_allocated = sizes.Get("global.vtcm" ).value_or(0); |
93 | if (limit.IntValue() > 0 && vtcm_allocated.IntValue() > limit.IntValue()) { |
94 | return false; |
95 | } |
96 | return true; |
97 | } |
98 | |
99 | namespace transform { |
100 | |
101 | Pass VerifyVTCMLimit(const Integer& limit) { |
102 | auto pass_func = [=](IRModule mod, PassContext ctx) { |
103 | for (auto kv : mod->functions) { |
104 | if (auto* n = kv.second.as<PrimFuncNode>()) { |
105 | auto func = GetRef<PrimFunc>(n); |
106 | auto sizes = CalculateAllocatedBytes(func); |
107 | const auto vtcm_allocated = sizes.Get("global.vtcm" ).value_or(0); |
108 | if (limit.IntValue() > 0 && vtcm_allocated.IntValue() > limit.IntValue()) { |
109 | LOG(FATAL) << "RuntimeError: The global.vtcm memory allocation limit has been " |
110 | "exceeded(allocated: " |
111 | << vtcm_allocated << ", limit: " << limit << ").\n" |
112 | << "In function\n" |
113 | << func; |
114 | } |
115 | } |
116 | } |
117 | return mod; |
118 | }; |
119 | return tvm::transform::CreateModulePass(pass_func, 0, "tir.calculate_allocated_bytes" , {}); |
120 | } |
121 | |
122 | TVM_REGISTER_GLOBAL("tir.transform.VerifyVTCMLimit" ).set_body_typed(VerifyVTCMLimit); |
123 | |
124 | } // namespace transform |
125 | } // namespace tir |
126 | } // namespace tvm |
127 | |