1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/ir/op_strategy.cc
22 * \brief The Relay operator Strategy and related data structure.
23 */
24
25#include <tvm/relay/op_strategy.h>
26
27namespace tvm {
28namespace relay {
29
30TVM_REGISTER_NODE_TYPE(OpImplementationNode);
31TVM_REGISTER_NODE_TYPE(OpSpecializationNode);
32TVM_REGISTER_NODE_TYPE(OpStrategyNode);
33
34Array<te::Tensor> OpImplementation::Compute(const Attrs& attrs, const Array<te::Tensor>& inputs,
35 const Type& out_type) {
36 return (*this)->fcompute(attrs, inputs, out_type);
37}
38
39te::Schedule OpImplementation::Schedule(const Attrs& attrs, const Array<te::Tensor>& outs,
40 const Target& target) {
41 return (*this)->fschedule(attrs, outs, target);
42}
43
44void OpSpecialization::AddImplementation(tvm::relay::FTVMCompute fcompute,
45 tvm::relay::FTVMSchedule fschedule, String name,
46 int plevel) {
47 auto n = make_object<OpImplementationNode>();
48 n->fcompute = fcompute;
49 n->fschedule = fschedule;
50 n->name = std::move(name);
51 n->plevel = plevel;
52 (*this)->implementations.push_back(OpImplementation(n));
53}
54
55void OpStrategy::AddImplementation(FTVMCompute fcompute, FTVMSchedule fschedule, String name,
56 int plevel) {
57 auto curr_cond = te::SpecializedCondition::Current();
58 auto self = this->operator->();
59 Array<OpSpecialization> specializations = self->specializations;
60 OpSpecialization op_spec;
61 for (OpSpecialization op_spec : specializations) {
62 if (op_spec->condition == curr_cond) {
63 op_spec.AddImplementation(fcompute, fschedule, std::move(name), plevel);
64 return;
65 }
66 }
67 ObjectPtr<OpSpecializationNode> n = make_object<OpSpecializationNode>();
68 n->condition = curr_cond;
69 op_spec = OpSpecialization(n);
70 op_spec.AddImplementation(fcompute, fschedule, std::move(name), plevel);
71 self->specializations.push_back(op_spec);
72}
73
74TVM_REGISTER_GLOBAL("relay.op._OpImplementationCompute")
75 .set_body([](TVMArgs args, TVMRetValue* rv) {
76 OpImplementation imp = args[0];
77 Attrs attrs = args[1];
78 Array<te::Tensor> inputs = args[2];
79 Type out_type = args[3];
80 *rv = imp.Compute(attrs, inputs, out_type);
81 });
82
83TVM_REGISTER_GLOBAL("relay.op._OpImplementationSchedule")
84 .set_body([](TVMArgs args, TVMRetValue* rv) {
85 OpImplementation imp = args[0];
86 Attrs attrs = args[1];
87 Array<te::Tensor> outs = args[2];
88 Target target = args[3];
89 *rv = imp.Schedule(attrs, outs, target);
90 });
91
92TVM_REGISTER_GLOBAL("relay.op._make.OpStrategy").set_body([](TVMArgs args, TVMRetValue* rv) {
93 ObjectPtr<OpStrategyNode> n = make_object<OpStrategyNode>();
94 *rv = OpStrategy(n);
95});
96
97TVM_REGISTER_GLOBAL("relay.op._OpStrategyAddImplementation")
98 .set_body([](TVMArgs args, TVMRetValue* rv) {
99 OpStrategy strategy = args[0];
100 FTVMCompute compute = args[1];
101 FTVMSchedule schedule = args[2];
102 std::string name = args[3];
103 int plevel = args[4];
104 strategy.AddImplementation(compute, schedule, name, plevel);
105 });
106
107TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
108 .set_dispatch<OpStrategyNode>([](const ObjectRef& node, ReprPrinter* p) {
109 auto* op = static_cast<const OpStrategyNode*>(node.get());
110 p->stream << "op_strategy(" << op->specializations << ")";
111 })
112 .set_dispatch<OpSpecializationNode>([](const ObjectRef& node, ReprPrinter* p) {
113 auto* op = static_cast<const OpSpecializationNode*>(node.get());
114 p->stream << "op_spec(" << op->condition << ", " << op->implementations << ")";
115 })
116 .set_dispatch<OpImplementationNode>([](const ObjectRef& node, ReprPrinter* p) {
117 auto* op = static_cast<const OpImplementationNode*>(node.get());
118 p->stream << "op_impl(name=" << op->name << ", level=" << op->plevel << ")";
119 });
120
121} // namespace relay
122} // namespace tvm
123