1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file legalize.cc |
22 | * \brief Converts an expr to another expr. This pass can be used to transform an op based on its |
23 | * shape, dtype or layout to another op or a sequence of ops. |
24 | */ |
25 | |
26 | #include <tvm/relay/expr_functor.h> |
27 | #include <tvm/relay/op_attr_types.h> |
28 | #include <tvm/relay/transform.h> |
29 | #include <tvm/te/operation.h> |
30 | |
31 | namespace tvm { |
32 | namespace relay { |
33 | |
34 | namespace legalize { |
35 | |
36 | // Call registered FTVMLegalize of an op |
37 | // Returns the legalized expression |
38 | class Legalizer : public ExprRewriter { |
39 | public: |
40 | explicit Legalizer(const std::string& legalize_map_attr_name) |
41 | : legalize_map_attr_name_{legalize_map_attr_name} {} |
42 | |
43 | Expr Rewrite_(const CallNode* call_node, const Expr& post) override { |
44 | // Get the new_call node without any changes to current call node. |
45 | Call new_call = Downcast<Call>(post); |
46 | |
47 | // Check if the string is registered. |
48 | if (!Op::HasAttrMap(legalize_map_attr_name_)) { |
49 | return post; |
50 | } |
51 | |
52 | // Collect the registered legalize function. |
53 | auto fop_legalize = Op::GetAttrMap<FTVMLegalize>(legalize_map_attr_name_); |
54 | auto call_op = call_node->op; |
55 | if (call_op.as<OpNode>()) { |
56 | Op op = Downcast<Op>(call_node->op); |
57 | |
58 | if (fop_legalize.count(op)) { |
59 | // Collect the new_args. |
60 | tvm::Array<Expr> call_args = new_call->args; |
61 | |
62 | // Collect input and output dtypes to pass on to Legalize API. |
63 | tvm::Array<tvm::relay::Type> types; |
64 | for (auto arg : call_node->args) { |
65 | types.push_back(arg->checked_type()); |
66 | } |
67 | types.push_back(call_node->checked_type()); |
68 | |
69 | // Transform the op by calling the registered legalize function. |
70 | Expr legalized_value = fop_legalize[op](call_node->attrs, call_args, types); |
71 | |
72 | // Return the new expr if the transformation succeeded. |
73 | if (legalized_value.defined()) { |
74 | // Check that the returned Expr from legalize is CallNode. |
75 | const CallNode* legalized_call_node = legalized_value.as<CallNode>(); |
76 | ICHECK(legalized_call_node) |
77 | << "Can only replace the original operator with another call node" ; |
78 | return legalized_value; |
79 | } |
80 | } |
81 | } |
82 | |
83 | return post; |
84 | } |
85 | |
86 | private: |
87 | std::string legalize_map_attr_name_; |
88 | }; |
89 | |
90 | Expr Legalize(const Expr& expr, const std::string& legalize_map_attr_name) { |
91 | auto rewriter = Legalizer(legalize_map_attr_name); |
92 | return PostOrderRewrite(expr, &rewriter); |
93 | } |
94 | |
95 | } // namespace legalize |
96 | |
97 | namespace transform { |
98 | |
99 | Pass Legalize(const String& legalize_map_attr_name) { |
100 | runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func = |
101 | [=](Function f, IRModule m, PassContext pc) { |
102 | return Downcast<Function>(relay::legalize::Legalize(f, legalize_map_attr_name)); |
103 | }; |
104 | return CreateFunctionPass(pass_func, 1, "Legalize" , {"InferType" }); |
105 | } |
106 | |
107 | TVM_REGISTER_GLOBAL("relay._transform.Legalize" ).set_body_typed(Legalize); |
108 | |
109 | } // namespace transform |
110 | |
111 | } // namespace relay |
112 | } // namespace tvm |
113 | |