1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file legalize.cc
22 * \brief Converts an expr to another expr. This pass can be used to transform an op based on its
23 * shape, dtype or layout to another op or a sequence of ops.
24 */
25
26#include <tvm/relay/expr_functor.h>
27#include <tvm/relay/op_attr_types.h>
28#include <tvm/relay/transform.h>
29#include <tvm/te/operation.h>
30
31namespace tvm {
32namespace relay {
33
34namespace legalize {
35
36// Call registered FTVMLegalize of an op
37// Returns the legalized expression
38class Legalizer : public ExprRewriter {
39 public:
40 explicit Legalizer(const std::string& legalize_map_attr_name)
41 : legalize_map_attr_name_{legalize_map_attr_name} {}
42
43 Expr Rewrite_(const CallNode* call_node, const Expr& post) override {
44 // Get the new_call node without any changes to current call node.
45 Call new_call = Downcast<Call>(post);
46
47 // Check if the string is registered.
48 if (!Op::HasAttrMap(legalize_map_attr_name_)) {
49 return post;
50 }
51
52 // Collect the registered legalize function.
53 auto fop_legalize = Op::GetAttrMap<FTVMLegalize>(legalize_map_attr_name_);
54 auto call_op = call_node->op;
55 if (call_op.as<OpNode>()) {
56 Op op = Downcast<Op>(call_node->op);
57
58 if (fop_legalize.count(op)) {
59 // Collect the new_args.
60 tvm::Array<Expr> call_args = new_call->args;
61
62 // Collect input and output dtypes to pass on to Legalize API.
63 tvm::Array<tvm::relay::Type> types;
64 for (auto arg : call_node->args) {
65 types.push_back(arg->checked_type());
66 }
67 types.push_back(call_node->checked_type());
68
69 // Transform the op by calling the registered legalize function.
70 Expr legalized_value = fop_legalize[op](call_node->attrs, call_args, types);
71
72 // Return the new expr if the transformation succeeded.
73 if (legalized_value.defined()) {
74 // Check that the returned Expr from legalize is CallNode.
75 const CallNode* legalized_call_node = legalized_value.as<CallNode>();
76 ICHECK(legalized_call_node)
77 << "Can only replace the original operator with another call node";
78 return legalized_value;
79 }
80 }
81 }
82
83 return post;
84 }
85
86 private:
87 std::string legalize_map_attr_name_;
88};
89
90Expr Legalize(const Expr& expr, const std::string& legalize_map_attr_name) {
91 auto rewriter = Legalizer(legalize_map_attr_name);
92 return PostOrderRewrite(expr, &rewriter);
93}
94
95} // namespace legalize
96
97namespace transform {
98
99Pass Legalize(const String& legalize_map_attr_name) {
100 runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
101 [=](Function f, IRModule m, PassContext pc) {
102 return Downcast<Function>(relay::legalize::Legalize(f, legalize_map_attr_name));
103 };
104 return CreateFunctionPass(pass_func, 1, "Legalize", {"InferType"});
105}
106
107TVM_REGISTER_GLOBAL("relay._transform.Legalize").set_body_typed(Legalize);
108
109} // namespace transform
110
111} // namespace relay
112} // namespace tvm
113