1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file split_args.cc |
22 | */ |
23 | #include <tvm/relay/expr_functor.h> |
24 | #include <tvm/relay/transform.h> |
25 | |
26 | #include "../op/annotation/annotation.h" |
27 | #include "./pattern_utils.h" |
28 | |
29 | namespace tvm { |
30 | namespace relay { |
31 | |
32 | class ArgumentSplitter : public ExprRewriter { |
33 | public: |
34 | explicit ArgumentSplitter(int max_function_args) |
35 | : max_function_args_(max_function_args), concat_op_(Op::Get("concatenate" )) {} |
36 | |
37 | Expr Rewrite_(const CallNode* call, const Expr& post) final { |
38 | if (max_function_args_ < 0) return post; |
39 | if (call->op == concat_op_) { |
40 | auto tuple_node = call->args[0].as<TupleNode>(); |
41 | const auto param = call->attrs.as<ConcatenateAttrs>(); |
42 | int outputsNum = 1; |
43 | if (const auto* tuple_type = call->checked_type().as<TupleTypeNode>()) { |
44 | outputsNum = tuple_type->fields.size(); |
45 | } |
46 | const int limit = max_function_args_ - outputsNum; |
47 | int argsNum = tuple_node->fields.size(); |
48 | if (argsNum < limit) return post; |
49 | int splitNum = argsNum / limit; |
50 | splitNum = (argsNum % limit) ? splitNum + 1 : splitNum; |
51 | |
52 | std::vector<Expr> splitted(splitNum); |
53 | for (int i = 0; i < splitNum; ++i) { |
54 | int startIdx = i * limit; |
55 | int argsCount = std::min(limit, argsNum - startIdx); |
56 | tvm::Array<Expr> args; |
57 | args.reserve(argsCount); |
58 | |
59 | for (int j = 0; j < argsCount; ++j) { |
60 | args.push_back(tuple_node->fields[j + startIdx]); |
61 | } |
62 | Tuple new_tuple = WithFields(GetRef<Tuple>(tuple_node), args); |
63 | Expr body = MakeConcatenate(new_tuple, param->axis); |
64 | splitted[i] = StopFusion(body); |
65 | } |
66 | tvm::Array<Expr> tuple_args(splitted); |
67 | Tuple new_tuple = WithFields(GetRef<Tuple>(tuple_node), tuple_args); |
68 | return MakeConcatenate(new_tuple, param->axis); |
69 | } |
70 | return post; |
71 | } |
72 | |
73 | private: |
74 | const int max_function_args_; |
75 | const Op& concat_op_; |
76 | }; |
77 | |
78 | Expr SplitArgs(const Expr& expr, int max_function_args) { |
79 | auto rewriter = ArgumentSplitter(max_function_args); |
80 | return PostOrderRewrite(expr, &rewriter); |
81 | } |
82 | |
83 | namespace transform { |
84 | |
85 | Pass SplitArgs(int max_function_args) { |
86 | runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func = |
87 | [=](Function f, IRModule m, PassContext pc) { |
88 | auto r = Downcast<Function>(SplitArgs(f, max_function_args)); |
89 | return m->attrs.defined() ? WithAttrs(r, {m->attrs->dict}) : r; |
90 | }; |
91 | return CreateFunctionPass(pass_func, 1, "SplitArgs" , {"InferType" }); |
92 | } |
93 | |
94 | TVM_REGISTER_GLOBAL("relay._transform.SplitArgs" ).set_body_typed(SplitArgs); |
95 | |
96 | } // namespace transform |
97 | |
98 | } // namespace relay |
99 | } // namespace tvm |
100 | |