1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file split_args.cc
22 */
23#include <tvm/relay/expr_functor.h>
24#include <tvm/relay/transform.h>
25
26#include "../op/annotation/annotation.h"
27#include "./pattern_utils.h"
28
29namespace tvm {
30namespace relay {
31
32class ArgumentSplitter : public ExprRewriter {
33 public:
34 explicit ArgumentSplitter(int max_function_args)
35 : max_function_args_(max_function_args), concat_op_(Op::Get("concatenate")) {}
36
37 Expr Rewrite_(const CallNode* call, const Expr& post) final {
38 if (max_function_args_ < 0) return post;
39 if (call->op == concat_op_) {
40 auto tuple_node = call->args[0].as<TupleNode>();
41 const auto param = call->attrs.as<ConcatenateAttrs>();
42 int outputsNum = 1;
43 if (const auto* tuple_type = call->checked_type().as<TupleTypeNode>()) {
44 outputsNum = tuple_type->fields.size();
45 }
46 const int limit = max_function_args_ - outputsNum;
47 int argsNum = tuple_node->fields.size();
48 if (argsNum < limit) return post;
49 int splitNum = argsNum / limit;
50 splitNum = (argsNum % limit) ? splitNum + 1 : splitNum;
51
52 std::vector<Expr> splitted(splitNum);
53 for (int i = 0; i < splitNum; ++i) {
54 int startIdx = i * limit;
55 int argsCount = std::min(limit, argsNum - startIdx);
56 tvm::Array<Expr> args;
57 args.reserve(argsCount);
58
59 for (int j = 0; j < argsCount; ++j) {
60 args.push_back(tuple_node->fields[j + startIdx]);
61 }
62 Tuple new_tuple = WithFields(GetRef<Tuple>(tuple_node), args);
63 Expr body = MakeConcatenate(new_tuple, param->axis);
64 splitted[i] = StopFusion(body);
65 }
66 tvm::Array<Expr> tuple_args(splitted);
67 Tuple new_tuple = WithFields(GetRef<Tuple>(tuple_node), tuple_args);
68 return MakeConcatenate(new_tuple, param->axis);
69 }
70 return post;
71 }
72
73 private:
74 const int max_function_args_;
75 const Op& concat_op_;
76};
77
78Expr SplitArgs(const Expr& expr, int max_function_args) {
79 auto rewriter = ArgumentSplitter(max_function_args);
80 return PostOrderRewrite(expr, &rewriter);
81}
82
83namespace transform {
84
85Pass SplitArgs(int max_function_args) {
86 runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
87 [=](Function f, IRModule m, PassContext pc) {
88 auto r = Downcast<Function>(SplitArgs(f, max_function_args));
89 return m->attrs.defined() ? WithAttrs(r, {m->attrs->dict}) : r;
90 };
91 return CreateFunctionPass(pass_func, 1, "SplitArgs", {"InferType"});
92}
93
94TVM_REGISTER_GLOBAL("relay._transform.SplitArgs").set_body_typed(SplitArgs);
95
96} // namespace transform
97
98} // namespace relay
99} // namespace tvm
100