1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file auto_scheduler/feature.cc |
22 | * \brief Feature extraction for the cost model |
23 | */ |
24 | |
25 | #include <tvm/arith/analyzer.h> |
26 | #include <tvm/auto_scheduler/feature.h> |
27 | #include <tvm/auto_scheduler/measure.h> |
28 | #include <tvm/auto_scheduler/measure_record.h> |
29 | #include <tvm/driver/driver_api.h> |
30 | #include <tvm/ir/global_var_supply.h> |
31 | #include <tvm/runtime/registry.h> |
32 | #include <tvm/support/parallel_for.h> |
33 | #include <tvm/te/operation.h> |
34 | #include <tvm/te/schedule_pass.h> |
35 | #include <tvm/tir/analysis.h> |
36 | #include <tvm/tir/op_attr_types.h> |
37 | #include <tvm/tir/stmt_functor.h> |
38 | #include <tvm/tir/transform.h> |
39 | |
40 | #include <algorithm> |
41 | #include <cassert> |
42 | #include <cmath> |
43 | #include <numeric> |
44 | #include <unordered_map> |
45 | #include <vector> |
46 | |
47 | #include "search_policy/utils.h" |
48 | #include "utils.h" |
49 | |
50 | namespace tvm { |
51 | namespace auto_scheduler { |
52 | |
53 | using namespace tvm::tir; |
54 | using arith::Analyzer; |
55 | using arith::ConstIntBound; |
56 | |
57 | template <class T> |
58 | using BufferMap = std::unordered_map<Var, T, ObjectHash, ObjectEqual>; |
59 | |
60 | // The number of samples to extract for arithmetic intensity curves |
61 | static const int ARITH_INTENSITY_CURVE_SAMPLE_N = 10; |
62 | |
63 | // Annotation position encoding |
64 | enum class AnnotationPosType : int { |
65 | kPosNone = 0, // Does not have this kind of annotation |
66 | kPosInnerSpatial = 1, // The annotated iterator is the innermost spatial iterator |
67 | kPosMiddleSpatial = 2, // The annotated iterator is a middle spatial iterator |
68 | kPosOuterSpatial = 3, // The annotated iterator is the outermost spatial iterator |
69 | kPosInnerReduce = 4, // The annotated iterator is the innermost reduce iterator |
70 | kPosMiddleReduce = 5, // The annotated iterator is a middle reduce iterator |
71 | kPosOuterReduce = 6, // The annotated iterator is the outermost reduce iterator |
72 | kPosMixed = 7 // The annotated iterator is a mixed space and reduce iterator |
73 | }; |
74 | |
75 | // Buffer access type |
76 | enum class BufferAccessType : int { kRead = 0, kWrite = 1, kReadWrite = 2, kUnknownRW = 3 }; |
77 | |
78 | // Accesses to a buffer |
79 | struct BufferAccess { |
80 | // data reuse type |
81 | BufferAccessType acc_type{BufferAccessType::kUnknownRW}; |
82 | // Use a two-dimensional array to store multiple multi-dimensional accesses. |
83 | // The innermost vector stores the multi-dimensional indices of one access. |
84 | std::vector<std::vector<PrimExpr>> indices; |
85 | }; |
86 | |
87 | // Data reuse type |
88 | enum class ReuseType : int { kLoopMultipleRead = 0, kSerialMultipleReadWrite = 1, kNoReuse = 2 }; |
89 | |
90 | // Feature for an access of a buffer |
91 | struct BufferAccessFeature { |
92 | std::string buffer_name; // The name of the buffer |
93 | BufferAccessType acc_type; // The type of the access |
94 | float bytes; // The touched memory in bytes |
95 | float unique_bytes; // The touched unique memory in bytes |
96 | float lines; // The number of touched cache lines |
97 | float unique_lines; // The number touched unique cache lines |
98 | ReuseType reuse_type; // Tye type of data reuse |
99 | float reuse_dis_iter; // The reuse distance in iterator number |
100 | float reuse_dis_bytes; // The reuse distance in total touched bytes |
101 | float reuse_ct; // The reuse ratio |
102 | float bytes_d_reuse_ct; // bytes / reuse_ct |
103 | float unique_bytes_d_reuse_ct; // unique_bytes / reuse_ct |
104 | float lines_d_reuse_ct; // lines / reuse_ct |
105 | float unique_lines_d_reuse_ct; // unique_lines / reuse_ct |
106 | float stride; // The stride in access |
107 | }; |
108 | |
109 | // Feature set of a BufferStore statement |
110 | struct FeatureSet { |
111 | // Group 1: Computation related features |
112 | float float_mad; // The number of float MAD (Multiply–add) ops |
113 | float float_addsub; // The number of float add and sub ops |
114 | float float_mul; // The number of float multiply ops |
115 | float float_divmod; // The number of float div and mod ops |
116 | float float_cmp; // The number of float comparison ops |
117 | float float_math_func; // The number of float math func calls |
118 | float float_other_func; // The number of other float func calls |
119 | float int_mad; // The number of integer MAD (Multiply–add) ops |
120 | float int_addsub; // The number of integer add and sub ops |
121 | float int_mul; // The number of float multiply ops |
122 | float int_divmod; // The number of float div and mod ops |
123 | float int_cmp; // The number of float comparison ops |
124 | float int_math_func; // The number of float math func calls |
125 | float int_other_func; // The number of other float func calls |
126 | float bool_op; // The number of bool ops |
127 | float select_op; // The number of select ops |
128 | float vec_num; // The number of vectorized iterators |
129 | float vec_prod; // The product of the lengths of vectorized iterators |
130 | float vec_len; // The length of the innermost vectorized iterator |
131 | AnnotationPosType vec_type; // The type of vectorization position |
132 | float unroll_num; // The number of unrolled iterators |
133 | float unroll_prod; // The product of the lengths of vectorized iterators |
134 | float unroll_len; // The length of the innermost unrolled iterator |
135 | AnnotationPosType unroll_type; // The type of unroll position |
136 | float parallel_num; // The number of paralleled iterators |
137 | float parallel_prod; // The product of the lengths of paralleled iterators |
138 | float parallel_len; // The length of the innermost paralleled iterators |
139 | AnnotationPosType parallel_type; // The type of parallel position |
140 | float is_gpu; // Whether it is a GPU task |
141 | float blockIdx_x_len; // The length of blockIdx.x |
142 | float blockIdx_y_len; // The length of blockIdx.y |
143 | float blockIdx_z_len; // The length of blockIdx.z |
144 | float threadIdx_x_len; // The length of threadIdx.x |
145 | float threadIdx_y_len; // The length of threadIdx.y |
146 | float threadIdx_z_len; // The length of threadIdx.z |
147 | float vthread_len; // The length of virtual thread |
148 | |
149 | // Group 2: Buffer access related features (per buffer) |
150 | std::vector<BufferAccessFeature> access_feas; |
151 | |
152 | // Group 3: Arithmetic intensity related features |
153 | float arith_intensity_curve[ARITH_INTENSITY_CURVE_SAMPLE_N]; // points sampled from the |
154 | // arithmetic intensity curve |
155 | |
156 | // Group 4: Allocation related features |
157 | float alloc_size; // The size of allocated buffer in bytes |
158 | float alloc_outer_prod; // The product of lengths of loops outside the scope of the allocation |
159 | float alloc_inner_prod; // The product of lengths of loops inside the score of the allocation |
160 | float alloc_prod; // alloc_outer_prod * alloc_inner_prod |
161 | |
162 | // Group 5: Outer scope related features |
163 | float outer_prod; // The product of lengths of outer loops |
164 | float num_loops; // The number of outer loops |
165 | float auto_unroll_max_step; // The value of pragma "auto_unroll_max_step" |
166 | }; |
167 | |
168 | // Return whether a var is in an expr |
169 | bool VarInExpr(const Var& var, const PrimExpr& expr) { |
170 | bool find = false; |
171 | |
172 | PostOrderVisit(expr, [&find, &var](const ObjectRef& node) { |
173 | if (find) { |
174 | return; |
175 | } |
176 | |
177 | if (const VarNode* op = node.as<VarNode>()) { |
178 | if (op == var.get()) { |
179 | find = true; |
180 | } |
181 | } |
182 | }); |
183 | |
184 | return find; |
185 | } |
186 | |
187 | // Get position encoding for annotation |
188 | AnnotationPosType GetAnnotationPosEncoding(const Var& var, const Array<PrimExpr>& spatial_args, |
189 | const Array<IterVar>& axis, |
190 | const Array<IterVar>& reduce_axis) { |
191 | // Try to match spatial args first |
192 | size_t find_i = 0; |
193 | size_t find_ct = 0; |
194 | for (size_t i = 0; i < spatial_args.size(); ++i) { |
195 | if (VarInExpr(var, spatial_args[i])) { |
196 | find_i = i; |
197 | find_ct += 1; |
198 | } |
199 | } |
200 | |
201 | if (find_ct == 0) { |
202 | // If it is not found in spacial args, then it is a reduce iterator. |
203 | // Use name to match |
204 | const std::string& var_name = var->name_hint; |
205 | for (size_t i = 0; i < reduce_axis.size(); ++i) { |
206 | if (var_name.find(reduce_axis[i]->var->name_hint) != std::string::npos) { |
207 | find_i = i; |
208 | find_ct++; |
209 | } |
210 | } |
211 | if (find_ct >= 1) { |
212 | if (find_i == 0) { |
213 | return AnnotationPosType::kPosInnerReduce; |
214 | } else if (find_i == reduce_axis.size() - 1) { |
215 | return AnnotationPosType::kPosOuterReduce; |
216 | } else { |
217 | return AnnotationPosType::kPosMiddleReduce; |
218 | } |
219 | } else { |
220 | // If the axis is not found in both spatial args and reduce axis, |
221 | // then this stage must compute_at somewhere under this axis and this axis is simplified out |
222 | // We assume it is an outer spatial |
223 | return AnnotationPosType::kPosOuterSpatial; |
224 | } |
225 | } else if (find_ct == 1) { |
226 | if (find_i == spatial_args.size() - 1) { |
227 | return AnnotationPosType::kPosInnerSpatial; |
228 | } else if (find_i == 0) { |
229 | return AnnotationPosType::kPosOuterSpatial; |
230 | } else { |
231 | return AnnotationPosType::kPosMiddleSpatial; |
232 | } |
233 | } else { |
234 | return AnnotationPosType::kPosMixed; |
235 | } |
236 | } |
237 | |
238 | // Return the maximum extent of a for loop |
239 | int64_t GetLoopExtent(const ForNode* node, const Analyzer& ana) { |
240 | int64_t bound = ana.const_int_bound(node->extent)->max_value; |
241 | if (bound == ConstIntBound::kPosInf) { |
242 | return 1; // Analyzer could not determine a valid bound, use 1 instead. |
243 | } else { |
244 | return bound; |
245 | } |
246 | } |
247 | |
248 | // Count math ops in an expr |
249 | class MathOpCounter : public StmtExprVisitor { |
250 | public: |
251 | #define VisitBinary(Type, float_ct, int_ct) \ |
252 | void VisitExpr_(const Type* op) final { \ |
253 | if (op->a.dtype().is_float() || op->a.dtype().is_bfloat16()) { \ |
254 | float_ct += op->a.dtype().lanes(); \ |
255 | } else { \ |
256 | int_ct += op->a.dtype().lanes(); \ |
257 | } \ |
258 | StmtExprVisitor::VisitExpr_(op); \ |
259 | } |
260 | |
261 | VisitBinary(AddNode, float_addsub, int_addsub); |
262 | VisitBinary(SubNode, float_addsub, int_addsub); |
263 | VisitBinary(MulNode, float_mul, int_mul); |
264 | VisitBinary(DivNode, float_divmod, int_divmod); |
265 | VisitBinary(ModNode, float_divmod, int_divmod); |
266 | VisitBinary(FloorDivNode, float_divmod, int_divmod); |
267 | VisitBinary(FloorModNode, float_divmod, int_divmod); |
268 | VisitBinary(MaxNode, float_cmp, int_cmp); |
269 | VisitBinary(MinNode, float_cmp, int_cmp); |
270 | VisitBinary(EQNode, float_cmp, int_cmp); |
271 | VisitBinary(NENode, float_cmp, int_cmp); |
272 | VisitBinary(LTNode, float_cmp, int_cmp); |
273 | VisitBinary(LENode, float_cmp, int_cmp); |
274 | VisitBinary(GTNode, float_cmp, int_cmp); |
275 | VisitBinary(GENode, float_cmp, int_cmp); |
276 | |
277 | #undef VisitBinary |
278 | |
279 | void VisitExpr_(const AndNode* op) final { |
280 | bool_op++; |
281 | StmtExprVisitor::VisitExpr_(op); |
282 | } |
283 | void VisitExpr_(const OrNode* op) final { |
284 | bool_op++; |
285 | StmtExprVisitor::VisitExpr_(op); |
286 | } |
287 | void VisitExpr_(const NotNode* op) final { |
288 | bool_op++; |
289 | StmtExprVisitor::VisitExpr_(op); |
290 | } |
291 | void VisitExpr_(const SelectNode* op) final { |
292 | select_op++; |
293 | StmtExprVisitor::VisitExpr_(op); |
294 | } |
295 | |
296 | void VisitExpr_(const CallNode* op) final { |
297 | auto* pop = op->op.as<OpNode>(); |
298 | ICHECK(pop != nullptr); |
299 | auto effect_kind = op_call_effect_[GetRef<Op>(pop)]; |
300 | bool is_pure = |
301 | effect_kind == CallEffectKind::kPure || effect_kind == CallEffectKind::kExprAnnotation; |
302 | |
303 | if (is_pure) { |
304 | if (op->dtype.is_float() || op->dtype.is_bfloat16()) { |
305 | float_math_func++; |
306 | } else { |
307 | int_math_func++; |
308 | } |
309 | } else { |
310 | if (op->dtype.is_float() || op->dtype.is_bfloat16()) { |
311 | float_other_func++; |
312 | } else { |
313 | int_other_func++; |
314 | } |
315 | } |
316 | StmtExprVisitor::VisitExpr_(op); |
317 | } |
318 | |
319 | // todo(merrymercy): Detect MAD (Multiply–add) |
320 | size_t float_mad{0}; // The number of float MAD (Multiply–add) ops |
321 | size_t float_addsub{0}; // The number of float add and sub ops |
322 | size_t float_mul{0}; // The number of float multiply ops |
323 | size_t float_divmod{0}; // The number of float div and mod ops |
324 | size_t float_cmp{0}; // The number of float comparison ops |
325 | size_t float_math_func{0}; // The number of float math func calls |
326 | size_t float_other_func{0}; // The number of other float func calls |
327 | size_t int_mad{0}; // The number of integer MAD (Multiply–add) ops |
328 | size_t int_addsub{0}; // The number of integer add and sub ops |
329 | size_t int_mul{0}; // The number of float multiply ops |
330 | size_t int_divmod{0}; // The number of float div and mod ops |
331 | size_t int_cmp{0}; // The number of float comparison ops |
332 | size_t int_math_func{0}; // The number of float math func calls |
333 | size_t int_other_func{0}; // The number of other float func calls |
334 | size_t bool_op{0}; // The number of bool ops |
335 | size_t select_op{0}; // The number of select ops |
336 | |
337 | OpAttrMap<TCallEffectKind> op_call_effect_ = Op::GetAttrMap<TCallEffectKind>("TCallEffectKind" ); |
338 | }; |
339 | |
340 | // Extract all buffer accesses in an expr |
341 | class : public StmtExprVisitor { |
342 | public: |
343 | void (const PrimExpr& expr) { this->VisitExpr(expr); } |
344 | |
345 | void (const Var& buf, BufferAccessType acc_type, const Array<PrimExpr>& indices) { |
346 | BufferAccess& acc = buf_accesses[buf]; |
347 | acc.acc_type = acc_type; |
348 | acc.indices.push_back(std::vector<PrimExpr>(indices.begin(), indices.end())); |
349 | } |
350 | |
351 | void (const BufferLoadNode* op) final { |
352 | AddAccess(op->buffer->data, op->indices); |
353 | StmtExprVisitor::VisitExpr_(op); |
354 | } |
355 | |
356 | void (const Var& buffer, const Array<PrimExpr>& indices) { |
357 | BufferAccess& acc = buf_accesses[buffer]; |
358 | switch (acc.acc_type) { |
359 | case BufferAccessType::kRead: |
360 | break; |
361 | case BufferAccessType::kWrite: |
362 | acc.acc_type = BufferAccessType::kReadWrite; |
363 | break; |
364 | case BufferAccessType::kReadWrite: |
365 | break; |
366 | case BufferAccessType::kUnknownRW: |
367 | default: |
368 | acc.acc_type = BufferAccessType::kRead; |
369 | break; |
370 | } |
371 | |
372 | if (acc.acc_type != BufferAccessType::kReadWrite) { |
373 | // If a buffer is both read and written, in the tvm DSL, it must be a update, |
374 | // so the indices should be the same. Then we can skip appending indices for it. |
375 | // Otherwise we do the following. |
376 | buf_accesses[buffer].indices.push_back(std::vector<PrimExpr>(indices.begin(), indices.end())); |
377 | } |
378 | } |
379 | |
380 | BufferMap<BufferAccess> ; |
381 | }; |
382 | |
383 | // Compute the coefficient for an loop iterator in an expression |
384 | // Note: we use an approximation strategy to find coefficient. |
385 | // Hopefully, it is faster than DetectLinearEquation and can handle more cases (non-linear) |
386 | class : public StmtExprVisitor { |
387 | public: |
388 | void (const MulNode* node) final { |
389 | StmtExprVisitor::VisitExpr_(node); |
390 | if (visited_var) { |
391 | if (!visited_add) { |
392 | if (auto a = node->a.as<IntImmNode>()) { |
393 | visited_mul = true; |
394 | stride = a->value; |
395 | } else if (auto b = node->b.as<IntImmNode>()) { |
396 | visited_mul = true; |
397 | stride = b->value; |
398 | } |
399 | } |
400 | } |
401 | } |
402 | |
403 | void (const AddNode* node) final { |
404 | StmtExprVisitor::VisitExpr_(node); |
405 | if (visited_var) { |
406 | if (!visited_mul) { |
407 | visited_add = true; |
408 | stride = 1; |
409 | } |
410 | } |
411 | } |
412 | |
413 | void (const VarNode* node) final { |
414 | if (node == var_) { |
415 | visited_var = true; |
416 | // This is a magic default stride in case our approximation strategy fails |
417 | stride = 2; |
418 | } |
419 | } |
420 | |
421 | int (const PrimExpr& expr, const VarNode* var) { |
422 | visited_var = visited_mul = visited_add = false; |
423 | var_ = var; |
424 | |
425 | this->VisitExpr(expr); |
426 | |
427 | if (visited_var && !visited_mul && !visited_add) { |
428 | return 1; |
429 | } else { |
430 | return stride; |
431 | } |
432 | } |
433 | |
434 | bool {false}; |
435 | bool {false}; |
436 | bool {false}; |
437 | int {0}; |
438 | |
439 | private: |
440 | const VarNode* {nullptr}; |
441 | }; |
442 | |
443 | // Compute stride for the accesses to a buffer |
444 | int64_t ComputeStride(const std::vector<std::vector<PrimExpr>>& indices, |
445 | const std::vector<int>& shape, const VarNode* stride_var) { |
446 | // Use stride of 1 for 0-dimensional buffers. 0-dim buffers has a single |
447 | // index access, so we have to check here. |
448 | if (shape.size() == 0) { |
449 | return 1; |
450 | } |
451 | int64_t min_stride = std::numeric_limits<int64_t>::max(); |
452 | bool find = false; |
453 | CoefficientExtractor ; |
454 | |
455 | for (const auto& index : indices) { |
456 | int64_t shape_stride = 1; |
457 | for (int i = static_cast<int>(index.size()) - 1; i >= 0; i--) { |
458 | int coefficient = extractor.ExtractCoefficient(index[i], stride_var); |
459 | if (extractor.visited_var) { |
460 | find = true; |
461 | min_stride = std::min(min_stride, std::abs(coefficient) * shape_stride); |
462 | break; |
463 | } |
464 | shape_stride *= shape[i]; |
465 | } |
466 | } |
467 | |
468 | return find ? min_stride : 0; |
469 | } |
470 | |
471 | // Compute touched bytes and cache lines for accesses to a buffer |
472 | void ComputeRegion(const std::vector<std::vector<PrimExpr>>& indices, arith::Analyzer* ana, |
473 | std::vector<int>* region) { |
474 | region->clear(); |
475 | |
476 | if (indices.empty()) { |
477 | return; |
478 | } |
479 | |
480 | region->reserve(indices[0].size()); |
481 | |
482 | if (indices.size() == 1) { |
483 | for (const auto& index : indices[0]) { |
484 | ConstIntBound bound = ana->const_int_bound(index); |
485 | region->push_back(bound->max_value - bound->min_value + 1); |
486 | } |
487 | } else { |
488 | // future(lmzheng): implement a more accurate IntSet? |
489 | for (size_t i = 0; i < indices[0].size(); ++i) { |
490 | int64_t minimum = ConstIntBound::kPosInf, maximum = ConstIntBound::kNegInf; |
491 | for (size_t j = 0; j < indices.size(); ++j) { |
492 | ConstIntBound bound = ana->const_int_bound(indices[j][i]); |
493 | |
494 | minimum = std::min(minimum, bound->min_value); |
495 | maximum = std::max(maximum, bound->max_value); |
496 | } |
497 | region->push_back(maximum - minimum + 1); |
498 | } |
499 | } |
500 | } |
501 | |
502 | // Compute reuse distance and reuse ratio for accesses to a buffer |
503 | // return values: reuse_type, reuse_dis_iter, reuse_dis_bytes, reuse_ct |
504 | std::tuple<ReuseType, float, float, float> ComputeReuse( |
505 | const Var& buf, const std::vector<std::vector<PrimExpr>>& indices, |
506 | const std::vector<const ForNode*>& for_loop_stack, |
507 | const std::unordered_map<const ForNode*, |
508 | BufferMap<std::vector<std::tuple<BufferAccessType, int64_t, int>>>>& |
509 | for_touch_regions, |
510 | const Analyzer& ana) { |
511 | float reuse_dis_iter = 1.0f; |
512 | float reuse_dis_bytes = -1.0f; |
513 | |
514 | for (int i = static_cast<int>(for_loop_stack.size()) - 1; i >= 0; --i) { |
515 | const ForNode* cur_for = for_loop_stack[i]; |
516 | bool find = false; |
517 | |
518 | for (size_t j = 0; j < indices.size(); j++) { |
519 | for (size_t k = 0; k < indices[j].size(); k++) { |
520 | if (VarInExpr(cur_for->loop_var, indices[j][k])) { |
521 | find = true; |
522 | break; |
523 | } |
524 | } |
525 | if (find) { |
526 | break; |
527 | } |
528 | } |
529 | |
530 | int64_t extent = GetLoopExtent(for_loop_stack[i], ana); |
531 | if (find) { |
532 | // accumulate/update reuse distance |
533 | reuse_dis_iter *= extent; |
534 | reuse_dis_bytes = 0.0f; |
535 | for (const auto& iter : for_touch_regions.at(cur_for)) { |
536 | for (const auto& access : iter.second) { |
537 | reuse_dis_bytes += std::get<1>(access) * std::get<2>(access); |
538 | } |
539 | } |
540 | } else { |
541 | // Have LoopMultipleRead reuse |
542 | if (reuse_dis_bytes < 0) { |
543 | // For the reuse in the innermost axis, the above code won't be executed. |
544 | // So we compute bytes here |
545 | reuse_dis_bytes = 0.0f; |
546 | for (const auto& iter : for_touch_regions.at(cur_for)) { |
547 | for (const auto& access : iter.second) { |
548 | reuse_dis_bytes += 1 * std::get<2>(access); |
549 | } |
550 | } |
551 | } |
552 | return std::make_tuple(ReuseType::kLoopMultipleRead, reuse_dis_iter, reuse_dis_bytes, extent); |
553 | } |
554 | |
555 | const BufferMap<std::vector<std::tuple<BufferAccessType, int64_t, int>>>& buffer_map = |
556 | for_touch_regions.at(cur_for); |
557 | |
558 | int serial_reuse = static_cast<int>(buffer_map.at(buf).size()) - 1; |
559 | if (serial_reuse > 0) { |
560 | int64_t extent = GetLoopExtent(cur_for, ana); |
561 | |
562 | // Have SerialMultipleReadWrite reuse |
563 | reuse_dis_iter = std::numeric_limits<float>::max(); |
564 | for (const auto& acc_info : buffer_map.at(buf)) { |
565 | reuse_dis_iter = std::min(reuse_dis_iter, static_cast<float>(std::get<1>(acc_info))); |
566 | } |
567 | |
568 | reuse_dis_bytes = 0.0f; |
569 | for (const auto& iter : for_touch_regions.at(cur_for)) { |
570 | for (const auto& access : iter.second) { |
571 | reuse_dis_bytes += std::get<1>(access) * std::get<2>(access); |
572 | } |
573 | } |
574 | |
575 | return std::make_tuple(ReuseType::kSerialMultipleReadWrite, reuse_dis_iter / extent, |
576 | reuse_dis_bytes / extent, serial_reuse); |
577 | } |
578 | } |
579 | |
580 | return std::make_tuple(ReuseType::kNoReuse, 0, 0, 0); |
581 | } |
582 | |
583 | // Extract features for every BufferStore statement |
584 | // |
585 | // This visitor assumes that loop bounds do no depend on data or on parent loop |
586 | // bounds. For example, `for i in .. { for j in range(i, ..) }` would result in |
587 | // inaccurate features. This visitor also does not take conditionals into |
588 | // consideration when creating features. Each branch of the conditional is |
589 | // taken at the same time. |
590 | class : public StmtExprVisitor { |
591 | public: |
592 | explicit (int cache_line_size, const Map<Var, Buffer>& existing_buffers) |
593 | : cache_line_size_(cache_line_size) { |
594 | for (const auto& buffer : existing_buffers) { |
595 | buffer_shapes[buffer.first] = buffer.second->shape; |
596 | buffer_dtypes[buffer.first] = buffer.second->dtype; |
597 | // Also need to add a reference from the buffers internal variable. This |
598 | // is usually how buffers are referenced within the body of a PrimFunc |
599 | buffer_shapes[buffer.second->data] = buffer.second->shape; |
600 | buffer_dtypes[buffer.second->data] = buffer.second->dtype; |
601 | } |
602 | } |
603 | |
604 | void (const AttrStmtNode* node) final { |
605 | if (node->attr_key == tir::attr::thread_extent || node->attr_key == tir::attr::virtual_thread) { |
606 | const Var& var = node->node.as<IterVarNode>()->var; |
607 | int extent = GetIntImm(node->value); |
608 | |
609 | int* plen = nullptr; |
610 | |
611 | const std::string& name = var.get()->name_hint; |
612 | if (node->attr_key == tir::attr::thread_extent) { |
613 | if (name == "blockIdx.x" ) { |
614 | plen = &blockIdx_x_len_; |
615 | } else if (name == "blockIdx.y" ) { |
616 | plen = &block_idx_y_len_; |
617 | } else if (name == "blockIdx.z" ) { |
618 | plen = &block_idx_z_len_; |
619 | } else if (name == "threadIdx.x" ) { |
620 | plen = &threadIdx_x_len_; |
621 | } else if (name == "threadIdx.y" ) { |
622 | plen = &thread_idx_y_len_; |
623 | } else if (name == "threadIdx.z" ) { |
624 | plen = &thread_idx_z_len_; |
625 | } else { |
626 | LOG(FATAL) << "invalid thread itervar " + name; |
627 | } |
628 | } else { |
629 | plen = &vthread_len_; |
630 | } |
631 | |
632 | int extent_before = *plen; |
633 | if (node->attr_key == tir::attr::thread_extent) { |
634 | *plen = extent; |
635 | } else { |
636 | *plen *= extent; |
637 | } |
638 | |
639 | is_gpu_ = true; |
640 | |
641 | // make a fake for node for blockIdx.x or threadIdx.x |
642 | Stmt fake_for_node = For(var, 0, extent, ForKind::kParallel, node->body); |
643 | |
644 | outer_loop_prod_ *= extent; |
645 | for_loop_stack_.push_back(fake_for_node.as<ForNode>()); |
646 | variable_definition_stack_.push_back({}); |
647 | StmtExprVisitor::VisitStmt_(node); |
648 | variable_definition_stack_.pop_back(); |
649 | for_loop_stack_.pop_back(); |
650 | outer_loop_prod_ /= extent; |
651 | |
652 | *plen = extent_before; |
653 | } else if (node->attr_key == "pragma_auto_unroll_max_step" ) { |
654 | int value = GetIntImm(node->value); |
655 | |
656 | int16_t old_value = cur_auto_unroll_max_step_; |
657 | cur_auto_unroll_max_step_ = value; |
658 | StmtExprVisitor::VisitStmt_(node); |
659 | cur_auto_unroll_max_step_ = old_value; |
660 | } else { |
661 | StmtExprVisitor::VisitStmt_(node); |
662 | } |
663 | } |
664 | |
665 | void (const ForNode* node) final { |
666 | ana_.Bind(node->loop_var, Range::FromMinExtent(node->min, node->extent)); |
667 | int64_t loop_extent = GetLoopExtent(node, ana_); |
668 | |
669 | if (node->kind == ForKind::kVectorized) { |
670 | vec_for_stack_.push_back(node); |
671 | } else if (node->kind == ForKind::kUnrolled) { |
672 | unroll_for_stack_.push_back(node); |
673 | } else if (node->kind == ForKind::kParallel) { |
674 | parallel_for_stack_.push_back(node); |
675 | } |
676 | |
677 | outer_loop_prod_ *= loop_extent; |
678 | for_loop_stack_.push_back(node); |
679 | variable_definition_stack_.push_back({}); |
680 | StmtExprVisitor::VisitStmt_(node); |
681 | variable_definition_stack_.pop_back(); |
682 | for_loop_stack_.pop_back(); |
683 | outer_loop_prod_ /= loop_extent; |
684 | |
685 | if (node->kind == ForKind::kVectorized) { |
686 | vec_for_stack_.pop_back(); |
687 | } else if (node->kind == ForKind::kUnrolled) { |
688 | unroll_for_stack_.pop_back(); |
689 | } else if (node->kind == ForKind::kParallel) { |
690 | parallel_for_stack_.pop_back(); |
691 | } |
692 | } |
693 | |
694 | void (const BufferLoadNode* node) final { |
695 | // Store buffer shape/dtype. It may already be stored. |
696 | buffer_shapes[node->buffer->data] = node->buffer->shape; |
697 | buffer_dtypes[node->buffer->data] = node->buffer->dtype; |
698 | StmtExprVisitor::VisitExpr_(node); |
699 | } |
700 | |
701 | void (const BufferStoreNode* node) final { |
702 | // Store buffer shape/dtype. It may already be stored. |
703 | buffer_shapes[node->buffer->data] = node->buffer->shape; |
704 | buffer_dtypes[node->buffer->data] = node->buffer->dtype; |
705 | |
706 | MathOpCounter math_op_counter; |
707 | math_op_counter(node->value); |
708 | std::vector<float> mem_bytes_list; |
709 | std::vector<float> compute_ops_list; |
710 | double cur_compute_ops; |
711 | |
712 | // Group 1: Computation related features |
713 | ExtractComputationFeature(node->buffer->data, node->indices, math_op_counter); |
714 | |
715 | // Group 2: Buffer access related features (per buffer) |
716 | ExtractBufferAccessFeature(node->buffer->data, node->indices, node->value, math_op_counter, |
717 | &cur_compute_ops, &compute_ops_list, &mem_bytes_list); |
718 | |
719 | // Group 3: Arithmetic intensity related features |
720 | ExtractArithmeticIntensityFeature(node->buffer->data, cur_compute_ops, compute_ops_list, |
721 | mem_bytes_list); |
722 | |
723 | // Group 4: Allocation related features |
724 | ExtractOuterScopeFeature(node->buffer->data); |
725 | } |
726 | |
727 | void (const BufferRealizeNode* node) final { |
728 | // Store buffer shape/dtype. It may already be stored. |
729 | buffer_shapes[node->buffer->data] = node->buffer->shape; |
730 | buffer_dtypes[node->buffer->data] = node->buffer->dtype; |
731 | StmtExprVisitor::VisitStmt_(node); |
732 | |
733 | // Group 5: Outer scope related features |
734 | ExtractAllocationFeature(node); |
735 | } |
736 | |
737 | void (const AllocateNode* node) final { |
738 | buffer_dtypes[node->buffer_var] = node->dtype; |
739 | buffer_shapes[node->buffer_var] = node->extents; |
740 | StmtExprVisitor::VisitStmt_(node); |
741 | |
742 | // Group 5: Outer scope related features |
743 | ExtractAllocationFeature(node); |
744 | } |
745 | |
746 | void (const LetStmtNode* node) final { |
747 | // TODO(tkonolige): add arithmetic counts from this statement to counts of inner stores. |
748 | ana_.Bind(node->var, node->value); |
749 | ICHECK(variable_definition_stack_.size() > 0) |
750 | << "Variable definition outside of a for loop is not handled by feature extraction" ; |
751 | variable_definition_stack_.back().push_back(std::make_tuple(node->var, node->value)); |
752 | StmtExprVisitor::VisitStmt_(node); |
753 | } |
754 | |
755 | // Extract computation related features (group 1) |
756 | void (const Var& buffer, const Array<PrimExpr>& indices, |
757 | const MathOpCounter& math_op_counter) { |
758 | FeatureSet& fea = buffer_features[buffer]; |
759 | |
760 | // Computation related features |
761 | fea.float_mad += outer_loop_prod_ * math_op_counter.float_mad; |
762 | fea.float_addsub += outer_loop_prod_ * math_op_counter.float_addsub; |
763 | fea.float_mul += outer_loop_prod_ * math_op_counter.float_mul; |
764 | fea.float_divmod += outer_loop_prod_ * math_op_counter.float_divmod; |
765 | fea.float_cmp += outer_loop_prod_ * math_op_counter.float_cmp; |
766 | fea.float_math_func += outer_loop_prod_ * math_op_counter.float_math_func; |
767 | fea.float_other_func += outer_loop_prod_ * math_op_counter.float_other_func; |
768 | fea.int_mad += outer_loop_prod_ * math_op_counter.int_mad; |
769 | fea.int_addsub += outer_loop_prod_ * math_op_counter.int_addsub; |
770 | fea.int_mul += outer_loop_prod_ * math_op_counter.int_mul; |
771 | fea.int_divmod += outer_loop_prod_ * math_op_counter.int_divmod; |
772 | fea.int_math_func += outer_loop_prod_ * math_op_counter.int_math_func; |
773 | fea.int_cmp += outer_loop_prod_ * math_op_counter.int_cmp; |
774 | fea.int_other_func += outer_loop_prod_ * math_op_counter.int_other_func; |
775 | fea.bool_op += outer_loop_prod_ * math_op_counter.bool_op; |
776 | fea.select_op += outer_loop_prod_ * math_op_counter.select_op; |
777 | |
778 | fea.vec_len = fea.unroll_len = fea.parallel_len = 0.0f; |
779 | fea.vec_type = fea.unroll_type = fea.parallel_type = AnnotationPosType::kPosNone; |
780 | |
781 | fea.vec_num = vec_for_stack_.size(); |
782 | if (!vec_for_stack_.empty()) { |
783 | fea.vec_len = GetLoopExtent(vec_for_stack_.back(), ana_); |
784 | fea.vec_prod = 1.0; |
785 | for (const ForNode* pfor : vec_for_stack_) { |
786 | fea.vec_prod *= GetLoopExtent(pfor, ana_); |
787 | } |
788 | fea.vec_type = AnnotationPosType::kPosMixed; |
789 | // todo(merrymercy): this feature requires operation (tvm.compute) information |
790 | // GetAnnotationPosEncoding(vec_for_stack_.back()->loop_var, |
791 | // node->args, pcompute->axis, pcompute->reduce_axis); |
792 | } |
793 | |
794 | fea.unroll_num = unroll_for_stack_.size(); |
795 | if (!unroll_for_stack_.empty()) { |
796 | fea.unroll_len = GetLoopExtent(unroll_for_stack_.back(), ana_); |
797 | fea.unroll_prod = 1.0; |
798 | for (const ForNode* pfor : unroll_for_stack_) { |
799 | fea.unroll_prod *= GetLoopExtent(pfor, ana_); |
800 | } |
801 | fea.unroll_type = AnnotationPosType::kPosMixed; |
802 | // GetAnnotationPosEncoding(unroll_for_stack_.back()->loop_var, |
803 | // node->args, pcompute->axis, pcompute->reduce_axis); |
804 | } |
805 | |
806 | fea.parallel_num = parallel_for_stack_.size(); |
807 | if (!parallel_for_stack_.empty()) { |
808 | fea.parallel_len = GetLoopExtent(parallel_for_stack_.back(), ana_); |
809 | fea.parallel_prod = 1.0; |
810 | for (const ForNode* pfor : parallel_for_stack_) { |
811 | fea.parallel_prod *= GetLoopExtent(pfor, ana_); |
812 | } |
813 | fea.parallel_type = AnnotationPosType::kPosMixed; |
814 | // GetAnnotationPosEncoding(parallel_for_stack_.back()->loop_var, |
815 | // node->args, pcompute->axis, pcompute->reduce_axis); |
816 | } |
817 | |
818 | // GPU threads |
819 | fea.is_gpu = is_gpu_; |
820 | fea.blockIdx_x_len = blockIdx_x_len_; |
821 | fea.blockIdx_y_len = block_idx_y_len_; |
822 | fea.blockIdx_z_len = block_idx_z_len_; |
823 | fea.threadIdx_x_len = threadIdx_x_len_; |
824 | fea.threadIdx_y_len = thread_idx_y_len_; |
825 | fea.threadIdx_z_len = thread_idx_z_len_; |
826 | fea.vthread_len = vthread_len_; |
827 | } |
828 | |
829 | // Extract buffer access related features (group 2) |
830 | void (const Var& buffer, const Array<PrimExpr>& indices, |
831 | const PrimExpr& value, const MathOpCounter& math_op_counter, |
832 | double* cur_compute_ops, std::vector<float>* compute_ops_list, |
833 | std::vector<float>* mem_bytes_list) { |
834 | FeatureSet& fea = buffer_features[buffer]; |
835 | |
836 | // Extract all buffer accesses |
837 | std::vector<BufferAccessFeature> acc_feas; |
838 | BufferAccessExtractor ; |
839 | buf_extractor.InsertAccess(buffer, BufferAccessType::kWrite, indices); |
840 | buf_extractor.ExtractReads(value); |
841 | |
842 | mem_bytes_list->reserve(for_loop_stack_.size()); |
843 | compute_ops_list->reserve(for_loop_stack_.size()); |
844 | |
845 | *cur_compute_ops = math_op_counter.float_mad + math_op_counter.float_addsub + |
846 | math_op_counter.float_mul + math_op_counter.float_divmod + |
847 | math_op_counter.float_cmp + math_op_counter.float_math_func + |
848 | math_op_counter.float_other_func; |
849 | |
850 | ICHECK_EQ(for_loop_stack_.size(), variable_definition_stack_.size()) |
851 | << "variable_definition_stack_ should mirror for_loop_stack_ in size" ; |
852 | std::vector<int> tmp_region; |
853 | for (int i = static_cast<int>(for_loop_stack_.size()) - 1; i >= 0; i--) { |
854 | const ForNode* p_for = for_loop_stack_[i]; |
855 | |
856 | // Construct a local analyzer context which contains definitions (for and |
857 | // let) from innermost loops up to and including `i`. For loop variable |
858 | // definitions in loops more outer than `i` are set to 1 so that we can |
859 | // get per-loop-iteration features. Note that we add these definitions |
860 | // from outermost to innermost because inner definitions may depend on |
861 | // outer ones. |
862 | Analyzer local_analyzer; |
863 | for (int j = 0; j < i; j++) { |
864 | local_analyzer.Bind(for_loop_stack_.at(j)->loop_var, |
865 | Range::FromMinExtent(for_loop_stack_.at(j)->min, 1)); |
866 | } |
867 | for (int j = i; j < static_cast<int>(for_loop_stack_.size()); j++) { |
868 | local_analyzer.Bind( |
869 | for_loop_stack_.at(j)->loop_var, |
870 | Range::FromMinExtent(for_loop_stack_.at(j)->min, for_loop_stack_.at(j)->extent)); |
871 | for (auto definition : variable_definition_stack_.at(j)) { |
872 | local_analyzer.Bind(std::get<0>(definition), std::get<1>(definition)); |
873 | } |
874 | } |
875 | |
876 | // Note, here we do overwrite. |
877 | // So if there are multiple BufferStoreNode, the last one will overwrite the first few. |
878 | // e.g. The update part in gemm will overwrite the init part. |
879 | BufferMap<std::vector<std::tuple<BufferAccessType, int64_t, int>>>& buffer_regions_map = |
880 | for_touch_regions_[p_for]; |
881 | |
882 | int64_t mem_bytes = 0; |
883 | for (const auto& x : buf_extractor.buf_accesses) { |
884 | const Var& t = x.first; |
885 | const BufferAccess& acc = x.second; |
886 | |
887 | ComputeRegion(acc.indices, &local_analyzer, &tmp_region); |
888 | int64_t touched_size = ElementProduct(tmp_region); |
889 | touched_size = std::max<int64_t>(0, touched_size); |
890 | buffer_regions_map[t].push_back( |
891 | std::make_tuple(acc.acc_type, touched_size, buffer_dtypes.at(t).bytes())); |
892 | mem_bytes += touched_size * buffer_dtypes.at(t).bytes(); |
893 | } |
894 | |
895 | mem_bytes_list->push_back(mem_bytes); |
896 | *cur_compute_ops *= GetLoopExtent(for_loop_stack_[i], local_analyzer); |
897 | compute_ops_list->push_back(*cur_compute_ops); |
898 | } |
899 | |
900 | // Buffer access related features (per buffer) |
901 | for (const auto& x : buf_extractor.buf_accesses) { |
902 | const Var& t = x.first; |
903 | const BufferAccess& acc = x.second; |
904 | |
905 | std::vector<int> int_shape; |
906 | for (const auto& dim : buffer_shapes.at(t)) { |
907 | int_shape.push_back(GetIntImm(dim)); |
908 | } |
909 | |
910 | size_t ele_bytes = buffer_dtypes.at(t).bytes(); |
911 | |
912 | // calculate bytes |
913 | float bytes = outer_loop_prod_ * ele_bytes; |
914 | float unique_bytes; |
915 | |
916 | // calculate cache lines |
917 | int64_t stride; |
918 | float lines; |
919 | float unique_lines; |
920 | |
921 | if (for_loop_stack_.empty()) { |
922 | unique_bytes = ele_bytes; |
923 | stride = 0; |
924 | lines = 1.0f; |
925 | unique_lines = 1.0f; |
926 | } else { |
927 | unique_bytes = static_cast<float>( |
928 | std::get<1>(for_touch_regions_[for_loop_stack_.front()][t].front())) * |
929 | ele_bytes; |
930 | |
931 | stride = 0; |
932 | int64_t reduce_ratio = 1; |
933 | |
934 | int i; |
935 | for (i = static_cast<int>(for_loop_stack_.size()) - 1; i >= 0; i--) { |
936 | stride = ComputeStride(acc.indices, int_shape, for_loop_stack_[i]->loop_var.get()); |
937 | if (stride != 0) { |
938 | break; |
939 | } |
940 | reduce_ratio *= GetLoopExtent(for_loop_stack_.back(), ana_); |
941 | } |
942 | |
943 | lines = outer_loop_prod_ / reduce_ratio * |
944 | std::min(1.0f, 1.0f * stride * ele_bytes / cache_line_size_); |
945 | lines = std::max(lines, 1.0f); |
946 | |
947 | // convert `stride` back to the stride of the innermost iterator |
948 | stride = (i == static_cast<int>(for_loop_stack_.size()) - 1 ? stride : 0); |
949 | |
950 | float n_continuous = ele_bytes; |
951 | for (int i = std::min(static_cast<int>(tmp_region.size()) - 1, |
952 | static_cast<int>(int_shape.size()) - 1); |
953 | i >= 0; i--) { |
954 | if (tmp_region[i] == int_shape[i]) { |
955 | n_continuous *= tmp_region[i]; |
956 | break; |
957 | } |
958 | } |
959 | unique_lines = unique_bytes / std::min(n_continuous, static_cast<float>(cache_line_size_)); |
960 | unique_lines = std::max(unique_lines, 1.0f); |
961 | } |
962 | |
963 | auto [reuse_type, reuse_dis_iter, reuse_dis_bytes, reuse_ct] = |
964 | ComputeReuse(t, acc.indices, for_loop_stack_, for_touch_regions_, ana_); |
965 | |
966 | acc_feas.emplace_back(); |
967 | BufferAccessFeature& acc_fea = acc_feas.back(); |
968 | |
969 | // TODO(tkonolige): save buffer names and use those instead? |
970 | acc_fea.buffer_name = t->name_hint; |
971 | acc_fea.acc_type = acc.acc_type; |
972 | acc_fea.stride = stride; |
973 | acc_fea.bytes = bytes; |
974 | acc_fea.unique_bytes = unique_bytes; |
975 | acc_fea.lines = lines; |
976 | acc_fea.unique_lines = unique_lines; |
977 | acc_fea.reuse_type = reuse_type; |
978 | acc_fea.reuse_dis_iter = reuse_dis_iter; |
979 | acc_fea.reuse_dis_bytes = reuse_dis_bytes; |
980 | acc_fea.reuse_ct = reuse_ct; |
981 | if (acc_fea.reuse_ct > 0.5) { |
982 | acc_fea.bytes_d_reuse_ct = bytes / reuse_ct; |
983 | acc_fea.unique_bytes_d_reuse_ct = unique_bytes / reuse_ct; |
984 | acc_fea.lines_d_reuse_ct = lines / reuse_ct; |
985 | acc_fea.unique_lines_d_reuse_ct = unique_lines / reuse_ct; |
986 | } else { |
987 | // no reuse, multiply by a magic number '2' |
988 | acc_fea.bytes_d_reuse_ct = bytes * 2; |
989 | acc_fea.unique_bytes_d_reuse_ct = unique_bytes * 2; |
990 | acc_fea.lines_d_reuse_ct = lines * 2; |
991 | acc_fea.unique_lines_d_reuse_ct = unique_lines * 2; |
992 | } |
993 | } |
994 | |
995 | fea.access_feas = acc_feas; |
996 | } |
997 | |
998 | // Extract arithmetic intensity related feature (group 3) |
999 | void (const Var& buffer, double cur_compute_ops, |
1000 | const std::vector<float>& compute_ops_list, |
1001 | const std::vector<float>& mem_bytes_list) { |
1002 | FeatureSet& fea = buffer_features[buffer]; |
1003 | |
1004 | // Compute arithmetic intensity curve (y axis : arithmetic intensity, x axis : flops). |
1005 | // We use piecewise linear interpolation to fit this curve. |
1006 | int pt = 0; |
1007 | if (cur_compute_ops <= 0 || compute_ops_list.empty()) { |
1008 | std::fill(fea.arith_intensity_curve, |
1009 | fea.arith_intensity_curve + ARITH_INTENSITY_CURVE_SAMPLE_N, 0.0); |
1010 | } else { |
1011 | for (size_t i = 0; i < ARITH_INTENSITY_CURVE_SAMPLE_N; ++i) { |
1012 | float cur_compute_ops = compute_ops_list.back() * (i + 1) / ARITH_INTENSITY_CURVE_SAMPLE_N; |
1013 | while (compute_ops_list[pt] < cur_compute_ops - 1e-4) { |
1014 | pt++; |
1015 | } |
1016 | ICHECK_LT(pt, compute_ops_list.size()); |
1017 | |
1018 | float value; |
1019 | if (pt == 0) { |
1020 | value = compute_ops_list[pt] / mem_bytes_list[pt]; |
1021 | } else { |
1022 | float base = compute_ops_list[pt - 1] / mem_bytes_list[pt - 1]; |
1023 | float slope = (compute_ops_list[pt] / mem_bytes_list[pt] - |
1024 | compute_ops_list[pt - 1] / mem_bytes_list[pt - 1]) / |
1025 | (compute_ops_list[pt] - compute_ops_list[pt - 1]); |
1026 | value = base + slope * (cur_compute_ops - compute_ops_list[pt - 1]); |
1027 | } |
1028 | fea.arith_intensity_curve[i] = value; |
1029 | } |
1030 | } |
1031 | } |
1032 | |
1033 | // Extract allocation related features (group 4) |
1034 | void (const BufferRealizeNode* node) { |
1035 | FeatureSet& fea = buffer_features[node->buffer->data]; |
1036 | |
1037 | float allocation_size = 1.0f; |
1038 | for (const auto& x : node->bounds) { |
1039 | allocation_size *= GetIntImm(x->extent); |
1040 | } |
1041 | // allocation feature |
1042 | fea.alloc_size = allocation_size * node->buffer->dtype.bytes(); |
1043 | fea.alloc_prod = allocation_size * outer_loop_prod_; |
1044 | fea.alloc_outer_prod = outer_loop_prod_; |
1045 | fea.alloc_inner_prod = fea.outer_prod / outer_loop_prod_; |
1046 | } |
1047 | |
1048 | void (const AllocateNode* node) { |
1049 | FeatureSet& fea = buffer_features[node->buffer_var]; |
1050 | |
1051 | float allocation_size = 1.0f; |
1052 | for (const auto& x : node->extents) { |
1053 | // TODO(tkonolige): will not handle dynamic shape |
1054 | allocation_size *= GetIntImm(x); |
1055 | } |
1056 | // allocation feature |
1057 | fea.alloc_size = allocation_size * node->dtype.bytes(); |
1058 | fea.alloc_prod = allocation_size * outer_loop_prod_; |
1059 | fea.alloc_outer_prod = outer_loop_prod_; |
1060 | fea.alloc_inner_prod = fea.outer_prod / outer_loop_prod_; |
1061 | } |
1062 | |
1063 | // Extract outer scope related features (group 5) |
1064 | void (const Var& buffer) { |
1065 | FeatureSet& fea = buffer_features[buffer]; |
1066 | |
1067 | fea.outer_prod = outer_loop_prod_; |
1068 | fea.num_loops = for_loop_stack_.size(); |
1069 | fea.auto_unroll_max_step = cur_auto_unroll_max_step_; |
1070 | } |
1071 | |
1072 | // Stores FeatureSet for every buffer |
1073 | BufferMap<FeatureSet> ; |
1074 | |
1075 | private: |
1076 | // The shared arithmetic analyzer |
1077 | Analyzer ; |
1078 | |
1079 | // The product of outer loop |
1080 | float = 1.0f; |
1081 | |
1082 | // The stacks to store parent loops during DFS |
1083 | std::vector<const ForNode*> ; |
1084 | std::vector<const ForNode*> ; |
1085 | std::vector<const ForNode*> ; |
1086 | std::vector<const ForNode*> ; |
1087 | std::vector<std::vector<std::tuple<Var, PrimExpr>>> ; |
1088 | |
1089 | // GPU-related features |
1090 | bool {false}; |
1091 | int {1}; |
1092 | int {1}; |
1093 | int {1}; |
1094 | int {1}; |
1095 | int {1}; |
1096 | int {1}; |
1097 | int {1}; |
1098 | int16_t {0}; |
1099 | |
1100 | // Store touch region information for all for loops. The format of this nested map: |
1101 | // For a loop, for all its touched buffers, for all different accesses to the buffers, |
1102 | // its (access type, number of touched elements, number of bytes of single element) |
1103 | std::unordered_map<const ForNode*, |
1104 | BufferMap<std::vector<std::tuple<BufferAccessType, int64_t, int>>>> |
1105 | ; |
1106 | |
1107 | // The default cache line size in bytes |
1108 | const int = 64; |
1109 | |
1110 | // Storage of buffer shape and dtype information. Needed because Load/Store |
1111 | // nodes only do not contain this information. |
1112 | BufferMap<Array<PrimExpr>> ; |
1113 | BufferMap<DataType> ; |
1114 | }; |
1115 | |
1116 | // shifted log to incorporate the property that log2p(0) = 0 |
1117 | inline float log2p(float x) { return x < 0 ? -std::log2(-x + 1) : std::log2(x + 1); } |
1118 | |
1119 | void GetPerStoreFeature(const PrimFunc& func, int cache_line_size, int max_n_bufs, |
1120 | std::vector<float>* ret, bool log_scale) { |
1121 | PerStoreFeatureExtractor (cache_line_size, func->buffer_map); |
1122 | extractor(func->body); |
1123 | |
1124 | auto slog = log_scale ? log2p : [](float x) { return x; }; |
1125 | |
1126 | ret->push_back(extractor.buffer_features.size()); |
1127 | |
1128 | for (const auto& x : extractor.buffer_features) { |
1129 | const FeatureSet& fea_set = x.second; |
1130 | |
1131 | /***** Group 1: Computation related features *****/ |
1132 | ret->push_back(slog(fea_set.float_mad)); |
1133 | ret->push_back(slog(fea_set.float_addsub)); |
1134 | ret->push_back(slog(fea_set.float_mul)); |
1135 | ret->push_back(slog(fea_set.float_divmod)); |
1136 | ret->push_back(slog(fea_set.float_cmp)); |
1137 | ret->push_back(slog(fea_set.float_math_func)); |
1138 | ret->push_back(slog(fea_set.float_other_func)); |
1139 | ret->push_back(slog(fea_set.int_mad)); |
1140 | ret->push_back(slog(fea_set.int_addsub)); |
1141 | ret->push_back(slog(fea_set.int_mul)); |
1142 | ret->push_back(slog(fea_set.int_divmod)); |
1143 | ret->push_back(slog(fea_set.int_cmp)); |
1144 | ret->push_back(slog(fea_set.int_math_func)); |
1145 | ret->push_back(slog(fea_set.int_other_func)); |
1146 | ret->push_back(slog(fea_set.bool_op)); |
1147 | ret->push_back(slog(fea_set.select_op)); |
1148 | |
1149 | ret->push_back(slog(fea_set.vec_num)); |
1150 | ret->push_back(slog(fea_set.vec_prod)); |
1151 | ret->push_back(slog(fea_set.vec_len)); |
1152 | for (int i = 0; i <= static_cast<int>(AnnotationPosType::kPosMixed); i++) { |
1153 | ret->push_back(i == static_cast<int>(fea_set.vec_type)); |
1154 | } |
1155 | |
1156 | ret->push_back(slog(fea_set.unroll_num)); |
1157 | ret->push_back(slog(fea_set.unroll_prod)); |
1158 | ret->push_back(slog(fea_set.unroll_len)); |
1159 | for (int i = 0; i <= static_cast<int>(AnnotationPosType::kPosMixed); i++) { |
1160 | ret->push_back(i == static_cast<int>(fea_set.unroll_type)); |
1161 | } |
1162 | |
1163 | ret->push_back(slog(fea_set.parallel_num)); |
1164 | ret->push_back(slog(fea_set.parallel_prod)); |
1165 | ret->push_back(slog(fea_set.parallel_len)); |
1166 | for (int i = 0; i <= static_cast<int>(AnnotationPosType::kPosMixed); i++) { |
1167 | ret->push_back(i == static_cast<int>(fea_set.parallel_type)); |
1168 | } |
1169 | |
1170 | ret->push_back(fea_set.is_gpu); |
1171 | ret->push_back(slog(fea_set.blockIdx_x_len)); |
1172 | ret->push_back(slog(fea_set.blockIdx_y_len)); |
1173 | ret->push_back(slog(fea_set.blockIdx_z_len)); |
1174 | ret->push_back(slog(fea_set.threadIdx_x_len)); |
1175 | ret->push_back(slog(fea_set.threadIdx_y_len)); |
1176 | ret->push_back(slog(fea_set.threadIdx_z_len)); |
1177 | ret->push_back(slog(fea_set.vthread_len)); |
1178 | |
1179 | /***** Group 2: Buffer access related features *****/ |
1180 | // sort according to pair (lines, bytes) |
1181 | std::vector<std::pair<float, float>> buf_order_key; |
1182 | for (const auto& acc_fea : fea_set.access_feas) { |
1183 | buf_order_key.emplace_back(acc_fea.lines, acc_fea.bytes); |
1184 | } |
1185 | std::vector<int> buf_order(buf_order_key.size()); |
1186 | std::iota(buf_order.begin(), buf_order.end(), 0); |
1187 | |
1188 | auto cmp = [&buf_order_key](int l, int r) { |
1189 | return buf_order_key[l].first > buf_order_key[r].first || |
1190 | (buf_order_key[l].first == buf_order_key[r].first && |
1191 | buf_order_key[l].second > buf_order_key[r].second); |
1192 | }; |
1193 | std::sort(buf_order.begin(), buf_order.end(), cmp); |
1194 | int n_bufs = std::min(max_n_bufs, static_cast<int>(buf_order.size())); |
1195 | buf_order.resize(n_bufs); |
1196 | |
1197 | for (int idx : buf_order) { |
1198 | const auto& acc_fea = fea_set.access_feas[idx]; |
1199 | for (int j = 0; j <= static_cast<int>(BufferAccessType::kReadWrite); ++j) { |
1200 | ret->push_back(j == static_cast<int>(acc_fea.acc_type)); |
1201 | } |
1202 | ret->push_back(slog(acc_fea.bytes)); |
1203 | ret->push_back(slog(acc_fea.unique_bytes)); |
1204 | ret->push_back(slog(acc_fea.lines)); |
1205 | ret->push_back(slog(acc_fea.unique_lines)); |
1206 | for (int j = 0; j <= static_cast<int>(ReuseType::kNoReuse); ++j) { |
1207 | ret->push_back(j == static_cast<int>(acc_fea.reuse_type)); |
1208 | } |
1209 | ret->push_back(slog(acc_fea.reuse_dis_iter)); |
1210 | ret->push_back(slog(acc_fea.reuse_dis_bytes)); |
1211 | ret->push_back(slog(acc_fea.reuse_ct)); |
1212 | ret->push_back(slog(acc_fea.bytes_d_reuse_ct)); |
1213 | ret->push_back(slog(acc_fea.unique_bytes_d_reuse_ct)); |
1214 | ret->push_back(slog(acc_fea.lines_d_reuse_ct)); |
1215 | ret->push_back(slog(acc_fea.unique_lines_d_reuse_ct)); |
1216 | ret->push_back(slog(acc_fea.stride)); |
1217 | } |
1218 | // - fill padding |
1219 | for (int i = 0; i < max_n_bufs - n_bufs; ++i) { |
1220 | for (int j = 0; j <= static_cast<int>(BufferAccessType::kReadWrite); ++j) { // 3 |
1221 | ret->push_back(0.0f); |
1222 | } |
1223 | ret->push_back(0.0f); |
1224 | ret->push_back(0.0f); |
1225 | ret->push_back(0.0f); |
1226 | ret->push_back(0.0f); |
1227 | for (int j = 0; j <= static_cast<int>(ReuseType::kNoReuse); ++j) { // 3 |
1228 | ret->push_back(0.0f); |
1229 | } |
1230 | ret->push_back(0.0f); |
1231 | ret->push_back(0.0f); |
1232 | ret->push_back(0.0f); |
1233 | ret->push_back(0.0f); |
1234 | ret->push_back(0.0f); |
1235 | ret->push_back(0.0f); |
1236 | ret->push_back(0.0f); |
1237 | ret->push_back(0.0f); |
1238 | } |
1239 | |
1240 | /***** Group 3: Arithmetic intensity related features *****/ |
1241 | for (size_t i = 0; i < ARITH_INTENSITY_CURVE_SAMPLE_N; ++i) { |
1242 | ret->push_back(slog(fea_set.arith_intensity_curve[i])); |
1243 | } |
1244 | |
1245 | /***** Group 4: Allocation related features *****/ |
1246 | ret->push_back(slog(fea_set.alloc_size)); |
1247 | ret->push_back(slog(fea_set.alloc_prod)); |
1248 | ret->push_back(slog(fea_set.alloc_outer_prod)); |
1249 | ret->push_back(slog(fea_set.alloc_inner_prod)); |
1250 | |
1251 | /***** Group 5: Outer scope related features *****/ |
1252 | ret->push_back(slog(fea_set.outer_prod)); |
1253 | ret->push_back(slog(fea_set.num_loops)); |
1254 | ret->push_back(slog(fea_set.auto_unroll_max_step)); |
1255 | } |
1256 | } |
1257 | |
1258 | void GetPerStoreFeatureName(int max_n_bufs, std::vector<std::string>* ret) { |
1259 | /***** Group 1: Computation related features *****/ |
1260 | ret->push_back(("float_mad" )); |
1261 | ret->push_back(("float_addsub" )); |
1262 | ret->push_back(("float_mul" )); |
1263 | ret->push_back(("float_divmod" )); |
1264 | ret->push_back(("float_cmp" )); |
1265 | ret->push_back(("float_mathfunc" )); |
1266 | ret->push_back(("float_otherfunc" )); |
1267 | ret->push_back(("int_mad" )); |
1268 | ret->push_back(("int_addsub" )); |
1269 | ret->push_back(("int_mul" )); |
1270 | ret->push_back(("int_divmod" )); |
1271 | ret->push_back(("int_cmp" )); |
1272 | ret->push_back(("int_mathfunc" )); |
1273 | ret->push_back(("int_otherfunc" )); |
1274 | ret->push_back(("bool_op" )); |
1275 | ret->push_back(("select_op" )); |
1276 | ret->push_back(("vec_num" )); |
1277 | ret->push_back(("vec_prod" )); |
1278 | ret->push_back(("vec_len" )); |
1279 | ret->push_back(("vec_type.kPosNone" )); |
1280 | ret->push_back(("vec_type.kPosInnerSpatial" )); |
1281 | ret->push_back(("vec_type.kPosMiddleSpatial" )); |
1282 | ret->push_back(("vec_type.kPosOuterSpatial" )); |
1283 | ret->push_back(("vec_type.kPosInnerReduce" )); |
1284 | ret->push_back(("vec_type.kPosMiddleReduce" )); |
1285 | ret->push_back(("vec_type.kPosOuterReduce" )); |
1286 | ret->push_back(("vec_type.kPosMixed" )); |
1287 | ret->push_back(("unroll_num" )); |
1288 | ret->push_back(("unroll_prod" )); |
1289 | ret->push_back(("unroll_len" )); |
1290 | ret->push_back(("unroll_type.kPosNone" )); |
1291 | ret->push_back(("unroll_type.kPosInnerSpatial" )); |
1292 | ret->push_back(("unroll_type.kPosMiddleSpatial" )); |
1293 | ret->push_back(("unroll_type.kPosOuterSpatial" )); |
1294 | ret->push_back(("unroll_type.kPosInnerReduce" )); |
1295 | ret->push_back(("unroll_type.kPosMiddleReduce" )); |
1296 | ret->push_back(("unroll_type.kPosOuterReduce" )); |
1297 | ret->push_back(("unroll_type.kPosMixed" )); |
1298 | ret->push_back(("parallel_num" )); |
1299 | ret->push_back(("parallel_prod" )); |
1300 | ret->push_back(("parallel_len" )); |
1301 | ret->push_back(("parallel_type.kPosNone" )); |
1302 | ret->push_back(("parallel_type.kPosInnerSpatial" )); |
1303 | ret->push_back(("parallel_type.kPosMiddleSpatial" )); |
1304 | ret->push_back(("parallel_type.kPosOuterSpatial" )); |
1305 | ret->push_back(("parallel_type.kPosInnerReduce" )); |
1306 | ret->push_back(("parallel_type.kPosMiddleReduce" )); |
1307 | ret->push_back(("parallel_type.kPosOuterReduce" )); |
1308 | ret->push_back(("parallel_type.kPosMixed" )); |
1309 | ret->push_back(("is_gpu" )); |
1310 | ret->push_back(("blockIdx_x_len" )); |
1311 | ret->push_back(("blockIdx_y_len" )); |
1312 | ret->push_back(("blockIdx_z_len" )); |
1313 | ret->push_back(("threadIdx_x_len" )); |
1314 | ret->push_back(("threadIdx_y_len" )); |
1315 | ret->push_back(("threadIdx_z_len" )); |
1316 | ret->push_back(("vthread_len" )); |
1317 | // section total: 57 |
1318 | |
1319 | /***** Group 2: Buffer access related features *****/ |
1320 | for (size_t i = 0; i < static_cast<size_t>(max_n_bufs); ++i) { |
1321 | std::string prefix = "B" + std::to_string(i) + "." ; |
1322 | ret->push_back((prefix + "acc_type.kRead" )); |
1323 | ret->push_back((prefix + "acc_type.kWrite" )); |
1324 | ret->push_back((prefix + "acc_type.kReadWrite" )); |
1325 | ret->push_back((prefix + "bytes" )); |
1326 | ret->push_back((prefix + "unique_bytes" )); |
1327 | ret->push_back((prefix + "lines" )); |
1328 | ret->push_back((prefix + "unique_lines" )); |
1329 | ret->push_back((prefix + "reuse_type.kLoopMultipleRead" )); |
1330 | ret->push_back((prefix + "reuse_type.kSerialMultipleReadWrite" )); |
1331 | ret->push_back((prefix + "reuse_type.kNoReuse" )); |
1332 | ret->push_back((prefix + "reuse_dis_iter" )); |
1333 | ret->push_back((prefix + "reuse_dis_bytes" )); |
1334 | ret->push_back((prefix + "reuse_ct" )); |
1335 | ret->push_back((prefix + "bytes_d_reuse_ct" )); |
1336 | ret->push_back((prefix + "unique_bytes_d_reuse_ct" )); |
1337 | ret->push_back((prefix + "lines_d_reuse_ct" )); |
1338 | ret->push_back((prefix + "unique_lines_d_reuse_ct" )); |
1339 | ret->push_back((prefix + "stride" )); |
1340 | } |
1341 | // section total : max_n_bufs * 18 |
1342 | |
1343 | /***** Group 3: Arithmetic intensity related features *****/ |
1344 | for (size_t i = 0; i < ARITH_INTENSITY_CURVE_SAMPLE_N; ++i) { |
1345 | ret->push_back(("arith_intensity_curve_" + std::to_string(i))); |
1346 | } |
1347 | // section total: ARITH_INTENSITY_CURVE_SAMPLE_N = 10 |
1348 | |
1349 | /***** Group 4: Allocation related features *****/ |
1350 | ret->push_back(("alloc_size" )); |
1351 | ret->push_back(("alloc_prod" )); |
1352 | ret->push_back(("alloc_outer_prod" )); |
1353 | ret->push_back(("alloc_inner_prod" )); |
1354 | // section total : 4 |
1355 | |
1356 | /***** Group 5: Outer scope related features *****/ |
1357 | ret->push_back(("outer_prod" )); |
1358 | ret->push_back(("num_loops" )); |
1359 | ret->push_back(("auto_unroll_max_step" )); |
1360 | // section total : 3 |
1361 | } |
1362 | |
1363 | void GetPerStoreFeaturesWorkerFunc(const SearchTask& task, const State& state, int max_n_bufs, |
1364 | std::vector<float>* feature, std::atomic<int>* error_ct) { |
1365 | auto [sch, tensors] = task->compute_dag.ApplySteps(state->transform_steps); |
1366 | |
1367 | // When inlining, replace const matrices with const values. |
1368 | // Produces wrong IR, but good enough for feature extraction, and |
1369 | // can improve the speed of feature extraction/search. Must be |
1370 | // called before ScheduleToModule to have an effect. |
1371 | sch = sch.normalize_for_feature_extraction(); |
1372 | |
1373 | try { |
1374 | const std::string& name = "main" ; |
1375 | auto pass_ctx = tvm::transform::PassContext::Current(); |
1376 | |
1377 | auto mod = ScheduleToModule(sch, Array<ObjectRef>{tensors.begin(), tensors.end()}, name, |
1378 | std::unordered_map<te::Tensor, te::Buffer>(), |
1379 | GlobalVarSupply(NameSupply("" ))); |
1380 | |
1381 | bool disable_vectorize = |
1382 | pass_ctx->GetConfig<Bool>("tir.disable_vectorize" , Bool(false)).value(); |
1383 | bool instrument_bound_checkers = |
1384 | pass_ctx->GetConfig<Bool>("tir.instrument_bound_checkers" , Bool(false)).value(); |
1385 | |
1386 | if (IsGPUTask(task)) { |
1387 | auto pass_list = Array<tvm::transform::Pass>(); |
1388 | // Phase 0 |
1389 | pass_list.push_back(tir::transform::InjectPrefetch()); |
1390 | pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers)); |
1391 | // Phase 1 |
1392 | pass_list.push_back(tir::transform::NarrowDataType(32)); |
1393 | pass_list.push_back(tir::transform::Simplify()); |
1394 | pass_list.push_back(tir::transform::VectorizeLoop(!disable_vectorize)); |
1395 | pass_list.push_back(tir::transform::InjectVirtualThread()); |
1396 | pass_list.push_back(tir::transform::StorageRewrite()); |
1397 | pass_list.push_back(tir::transform::Simplify()); |
1398 | tvm::Map<String, tvm::PrimExpr> gpu_params{ |
1399 | {"max_shared_memory_per_block" , task->hardware_params->max_shared_memory_per_block}, |
1400 | {"max_local_memory_per_block" , task->hardware_params->max_local_memory_per_block}, |
1401 | {"max_threads_per_block" , task->hardware_params->max_threads_per_block}, |
1402 | {"max_vector_bytes" , task->hardware_params->vector_unit_bytes}, |
1403 | {"max_vthread" , task->hardware_params->max_vthread_extent}, |
1404 | }; |
1405 | pass_list.push_back(tir::transform::VerifyGPUCode(gpu_params)); |
1406 | const auto& optimize = tir::transform::Sequential(pass_list); |
1407 | optimize(mod); |
1408 | } |
1409 | if (IsHexagonTask(task)) { |
1410 | Target target = task->target; |
1411 | const auto vtcm_capacity = target->GetAttr<Integer>("vtcm-capacity" ).value().IntValue(); |
1412 | const auto& optimize = |
1413 | tir::transform::Sequential({tir::transform::VerifyVTCMLimit(vtcm_capacity)}); |
1414 | optimize(mod); |
1415 | } |
1416 | const auto& optimize = |
1417 | tir::transform::Sequential(Array<tvm::transform::Pass>{tir::transform::Simplify()}); |
1418 | mod = optimize(std::move(mod)); |
1419 | PrimFunc prim_func = Downcast<PrimFunc>(mod->Lookup(name)); |
1420 | GetPerStoreFeature(prim_func, task->hardware_params->cache_line_bytes, max_n_bufs, feature); |
1421 | } catch (Error& e) { |
1422 | (*error_ct)++; |
1423 | } |
1424 | } |
1425 | |
1426 | void GetPerStoreFeaturesFromStates(const Array<State>& states, const SearchTask& task, |
1427 | int , int max_n_bufs, |
1428 | std::vector<std::vector<float>>* features) { |
1429 | // extract features |
1430 | features->assign(states.size(), std::vector<float>()); |
1431 | |
1432 | std::atomic<int> error_ct(0); |
1433 | |
1434 | support::parallel_for(skip_first_n_feature_extraction, states.size(), |
1435 | [&task, &states, &max_n_bufs, &features, &error_ct](int i) { |
1436 | GetPerStoreFeaturesWorkerFunc(task, states[i], max_n_bufs, |
1437 | &(*features)[i], &error_ct); |
1438 | }); |
1439 | } |
1440 | |
1441 | void GetPerStoreFeaturesFromStates(const Array<State>& states, const std::vector<SearchTask>& tasks, |
1442 | int , int max_n_bufs, |
1443 | std::vector<std::vector<float>>* features) { |
1444 | // extract features |
1445 | features->assign(states.size(), std::vector<float>()); |
1446 | |
1447 | std::atomic<int> error_ct(0); |
1448 | |
1449 | support::parallel_for(skip_first_n_feature_extraction, states.size(), |
1450 | [&tasks, &states, &max_n_bufs, &features, &error_ct](int i) { |
1451 | GetPerStoreFeaturesWorkerFunc(tasks[i], states[i], max_n_bufs, |
1452 | &(*features)[i], &error_ct); |
1453 | }); |
1454 | } |
1455 | |
1456 | void GetPerStoreFeaturesFromFile(const std::string& filename, int max_lines, int max_n_bufs, |
1457 | std::vector<std::vector<float>>* features, |
1458 | std::vector<float>* normalized_throughputs, |
1459 | std::vector<int>* task_ids) { |
1460 | Array<State> states; |
1461 | std::vector<SearchTask> tasks; |
1462 | |
1463 | normalized_throughputs->clear(); |
1464 | task_ids->clear(); |
1465 | |
1466 | // (workload_key, target) -> (search_task, task_id) |
1467 | std::unordered_map<std::pair<std::string, std::string>, std::pair<SearchTask, size_t>> task_cache; |
1468 | // task_id -> min_cost |
1469 | std::vector<float> min_costs; |
1470 | |
1471 | const auto* workload_key_to_tensors = |
1472 | tvm::runtime::Registry::Get("auto_scheduler.workload_key_to_tensors" ); |
1473 | ICHECK(workload_key_to_tensors != nullptr); |
1474 | |
1475 | // read from file |
1476 | RecordReader reader(filename); |
1477 | auto cur_inp = make_object<MeasureInputNode>(); |
1478 | auto cur_res = make_object<MeasureResultNode>(); |
1479 | while (reader->ReadNext(cur_inp.get(), cur_res.get())) { |
1480 | float cost = static_cast<float>(FloatArrayMean(cur_res->costs)); |
1481 | const std::string& workload_key = cur_inp->task->workload_key; |
1482 | |
1483 | SearchTask task; |
1484 | size_t task_id; |
1485 | std::pair<std::string, std::string> key(workload_key, cur_inp->task->target->str()); |
1486 | auto find_res = task_cache.find(key); |
1487 | if (find_res == task_cache.end()) { |
1488 | // rebuild task |
1489 | Array<te::Tensor> tensors = (*workload_key_to_tensors)(workload_key); |
1490 | Target target = cur_inp->task->target; |
1491 | Target target_host = cur_inp->task->target_host; |
1492 | CheckAndUpdateHostConsistency(&target, &target_host); |
1493 | task = SearchTask(ComputeDAG(tensors), workload_key, target, target_host, |
1494 | cur_inp->task->hardware_params, cur_inp->task->layout_rewrite_option, |
1495 | cur_inp->task->task_input_names); |
1496 | task_id = task_cache.size(); |
1497 | |
1498 | // compute min cost for each task |
1499 | task_cache.insert(std::make_pair(key, std::make_pair(task, task_id))); |
1500 | min_costs.push_back(cost); |
1501 | } else { |
1502 | std::tie(task, task_id) = find_res->second; |
1503 | min_costs[task_id] = std::min(min_costs[task_id], cost); |
1504 | } |
1505 | |
1506 | tasks.push_back(std::move(task)); |
1507 | task_ids->push_back(task_id); |
1508 | states.push_back(cur_inp->state); |
1509 | normalized_throughputs->push_back(cost); |
1510 | |
1511 | if (max_lines > 0 && static_cast<int>(states.size()) >= max_lines) { |
1512 | break; |
1513 | } |
1514 | } |
1515 | |
1516 | for (size_t i = 0; i < normalized_throughputs->size(); ++i) { |
1517 | (*normalized_throughputs)[i] = min_costs[(*task_ids)[i]] / (*normalized_throughputs)[i]; |
1518 | } |
1519 | |
1520 | GetPerStoreFeaturesFromStates(states, tasks, 0, max_n_bufs, features); |
1521 | } |
1522 | |
1523 | void GetPerStoreFeaturesFromMeasurePairs(const Array<MeasureInput>& inputs, |
1524 | const Array<MeasureResult>& results, |
1525 | int , int max_n_bufs, |
1526 | std::vector<std::vector<float>>* features, |
1527 | std::vector<float>* normalized_throughputs, |
1528 | std::vector<int>* task_ids) { |
1529 | Array<State> states; |
1530 | std::vector<SearchTask> tasks; |
1531 | |
1532 | normalized_throughputs->clear(); |
1533 | task_ids->clear(); |
1534 | |
1535 | // (workload_key, target) -> (search_task, task_id) |
1536 | std::unordered_map<std::pair<std::string, std::string>, std::pair<SearchTask, size_t>> task_cache; |
1537 | // task_id -> min_cost |
1538 | std::vector<float> min_costs; |
1539 | |
1540 | const auto* workload_key_to_tensors = |
1541 | tvm::runtime::Registry::Get("auto_scheduler.workload_key_to_tensors" ); |
1542 | ICHECK(workload_key_to_tensors != nullptr); |
1543 | |
1544 | tasks.reserve(inputs.size()); |
1545 | normalized_throughputs->reserve(inputs.size()); |
1546 | task_ids->reserve(inputs.size()); |
1547 | for (size_t i = 0; i < inputs.size(); ++i) { |
1548 | float cost = static_cast<float>(FloatArrayMean(results[i]->costs)); |
1549 | const std::string& workload_key = inputs[i]->task->workload_key; |
1550 | SearchTask task; |
1551 | |
1552 | size_t task_id; |
1553 | std::pair<std::string, std::string> key(workload_key, inputs[i]->task->target->str()); |
1554 | auto find_res = task_cache.find(key); |
1555 | if (find_res == task_cache.end()) { |
1556 | if (inputs[i]->task->compute_dag.defined()) { // the measure input is complete |
1557 | task = inputs[i]->task; |
1558 | } else { |
1559 | // The measure input is incomplete, rebuild task for incomplete measure pairs read from file |
1560 | try { |
1561 | Array<te::Tensor> tensors = (*workload_key_to_tensors)(workload_key); |
1562 | Target target = inputs[i]->task->target; |
1563 | Target target_host = inputs[i]->task->target_host; |
1564 | CheckAndUpdateHostConsistency(&target, &target_host); |
1565 | task = |
1566 | SearchTask(ComputeDAG(tensors), workload_key, target, target_host, |
1567 | inputs[i]->task->hardware_params, inputs[i]->task->layout_rewrite_option, |
1568 | inputs[i]->task->task_input_names); |
1569 | } catch (std::exception& e) { |
1570 | // Cannot build ComputeDAG from workload key, the task may have not been registered in |
1571 | // this search round |
1572 | continue; |
1573 | } |
1574 | } |
1575 | task_id = task_cache.size(); |
1576 | |
1577 | // compute min cost for each task |
1578 | task_cache.insert(std::make_pair(key, std::make_pair(task, task_id))); |
1579 | min_costs.push_back(cost); |
1580 | } else { |
1581 | std::tie(task, task_id) = find_res->second; |
1582 | min_costs[task_id] = std::min(min_costs[task_id], cost); |
1583 | } |
1584 | |
1585 | tasks.push_back(std::move(task)); |
1586 | task_ids->push_back(task_id); |
1587 | states.push_back(inputs[i]->state); |
1588 | normalized_throughputs->push_back(cost); |
1589 | } |
1590 | |
1591 | for (size_t i = 0; i < normalized_throughputs->size(); ++i) { |
1592 | (*normalized_throughputs)[i] = min_costs[(*task_ids)[i]] / (*normalized_throughputs)[i]; |
1593 | } |
1594 | |
1595 | GetPerStoreFeaturesFromStates(states, tasks, skip_first_n_feature_extraction, max_n_bufs, |
1596 | features); |
1597 | } |
1598 | |
1599 | /* |
1600 | * \brief Serialize a two-dimensional variable-size feature vector with normalized throughputs |
1601 | * and task ids to a one-dimensional flatten byte array. |
1602 | * |
1603 | * For faster data copy between c++ and python, the c++ part returns features in a single |
1604 | * flatten array using a packed format. The python part then unpacks the flatten array. |
1605 | * |
1606 | * The packed format for n records is: |
1607 | * { |
1608 | * int n; |
1609 | * int sizes[n+2]; // The sizes for the following arrays |
1610 | * |
1611 | * float features_0[size[0]]; // The features for record 0 |
1612 | * float features_1[size[1]]; // The features for record 1 |
1613 | * ... |
1614 | * float features_i[size[i]]; // The features for record i |
1615 | * ... // until i == n - 1 |
1616 | * |
1617 | * float throughputs[sizes[n]]; // The normalized throughputs for n records |
1618 | * int task_ids[size[n+1]]; // The task ids for n records |
1619 | * |
1620 | * } |
1621 | * To implement this format, we also store int as float, so we can store all numbers |
1622 | * into a single float array. |
1623 | */ |
1624 | TVMByteArray SerializeFeatures(std::vector<std::vector<float>>&& features, |
1625 | std::vector<float>&& normalized_throughputs, |
1626 | std::vector<int>&& task_ids, std::vector<char>* out_data) { |
1627 | size_t total_bytes = 0; |
1628 | std::vector<int> size_vector; |
1629 | |
1630 | int n = features.size(); |
1631 | |
1632 | // serialize sizes |
1633 | size_t size_vector_size = 1 + n + 2; |
1634 | total_bytes += size_vector_size * sizeof(int); |
1635 | |
1636 | size_vector.reserve(size_vector_size); |
1637 | size_vector.push_back(features.size()); |
1638 | for (const auto& x : features) { |
1639 | size_vector.push_back(static_cast<int>(x.size())); |
1640 | total_bytes += sizeof(float) * x.size(); |
1641 | } |
1642 | size_vector.push_back(static_cast<int>(normalized_throughputs.size())); |
1643 | total_bytes += sizeof(float) * normalized_throughputs.size(); |
1644 | size_vector.push_back(static_cast<int>(task_ids.size())); |
1645 | total_bytes += sizeof(int) * task_ids.size(); |
1646 | |
1647 | ICHECK_EQ(size_vector.size(), size_vector_size); |
1648 | |
1649 | // allocate memory |
1650 | out_data->reserve(total_bytes); |
1651 | char* ptr = out_data->data(); |
1652 | |
1653 | // serialize size_vector |
1654 | memmove(ptr, reinterpret_cast<char*>(size_vector.data()), size_vector.size() * sizeof(int)); |
1655 | ptr += size_vector.size() * sizeof(int); |
1656 | |
1657 | // serialize features |
1658 | for (auto& x : features) { |
1659 | memmove(ptr, x.data(), sizeof(float) * x.size()); |
1660 | ptr += sizeof(float) * x.size(); |
1661 | x.clear(); |
1662 | } |
1663 | |
1664 | // serialize normalized_throughputs |
1665 | memmove(ptr, reinterpret_cast<char*>(normalized_throughputs.data()), |
1666 | normalized_throughputs.size() * sizeof(int)); |
1667 | ptr += normalized_throughputs.size() * sizeof(int); |
1668 | |
1669 | // serialize task_ids |
1670 | memmove(ptr, reinterpret_cast<char*>(task_ids.data()), task_ids.size() * sizeof(int)); |
1671 | ptr += task_ids.size() * sizeof(int); |
1672 | |
1673 | ICHECK_EQ(ptr - out_data->data(), total_bytes); |
1674 | |
1675 | return TVMByteArray{out_data->data(), total_bytes}; |
1676 | } |
1677 | |
1678 | TVM_REGISTER_GLOBAL("auto_scheduler.GetPerStoreFeaturesFromFile" ) |
1679 | .set_body([](TVMArgs args, TVMRetValue* ret) { |
1680 | std::string filename = args[0]; |
1681 | int max_lines = args[1]; |
1682 | int max_n_bufs = args[2]; |
1683 | |
1684 | std::vector<std::vector<float>> features; |
1685 | std::vector<float> normalized_throughputs; |
1686 | std::vector<int> task_ids; |
1687 | |
1688 | GetPerStoreFeaturesFromFile(filename, max_lines, max_n_bufs, &features, |
1689 | &normalized_throughputs, &task_ids); |
1690 | |
1691 | std::vector<char> byte_data; |
1692 | *ret = SerializeFeatures(std::move(features), std::move(normalized_throughputs), |
1693 | std::move(task_ids), &byte_data); |
1694 | }); |
1695 | |
1696 | TVM_REGISTER_GLOBAL("auto_scheduler.GetPerStoreFeaturesFromMeasurePairs" ) |
1697 | .set_body([](TVMArgs args, TVMRetValue* ret) { |
1698 | Array<MeasureInput> inputs = args[0]; |
1699 | Array<MeasureResult> results = args[1]; |
1700 | int = args[2]; |
1701 | int max_n_bufs = args[3]; |
1702 | |
1703 | std::vector<std::vector<float>> features; |
1704 | std::vector<float> normalized_throughputs; |
1705 | std::vector<int> task_ids; |
1706 | |
1707 | GetPerStoreFeaturesFromMeasurePairs(inputs, results, skip_first_n_feature_extraction, |
1708 | max_n_bufs, &features, &normalized_throughputs, |
1709 | &task_ids); |
1710 | |
1711 | std::vector<char> byte_data; |
1712 | *ret = SerializeFeatures(std::move(features), std::move(normalized_throughputs), |
1713 | std::move(task_ids), &byte_data); |
1714 | }); |
1715 | |
1716 | TVM_REGISTER_GLOBAL("auto_scheduler.GetPerStoreFeaturesFromStates" ) |
1717 | .set_body([](TVMArgs args, TVMRetValue* ret) { |
1718 | Array<State> states = args[0]; |
1719 | SearchTask task = args[1]; |
1720 | int max_n_bufs = args[2]; |
1721 | |
1722 | std::vector<std::vector<float>> features; |
1723 | std::vector<float> normalized_throughputs; |
1724 | std::vector<int> task_ids; |
1725 | |
1726 | GetPerStoreFeaturesFromStates(states, task, 0, max_n_bufs, &features); |
1727 | |
1728 | std::vector<char> byte_data; |
1729 | *ret = SerializeFeatures(std::move(features), std::move(normalized_throughputs), |
1730 | std::move(task_ids), &byte_data); |
1731 | }); |
1732 | |
1733 | TVM_REGISTER_GLOBAL("auto_scheduler.GetPerStoreFeatureNames" ) |
1734 | .set_body([](TVMArgs args, TVMRetValue* ret) { |
1735 | int max_n_bufs = args[0]; |
1736 | std::vector<std::string> names; |
1737 | |
1738 | GetPerStoreFeatureName(max_n_bufs, &names); |
1739 | |
1740 | Array<String> arr; |
1741 | for (const auto& x : names) { |
1742 | arr.push_back(x); |
1743 | } |
1744 | *ret = arr; |
1745 | }); |
1746 | |
1747 | TVM_REGISTER_GLOBAL("auto_scheduler.FeaturesFromPrimFunc" ) |
1748 | .set_body_typed([](const PrimFunc& func, int cache_line_size, int max_n_bufs, bool log_scale) { |
1749 | std::vector<float> vec; |
1750 | GetPerStoreFeature(func, cache_line_size, max_n_bufs, &vec, log_scale); |
1751 | int64_t num_feature_rows = vec[0]; // first element is number of rows |
1752 | int64_t row_length = 0; |
1753 | if (num_feature_rows != 0) { |
1754 | row_length = (vec.size() - 1) / num_feature_rows; |
1755 | } |
1756 | auto ary = |
1757 | runtime::NDArray::Empty({num_feature_rows, row_length}, {kDLFloat, 32, 1}, {kDLCPU, 0}); |
1758 | // NDArray is row major by default |
1759 | ary.CopyFromBytes(vec.data() + 1, sizeof(float) * num_feature_rows * row_length); |
1760 | return ary; |
1761 | }); |
1762 | |
1763 | } // namespace auto_scheduler |
1764 | } // namespace tvm |
1765 | |