1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | |
17 | #include "glow/Runtime/InputSanitizer.h" |
18 | #include "glow/Flags/Flags.h" |
19 | |
20 | #include <folly/Random.h> |
21 | #include <glog/logging.h> |
22 | #include <llvm/Support/Casting.h> |
23 | |
24 | namespace glow { |
25 | namespace runtime { |
26 | |
27 | namespace { |
28 | |
29 | template <class T> |
30 | static Error sanitizeIndices(const Tensor *indicesTensor, size_t tableHeight, |
31 | llvm::StringRef tensorName) { |
32 | auto indices = indicesTensor->getHandle<T>(); |
33 | size_t indicesLen = indices.getRealNumElements(); |
34 | // indices in [0, tableHeight) |
35 | for (auto i = 0; i < indicesLen; i++) { |
36 | RETURN_ERR_IF_NOT(indices.raw(i) >= 0 && indices.raw(i) < tableHeight, |
37 | "Indices sanitization failed on tensor " + |
38 | tensorName.str() + ": index " + |
39 | std::to_string(indices.raw(i)) + " at pos " + |
40 | std::to_string(i) + " is out of range [0, " + |
41 | std::to_string(tableHeight) + ")" ); |
42 | } |
43 | |
44 | return Error::success(); |
45 | } |
46 | |
47 | template <class T> |
48 | static Error sanitizeLengths(const Tensor *lengthsTensor, |
49 | const size_t indicesLen, |
50 | llvm::StringRef tensorName) { |
51 | auto lengths = lengthsTensor->getHandle<T>(); |
52 | |
53 | size_t totalLensSum = 0; |
54 | for (auto i = 0; i < lengths.getRealNumElements(); ++i) { |
55 | auto length = lengths.raw(i); |
56 | RETURN_ERR_IF_NOT(length >= 0, |
57 | "SLS lengths sanitization failed on tensor " + |
58 | tensorName.str() + ": length " + |
59 | std::to_string(length) + " at pos " + |
60 | std::to_string(i) + " is negative" ); |
61 | totalLensSum += length; |
62 | } |
63 | |
64 | RETURN_ERR_IF_NOT( |
65 | indicesLen == totalLensSum, |
66 | strFormat("SLS lengths sanitization failed on tensor %s: indices " |
67 | "length %lu is not equal to sum of lengths %lu" , |
68 | tensorName.str().c_str(), indicesLen, totalLensSum)); |
69 | |
70 | return Error::success(); |
71 | } |
72 | |
73 | template <class T> |
74 | static Error sanitizeOffsets(const Tensor *offsetsTensor, |
75 | const size_t numberOfIndices, |
76 | llvm::StringRef tensorName) { |
77 | auto offsets = offsetsTensor->getHandle<T>(); |
78 | |
79 | RETURN_ERR_IF_NOT(offsets.raw(0) == 0, |
80 | "EBB offsets sanitization failed on tensor " + |
81 | tensorName.str() + ": the first offset is not zero " + |
82 | std::to_string(offsets.raw(0))); |
83 | |
84 | bool zeroTensor = true; |
85 | size_t offsetsLen = offsets.getRealNumElements(); |
86 | for (auto i = 0; i < offsetsLen - 1; i++) { |
87 | RETURN_ERR_IF_NOT(offsets.raw(i) <= offsets.raw(i + 1), |
88 | "EBB offsets sanitization failed on tensor " + |
89 | tensorName.str() + ": decreasing offsets " + |
90 | std::to_string(offsets.raw(i)) + " and " + |
91 | std::to_string(offsets.raw(i + 1)) + " at pos " + |
92 | std::to_string(i)); |
93 | |
94 | if (zeroTensor && offsets.raw(i + 1) != 0) { |
95 | zeroTensor = false; |
96 | } |
97 | } |
98 | |
99 | size_t lastOffset = offsets.raw(offsetsLen - 1); |
100 | RETURN_ERR_IF_NOT( |
101 | zeroTensor || lastOffset == numberOfIndices, |
102 | strFormat("EBB offsets sanitization failed on tensor %s: " |
103 | "the last offset %lu is not equal to the number of indices %lu" , |
104 | tensorName.str().c_str(), lastOffset, numberOfIndices)); |
105 | |
106 | return Error::success(); |
107 | } |
108 | |
109 | } // namespace |
110 | |
111 | // |
112 | // SparseLengthsSum input sanitization |
113 | // |
114 | SparseLengthsSumInputSanitizer::SparseLengthsSumInputSanitizer( |
115 | const size_t tableHeight, Placeholder *indicesPH, Placeholder *weightsPH, |
116 | Placeholder *lengthsPH) |
117 | : tableHeight_{tableHeight}, indicesPH_{indicesPH}, weightsPH_{weightsPH}, |
118 | lengthsPH_{lengthsPH} {} |
119 | |
120 | Error SparseLengthsSumInputSanitizer::sanitize( |
121 | const PlaceholderBindings &bindings) { |
122 | auto *indices = bindings.get(indicesPH_); |
123 | |
124 | // Either a constant or some node internal to the function, skip |
125 | if (indices == nullptr) { |
126 | return Error::success(); |
127 | } |
128 | |
129 | size_t indicesLen = indices->getRealNumElements(); |
130 | |
131 | if (weightsPH_) { |
132 | auto *weights = bindings.get(weightsPH_); |
133 | // If this is a weigthed one and the placeholder is real (not a constant |
134 | // or internal to the function, then sanitize |
135 | if (weights != nullptr) { |
136 | size_t weightsLen = weights->getRealNumElements(); |
137 | RETURN_ERR_IF_NOT( |
138 | indicesLen == weightsLen, |
139 | strFormat("SLS weights sanitization failed on %s: number of indices " |
140 | "%lu is not equal to number of weights %lu" , |
141 | weightsPH_->getName().str().c_str(), indicesLen, |
142 | weightsLen)); |
143 | } |
144 | } |
145 | |
146 | // Sanitize indices |
147 | if (indices->getElementType() == ElemKind::Int64ITy) { |
148 | RETURN_IF_ERR( |
149 | sanitizeIndices<int64_t>(indices, tableHeight_, indicesPH_->getName())); |
150 | } else if (indices->getElementType() == ElemKind::Int32ITy) { |
151 | RETURN_IF_ERR( |
152 | sanitizeIndices<int32_t>(indices, tableHeight_, indicesPH_->getName())); |
153 | } else { |
154 | return MAKE_ERR(strFormat( |
155 | "SLS indices sanitization failed on tensor %s: unsupported " |
156 | "element type %s" , |
157 | indicesPH_->getName().str().c_str(), |
158 | Type::getElementName(indices->getElementType()).str().c_str())); |
159 | } |
160 | |
161 | // Sanitize SLS lengths |
162 | auto *lengths = bindings.get(lengthsPH_); |
163 | |
164 | // Either a constant or some node internal to the function, skip |
165 | if (lengths == nullptr) { |
166 | return Error::success(); |
167 | } |
168 | |
169 | if (lengths->getElementType() == ElemKind::Int32ITy) { |
170 | RETURN_IF_ERR( |
171 | sanitizeLengths<int32_t>(lengths, indicesLen, lengthsPH_->getName())); |
172 | } else if (lengths->getElementType() == ElemKind::Int64ITy) { |
173 | RETURN_IF_ERR( |
174 | sanitizeLengths<int64_t>(lengths, indicesLen, lengthsPH_->getName())); |
175 | } else { |
176 | return MAKE_ERR(strFormat( |
177 | "SLS lengths sanitization failed on tensor %s: unsupported " |
178 | "element type %s" , |
179 | lengthsPH_->getName().str().c_str(), |
180 | Type::getElementName(lengths->getElementType()).str().c_str())); |
181 | } |
182 | |
183 | return Error::success(); |
184 | } |
185 | |
186 | std::string SparseLengthsSumInputSanitizer::toString() { |
187 | std::ostringstream ss; |
188 | ss << "SparseLengthsSumInputSanitizer[" ; |
189 | ss << "tableHeight=" << tableHeight_; |
190 | ss << ", indices=" ; |
191 | if (indicesPH_) { |
192 | ss << indicesPH_->getName().str(); |
193 | } |
194 | ss << ", weigths=" ; |
195 | if (weightsPH_) { |
196 | ss << weightsPH_->getName().str(); |
197 | } |
198 | ss << ", lengths=" ; |
199 | if (lengthsPH_) { |
200 | ss << lengthsPH_->getName().str(); |
201 | } |
202 | ss << "]" ; |
203 | return ss.str(); |
204 | } |
205 | |
206 | // |
207 | // EmbeddingBag input sanitization |
208 | // |
209 | EmbeddingBagInputSanitizer::EmbeddingBagInputSanitizer(size_t tableHeight, |
210 | Placeholder *indicesPH, |
211 | Placeholder *weightsPH, |
212 | Placeholder *offsetsPH) |
213 | : tableHeight_{tableHeight}, indicesPH_{indicesPH}, weightsPH_{weightsPH}, |
214 | offsetsPH_{offsetsPH} {} |
215 | |
216 | Error EmbeddingBagInputSanitizer::sanitize( |
217 | const PlaceholderBindings &bindings) { |
218 | auto *indices = bindings.get(indicesPH_); |
219 | |
220 | // Either a constant or some node internal to the function, skip |
221 | if (indices == nullptr) { |
222 | return Error::success(); |
223 | } |
224 | |
225 | size_t indicesLen = indices->getRealNumElements(); |
226 | |
227 | if (weightsPH_) { |
228 | auto *weights = bindings.get(weightsPH_); |
229 | // If this is a weigthed one and the placeholder is real (not a constant |
230 | // or internal to the function, then sanitize |
231 | if (weights != nullptr) { |
232 | size_t weightsLen = weights->getRealNumElements(); |
233 | RETURN_ERR_IF_NOT( |
234 | indicesLen == weightsLen, |
235 | strFormat("EBB weights sanitization failed on %s: number of indices " |
236 | "%lu is not equal to number of weights %lu" , |
237 | weightsPH_->getName().str().c_str(), indicesLen, |
238 | weightsLen)); |
239 | } |
240 | } |
241 | |
242 | // Sanitize indices |
243 | if (indices->getElementType() == ElemKind::Int64ITy) { |
244 | RETURN_IF_ERR( |
245 | sanitizeIndices<int64_t>(indices, tableHeight_, indicesPH_->getName())); |
246 | } else if (indices->getElementType() == ElemKind::Int32ITy) { |
247 | RETURN_IF_ERR( |
248 | sanitizeIndices<int32_t>(indices, tableHeight_, indicesPH_->getName())); |
249 | } else { |
250 | return MAKE_ERR(strFormat( |
251 | "EBB indices sanitization failed on tensor %s: unsupported " |
252 | "element type %s" , |
253 | indicesPH_->getName().str().c_str(), |
254 | Type::getElementName(indices->getElementType()).str().c_str())); |
255 | } |
256 | |
257 | // Sanitize offsets |
258 | auto *offsets = bindings.get(offsetsPH_); |
259 | |
260 | // Either a constant or some node internal to the function, skip |
261 | if (offsets == nullptr) { |
262 | return Error::success(); |
263 | } |
264 | |
265 | if (offsets->getElementType() == ElemKind::Int32ITy) { |
266 | RETURN_IF_ERR( |
267 | sanitizeOffsets<int32_t>(offsets, indicesLen, offsetsPH_->getName())); |
268 | } else if (offsets->getElementType() == ElemKind::Int64ITy) { |
269 | RETURN_IF_ERR( |
270 | sanitizeOffsets<int64_t>(offsets, indicesLen, offsetsPH_->getName())); |
271 | } else { |
272 | return MAKE_ERR(strFormat( |
273 | "EBB offsets sanitization failed on tensor %s: unsupported " |
274 | "element type %s" , |
275 | offsetsPH_->getName().str().c_str(), |
276 | Type::getElementName(offsets->getElementType()).str().c_str())); |
277 | } |
278 | |
279 | return Error::success(); |
280 | } |
281 | |
282 | std::string EmbeddingBagInputSanitizer::toString() { |
283 | std::ostringstream ss; |
284 | ss << "EmbeddingBagInputSanitizer[" ; |
285 | ss << "tableHeight=" << tableHeight_; |
286 | ss << ", indices=" ; |
287 | if (indicesPH_) { |
288 | ss << indicesPH_->getName().str(); |
289 | } |
290 | ss << ", weigths=" ; |
291 | if (weightsPH_) { |
292 | ss << weightsPH_->getName().str(); |
293 | } |
294 | ss << ", offsets=" ; |
295 | if (offsetsPH_) { |
296 | ss << offsetsPH_->getName().str(); |
297 | } |
298 | ss << "]" ; |
299 | return ss.str(); |
300 | } |
301 | |
302 | // |
303 | // Public utility functions |
304 | // |
305 | std::vector<InputSanitizerPtr> getInputSanitizers(const Function &function) { |
306 | std::vector<InputSanitizerPtr> result; |
307 | |
308 | for (const auto &node : function.getNodes()) { |
309 | if (auto *SLS = |
310 | llvm::dyn_cast<FusedRowwiseQuantizedSparseLengthsWeightedSumNode>( |
311 | &node)) { |
312 | VLOG(1) << SLS->getIndices() << " " << SLS->getWeights() << " " |
313 | << SLS->getLengths() << " " << SLS->getData().dims()[0]; |
314 | result.push_back(std::make_shared<SparseLengthsSumInputSanitizer>( |
315 | SLS->getData().dims()[0], |
316 | llvm::dyn_cast<Placeholder>(SLS->getIndices()), |
317 | llvm::dyn_cast<Placeholder>(SLS->getWeights()), |
318 | llvm::dyn_cast<Placeholder>(SLS->getLengths()))); |
319 | } else if (auto *SLS = |
320 | llvm::dyn_cast<FusedRowwiseQuantizedSparseLengthsSumNode>( |
321 | &node)) { |
322 | VLOG(1) << SLS->getIndices() << " " << SLS->getLengths() << " " |
323 | << SLS->getData().dims()[0]; |
324 | result.push_back(std::make_shared<SparseLengthsSumInputSanitizer>( |
325 | SLS->getData().dims()[0], |
326 | llvm::dyn_cast<Placeholder>(SLS->getIndices()), |
327 | /* weights */ nullptr, |
328 | llvm::dyn_cast<Placeholder>(SLS->getLengths()))); |
329 | } else if (auto *SLS = llvm::dyn_cast<SparseLengthsSumNode>(&node)) { |
330 | VLOG(1) << SLS->getIndices() << " " << SLS->getLengths() << " " |
331 | << SLS->getData().dims()[0]; |
332 | result.push_back(std::make_shared<SparseLengthsSumInputSanitizer>( |
333 | SLS->getData().dims()[0], |
334 | llvm::dyn_cast<Placeholder>(SLS->getIndices()), |
335 | /* weights */ nullptr, |
336 | llvm::dyn_cast<Placeholder>(SLS->getLengths()))); |
337 | } else if (auto *SLS = |
338 | llvm::dyn_cast<SparseLengthsWeightedSumNode>(&node)) { |
339 | VLOG(1) << SLS->getIndices() << " " << SLS->getWeights() << " " |
340 | << SLS->getLengths() << " " << SLS->getData().dims()[0]; |
341 | result.push_back(std::make_shared<SparseLengthsSumInputSanitizer>( |
342 | SLS->getData().dims()[0], |
343 | llvm::dyn_cast<Placeholder>(SLS->getIndices()), |
344 | llvm::dyn_cast<Placeholder>(SLS->getWeights()), |
345 | llvm::dyn_cast<Placeholder>(SLS->getLengths()))); |
346 | } else if (auto *EBB = llvm::dyn_cast<EmbeddingBagNode>(&node)) { |
347 | VLOG(1) << EBB->getIndices() << " " << EBB->getOffsets() << " " |
348 | << EBB->getData().dims()[0]; |
349 | result.push_back(std::make_shared<EmbeddingBagInputSanitizer>( |
350 | EBB->getData().dims()[0], |
351 | llvm::dyn_cast<Placeholder>(EBB->getIndices()), |
352 | /* weights */ nullptr, |
353 | llvm::dyn_cast<Placeholder>(EBB->getOffsets()))); |
354 | } else if (auto *EBB = |
355 | llvm::dyn_cast<EmbeddingBagByteRowwiseOffsetsNode>(&node)) { |
356 | VLOG(1) << EBB->getIndices() << " " << EBB->getWeights() << " " |
357 | << EBB->getOffsets() << " " << EBB->getData().dims()[0]; |
358 | result.push_back(std::make_shared<EmbeddingBagInputSanitizer>( |
359 | EBB->getData().dims()[0], |
360 | llvm::dyn_cast<Placeholder>(EBB->getIndices()), |
361 | llvm::dyn_cast<Placeholder>(EBB->getWeights()), |
362 | llvm::dyn_cast<Placeholder>(EBB->getOffsets()))); |
363 | } |
364 | } |
365 | |
366 | return result; |
367 | } |
368 | |
369 | Error sanitizeInputs(const std::vector<InputSanitizerPtr> &sanitizers, |
370 | const PlaceholderBindings &bindings) { |
371 | if (flags::SanitizeInputsPercent == 0 || |
372 | folly::Random::rand32() % 100 > flags::SanitizeInputsPercent) { |
373 | return Error::success(); |
374 | } |
375 | |
376 | for (auto &sanitizer : sanitizers) { |
377 | RETURN_IF_ERR(sanitizer->sanitize(bindings)); |
378 | } |
379 | |
380 | return Error::success(); |
381 | } |
382 | |
383 | } // namespace runtime |
384 | } // namespace glow |
385 | |