1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file convert_op_layout.cc
22 * \brief Alternate the layouts of operators or replace primitive operators with
23 other expressions. This pass can be used for computing convolution in
24 custom layouts or other general weight pre-transformation.
25 */
26#include <tvm/relay/analysis.h>
27#include <tvm/relay/attrs/transform.h>
28#include <tvm/relay/op_attr_types.h>
29#include <tvm/relay/transform.h>
30#include <tvm/te/operation.h>
31
32#include <functional>
33#include <string>
34#include <tuple>
35#include <unordered_map>
36#include <utility>
37#include <vector>
38
39#include "pattern_utils.h"
40#include "transform_layout.h"
41
42namespace tvm {
43namespace relay {
44
45namespace convert_op_layout {
46
47/*!
48 * \brief Container for the transformations for ConvertLayout.
49 */
50class ConvertTransformMemorizerNode : public TransformMemorizerNode {
51 public:
52 /*!
53 * \brief Initializes the desired_layout.
54 * \param desired_layouts Specify mapping of op_name to array of desired layouts for each input.
55 * For example: Map("nn.conv2d", Array("NHWC", "OHWI")),
56 * this specifies the desired layout for data then kernel for nn.conv2d.
57 */
58 explicit ConvertTransformMemorizerNode(Map<String, Array<String>> desired_layouts)
59 : desired_layouts_(std::move(desired_layouts)) {}
60
61 /*!
62 * \brief Defines the call transformation for ConvertLayout pass. The new layouts should be the
63 * desired layout as specified by the user.
64 * \param ref_call The original call.
65 * \param new_attrs Updated attributes consistent with new layouts.
66 * \param new_args The traversed/recursed args to the call.
67 * \return The new Call after calling the packed func.
68 */
69 Call CallWithNewLayouts(const Call& ref_call, Attrs new_attrs,
70 const std::vector<Expr>& new_args) override {
71 static auto fconvert_layout = Op::GetAttrMap<FTVMConvertOpLayout>("FTVMConvertOpLayout");
72 Op op = Downcast<Op>(ref_call->op);
73 Expr new_e;
74 bool modified = false;
75 if (fconvert_layout.count(op)) {
76 auto desired_layouts = desired_layouts_;
77 if (desired_layouts.find(op->name) != desired_layouts.end()) {
78 tvm::Array<tvm::te::Tensor> tinfos;
79 for (auto& expr : ref_call->args) {
80 if (expr->checked_type()->IsInstance<TupleTypeNode>()) {
81 auto tuple_ttype_node = expr->type_as<TupleTypeNode>();
82 for (auto& ttype : tuple_ttype_node->fields) {
83 auto ttype_node = ttype.as<TensorTypeNode>();
84 tinfos.push_back(tvm::te::placeholder(ttype_node->shape, ttype_node->dtype));
85 }
86 } else {
87 auto ttype = expr->type_as<TensorTypeNode>();
88 tinfos.push_back(tvm::te::placeholder(ttype->shape, ttype->dtype));
89 }
90 }
91
92 Array<String> op_desired_layouts = desired_layouts.at(op->name);
93 Expr altered_value = fconvert_layout[op](new_attrs, new_args, tinfos, op_desired_layouts);
94 if (altered_value.defined()) {
95 new_e = altered_value;
96 modified = true;
97 }
98 } else {
99 LOG(WARNING) << "Desired layout(s) not specified for op: " << op->name;
100 }
101 }
102 if (!modified) {
103 new_e = Call(ref_call->op, new_args, new_attrs);
104 }
105
106 const CallNode* new_call = new_e.as<CallNode>();
107 ICHECK(new_call) << "Can only replace the original operator with another call node";
108 return Call(new_call->op, new_call->args, new_call->attrs, new_call->type_args, ref_call->span);
109 }
110
111 Call CallWithNewLayouts(const Call& ref_call, const std::vector<Expr>& new_args) override {
112 return CallWithNewLayouts(ref_call, ref_call->attrs, new_args);
113 }
114
115 /*! \brief A mapping of op_name to array of desired layouts for each input. */
116 Map<String, Array<String>> desired_layouts_;
117};
118
119/*!
120 * \brief Container that provides the transformation function for convert layout.
121 */
122class ConvertTransformMemorizer : public TransformMemorizer {
123 public:
124 ConvertTransformMemorizer() = default;
125 explicit ConvertTransformMemorizer(ObjectPtr<Object> n) : TransformMemorizer(n) {}
126
127 ConvertTransformMemorizerNode* operator->() {
128 return static_cast<ConvertTransformMemorizerNode*>(get_mutable());
129 }
130
131 using ContainerType = ConvertTransformMemorizerNode;
132};
133
134/*!
135 * Limitations:
136 * 1. The altered op should have the same number of arguments as the previous one.
137 * 2. Do not support nested tuple arguments.
138 */
139Expr ConvertLayout(const Expr& expr, const Map<String, Array<String>>& desired_layouts) {
140 ConvertTransformMemorizer transformMemorizer(
141 make_object<ConvertTransformMemorizerNode>(desired_layouts));
142 auto fcontext = [&](const Call& call) -> ObjectRef { return transformMemorizer; };
143
144 return ForwardRewrite(expr, LayoutRewriter<ConvertTransformMemorizer>, fcontext);
145}
146
147} // namespace convert_op_layout
148
149namespace transform {
150
151Pass ConvertLayout(const Map<String, Array<String>>& desired_layouts) {
152 runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
153 [=](Function f, IRModule m, PassContext pc) {
154 return Downcast<Function>(relay::convert_op_layout::ConvertLayout(f, desired_layouts));
155 };
156 return CreateFunctionPass(pass_func, 3, "ConvertLayout", {"InferType", "CanonicalizeOps"});
157}
158
159TVM_REGISTER_GLOBAL("relay._transform.ConvertLayout").set_body_typed(ConvertLayout);
160
161TVM_REGISTER_GLOBAL("relay._transform.InferCorrectLayoutOutput")
162 .set_body_typed([](Array<Layout> input_layouts, Array<Layout> output_layouts, Attrs new_attrs) {
163 return InferCorrectLayoutOutput(input_layouts, output_layouts, new_attrs);
164 });
165
166TVM_REGISTER_NODE_TYPE(InferCorrectLayoutOutputNode);
167
168} // namespace transform
169
170} // namespace relay
171} // namespace tvm
172