1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | |
17 | #include "glow/IR/Instrs.h" |
18 | #include "glow/IR/IR.h" |
19 | #include "glow/Support/Support.h" |
20 | |
21 | #include "llvm/Support/Casting.h" |
22 | |
23 | #include <cassert> |
24 | |
25 | using namespace glow; |
26 | using llvm::cast; |
27 | using llvm::isa; |
28 | |
29 | //===----------------------------------------------------------------------===// |
30 | // Instruction textual printers |
31 | //===----------------------------------------------------------------------===// |
32 | |
33 | const char *WeightVar::getMutabilityStr(MutabilityKind kind) { |
34 | const char *names[] = {"const" , "mutable" , nullptr}; |
35 | return names[static_cast<int>(kind)]; |
36 | } |
37 | |
38 | const char *WeightVar::getMutabilityStr() const { |
39 | return getMutabilityStr(mut_); |
40 | } |
41 | |
42 | void WeightVar::dump(llvm::raw_ostream &os) const { |
43 | os << "%" << getName() << " = WeightVar " ; |
44 | os << *getType() << " " << getMutabilityStr(); |
45 | } |
46 | |
47 | //===----------------------------------------------------------------------===// |
48 | // Instruction verification |
49 | //===----------------------------------------------------------------------===// |
50 | |
51 | void CopyInst::verify() const { |
52 | auto *dest = getDest(); |
53 | auto *src = getSrc(); |
54 | (void)dest; |
55 | (void)src; |
56 | assert(dest->getType() == src->getType() && "Invalid type." ); |
57 | } |
58 | |
59 | void TensorViewInst::verify() const { |
60 | assert(getSrc()->getType()->size() >= getType()->size() && |
61 | "TensorView view size should be no larger than Src size" ); |
62 | assert(getSrc()->getElementType() == getType()->getElementType() && |
63 | "TensorView view element type should be the same as Src type" ); |
64 | assert(getSrc()->getType()->dims().size() == getOffsets().size() && |
65 | "TensorView offsets should have the same number of dims as Src type " |
66 | "shape" ); |
67 | } |
68 | |
69 | void AllocActivationInst::verify() const { |
70 | unsigned numDealloc = 0; |
71 | for (const Use &U : getUsers()) { |
72 | numDealloc += isa<DeallocActivationInst>(U.get()); |
73 | } |
74 | |
75 | // Make sure that there is exactly one user is a deallocation. |
76 | assert(numDealloc == 1 && "Invalid number of tensor deallocation" ); |
77 | } |
78 | |
79 | void DeallocActivationInst::verify() const { |
80 | // The operand of this instruction needs to be an AllocActivationInst. |
81 | assert(isa<AllocActivationInst>(getSrc()) && "Invalid operand" ); |
82 | } |
83 | |
84 | void InsertTensorInst::verify() const { |
85 | assert(getSrc()->getElementType() == getDest()->getElementType() && |
86 | "InsertTensor dest element type should be the same as Src type." ); |
87 | assert(getCount() > 0 && "Count must be non-zero." ); |
88 | assert(getAxis() >= 0 && getAxis() < getDest()->dims().size() && |
89 | "Axis must fit inside Dest dims." ); |
90 | assert( |
91 | getDest()->getType()->dims().size() == getOffsets().size() && |
92 | "InsertTensor offsets should have the same number of dims as Dest type " |
93 | "shape" ); |
94 | } |
95 | |
96 | void ExtractTensorInst::() const { |
97 | assert(getSrc()->getElementType() == getDest()->getElementType() && |
98 | "ExtractTensor dest element type should be the same as Src type." ); |
99 | assert( |
100 | getSrc()->getType()->dims().size() == getOffsets().size() && |
101 | "ExtractTensor offsets should have the same number of dims as Src type " |
102 | "shape" ); |
103 | } |
104 | |
105 | static void verifyRelu(TypeRef srcTy, TypeRef destTy) { |
106 | if (srcTy->isQuantizedType()) { |
107 | assert(srcTy->isQuantizedType() == destTy->isQuantizedType() && |
108 | "Mismatching isQuantized" ); |
109 | assert(srcTy->dims() == destTy->dims() && "Mismatching dimensions" ); |
110 | assert(srcTy->getElementType() == destTy->getElementType() && |
111 | "Mismatching element type" ); |
112 | return; |
113 | } |
114 | assert(destTy->isEqual(*srcTy) && "Mismatching types" ); |
115 | } |
116 | |
117 | void ReluInst::verify() const { |
118 | verifyRelu(getSrc()->getType(), getDest()->getType()); |
119 | } |
120 | |
121 | void ReluGradInst::verify() const { |
122 | verifyRelu(getSrcGrad()->getType(), getDest()->getType()); |
123 | verifyRelu(getSrcGrad()->getType(), getDestGrad()->getType()); |
124 | } |
125 | |
126 | //===----------------------------------------------------------------------===// |
127 | // Instruction scratch requirements |
128 | //===----------------------------------------------------------------------===// |
129 | dim_t TopKInst::getScratchSize() const { |
130 | // Allocate enough scratch space to hold N values and N indices. |
131 | dim_t N = getInput()->dims().back(); |
132 | dim_t elemSize = getIndices()->getType()->getElementSize(); |
133 | return (2 * N * elemSize); |
134 | } |
135 | |
136 | dim_t AudioSpectrogramInst::getWinOutScratchSize() const { |
137 | dim_t spectrogramLen = getSpectrogram()->dims()[1]; |
138 | dim_t fftLen = (spectrogramLen - 1) * 2; |
139 | return fftLen * sizeof(float); |
140 | } |
141 | |
142 | dim_t AudioSpectrogramInst::getFftOutScratchSize() const { |
143 | dim_t spectrogramLen = getSpectrogram()->dims()[1]; |
144 | dim_t fftLen = (spectrogramLen - 1) * 2; |
145 | return (fftLen + 2) * sizeof(float); |
146 | } |
147 | |
148 | dim_t MFCCInst::getScratchSize() const { |
149 | return getFilterBankCount() * sizeof(float); |
150 | } |
151 | |
152 | dim_t TFLiteDetectionPostProcessInst::getScratchSize() const { |
153 | |
154 | dim_t numBoxes = getAnchors()->dims()[0]; |
155 | dim_t numClasses = getNumClasses(); |
156 | dim_t maxDetections = getMaxDetections(); |
157 | dim_t maxDetectionsPerClass = getMaxDetectionsPerClass(); |
158 | |
159 | dim_t scratchSize = 0; |
160 | if (getRegularNMS()) { |
161 | // Compute scratch size for regular NMS. |
162 | scratchSize += numBoxes * sizeof(float); |
163 | scratchSize += (numBoxes + maxDetections) * sizeof(int32_t); |
164 | scratchSize += (numBoxes + maxDetections) * sizeof(float); |
165 | scratchSize += (numBoxes + maxDetections) * sizeof(int32_t); |
166 | scratchSize += std::min(numBoxes, maxDetectionsPerClass) * sizeof(float); |
167 | scratchSize += numBoxes * sizeof(int32_t); |
168 | scratchSize += numBoxes * sizeof(int32_t); |
169 | scratchSize += numBoxes * sizeof(float); |
170 | scratchSize += numBoxes * sizeof(int32_t); |
171 | } else { |
172 | // Compute scratch size for fast NMS. |
173 | scratchSize += numBoxes * sizeof(float); |
174 | scratchSize += |
175 | numBoxes * std::min(maxDetections, numClasses) * sizeof(int32_t); |
176 | scratchSize += numBoxes * sizeof(int32_t); |
177 | scratchSize += numBoxes * sizeof(int32_t); |
178 | scratchSize += numBoxes * sizeof(float); |
179 | scratchSize += numBoxes * sizeof(int32_t); |
180 | } |
181 | return scratchSize; |
182 | } |
183 | |