1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file convert_op_layout.cc |
22 | * \brief Alternate the layouts of operators or replace primitive operators with |
23 | other expressions. This pass can be used for computing convolution in |
24 | custom layouts or other general weight pre-transformation. |
25 | */ |
26 | #include <tvm/relay/analysis.h> |
27 | #include <tvm/relay/attrs/transform.h> |
28 | #include <tvm/relay/op_attr_types.h> |
29 | #include <tvm/relay/transform.h> |
30 | #include <tvm/te/operation.h> |
31 | |
32 | #include <functional> |
33 | #include <string> |
34 | #include <tuple> |
35 | #include <unordered_map> |
36 | #include <utility> |
37 | #include <vector> |
38 | |
39 | #include "pattern_utils.h" |
40 | #include "transform_layout.h" |
41 | |
42 | namespace tvm { |
43 | namespace relay { |
44 | |
45 | namespace convert_op_layout { |
46 | |
47 | /*! |
48 | * \brief Container for the transformations for ConvertLayout. |
49 | */ |
50 | class ConvertTransformMemorizerNode : public TransformMemorizerNode { |
51 | public: |
52 | /*! |
53 | * \brief Initializes the desired_layout. |
54 | * \param desired_layouts Specify mapping of op_name to array of desired layouts for each input. |
55 | * For example: Map("nn.conv2d", Array("NHWC", "OHWI")), |
56 | * this specifies the desired layout for data then kernel for nn.conv2d. |
57 | */ |
58 | explicit ConvertTransformMemorizerNode(Map<String, Array<String>> desired_layouts) |
59 | : desired_layouts_(std::move(desired_layouts)) {} |
60 | |
61 | /*! |
62 | * \brief Defines the call transformation for ConvertLayout pass. The new layouts should be the |
63 | * desired layout as specified by the user. |
64 | * \param ref_call The original call. |
65 | * \param new_attrs Updated attributes consistent with new layouts. |
66 | * \param new_args The traversed/recursed args to the call. |
67 | * \return The new Call after calling the packed func. |
68 | */ |
69 | Call CallWithNewLayouts(const Call& ref_call, Attrs new_attrs, |
70 | const std::vector<Expr>& new_args) override { |
71 | static auto fconvert_layout = Op::GetAttrMap<FTVMConvertOpLayout>("FTVMConvertOpLayout" ); |
72 | Op op = Downcast<Op>(ref_call->op); |
73 | Expr new_e; |
74 | bool modified = false; |
75 | if (fconvert_layout.count(op)) { |
76 | auto desired_layouts = desired_layouts_; |
77 | if (desired_layouts.find(op->name) != desired_layouts.end()) { |
78 | tvm::Array<tvm::te::Tensor> tinfos; |
79 | for (auto& expr : ref_call->args) { |
80 | if (expr->checked_type()->IsInstance<TupleTypeNode>()) { |
81 | auto tuple_ttype_node = expr->type_as<TupleTypeNode>(); |
82 | for (auto& ttype : tuple_ttype_node->fields) { |
83 | auto ttype_node = ttype.as<TensorTypeNode>(); |
84 | tinfos.push_back(tvm::te::placeholder(ttype_node->shape, ttype_node->dtype)); |
85 | } |
86 | } else { |
87 | auto ttype = expr->type_as<TensorTypeNode>(); |
88 | tinfos.push_back(tvm::te::placeholder(ttype->shape, ttype->dtype)); |
89 | } |
90 | } |
91 | |
92 | Array<String> op_desired_layouts = desired_layouts.at(op->name); |
93 | Expr altered_value = fconvert_layout[op](new_attrs, new_args, tinfos, op_desired_layouts); |
94 | if (altered_value.defined()) { |
95 | new_e = altered_value; |
96 | modified = true; |
97 | } |
98 | } else { |
99 | LOG(WARNING) << "Desired layout(s) not specified for op: " << op->name; |
100 | } |
101 | } |
102 | if (!modified) { |
103 | new_e = Call(ref_call->op, new_args, new_attrs); |
104 | } |
105 | |
106 | const CallNode* new_call = new_e.as<CallNode>(); |
107 | ICHECK(new_call) << "Can only replace the original operator with another call node" ; |
108 | return Call(new_call->op, new_call->args, new_call->attrs, new_call->type_args, ref_call->span); |
109 | } |
110 | |
111 | Call CallWithNewLayouts(const Call& ref_call, const std::vector<Expr>& new_args) override { |
112 | return CallWithNewLayouts(ref_call, ref_call->attrs, new_args); |
113 | } |
114 | |
115 | /*! \brief A mapping of op_name to array of desired layouts for each input. */ |
116 | Map<String, Array<String>> desired_layouts_; |
117 | }; |
118 | |
119 | /*! |
120 | * \brief Container that provides the transformation function for convert layout. |
121 | */ |
122 | class ConvertTransformMemorizer : public TransformMemorizer { |
123 | public: |
124 | ConvertTransformMemorizer() = default; |
125 | explicit ConvertTransformMemorizer(ObjectPtr<Object> n) : TransformMemorizer(n) {} |
126 | |
127 | ConvertTransformMemorizerNode* operator->() { |
128 | return static_cast<ConvertTransformMemorizerNode*>(get_mutable()); |
129 | } |
130 | |
131 | using ContainerType = ConvertTransformMemorizerNode; |
132 | }; |
133 | |
134 | /*! |
135 | * Limitations: |
136 | * 1. The altered op should have the same number of arguments as the previous one. |
137 | * 2. Do not support nested tuple arguments. |
138 | */ |
139 | Expr ConvertLayout(const Expr& expr, const Map<String, Array<String>>& desired_layouts) { |
140 | ConvertTransformMemorizer transformMemorizer( |
141 | make_object<ConvertTransformMemorizerNode>(desired_layouts)); |
142 | auto fcontext = [&](const Call& call) -> ObjectRef { return transformMemorizer; }; |
143 | |
144 | return ForwardRewrite(expr, LayoutRewriter<ConvertTransformMemorizer>, fcontext); |
145 | } |
146 | |
147 | } // namespace convert_op_layout |
148 | |
149 | namespace transform { |
150 | |
151 | Pass ConvertLayout(const Map<String, Array<String>>& desired_layouts) { |
152 | runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func = |
153 | [=](Function f, IRModule m, PassContext pc) { |
154 | return Downcast<Function>(relay::convert_op_layout::ConvertLayout(f, desired_layouts)); |
155 | }; |
156 | return CreateFunctionPass(pass_func, 3, "ConvertLayout" , {"InferType" , "CanonicalizeOps" }); |
157 | } |
158 | |
159 | TVM_REGISTER_GLOBAL("relay._transform.ConvertLayout" ).set_body_typed(ConvertLayout); |
160 | |
161 | TVM_REGISTER_GLOBAL("relay._transform.InferCorrectLayoutOutput" ) |
162 | .set_body_typed([](Array<Layout> input_layouts, Array<Layout> output_layouts, Attrs new_attrs) { |
163 | return InferCorrectLayoutOutput(input_layouts, output_layouts, new_attrs); |
164 | }); |
165 | |
166 | TVM_REGISTER_NODE_TYPE(InferCorrectLayoutOutputNode); |
167 | |
168 | } // namespace transform |
169 | |
170 | } // namespace relay |
171 | } // namespace tvm |
172 | |