1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tir/analysis/calculate_allocated_memory.cc
22 * \brief Calculate allocated memory per memory scope required by PrimFuncs.
23 */
24#include <tvm/arith/analyzer.h>
25#include <tvm/runtime/container/map.h>
26#include <tvm/runtime/device_api.h>
27#include <tvm/tir/analysis.h>
28#include <tvm/tir/function.h>
29#include <tvm/tir/stmt_functor.h>
30#include <tvm/tir/usmp/utils.h>
31
32#include <algorithm>
33#include <map>
34#include <unordered_map>
35
36namespace tvm {
37namespace tir {
38
39template <typename T>
40class AllocationCalculator : public StmtExprVisitor {
41 public:
42 AllocationCalculator() = default;
43 tvm::Map<String, Integer> operator()(const PrimFunc& func);
44
45 private:
46 void VisitStmt_(const T* op) override;
47 std::unordered_map<std::string, int64_t> _max_size;
48 std::unordered_map<std::string, int64_t> _current_size;
49};
50
51template <typename T>
52tvm::Map<String, Integer> AllocationCalculator<T>::operator()(const PrimFunc& func) {
53 this->VisitStmt(func->body);
54 tvm::Map<String, Integer> res;
55 for (auto [k, v] : _max_size) {
56 res.Set(String(k), Integer(v));
57 }
58 return res;
59}
60
61std::string GetStorageScope(const Var& var) {
62 auto* ptr = var->type_annotation.as<PointerTypeNode>();
63 ICHECK(ptr) << "Buffer Var's type annotation must be of PointerType";
64 return ptr->storage_scope;
65}
66
67template <typename T>
68void AllocationCalculator<T>::VisitStmt_(const T* op) {
69 std::string storage_scope = GetStorageScope(op->buffer_var);
70 auto search = _current_size.find(storage_scope);
71 if (search == _current_size.end()) {
72 _current_size[storage_scope] = 0;
73 _max_size[storage_scope] = 0;
74 }
75 auto size = op->ConstantAllocationSize() * op->dtype.bytes() * op->dtype.lanes();
76 _current_size[storage_scope] += size;
77 _max_size[storage_scope] = std::max(_current_size[storage_scope], _max_size[storage_scope]);
78 StmtExprVisitor::VisitStmt(op->body);
79 _current_size[storage_scope] -= size;
80}
81
82tvm::Map<String, Integer> CalculateAllocatedBytes(const PrimFunc& func) {
83 return AllocationCalculator<AllocateNode>()(func);
84}
85
86TVM_REGISTER_GLOBAL("tir.analysis.calculate_allocated_bytes").set_body_typed([](PrimFunc func) {
87 return CalculateAllocatedBytes(func);
88});
89
90bool VerifyVTCMLimit(const PrimFunc& func, Integer limit) {
91 auto sizes = CalculateAllocatedBytes(func);
92 const auto vtcm_allocated = sizes.Get("global.vtcm").value_or(0);
93 if (limit.IntValue() > 0 && vtcm_allocated.IntValue() > limit.IntValue()) {
94 return false;
95 }
96 return true;
97}
98
99namespace transform {
100
101Pass VerifyVTCMLimit(const Integer& limit) {
102 auto pass_func = [=](IRModule mod, PassContext ctx) {
103 for (auto kv : mod->functions) {
104 if (auto* n = kv.second.as<PrimFuncNode>()) {
105 auto func = GetRef<PrimFunc>(n);
106 auto sizes = CalculateAllocatedBytes(func);
107 const auto vtcm_allocated = sizes.Get("global.vtcm").value_or(0);
108 if (limit.IntValue() > 0 && vtcm_allocated.IntValue() > limit.IntValue()) {
109 LOG(FATAL) << "RuntimeError: The global.vtcm memory allocation limit has been "
110 "exceeded(allocated: "
111 << vtcm_allocated << ", limit: " << limit << ").\n"
112 << "In function\n"
113 << func;
114 }
115 }
116 }
117 return mod;
118 };
119 return tvm::transform::CreateModulePass(pass_func, 0, "tir.calculate_allocated_bytes", {});
120}
121
122TVM_REGISTER_GLOBAL("tir.transform.VerifyVTCMLimit").set_body_typed(VerifyVTCMLimit);
123
124} // namespace transform
125} // namespace tir
126} // namespace tvm
127