1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/relay/ir/op_strategy.cc |
22 | * \brief The Relay operator Strategy and related data structure. |
23 | */ |
24 | |
25 | #include <tvm/relay/op_strategy.h> |
26 | |
27 | namespace tvm { |
28 | namespace relay { |
29 | |
30 | TVM_REGISTER_NODE_TYPE(OpImplementationNode); |
31 | TVM_REGISTER_NODE_TYPE(OpSpecializationNode); |
32 | TVM_REGISTER_NODE_TYPE(OpStrategyNode); |
33 | |
34 | Array<te::Tensor> OpImplementation::Compute(const Attrs& attrs, const Array<te::Tensor>& inputs, |
35 | const Type& out_type) { |
36 | return (*this)->fcompute(attrs, inputs, out_type); |
37 | } |
38 | |
39 | te::Schedule OpImplementation::Schedule(const Attrs& attrs, const Array<te::Tensor>& outs, |
40 | const Target& target) { |
41 | return (*this)->fschedule(attrs, outs, target); |
42 | } |
43 | |
44 | void OpSpecialization::AddImplementation(tvm::relay::FTVMCompute fcompute, |
45 | tvm::relay::FTVMSchedule fschedule, String name, |
46 | int plevel) { |
47 | auto n = make_object<OpImplementationNode>(); |
48 | n->fcompute = fcompute; |
49 | n->fschedule = fschedule; |
50 | n->name = std::move(name); |
51 | n->plevel = plevel; |
52 | (*this)->implementations.push_back(OpImplementation(n)); |
53 | } |
54 | |
55 | void OpStrategy::AddImplementation(FTVMCompute fcompute, FTVMSchedule fschedule, String name, |
56 | int plevel) { |
57 | auto curr_cond = te::SpecializedCondition::Current(); |
58 | auto self = this->operator->(); |
59 | Array<OpSpecialization> specializations = self->specializations; |
60 | OpSpecialization op_spec; |
61 | for (OpSpecialization op_spec : specializations) { |
62 | if (op_spec->condition == curr_cond) { |
63 | op_spec.AddImplementation(fcompute, fschedule, std::move(name), plevel); |
64 | return; |
65 | } |
66 | } |
67 | ObjectPtr<OpSpecializationNode> n = make_object<OpSpecializationNode>(); |
68 | n->condition = curr_cond; |
69 | op_spec = OpSpecialization(n); |
70 | op_spec.AddImplementation(fcompute, fschedule, std::move(name), plevel); |
71 | self->specializations.push_back(op_spec); |
72 | } |
73 | |
74 | TVM_REGISTER_GLOBAL("relay.op._OpImplementationCompute" ) |
75 | .set_body([](TVMArgs args, TVMRetValue* rv) { |
76 | OpImplementation imp = args[0]; |
77 | Attrs attrs = args[1]; |
78 | Array<te::Tensor> inputs = args[2]; |
79 | Type out_type = args[3]; |
80 | *rv = imp.Compute(attrs, inputs, out_type); |
81 | }); |
82 | |
83 | TVM_REGISTER_GLOBAL("relay.op._OpImplementationSchedule" ) |
84 | .set_body([](TVMArgs args, TVMRetValue* rv) { |
85 | OpImplementation imp = args[0]; |
86 | Attrs attrs = args[1]; |
87 | Array<te::Tensor> outs = args[2]; |
88 | Target target = args[3]; |
89 | *rv = imp.Schedule(attrs, outs, target); |
90 | }); |
91 | |
92 | TVM_REGISTER_GLOBAL("relay.op._make.OpStrategy" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
93 | ObjectPtr<OpStrategyNode> n = make_object<OpStrategyNode>(); |
94 | *rv = OpStrategy(n); |
95 | }); |
96 | |
97 | TVM_REGISTER_GLOBAL("relay.op._OpStrategyAddImplementation" ) |
98 | .set_body([](TVMArgs args, TVMRetValue* rv) { |
99 | OpStrategy strategy = args[0]; |
100 | FTVMCompute compute = args[1]; |
101 | FTVMSchedule schedule = args[2]; |
102 | std::string name = args[3]; |
103 | int plevel = args[4]; |
104 | strategy.AddImplementation(compute, schedule, name, plevel); |
105 | }); |
106 | |
107 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
108 | .set_dispatch<OpStrategyNode>([](const ObjectRef& node, ReprPrinter* p) { |
109 | auto* op = static_cast<const OpStrategyNode*>(node.get()); |
110 | p->stream << "op_strategy(" << op->specializations << ")" ; |
111 | }) |
112 | .set_dispatch<OpSpecializationNode>([](const ObjectRef& node, ReprPrinter* p) { |
113 | auto* op = static_cast<const OpSpecializationNode*>(node.get()); |
114 | p->stream << "op_spec(" << op->condition << ", " << op->implementations << ")" ; |
115 | }) |
116 | .set_dispatch<OpImplementationNode>([](const ObjectRef& node, ReprPrinter* p) { |
117 | auto* op = static_cast<const OpImplementationNode*>(node.get()); |
118 | p->stream << "op_impl(name=" << op->name << ", level=" << op->plevel << ")" ; |
119 | }); |
120 | |
121 | } // namespace relay |
122 | } // namespace tvm |
123 | |