1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tir/ir/transform.cc |
22 | * \brief TIR specific transformation passes. |
23 | */ |
24 | #include <tvm/node/repr_printer.h> |
25 | #include <tvm/runtime/registry.h> |
26 | #include <tvm/tir/transform.h> |
27 | |
28 | namespace tvm { |
29 | namespace tir { |
30 | namespace transform { |
31 | |
32 | /*! |
33 | * \brief Function level pass that applies transformations to all |
34 | * TIR functions within the module. |
35 | */ |
36 | class PrimFuncPassNode : public PassNode { |
37 | public: |
38 | /* \brief The pass meta data.*/ |
39 | PassInfo pass_info; |
40 | |
41 | /*! \brief The pass function called on each. */ |
42 | runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)> pass_func; |
43 | |
44 | void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("pass_info" , &pass_info); } |
45 | |
46 | /*! |
47 | * \brief Run a function pass on given pass context. |
48 | * |
49 | * \param mod The module that an optimization pass is applied on. |
50 | * \param pass_ctx The context that an optimization pass executes on. |
51 | * |
52 | * \return Return the updated module. |
53 | */ |
54 | IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final; |
55 | |
56 | /*! |
57 | * \brief Get the pass information/meta data. |
58 | */ |
59 | PassInfo Info() const override { return pass_info; } |
60 | |
61 | static constexpr const char* _type_key = "tir.PrimFuncPass" ; |
62 | TVM_DECLARE_FINAL_OBJECT_INFO(PrimFuncPassNode, PassNode); |
63 | }; |
64 | |
65 | class PrimFuncPass : public Pass { |
66 | public: |
67 | /*! |
68 | * \brief The constructor |
69 | * \param pass_func The packed function which implements a pass. |
70 | * \param pass_info The pass info. |
71 | */ |
72 | TVM_DLL PrimFuncPass( |
73 | runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)> pass_func, |
74 | PassInfo pass_info); |
75 | |
76 | TVM_DEFINE_OBJECT_REF_METHODS(PrimFuncPass, Pass, PrimFuncPassNode); |
77 | }; |
78 | |
79 | PrimFuncPass::PrimFuncPass( |
80 | runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)> pass_func, |
81 | PassInfo pass_info) { |
82 | auto n = make_object<PrimFuncPassNode>(); |
83 | n->pass_func = std::move(pass_func); |
84 | n->pass_info = std::move(pass_info); |
85 | data_ = std::move(n); |
86 | } |
87 | |
88 | // Perform Module -> Module optimizations at the PrimFunc level. |
89 | IRModule PrimFuncPassNode::operator()(IRModule mod, const PassContext& pass_ctx) const { |
90 | ICHECK(mod.defined()); |
91 | std::vector<ObjectRef> deleted_list; |
92 | IRModuleNode* mod_ptr = mod.CopyOnWrite(); |
93 | auto* func_dict = mod_ptr->functions.CopyOnWrite(); |
94 | // directly loop over the underlying dict |
95 | for (auto& kv : *func_dict) { |
96 | // only picks up tir::PrimFunc |
97 | if (kv.second->IsInstance<PrimFuncNode>()) { |
98 | // move out the function so that it is the only copy. |
99 | PrimFunc func = Downcast<PrimFunc>(std::move(kv.second)); |
100 | func = pass_func(std::move(func), mod, pass_ctx); |
101 | kv.second = std::move(func); |
102 | |
103 | if (!kv.second.defined()) { |
104 | deleted_list.push_back(kv.first); |
105 | } |
106 | } |
107 | } |
108 | |
109 | // automatic removal of None |
110 | for (const auto& gv : deleted_list) { |
111 | func_dict->erase(gv); |
112 | } |
113 | return mod; |
114 | } |
115 | |
116 | Pass CreatePrimFuncPass( |
117 | const runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)>& pass_func, |
118 | int opt_level, String name, tvm::Array<String> required) { |
119 | PassInfo pass_info = PassInfo(opt_level, name, required); |
120 | return PrimFuncPass(pass_func, pass_info); |
121 | } |
122 | |
123 | TVM_REGISTER_NODE_TYPE(PrimFuncPassNode); |
124 | |
125 | TVM_REGISTER_GLOBAL("tir.transform.CreatePrimFuncPass" ) |
126 | .set_body_typed( |
127 | [](runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)> pass_func, |
128 | PassInfo pass_info) { return PrimFuncPass(pass_func, pass_info); }); |
129 | |
130 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
131 | .set_dispatch<PrimFuncPassNode>([](const ObjectRef& ref, ReprPrinter* p) { |
132 | auto* node = static_cast<const PrimFuncPassNode*>(ref.get()); |
133 | const PassInfo info = node->Info(); |
134 | p->stream << "PrimFuncPass(" << info->name << ", opt_level=" << info->opt_level << ")" ; |
135 | }); |
136 | |
137 | } // namespace transform |
138 | } // namespace tir |
139 | } // namespace tvm |
140 | |