1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file alter_op_layout.cc
22 * \brief Alternate the layouts of operators or replace primitive operators with
23 other expressions. This pass can be used for computing convolution in
24 custom layouts or other general weight pre-transformation.
25 */
26#include <tvm/relay/analysis.h>
27#include <tvm/relay/attrs/transform.h>
28#include <tvm/relay/op_attr_types.h>
29#include <tvm/relay/transform.h>
30#include <tvm/te/operation.h>
31
32#include <functional>
33#include <string>
34#include <tuple>
35#include <unordered_map>
36#include <utility>
37#include <vector>
38
39#include "pattern_utils.h"
40#include "transform_layout.h"
41
42namespace tvm {
43namespace relay {
44
45namespace alter_op_layout {
46
47/*!
48 * \brief Container to instantiate a Node for alter op layouts.
49 */
50class AlterTransformMemorizerNode : public TransformMemorizerNode {
51 public:
52 static constexpr const char* _type_key = "relay.alter_op_layout.AlterTransformMemorizerNode";
53
54 /*!
55 * \brief Defines the call transformation for AlterOpLayout pass. The new layouts are defined by
56 * used for different targets using a packed func.
57 * \param ref_call The original call.
58 * \param new_attrs Updated attributes consistent with new layouts.
59 * \param new_args The traversed/recursed args to the call.
60 * \return The new Call after calling the packed func.
61 */
62 Call CallWithNewLayouts(const Call& ref_call, Attrs new_attrs,
63 const std::vector<Expr>& new_args) override {
64 static auto falter_layout = Op::GetAttrMap<FTVMAlterOpLayout>("FTVMAlterOpLayout");
65 Op op = Downcast<Op>(ref_call->op);
66
67 Expr new_e;
68 bool modified = false;
69 if (falter_layout.count(op)) {
70 tvm::Array<tvm::te::Tensor> tinfos;
71 for (auto expr : ref_call->args) {
72 auto ttype = expr->type_as<TensorTypeNode>();
73 tinfos.push_back(tvm::te::placeholder(ttype->shape, ttype->dtype));
74 }
75 // TODO(@kevinthesun, @icemelon9): This won't work if inputs/outputs are dynamic shapes.
76 // Probably we need to disable the AlterOpLayout when compiling dynamic models.
77 Expr altered_value = falter_layout[op](new_attrs, new_args, tinfos, ref_call->checked_type());
78 if (altered_value.defined()) {
79 new_e = altered_value;
80 modified = true;
81 }
82 }
83 if (!modified) {
84 new_e = Call(ref_call->op, new_args, new_attrs);
85 }
86
87 const CallNode* new_call = new_e.as<CallNode>();
88 ICHECK(new_call) << "Can only replace the original operator with another call node";
89 return GetRef<Call>(new_call);
90 }
91
92 Call CallWithNewLayouts(const Call& ref_call, const std::vector<Expr>& new_args) override {
93 return CallWithNewLayouts(ref_call, ref_call->attrs, new_args);
94 }
95};
96
97/*!
98 * \brief Container that provides the transformation function for alter layout..
99 */
100class AlterTransformMemorizer : public TransformMemorizer {
101 public:
102 AlterTransformMemorizer() = default;
103 explicit AlterTransformMemorizer(ObjectPtr<Object> n) : TransformMemorizer(n) {}
104
105 AlterTransformMemorizerNode* operator->() {
106 return static_cast<AlterTransformMemorizerNode*>(get_mutable());
107 }
108
109 using ContainerType = AlterTransformMemorizerNode;
110};
111
112/*!
113 * Limitations:
114 * 1. The altered op should have the same number of arguments as the previous one.
115 * 2. Do not support nested tuple arguments.
116 */
117Expr AlterOpLayout(const Expr& expr) {
118 // TODO(@icemelon9): need to rerun type inference after applying an alter op.
119 AlterTransformMemorizer alter_memorizer(make_object<AlterTransformMemorizerNode>());
120 std::function<ObjectRef(const Call&)> fcontext = [=](const Call& call) -> ObjectRef {
121 return alter_memorizer;
122 };
123 FForwardRewrite rewrite_func = LayoutRewriter<AlterTransformMemorizer>;
124 return ForwardRewrite(expr, rewrite_func, fcontext);
125}
126
127} // namespace alter_op_layout
128
129namespace transform {
130
131Pass AlterOpLayout() {
132 runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
133 [=](Function f, IRModule m, PassContext pc) {
134 return Downcast<Function>(relay::alter_op_layout::AlterOpLayout(f));
135 };
136 return CreateFunctionPass(pass_func, 3, "AlterOpLayout", {"InferType"});
137}
138
139TVM_REGISTER_GLOBAL("relay._transform.AlterOpLayout").set_body_typed(AlterOpLayout);
140
141} // namespace transform
142
143} // namespace relay
144} // namespace tvm
145