1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file feature_visitor.cc |
22 | * \brief Base class for feature extractor. |
23 | * These features are used for machine learning cost model |
24 | */ |
25 | |
26 | #include "feature_visitor.h" |
27 | |
28 | namespace tvm { |
29 | namespace autotvm { |
30 | |
31 | // for loop |
32 | void FeatureVisitor::VisitStmt_(const ForNode* op) { |
33 | const auto* extent = op->extent.as<IntImmNode>(); |
34 | int64_t loop_extent = -1; |
35 | if (extent != nullptr) loop_extent = extent->value; |
36 | AnnotationType ann = kSerial; |
37 | switch (op->kind) { |
38 | case ForKind ::kParallel: |
39 | ann = kParallel; |
40 | break; |
41 | case ForKind::kUnrolled: |
42 | ann = kUnrolled; |
43 | break; |
44 | case ForKind::kVectorized: |
45 | ann = kVectorized; |
46 | break; |
47 | case ForKind::kSerial: |
48 | ann = kSerial; |
49 | break; |
50 | case ForKind::kThreadBinding: |
51 | LOG(FATAL) << "Loop ThreadBinding is reserved for future used and " |
52 | << "not yet supported in TIR" ; |
53 | break; |
54 | } |
55 | |
56 | if (EnterItervar_(op->loop_var, loop_extent, ann)) { |
57 | StmtExprVisitor::VisitStmt_(op); |
58 | ExitItervar_(); |
59 | } |
60 | } |
61 | |
62 | // parallel axis, virtual thread |
63 | void FeatureVisitor::VisitStmt_(const AttrStmtNode* op) { |
64 | if (op->attr_key == tir::attr::thread_extent || op->attr_key == tir::attr::virtual_thread) { |
65 | Var var = op->node.as<tir::IterVarNode>()->var; |
66 | const auto* extent = op->value.as<IntImmNode>(); |
67 | ICHECK(extent); |
68 | |
69 | std::string name = var.get()->name_hint; |
70 | AnnotationType ann = kParallel; |
71 | if (op->attr_key == tir::attr::thread_extent) { |
72 | if (name == "blockIdx.x" ) |
73 | ann = kBlockX; |
74 | else if (name == "blockIdx.y" ) |
75 | ann = kBlockY; |
76 | else if (name == "blockIdx.z" ) |
77 | ann = kBlockZ; |
78 | else if (name == "threadIdx.x" ) |
79 | ann = kThreadX; |
80 | else if (name == "threadIdx.y" ) |
81 | ann = kThreadY; |
82 | else if (name == "threadIdx.z" ) |
83 | ann = kThreadZ; |
84 | else |
85 | LOG(FATAL) << "invalid thread itervar " + name; |
86 | } else { |
87 | ann = kVirtualThread; |
88 | } |
89 | |
90 | if (EnterItervar_(var, extent->value, ann)) { |
91 | StmtExprVisitor::VisitStmt_(op); |
92 | ExitItervar_(); |
93 | } |
94 | } else { |
95 | StmtExprVisitor::VisitStmt_(op); |
96 | } |
97 | } |
98 | |
99 | // memory access |
100 | void FeatureVisitor::VisitExpr_(const BufferLoadNode* op) { |
101 | ICHECK_EQ(op->indices.size(), 1) << "FeatureVisitor can only be used on flattened buffers" ; |
102 | EnterMem_(op->buffer->data, op->indices[0]); |
103 | StmtExprVisitor::VisitExpr_(op); |
104 | ExitMem_(); |
105 | } |
106 | |
107 | void FeatureVisitor::VisitStmt_(const BufferStoreNode* op) { |
108 | ICHECK_EQ(op->indices.size(), 1) << "FeatureVisitor can only be used on flattened buffers" ; |
109 | EnterMem_(op->buffer->data, op->indices[0]); |
110 | StmtExprVisitor::VisitStmt_(op); |
111 | ExitMem_(); |
112 | } |
113 | |
114 | } // namespace autotvm |
115 | } // namespace tvm |
116 | |