1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tir/analysis/usmp/unified_static_memory_planner.cc
22 * \brief This is the pass that integrates the USMP passes to
23 * a single composite pass.
24 */
25
26#include <tvm/relay/executor.h>
27#include <tvm/relay/runtime.h>
28#include <tvm/target/target.h>
29#include <tvm/tir/stmt_functor.h>
30#include <tvm/tir/transform.h>
31#include <tvm/tir/usmp/algorithms.h>
32#include <tvm/tir/usmp/analysis.h>
33#include <tvm/tir/usmp/transform.h>
34#include <tvm/tir/usmp/utils.h>
35
36#include <algorithm>
37#include <string>
38
39namespace tvm {
40
41TVM_REGISTER_PASS_CONFIG_OPTION(kUSMPEnableOption, Bool);
42TVM_REGISTER_PASS_CONFIG_OPTION(kUSMPAlgorithmOption, String);
43TVM_REGISTER_PASS_CONFIG_OPTION(kUSMPUseWorkspaceIO, Bool);
44TVM_REGISTER_PASS_CONFIG_OPTION(kUSMPCustomAlgorithmOption, String);
45
46namespace tir {
47namespace usmp {
48
49static constexpr const char* kDefaultAlgo = "greedy_by_size";
50
51static std::unordered_map<String, std::function<Map<BufferInfo, PoolAllocation>(
52 const Array<BufferInfo>&, const Integer&)>>
53 algorithms{{"greedy_by_size", algo::GreedyBySize},
54 {"greedy_by_conflicts", algo::GreedyByConflicts},
55 {"hill_climb", algo::HillClimb}};
56
57IRModule PlanMemory(const IRModule& mod, String algo, bool use_workspace_io,
58 Optional<String> opt_custom_algo) {
59 VLOG(1) << "workspace required = " << CalculateModuleWorkspaceSize(mod);
60 IRModule module = mod->ShallowCopy();
61 if (use_workspace_io) {
62 module = transform::CreateAllocatesForIO()(module);
63 }
64 module = transform::AssignPoolInfo()(module);
65 PrimFunc main_func = Downcast<PrimFunc>(module->Lookup(::tvm::runtime::symbol::tvm_module_main));
66 BufferInfoAnalysis buffer_info_analysis = ExtractBufferInfo(main_func, module);
67 Array<BufferInfo> buffer_info_arr =
68 ConvertToArrayOfBufferInfo(buffer_info_analysis->buffer_info_stmts);
69 decltype(algorithms)::mapped_type algorithm;
70 if (opt_custom_algo) {
71 String algo_func_name = "tir.usmp.algo." + opt_custom_algo.value();
72 const runtime::PackedFunc* pfAlgo = runtime::Registry::Get(algo_func_name);
73 CHECK(pfAlgo) << "The selected custom USMP algorithm : " << opt_custom_algo.value()
74 << " is not defined. Please register it as " << algo_func_name;
75 algorithm = *pfAlgo;
76 } else {
77 CHECK(algorithms.count(algo))
78 << "The selected USMP algorithm : " << algo
79 << " is not defined. Please define it in the above algorithms map.";
80 algorithm = algorithms[algo];
81 }
82 Map<BufferInfo, PoolAllocation> buffer_info_pool_allocations =
83 algorithm(buffer_info_arr, buffer_info_analysis->memory_pressure);
84
85 Map<Stmt, PoolAllocation> stmt_pool_allocations = AssignStmtPoolAllocations(
86 buffer_info_analysis->buffer_info_stmts, buffer_info_pool_allocations);
87
88 module = transform::ConvertPoolAllocationsToOffsets(stmt_pool_allocations)(module);
89 if (use_workspace_io) {
90 Map<String, PoolAllocation> io_pool_allocations =
91 GetIOPoolAllocations(buffer_info_pool_allocations);
92 module = WithAttr(module, tvm::attr::kIOTensorPoolAllocations, io_pool_allocations);
93 }
94 tir::PrimFunc tir_main_func =
95 Downcast<tir::PrimFunc>(module->Lookup(::tvm::runtime::symbol::tvm_module_main));
96 Optional<Array<tir::usmp::AllocatedPoolInfo>> allocated_pool_infos =
97 tir_main_func->GetAttr<Array<tir::usmp::AllocatedPoolInfo>>(tvm::attr::kPoolArgs);
98 if (allocated_pool_infos) {
99 for (const tir::usmp::AllocatedPoolInfo& allocated_pool_info : allocated_pool_infos.value()) {
100 VLOG(1) << "pool_size = " << allocated_pool_info->allocated_size;
101 }
102 }
103 return module;
104}
105
106} // namespace usmp
107
108namespace transform {
109
110tvm::transform::Pass UnifiedStaticMemoryPlanner() {
111 auto usmp_main_pass_func = [=](IRModule m, tvm::transform::PassContext ctx) {
112 auto algorithm_str = ctx->GetConfig(kUSMPAlgorithmOption, String(usmp::kDefaultAlgo));
113 auto use_workspace_io = ctx->GetConfig(kUSMPUseWorkspaceIO, Bool(false));
114 auto custom_algorithm_str = ctx->GetConfig<String>(kUSMPCustomAlgorithmOption);
115 tvm::relay::Executor executor_config =
116 m->GetAttr<tvm::relay::Executor>(tvm::attr::kExecutor).value();
117 String interface_api = executor_config->GetAttr<String>("interface-api").value_or("packed");
118 tvm::relay::Runtime runtime_config =
119 m->GetAttr<tvm::relay::Runtime>(tvm::attr::kRuntime).value();
120 if (use_workspace_io.value()) {
121 CHECK(interface_api == "c") << kUSMPUseWorkspaceIO
122 << " option is only compatible with interface_api c.\n"
123 << "Please use interface_api c to be able to enable "
124 << kUSMPUseWorkspaceIO << "\n";
125 }
126 return Downcast<IRModule>(
127 usmp::PlanMemory(m, algorithm_str.value_or(String(usmp::kDefaultAlgo)),
128 use_workspace_io.value_or(Bool(false)), custom_algorithm_str));
129 };
130
131 return tvm::transform::CreateModulePass(usmp_main_pass_func, 0,
132 "tir.transform.UnifiedStaticMemoryPlanner", {});
133}
134
135TVM_REGISTER_GLOBAL("tir.transform.UnifiedStaticMemoryPlanner")
136 .set_body_typed(UnifiedStaticMemoryPlanner);
137
138} // namespace transform
139} // namespace tir
140} // namespace tvm
141