1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file alter_op_layout.cc |
22 | * \brief Alternate the layouts of operators or replace primitive operators with |
23 | other expressions. This pass can be used for computing convolution in |
24 | custom layouts or other general weight pre-transformation. |
25 | */ |
26 | #include <tvm/relay/analysis.h> |
27 | #include <tvm/relay/attrs/transform.h> |
28 | #include <tvm/relay/op_attr_types.h> |
29 | #include <tvm/relay/transform.h> |
30 | #include <tvm/te/operation.h> |
31 | |
32 | #include <functional> |
33 | #include <string> |
34 | #include <tuple> |
35 | #include <unordered_map> |
36 | #include <utility> |
37 | #include <vector> |
38 | |
39 | #include "pattern_utils.h" |
40 | #include "transform_layout.h" |
41 | |
42 | namespace tvm { |
43 | namespace relay { |
44 | |
45 | namespace alter_op_layout { |
46 | |
47 | /*! |
48 | * \brief Container to instantiate a Node for alter op layouts. |
49 | */ |
50 | class AlterTransformMemorizerNode : public TransformMemorizerNode { |
51 | public: |
52 | static constexpr const char* _type_key = "relay.alter_op_layout.AlterTransformMemorizerNode" ; |
53 | |
54 | /*! |
55 | * \brief Defines the call transformation for AlterOpLayout pass. The new layouts are defined by |
56 | * used for different targets using a packed func. |
57 | * \param ref_call The original call. |
58 | * \param new_attrs Updated attributes consistent with new layouts. |
59 | * \param new_args The traversed/recursed args to the call. |
60 | * \return The new Call after calling the packed func. |
61 | */ |
62 | Call CallWithNewLayouts(const Call& ref_call, Attrs new_attrs, |
63 | const std::vector<Expr>& new_args) override { |
64 | static auto falter_layout = Op::GetAttrMap<FTVMAlterOpLayout>("FTVMAlterOpLayout" ); |
65 | Op op = Downcast<Op>(ref_call->op); |
66 | |
67 | Expr new_e; |
68 | bool modified = false; |
69 | if (falter_layout.count(op)) { |
70 | tvm::Array<tvm::te::Tensor> tinfos; |
71 | for (auto expr : ref_call->args) { |
72 | auto ttype = expr->type_as<TensorTypeNode>(); |
73 | tinfos.push_back(tvm::te::placeholder(ttype->shape, ttype->dtype)); |
74 | } |
75 | // TODO(@kevinthesun, @icemelon9): This won't work if inputs/outputs are dynamic shapes. |
76 | // Probably we need to disable the AlterOpLayout when compiling dynamic models. |
77 | Expr altered_value = falter_layout[op](new_attrs, new_args, tinfos, ref_call->checked_type()); |
78 | if (altered_value.defined()) { |
79 | new_e = altered_value; |
80 | modified = true; |
81 | } |
82 | } |
83 | if (!modified) { |
84 | new_e = Call(ref_call->op, new_args, new_attrs); |
85 | } |
86 | |
87 | const CallNode* new_call = new_e.as<CallNode>(); |
88 | ICHECK(new_call) << "Can only replace the original operator with another call node" ; |
89 | return GetRef<Call>(new_call); |
90 | } |
91 | |
92 | Call CallWithNewLayouts(const Call& ref_call, const std::vector<Expr>& new_args) override { |
93 | return CallWithNewLayouts(ref_call, ref_call->attrs, new_args); |
94 | } |
95 | }; |
96 | |
97 | /*! |
98 | * \brief Container that provides the transformation function for alter layout.. |
99 | */ |
100 | class AlterTransformMemorizer : public TransformMemorizer { |
101 | public: |
102 | AlterTransformMemorizer() = default; |
103 | explicit AlterTransformMemorizer(ObjectPtr<Object> n) : TransformMemorizer(n) {} |
104 | |
105 | AlterTransformMemorizerNode* operator->() { |
106 | return static_cast<AlterTransformMemorizerNode*>(get_mutable()); |
107 | } |
108 | |
109 | using ContainerType = AlterTransformMemorizerNode; |
110 | }; |
111 | |
112 | /*! |
113 | * Limitations: |
114 | * 1. The altered op should have the same number of arguments as the previous one. |
115 | * 2. Do not support nested tuple arguments. |
116 | */ |
117 | Expr AlterOpLayout(const Expr& expr) { |
118 | // TODO(@icemelon9): need to rerun type inference after applying an alter op. |
119 | AlterTransformMemorizer alter_memorizer(make_object<AlterTransformMemorizerNode>()); |
120 | std::function<ObjectRef(const Call&)> fcontext = [=](const Call& call) -> ObjectRef { |
121 | return alter_memorizer; |
122 | }; |
123 | FForwardRewrite rewrite_func = LayoutRewriter<AlterTransformMemorizer>; |
124 | return ForwardRewrite(expr, rewrite_func, fcontext); |
125 | } |
126 | |
127 | } // namespace alter_op_layout |
128 | |
129 | namespace transform { |
130 | |
131 | Pass AlterOpLayout() { |
132 | runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func = |
133 | [=](Function f, IRModule m, PassContext pc) { |
134 | return Downcast<Function>(relay::alter_op_layout::AlterOpLayout(f)); |
135 | }; |
136 | return CreateFunctionPass(pass_func, 3, "AlterOpLayout" , {"InferType" }); |
137 | } |
138 | |
139 | TVM_REGISTER_GLOBAL("relay._transform.AlterOpLayout" ).set_body_typed(AlterOpLayout); |
140 | |
141 | } // namespace transform |
142 | |
143 | } // namespace relay |
144 | } // namespace tvm |
145 | |