1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tir/ir/transform.cc
22 * \brief TIR specific transformation passes.
23 */
24#include <tvm/node/repr_printer.h>
25#include <tvm/runtime/registry.h>
26#include <tvm/tir/transform.h>
27
28namespace tvm {
29namespace tir {
30namespace transform {
31
32/*!
33 * \brief Function level pass that applies transformations to all
34 * TIR functions within the module.
35 */
36class PrimFuncPassNode : public PassNode {
37 public:
38 /* \brief The pass meta data.*/
39 PassInfo pass_info;
40
41 /*! \brief The pass function called on each. */
42 runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)> pass_func;
43
44 void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("pass_info", &pass_info); }
45
46 /*!
47 * \brief Run a function pass on given pass context.
48 *
49 * \param mod The module that an optimization pass is applied on.
50 * \param pass_ctx The context that an optimization pass executes on.
51 *
52 * \return Return the updated module.
53 */
54 IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final;
55
56 /*!
57 * \brief Get the pass information/meta data.
58 */
59 PassInfo Info() const override { return pass_info; }
60
61 static constexpr const char* _type_key = "tir.PrimFuncPass";
62 TVM_DECLARE_FINAL_OBJECT_INFO(PrimFuncPassNode, PassNode);
63};
64
65class PrimFuncPass : public Pass {
66 public:
67 /*!
68 * \brief The constructor
69 * \param pass_func The packed function which implements a pass.
70 * \param pass_info The pass info.
71 */
72 TVM_DLL PrimFuncPass(
73 runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)> pass_func,
74 PassInfo pass_info);
75
76 TVM_DEFINE_OBJECT_REF_METHODS(PrimFuncPass, Pass, PrimFuncPassNode);
77};
78
79PrimFuncPass::PrimFuncPass(
80 runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)> pass_func,
81 PassInfo pass_info) {
82 auto n = make_object<PrimFuncPassNode>();
83 n->pass_func = std::move(pass_func);
84 n->pass_info = std::move(pass_info);
85 data_ = std::move(n);
86}
87
88// Perform Module -> Module optimizations at the PrimFunc level.
89IRModule PrimFuncPassNode::operator()(IRModule mod, const PassContext& pass_ctx) const {
90 ICHECK(mod.defined());
91 std::vector<ObjectRef> deleted_list;
92 IRModuleNode* mod_ptr = mod.CopyOnWrite();
93 auto* func_dict = mod_ptr->functions.CopyOnWrite();
94 // directly loop over the underlying dict
95 for (auto& kv : *func_dict) {
96 // only picks up tir::PrimFunc
97 if (kv.second->IsInstance<PrimFuncNode>()) {
98 // move out the function so that it is the only copy.
99 PrimFunc func = Downcast<PrimFunc>(std::move(kv.second));
100 func = pass_func(std::move(func), mod, pass_ctx);
101 kv.second = std::move(func);
102
103 if (!kv.second.defined()) {
104 deleted_list.push_back(kv.first);
105 }
106 }
107 }
108
109 // automatic removal of None
110 for (const auto& gv : deleted_list) {
111 func_dict->erase(gv);
112 }
113 return mod;
114}
115
116Pass CreatePrimFuncPass(
117 const runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)>& pass_func,
118 int opt_level, String name, tvm::Array<String> required) {
119 PassInfo pass_info = PassInfo(opt_level, name, required);
120 return PrimFuncPass(pass_func, pass_info);
121}
122
123TVM_REGISTER_NODE_TYPE(PrimFuncPassNode);
124
125TVM_REGISTER_GLOBAL("tir.transform.CreatePrimFuncPass")
126 .set_body_typed(
127 [](runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)> pass_func,
128 PassInfo pass_info) { return PrimFuncPass(pass_func, pass_info); });
129
130TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
131 .set_dispatch<PrimFuncPassNode>([](const ObjectRef& ref, ReprPrinter* p) {
132 auto* node = static_cast<const PrimFuncPassNode*>(ref.get());
133 const PassInfo info = node->Info();
134 p->stream << "PrimFuncPass(" << info->name << ", opt_level=" << info->opt_level << ")";
135 });
136
137} // namespace transform
138} // namespace tir
139} // namespace tvm
140