1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tir/analysis/usmp/unified_static_memory_planner.cc |
22 | * \brief This is the pass that integrates the USMP passes to |
23 | * a single composite pass. |
24 | */ |
25 | |
26 | #include <tvm/relay/executor.h> |
27 | #include <tvm/relay/runtime.h> |
28 | #include <tvm/target/target.h> |
29 | #include <tvm/tir/stmt_functor.h> |
30 | #include <tvm/tir/transform.h> |
31 | #include <tvm/tir/usmp/algorithms.h> |
32 | #include <tvm/tir/usmp/analysis.h> |
33 | #include <tvm/tir/usmp/transform.h> |
34 | #include <tvm/tir/usmp/utils.h> |
35 | |
36 | #include <algorithm> |
37 | #include <string> |
38 | |
39 | namespace tvm { |
40 | |
41 | TVM_REGISTER_PASS_CONFIG_OPTION(kUSMPEnableOption, Bool); |
42 | TVM_REGISTER_PASS_CONFIG_OPTION(kUSMPAlgorithmOption, String); |
43 | TVM_REGISTER_PASS_CONFIG_OPTION(kUSMPUseWorkspaceIO, Bool); |
44 | TVM_REGISTER_PASS_CONFIG_OPTION(kUSMPCustomAlgorithmOption, String); |
45 | |
46 | namespace tir { |
47 | namespace usmp { |
48 | |
49 | static constexpr const char* kDefaultAlgo = "greedy_by_size" ; |
50 | |
51 | static std::unordered_map<String, std::function<Map<BufferInfo, PoolAllocation>( |
52 | const Array<BufferInfo>&, const Integer&)>> |
53 | algorithms{{"greedy_by_size" , algo::GreedyBySize}, |
54 | {"greedy_by_conflicts" , algo::GreedyByConflicts}, |
55 | {"hill_climb" , algo::HillClimb}}; |
56 | |
57 | IRModule PlanMemory(const IRModule& mod, String algo, bool use_workspace_io, |
58 | Optional<String> opt_custom_algo) { |
59 | VLOG(1) << "workspace required = " << CalculateModuleWorkspaceSize(mod); |
60 | IRModule module = mod->ShallowCopy(); |
61 | if (use_workspace_io) { |
62 | module = transform::CreateAllocatesForIO()(module); |
63 | } |
64 | module = transform::AssignPoolInfo()(module); |
65 | PrimFunc main_func = Downcast<PrimFunc>(module->Lookup(::tvm::runtime::symbol::tvm_module_main)); |
66 | BufferInfoAnalysis buffer_info_analysis = ExtractBufferInfo(main_func, module); |
67 | Array<BufferInfo> buffer_info_arr = |
68 | ConvertToArrayOfBufferInfo(buffer_info_analysis->buffer_info_stmts); |
69 | decltype(algorithms)::mapped_type algorithm; |
70 | if (opt_custom_algo) { |
71 | String algo_func_name = "tir.usmp.algo." + opt_custom_algo.value(); |
72 | const runtime::PackedFunc* pfAlgo = runtime::Registry::Get(algo_func_name); |
73 | CHECK(pfAlgo) << "The selected custom USMP algorithm : " << opt_custom_algo.value() |
74 | << " is not defined. Please register it as " << algo_func_name; |
75 | algorithm = *pfAlgo; |
76 | } else { |
77 | CHECK(algorithms.count(algo)) |
78 | << "The selected USMP algorithm : " << algo |
79 | << " is not defined. Please define it in the above algorithms map." ; |
80 | algorithm = algorithms[algo]; |
81 | } |
82 | Map<BufferInfo, PoolAllocation> buffer_info_pool_allocations = |
83 | algorithm(buffer_info_arr, buffer_info_analysis->memory_pressure); |
84 | |
85 | Map<Stmt, PoolAllocation> stmt_pool_allocations = AssignStmtPoolAllocations( |
86 | buffer_info_analysis->buffer_info_stmts, buffer_info_pool_allocations); |
87 | |
88 | module = transform::ConvertPoolAllocationsToOffsets(stmt_pool_allocations)(module); |
89 | if (use_workspace_io) { |
90 | Map<String, PoolAllocation> io_pool_allocations = |
91 | GetIOPoolAllocations(buffer_info_pool_allocations); |
92 | module = WithAttr(module, tvm::attr::kIOTensorPoolAllocations, io_pool_allocations); |
93 | } |
94 | tir::PrimFunc tir_main_func = |
95 | Downcast<tir::PrimFunc>(module->Lookup(::tvm::runtime::symbol::tvm_module_main)); |
96 | Optional<Array<tir::usmp::AllocatedPoolInfo>> allocated_pool_infos = |
97 | tir_main_func->GetAttr<Array<tir::usmp::AllocatedPoolInfo>>(tvm::attr::kPoolArgs); |
98 | if (allocated_pool_infos) { |
99 | for (const tir::usmp::AllocatedPoolInfo& allocated_pool_info : allocated_pool_infos.value()) { |
100 | VLOG(1) << "pool_size = " << allocated_pool_info->allocated_size; |
101 | } |
102 | } |
103 | return module; |
104 | } |
105 | |
106 | } // namespace usmp |
107 | |
108 | namespace transform { |
109 | |
110 | tvm::transform::Pass UnifiedStaticMemoryPlanner() { |
111 | auto usmp_main_pass_func = [=](IRModule m, tvm::transform::PassContext ctx) { |
112 | auto algorithm_str = ctx->GetConfig(kUSMPAlgorithmOption, String(usmp::kDefaultAlgo)); |
113 | auto use_workspace_io = ctx->GetConfig(kUSMPUseWorkspaceIO, Bool(false)); |
114 | auto custom_algorithm_str = ctx->GetConfig<String>(kUSMPCustomAlgorithmOption); |
115 | tvm::relay::Executor executor_config = |
116 | m->GetAttr<tvm::relay::Executor>(tvm::attr::kExecutor).value(); |
117 | String interface_api = executor_config->GetAttr<String>("interface-api" ).value_or("packed" ); |
118 | tvm::relay::Runtime runtime_config = |
119 | m->GetAttr<tvm::relay::Runtime>(tvm::attr::kRuntime).value(); |
120 | if (use_workspace_io.value()) { |
121 | CHECK(interface_api == "c" ) << kUSMPUseWorkspaceIO |
122 | << " option is only compatible with interface_api c.\n" |
123 | << "Please use interface_api c to be able to enable " |
124 | << kUSMPUseWorkspaceIO << "\n" ; |
125 | } |
126 | return Downcast<IRModule>( |
127 | usmp::PlanMemory(m, algorithm_str.value_or(String(usmp::kDefaultAlgo)), |
128 | use_workspace_io.value_or(Bool(false)), custom_algorithm_str)); |
129 | }; |
130 | |
131 | return tvm::transform::CreateModulePass(usmp_main_pass_func, 0, |
132 | "tir.transform.UnifiedStaticMemoryPlanner" , {}); |
133 | } |
134 | |
135 | TVM_REGISTER_GLOBAL("tir.transform.UnifiedStaticMemoryPlanner" ) |
136 | .set_body_typed(UnifiedStaticMemoryPlanner); |
137 | |
138 | } // namespace transform |
139 | } // namespace tir |
140 | } // namespace tvm |
141 | |