1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/transforms/merge_composite.cc
22 * \brief Merges expressions matching patterns into functions marked
23 * as 'composite'. This is primarily intended to be used alongside the
24 * external codegen infrastructure to support the case where multiple
25 * Relay operators map to a single external operator.
26 */
27
28#include <tvm/relay/analysis.h>
29#include <tvm/relay/dataflow_matcher.h>
30#include <tvm/relay/expr_functor.h>
31#include <tvm/relay/op_attr_types.h>
32#include <tvm/relay/transform.h>
33#include <tvm/te/operation.h>
34
35namespace tvm {
36namespace relay {
37namespace merge_composite {
38
39Function InferType(const Function& expr, const IRModule& m) {
40 IRModule mod(m);
41 mod->Update(mod->GetGlobalVar("main"), expr);
42 mod = transform::InferType()(mod);
43 return Downcast<Function>(mod->Lookup("main"));
44}
45
46Expr MergeComposite(const Function& func, const Array<runtime::String>& pattern_names,
47 const Array<DFPattern>& patterns, const std::vector<PackedFunc>& checks,
48 const IRModule& m) {
49 ICHECK_EQ(pattern_names.size(), patterns.size());
50 Function merged_func = func;
51 // merge the patterns one-by-one in order
52 for (size_t i = 0; i < patterns.size(); i++) {
53 Map<String, ObjectRef> attrs;
54 attrs.Set("Composite", pattern_names[i]);
55 merged_func = Downcast<Function>(PartitionPattern(patterns[i], merged_func, attrs, checks[i]));
56 merged_func = InferType(merged_func, m);
57 }
58 return std::move(merged_func);
59}
60
61} // namespace merge_composite
62
63namespace transform {
64
65Pass MergeComposite(const tvm::Array<runtime::String>& pattern_names,
66 const tvm::Array<DFPattern>& patterns, const std::vector<PackedFunc>& checks) {
67 runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
68 [=](Function f, IRModule m, PassContext pc) {
69 return Downcast<Function>(
70 relay::merge_composite::MergeComposite(f, pattern_names, patterns, checks, m));
71 };
72 auto func_pass = CreateFunctionPass(pass_func, 0, "MergeComposite", {});
73 return func_pass;
74}
75
76TVM_REGISTER_GLOBAL("relay._transform.MergeComposite").set_body([](TVMArgs args, TVMRetValue* rv) {
77 tvm::Array<runtime::String> pattern_names = args[0];
78 tvm::Array<DFPattern> patterns = args[1];
79 std::vector<PackedFunc> checks;
80 for (int i = 2; i < args.size(); i++) {
81 checks.push_back(args[i]);
82 }
83 *rv = MergeComposite(pattern_names, patterns, checks);
84});
85
86} // namespace transform
87
88} // namespace relay
89} // namespace tvm
90