1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include "graph.h" |
20 | |
21 | #include <tvm/node/reflection.h> |
22 | #include <tvm/runtime/container/array.h> |
23 | #include <tvm/runtime/object.h> |
24 | #include <tvm/runtime/registry.h> |
25 | |
26 | #include <algorithm> |
27 | #include <stack> |
28 | #include <unordered_set> |
29 | #include <utility> |
30 | #include <vector> |
31 | |
32 | #include "common.h" |
33 | #include "stripe_config.h" |
34 | |
35 | namespace tvm { |
36 | namespace contrib { |
37 | namespace ethosu { |
38 | namespace cascader { |
39 | |
40 | void PerformanceInfoNode::VisitAttrs(AttrVisitor* v) { |
41 | v->Visit("_compute_cycles" , &compute_cycles); |
42 | Array<IntImm> tmp_reads = make_array(read_bytes); |
43 | v->Visit("_read_bytes" , &tmp_reads); |
44 | v->Visit("_write_bytes" , &write_bytes); |
45 | v->Visit("_block_config" , &block_config); |
46 | } |
47 | |
48 | TVM_REGISTER_NODE_TYPE(PerformanceInfoNode); |
49 | |
50 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
51 | .set_dispatch<PerformanceInfoNode>([](const ObjectRef& ref, ReprPrinter* p) { |
52 | auto* node = static_cast<const PerformanceInfoNode*>(ref.get()); |
53 | p->stream << "PerformanceInfo(compute_cycles=" << node->compute_cycles << ", read_bytes=[" ; |
54 | for (auto rb : node->read_bytes) { |
55 | p->stream << rb << ", " ; |
56 | } |
57 | p->stream << "], write_bytes=" << node->write_bytes << ")" ; |
58 | }); |
59 | |
60 | void TensorNode::VisitAttrs(AttrVisitor* v) { |
61 | Array<Integer> tmp_arr = make_array(shape_); |
62 | v->Visit("_shape" , &tmp_arr); |
63 | v->Visit("_dtype" , &dtype_); |
64 | v->Visit("_is_constant" , &is_constant_); |
65 | double compression_ratio = static_cast<double>(compression_ratio_); |
66 | v->Visit("_compression_ratio" , &compression_ratio); |
67 | Array<Part> tmp_prods(producers_); |
68 | v->Visit("_producers" , &tmp_prods); |
69 | Array<Part> tmp_cons(consumers_); |
70 | v->Visit("_consumers" , &tmp_cons); |
71 | v->Visit("_size" , &size_); |
72 | } |
73 | |
74 | Tensor::Tensor(const std::vector<int>& shape, DataType dtype, bool is_constant = false, |
75 | float compression_ratio = 1.0) { |
76 | auto n = make_object<TensorNode>(); |
77 | n->shape_ = std::move(shape); |
78 | n->dtype_ = dtype; |
79 | n->is_constant_ = is_constant; |
80 | n->compression_ratio_ = compression_ratio; |
81 | n->size_ = mul_reduce(n->shape_) * n->dtype_.bytes(); |
82 | data_ = std::move(n); |
83 | } |
84 | |
85 | TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.Tensor" ) |
86 | .set_body_typed([](Array<Integer> shape, DataType dtype, bool is_constant, |
87 | double compression_ratio) { |
88 | std::vector<int> vshape = make_vector<int, Integer>(shape); |
89 | return Tensor(vshape, dtype, is_constant, compression_ratio); |
90 | }); |
91 | |
92 | TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.TensorAddProducer" ) |
93 | .set_body_method<Tensor>(&TensorNode::AddProducer); |
94 | TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.TensorAddConsumer" ) |
95 | .set_body_method<Tensor>(&TensorNode::AddConsumer); |
96 | |
97 | TVM_REGISTER_NODE_TYPE(TensorNode); |
98 | |
99 | void PartNode::VisitAttrs(AttrVisitor* v) { |
100 | Array<Propagator> tmp_prp(propagators_); |
101 | v->Visit("_propagators" , &tmp_prp); |
102 | Array<Tensor> tmp_ins(input_tensors_); |
103 | v->Visit("_input_tensors" , &tmp_ins); |
104 | v->Visit("_output_tensor" , &output_tensor_); |
105 | v->Visit("_in_line" , &in_line_); |
106 | Array<te::Tensor> tmp_te_ins(subgraph_.input_tensors); |
107 | v->Visit("_te_input_tensors" , &tmp_te_ins); |
108 | v->Visit("_te_output_tensor" , &subgraph_.output_tensor); |
109 | } |
110 | |
111 | void PartNode::SetInput(uint64_t input_index, const Tensor& input_tensor) { |
112 | ICHECK_LT(input_index, input_tensors_.size()); |
113 | input_tensors_[input_index] = std::move(input_tensor); |
114 | } |
115 | |
116 | std::vector<StripeConfig> PartNode::CalculateInputStripeConfigs( |
117 | const StripeConfig& output_stripe_config) { |
118 | std::vector<StripeConfig> input_stripe_configs; |
119 | for (const auto& propagator : propagators_) { |
120 | input_stripe_configs.push_back(propagator->propagate(output_stripe_config)); |
121 | } |
122 | return input_stripe_configs; |
123 | } |
124 | |
125 | const std::vector<int> PartNode::GetStripeAlignHint() const { |
126 | ICHECK_GT(propagators_.size(), 0); |
127 | size_t dims = propagators_[0]->GetOutputDims(); |
128 | std::vector<int> compute_quantum(dims); |
129 | for (size_t i = 0; i < dims; i++) { |
130 | compute_quantum[i] = 1; |
131 | } |
132 | return compute_quantum; |
133 | } |
134 | |
135 | TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.PartSetInput" ) |
136 | .set_body_method<Part>(&PartNode::SetInput); |
137 | TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.PartSetOutput" ) |
138 | .set_body_method<Part>(&PartNode::SetOutput); |
139 | TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.PartCalculateInputStripeConfigs" ) |
140 | .set_body_typed([](Part part, StripeConfig output_stripe_config) { |
141 | auto input_stripe_configs = part->CalculateInputStripeConfigs(output_stripe_config); |
142 | return Array<StripeConfig>(input_stripe_configs); |
143 | }); |
144 | TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.PartGetStripeAlignHint" ).set_body_typed([](Part part) { |
145 | std::vector<int> align_hint = part->GetStripeAlignHint(); |
146 | return make_array(align_hint); |
147 | }); |
148 | TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.PartGetPerformanceInfo" ) |
149 | .set_body_typed([](Part part, StripeConfig stripe_config, int buffer_mode) { |
150 | BufferMode ebuffer_mode = static_cast<BufferMode>(buffer_mode); |
151 | return part->GetPerformanceInfo(stripe_config, ebuffer_mode); |
152 | }); |
153 | |
154 | CascaderGraphNode::CascaderGraphNode(std::vector<Tensor> input_tensors, |
155 | std::vector<Tensor> output_tensors) |
156 | : input_tensors_(input_tensors), output_tensors_(output_tensors) { |
157 | Init_(); |
158 | } |
159 | |
160 | bool VisitedInputs( |
161 | const Part& part, |
162 | const std::unordered_set<Tensor, ObjectPtrHash, ObjectPtrEqual>& visited_tensors) { |
163 | for (const auto& input_tensor : part->GetInputTensors()) { |
164 | if (visited_tensors.find(input_tensor) == visited_tensors.end()) { |
165 | return false; |
166 | } |
167 | } |
168 | return true; |
169 | } |
170 | |
171 | void CascaderGraphNode::Init_() { |
172 | std::stack<Tensor> stack; |
173 | std::unordered_set<Tensor, ObjectPtrHash, ObjectPtrEqual> visited_tensors; |
174 | std::unordered_set<Part, ObjectPtrHash, ObjectPtrEqual> visited_parts; |
175 | for (const auto& input : input_tensors_) { |
176 | stack.push(input); |
177 | } |
178 | // Visit the Parts/Tensors in depth-first order using a non-recursive algorithm |
179 | while (!stack.empty()) { |
180 | Tensor tensor = stack.top(); |
181 | stack.pop(); |
182 | if (visited_tensors.find(tensor) == visited_tensors.end()) { |
183 | visited_tensors.insert(tensor); |
184 | tensor_order_.push_back(tensor); |
185 | for (const auto& part : tensor->GetConsumers()) { |
186 | if (visited_parts.find(part) == visited_parts.end()) { |
187 | // Only visit a Part once we've visited all its input Tensors |
188 | if (!VisitedInputs(part, visited_tensors)) continue; |
189 | visited_parts.insert(part); |
190 | part_order_.push_back(part); |
191 | stack.push(part->GetOutputTensor()); |
192 | } |
193 | } |
194 | } |
195 | } |
196 | std::reverse(tensor_order_.begin(), tensor_order_.end()); |
197 | std::reverse(part_order_.begin(), part_order_.end()); |
198 | int id = 0; |
199 | for (const auto& part : part_order_) { |
200 | part_id_map_[part] = id; |
201 | id++; |
202 | } |
203 | id = 0; |
204 | for (const auto& tensor : tensor_order_) { |
205 | tensor_id_map_[tensor] = id; |
206 | id++; |
207 | } |
208 | } |
209 | |
210 | void CascaderGraphNode::VisitAttrs(AttrVisitor* v) { |
211 | Array<Tensor> tmp_ins(input_tensors_); |
212 | v->Visit("_input_tensors" , &tmp_ins); |
213 | Array<Tensor> tmp_outs(output_tensors_); |
214 | v->Visit("_output_tensors" , &tmp_outs); |
215 | Array<Part> tmp_parr(part_order_); |
216 | v->Visit("_part_order" , &tmp_parr); |
217 | Array<Tensor> tmp_tarr(tensor_order_); |
218 | v->Visit("_tensor_order" , &tmp_tarr); |
219 | } |
220 | |
221 | int CascaderGraphNode::GetPartID(const Part& part) const { |
222 | if (part_id_map_.find(part) == part_id_map_.end()) { |
223 | return -1; |
224 | } |
225 | return part_id_map_.at(part); |
226 | } |
227 | |
228 | int CascaderGraphNode::GetTensorID(const Tensor& tensor) const { |
229 | if (tensor_id_map_.find(tensor) == tensor_id_map_.end()) { |
230 | return -1; |
231 | } |
232 | return tensor_id_map_.at(tensor); |
233 | } |
234 | |
235 | CascaderGraph::CascaderGraph(std::vector<Tensor> input_tensors, |
236 | std::vector<Tensor> output_tensors) { |
237 | auto n = make_object<CascaderGraphNode>(input_tensors, output_tensors); |
238 | data_ = std::move(n); |
239 | } |
240 | |
241 | TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.CascaderGraph" ) |
242 | .set_body_typed([](Array<Tensor> input_tensors, Array<Tensor> output_tensors) { |
243 | std::vector<Tensor> vinput_tensors(input_tensors.begin(), input_tensors.end()); |
244 | std::vector<Tensor> voutput_tensors(output_tensors.begin(), output_tensors.end()); |
245 | return CascaderGraph(vinput_tensors, voutput_tensors); |
246 | }); |
247 | TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.CascaderGraphGetPartID" ) |
248 | .set_body_method<CascaderGraph>(&CascaderGraphNode::GetPartID); |
249 | TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.CascaderGraphGetTensorID" ) |
250 | .set_body_method<CascaderGraph>(&CascaderGraphNode::GetTensorID); |
251 | |
252 | TVM_REGISTER_NODE_TYPE(CascaderGraphNode); |
253 | |
254 | } // namespace cascader |
255 | } // namespace ethosu |
256 | } // namespace contrib |
257 | } // namespace tvm |
258 | |