1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file feature_visitor.cc
22 * \brief Base class for feature extractor.
23 * These features are used for machine learning cost model
24 */
25
26#include "feature_visitor.h"
27
28namespace tvm {
29namespace autotvm {
30
31// for loop
32void FeatureVisitor::VisitStmt_(const ForNode* op) {
33 const auto* extent = op->extent.as<IntImmNode>();
34 int64_t loop_extent = -1;
35 if (extent != nullptr) loop_extent = extent->value;
36 AnnotationType ann = kSerial;
37 switch (op->kind) {
38 case ForKind ::kParallel:
39 ann = kParallel;
40 break;
41 case ForKind::kUnrolled:
42 ann = kUnrolled;
43 break;
44 case ForKind::kVectorized:
45 ann = kVectorized;
46 break;
47 case ForKind::kSerial:
48 ann = kSerial;
49 break;
50 case ForKind::kThreadBinding:
51 LOG(FATAL) << "Loop ThreadBinding is reserved for future used and "
52 << "not yet supported in TIR";
53 break;
54 }
55
56 if (EnterItervar_(op->loop_var, loop_extent, ann)) {
57 StmtExprVisitor::VisitStmt_(op);
58 ExitItervar_();
59 }
60}
61
62// parallel axis, virtual thread
63void FeatureVisitor::VisitStmt_(const AttrStmtNode* op) {
64 if (op->attr_key == tir::attr::thread_extent || op->attr_key == tir::attr::virtual_thread) {
65 Var var = op->node.as<tir::IterVarNode>()->var;
66 const auto* extent = op->value.as<IntImmNode>();
67 ICHECK(extent);
68
69 std::string name = var.get()->name_hint;
70 AnnotationType ann = kParallel;
71 if (op->attr_key == tir::attr::thread_extent) {
72 if (name == "blockIdx.x")
73 ann = kBlockX;
74 else if (name == "blockIdx.y")
75 ann = kBlockY;
76 else if (name == "blockIdx.z")
77 ann = kBlockZ;
78 else if (name == "threadIdx.x")
79 ann = kThreadX;
80 else if (name == "threadIdx.y")
81 ann = kThreadY;
82 else if (name == "threadIdx.z")
83 ann = kThreadZ;
84 else
85 LOG(FATAL) << "invalid thread itervar " + name;
86 } else {
87 ann = kVirtualThread;
88 }
89
90 if (EnterItervar_(var, extent->value, ann)) {
91 StmtExprVisitor::VisitStmt_(op);
92 ExitItervar_();
93 }
94 } else {
95 StmtExprVisitor::VisitStmt_(op);
96 }
97}
98
99// memory access
100void FeatureVisitor::VisitExpr_(const BufferLoadNode* op) {
101 ICHECK_EQ(op->indices.size(), 1) << "FeatureVisitor can only be used on flattened buffers";
102 EnterMem_(op->buffer->data, op->indices[0]);
103 StmtExprVisitor::VisitExpr_(op);
104 ExitMem_();
105}
106
107void FeatureVisitor::VisitStmt_(const BufferStoreNode* op) {
108 ICHECK_EQ(op->indices.size(), 1) << "FeatureVisitor can only be used on flattened buffers";
109 EnterMem_(op->buffer->data, op->indices[0]);
110 StmtExprVisitor::VisitStmt_(op);
111 ExitMem_();
112}
113
114} // namespace autotvm
115} // namespace tvm
116