1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "graph.h"
20
21#include <tvm/node/reflection.h>
22#include <tvm/runtime/container/array.h>
23#include <tvm/runtime/object.h>
24#include <tvm/runtime/registry.h>
25
26#include <algorithm>
27#include <stack>
28#include <unordered_set>
29#include <utility>
30#include <vector>
31
32#include "common.h"
33#include "stripe_config.h"
34
35namespace tvm {
36namespace contrib {
37namespace ethosu {
38namespace cascader {
39
40void PerformanceInfoNode::VisitAttrs(AttrVisitor* v) {
41 v->Visit("_compute_cycles", &compute_cycles);
42 Array<IntImm> tmp_reads = make_array(read_bytes);
43 v->Visit("_read_bytes", &tmp_reads);
44 v->Visit("_write_bytes", &write_bytes);
45 v->Visit("_block_config", &block_config);
46}
47
48TVM_REGISTER_NODE_TYPE(PerformanceInfoNode);
49
50TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
51 .set_dispatch<PerformanceInfoNode>([](const ObjectRef& ref, ReprPrinter* p) {
52 auto* node = static_cast<const PerformanceInfoNode*>(ref.get());
53 p->stream << "PerformanceInfo(compute_cycles=" << node->compute_cycles << ", read_bytes=[";
54 for (auto rb : node->read_bytes) {
55 p->stream << rb << ", ";
56 }
57 p->stream << "], write_bytes=" << node->write_bytes << ")";
58 });
59
60void TensorNode::VisitAttrs(AttrVisitor* v) {
61 Array<Integer> tmp_arr = make_array(shape_);
62 v->Visit("_shape", &tmp_arr);
63 v->Visit("_dtype", &dtype_);
64 v->Visit("_is_constant", &is_constant_);
65 double compression_ratio = static_cast<double>(compression_ratio_);
66 v->Visit("_compression_ratio", &compression_ratio);
67 Array<Part> tmp_prods(producers_);
68 v->Visit("_producers", &tmp_prods);
69 Array<Part> tmp_cons(consumers_);
70 v->Visit("_consumers", &tmp_cons);
71 v->Visit("_size", &size_);
72}
73
74Tensor::Tensor(const std::vector<int>& shape, DataType dtype, bool is_constant = false,
75 float compression_ratio = 1.0) {
76 auto n = make_object<TensorNode>();
77 n->shape_ = std::move(shape);
78 n->dtype_ = dtype;
79 n->is_constant_ = is_constant;
80 n->compression_ratio_ = compression_ratio;
81 n->size_ = mul_reduce(n->shape_) * n->dtype_.bytes();
82 data_ = std::move(n);
83}
84
85TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.Tensor")
86 .set_body_typed([](Array<Integer> shape, DataType dtype, bool is_constant,
87 double compression_ratio) {
88 std::vector<int> vshape = make_vector<int, Integer>(shape);
89 return Tensor(vshape, dtype, is_constant, compression_ratio);
90 });
91
92TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.TensorAddProducer")
93 .set_body_method<Tensor>(&TensorNode::AddProducer);
94TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.TensorAddConsumer")
95 .set_body_method<Tensor>(&TensorNode::AddConsumer);
96
97TVM_REGISTER_NODE_TYPE(TensorNode);
98
99void PartNode::VisitAttrs(AttrVisitor* v) {
100 Array<Propagator> tmp_prp(propagators_);
101 v->Visit("_propagators", &tmp_prp);
102 Array<Tensor> tmp_ins(input_tensors_);
103 v->Visit("_input_tensors", &tmp_ins);
104 v->Visit("_output_tensor", &output_tensor_);
105 v->Visit("_in_line", &in_line_);
106 Array<te::Tensor> tmp_te_ins(subgraph_.input_tensors);
107 v->Visit("_te_input_tensors", &tmp_te_ins);
108 v->Visit("_te_output_tensor", &subgraph_.output_tensor);
109}
110
111void PartNode::SetInput(uint64_t input_index, const Tensor& input_tensor) {
112 ICHECK_LT(input_index, input_tensors_.size());
113 input_tensors_[input_index] = std::move(input_tensor);
114}
115
116std::vector<StripeConfig> PartNode::CalculateInputStripeConfigs(
117 const StripeConfig& output_stripe_config) {
118 std::vector<StripeConfig> input_stripe_configs;
119 for (const auto& propagator : propagators_) {
120 input_stripe_configs.push_back(propagator->propagate(output_stripe_config));
121 }
122 return input_stripe_configs;
123}
124
125const std::vector<int> PartNode::GetStripeAlignHint() const {
126 ICHECK_GT(propagators_.size(), 0);
127 size_t dims = propagators_[0]->GetOutputDims();
128 std::vector<int> compute_quantum(dims);
129 for (size_t i = 0; i < dims; i++) {
130 compute_quantum[i] = 1;
131 }
132 return compute_quantum;
133}
134
135TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.PartSetInput")
136 .set_body_method<Part>(&PartNode::SetInput);
137TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.PartSetOutput")
138 .set_body_method<Part>(&PartNode::SetOutput);
139TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.PartCalculateInputStripeConfigs")
140 .set_body_typed([](Part part, StripeConfig output_stripe_config) {
141 auto input_stripe_configs = part->CalculateInputStripeConfigs(output_stripe_config);
142 return Array<StripeConfig>(input_stripe_configs);
143 });
144TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.PartGetStripeAlignHint").set_body_typed([](Part part) {
145 std::vector<int> align_hint = part->GetStripeAlignHint();
146 return make_array(align_hint);
147});
148TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.PartGetPerformanceInfo")
149 .set_body_typed([](Part part, StripeConfig stripe_config, int buffer_mode) {
150 BufferMode ebuffer_mode = static_cast<BufferMode>(buffer_mode);
151 return part->GetPerformanceInfo(stripe_config, ebuffer_mode);
152 });
153
154CascaderGraphNode::CascaderGraphNode(std::vector<Tensor> input_tensors,
155 std::vector<Tensor> output_tensors)
156 : input_tensors_(input_tensors), output_tensors_(output_tensors) {
157 Init_();
158}
159
160bool VisitedInputs(
161 const Part& part,
162 const std::unordered_set<Tensor, ObjectPtrHash, ObjectPtrEqual>& visited_tensors) {
163 for (const auto& input_tensor : part->GetInputTensors()) {
164 if (visited_tensors.find(input_tensor) == visited_tensors.end()) {
165 return false;
166 }
167 }
168 return true;
169}
170
171void CascaderGraphNode::Init_() {
172 std::stack<Tensor> stack;
173 std::unordered_set<Tensor, ObjectPtrHash, ObjectPtrEqual> visited_tensors;
174 std::unordered_set<Part, ObjectPtrHash, ObjectPtrEqual> visited_parts;
175 for (const auto& input : input_tensors_) {
176 stack.push(input);
177 }
178 // Visit the Parts/Tensors in depth-first order using a non-recursive algorithm
179 while (!stack.empty()) {
180 Tensor tensor = stack.top();
181 stack.pop();
182 if (visited_tensors.find(tensor) == visited_tensors.end()) {
183 visited_tensors.insert(tensor);
184 tensor_order_.push_back(tensor);
185 for (const auto& part : tensor->GetConsumers()) {
186 if (visited_parts.find(part) == visited_parts.end()) {
187 // Only visit a Part once we've visited all its input Tensors
188 if (!VisitedInputs(part, visited_tensors)) continue;
189 visited_parts.insert(part);
190 part_order_.push_back(part);
191 stack.push(part->GetOutputTensor());
192 }
193 }
194 }
195 }
196 std::reverse(tensor_order_.begin(), tensor_order_.end());
197 std::reverse(part_order_.begin(), part_order_.end());
198 int id = 0;
199 for (const auto& part : part_order_) {
200 part_id_map_[part] = id;
201 id++;
202 }
203 id = 0;
204 for (const auto& tensor : tensor_order_) {
205 tensor_id_map_[tensor] = id;
206 id++;
207 }
208}
209
210void CascaderGraphNode::VisitAttrs(AttrVisitor* v) {
211 Array<Tensor> tmp_ins(input_tensors_);
212 v->Visit("_input_tensors", &tmp_ins);
213 Array<Tensor> tmp_outs(output_tensors_);
214 v->Visit("_output_tensors", &tmp_outs);
215 Array<Part> tmp_parr(part_order_);
216 v->Visit("_part_order", &tmp_parr);
217 Array<Tensor> tmp_tarr(tensor_order_);
218 v->Visit("_tensor_order", &tmp_tarr);
219}
220
221int CascaderGraphNode::GetPartID(const Part& part) const {
222 if (part_id_map_.find(part) == part_id_map_.end()) {
223 return -1;
224 }
225 return part_id_map_.at(part);
226}
227
228int CascaderGraphNode::GetTensorID(const Tensor& tensor) const {
229 if (tensor_id_map_.find(tensor) == tensor_id_map_.end()) {
230 return -1;
231 }
232 return tensor_id_map_.at(tensor);
233}
234
235CascaderGraph::CascaderGraph(std::vector<Tensor> input_tensors,
236 std::vector<Tensor> output_tensors) {
237 auto n = make_object<CascaderGraphNode>(input_tensors, output_tensors);
238 data_ = std::move(n);
239}
240
241TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.CascaderGraph")
242 .set_body_typed([](Array<Tensor> input_tensors, Array<Tensor> output_tensors) {
243 std::vector<Tensor> vinput_tensors(input_tensors.begin(), input_tensors.end());
244 std::vector<Tensor> voutput_tensors(output_tensors.begin(), output_tensors.end());
245 return CascaderGraph(vinput_tensors, voutput_tensors);
246 });
247TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.CascaderGraphGetPartID")
248 .set_body_method<CascaderGraph>(&CascaderGraphNode::GetPartID);
249TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.CascaderGraphGetTensorID")
250 .set_body_method<CascaderGraph>(&CascaderGraphNode::GetTensorID);
251
252TVM_REGISTER_NODE_TYPE(CascaderGraphNode);
253
254} // namespace cascader
255} // namespace ethosu
256} // namespace contrib
257} // namespace tvm
258