1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "glow/Runtime/InputSanitizer.h"
18#include "glow/Flags/Flags.h"
19
20#include <folly/Random.h>
21#include <glog/logging.h>
22#include <llvm/Support/Casting.h>
23
24namespace glow {
25namespace runtime {
26
27namespace {
28
29template <class T>
30static Error sanitizeIndices(const Tensor *indicesTensor, size_t tableHeight,
31 llvm::StringRef tensorName) {
32 auto indices = indicesTensor->getHandle<T>();
33 size_t indicesLen = indices.getRealNumElements();
34 // indices in [0, tableHeight)
35 for (auto i = 0; i < indicesLen; i++) {
36 RETURN_ERR_IF_NOT(indices.raw(i) >= 0 && indices.raw(i) < tableHeight,
37 "Indices sanitization failed on tensor " +
38 tensorName.str() + ": index " +
39 std::to_string(indices.raw(i)) + " at pos " +
40 std::to_string(i) + " is out of range [0, " +
41 std::to_string(tableHeight) + ")");
42 }
43
44 return Error::success();
45}
46
47template <class T>
48static Error sanitizeLengths(const Tensor *lengthsTensor,
49 const size_t indicesLen,
50 llvm::StringRef tensorName) {
51 auto lengths = lengthsTensor->getHandle<T>();
52
53 size_t totalLensSum = 0;
54 for (auto i = 0; i < lengths.getRealNumElements(); ++i) {
55 auto length = lengths.raw(i);
56 RETURN_ERR_IF_NOT(length >= 0,
57 "SLS lengths sanitization failed on tensor " +
58 tensorName.str() + ": length " +
59 std::to_string(length) + " at pos " +
60 std::to_string(i) + " is negative");
61 totalLensSum += length;
62 }
63
64 RETURN_ERR_IF_NOT(
65 indicesLen == totalLensSum,
66 strFormat("SLS lengths sanitization failed on tensor %s: indices "
67 "length %lu is not equal to sum of lengths %lu",
68 tensorName.str().c_str(), indicesLen, totalLensSum));
69
70 return Error::success();
71}
72
73template <class T>
74static Error sanitizeOffsets(const Tensor *offsetsTensor,
75 const size_t numberOfIndices,
76 llvm::StringRef tensorName) {
77 auto offsets = offsetsTensor->getHandle<T>();
78
79 RETURN_ERR_IF_NOT(offsets.raw(0) == 0,
80 "EBB offsets sanitization failed on tensor " +
81 tensorName.str() + ": the first offset is not zero " +
82 std::to_string(offsets.raw(0)));
83
84 bool zeroTensor = true;
85 size_t offsetsLen = offsets.getRealNumElements();
86 for (auto i = 0; i < offsetsLen - 1; i++) {
87 RETURN_ERR_IF_NOT(offsets.raw(i) <= offsets.raw(i + 1),
88 "EBB offsets sanitization failed on tensor " +
89 tensorName.str() + ": decreasing offsets " +
90 std::to_string(offsets.raw(i)) + " and " +
91 std::to_string(offsets.raw(i + 1)) + " at pos " +
92 std::to_string(i));
93
94 if (zeroTensor && offsets.raw(i + 1) != 0) {
95 zeroTensor = false;
96 }
97 }
98
99 size_t lastOffset = offsets.raw(offsetsLen - 1);
100 RETURN_ERR_IF_NOT(
101 zeroTensor || lastOffset == numberOfIndices,
102 strFormat("EBB offsets sanitization failed on tensor %s: "
103 "the last offset %lu is not equal to the number of indices %lu",
104 tensorName.str().c_str(), lastOffset, numberOfIndices));
105
106 return Error::success();
107}
108
109} // namespace
110
111//
112// SparseLengthsSum input sanitization
113//
114SparseLengthsSumInputSanitizer::SparseLengthsSumInputSanitizer(
115 const size_t tableHeight, Placeholder *indicesPH, Placeholder *weightsPH,
116 Placeholder *lengthsPH)
117 : tableHeight_{tableHeight}, indicesPH_{indicesPH}, weightsPH_{weightsPH},
118 lengthsPH_{lengthsPH} {}
119
120Error SparseLengthsSumInputSanitizer::sanitize(
121 const PlaceholderBindings &bindings) {
122 auto *indices = bindings.get(indicesPH_);
123
124 // Either a constant or some node internal to the function, skip
125 if (indices == nullptr) {
126 return Error::success();
127 }
128
129 size_t indicesLen = indices->getRealNumElements();
130
131 if (weightsPH_) {
132 auto *weights = bindings.get(weightsPH_);
133 // If this is a weigthed one and the placeholder is real (not a constant
134 // or internal to the function, then sanitize
135 if (weights != nullptr) {
136 size_t weightsLen = weights->getRealNumElements();
137 RETURN_ERR_IF_NOT(
138 indicesLen == weightsLen,
139 strFormat("SLS weights sanitization failed on %s: number of indices "
140 "%lu is not equal to number of weights %lu",
141 weightsPH_->getName().str().c_str(), indicesLen,
142 weightsLen));
143 }
144 }
145
146 // Sanitize indices
147 if (indices->getElementType() == ElemKind::Int64ITy) {
148 RETURN_IF_ERR(
149 sanitizeIndices<int64_t>(indices, tableHeight_, indicesPH_->getName()));
150 } else if (indices->getElementType() == ElemKind::Int32ITy) {
151 RETURN_IF_ERR(
152 sanitizeIndices<int32_t>(indices, tableHeight_, indicesPH_->getName()));
153 } else {
154 return MAKE_ERR(strFormat(
155 "SLS indices sanitization failed on tensor %s: unsupported "
156 "element type %s",
157 indicesPH_->getName().str().c_str(),
158 Type::getElementName(indices->getElementType()).str().c_str()));
159 }
160
161 // Sanitize SLS lengths
162 auto *lengths = bindings.get(lengthsPH_);
163
164 // Either a constant or some node internal to the function, skip
165 if (lengths == nullptr) {
166 return Error::success();
167 }
168
169 if (lengths->getElementType() == ElemKind::Int32ITy) {
170 RETURN_IF_ERR(
171 sanitizeLengths<int32_t>(lengths, indicesLen, lengthsPH_->getName()));
172 } else if (lengths->getElementType() == ElemKind::Int64ITy) {
173 RETURN_IF_ERR(
174 sanitizeLengths<int64_t>(lengths, indicesLen, lengthsPH_->getName()));
175 } else {
176 return MAKE_ERR(strFormat(
177 "SLS lengths sanitization failed on tensor %s: unsupported "
178 "element type %s",
179 lengthsPH_->getName().str().c_str(),
180 Type::getElementName(lengths->getElementType()).str().c_str()));
181 }
182
183 return Error::success();
184}
185
186std::string SparseLengthsSumInputSanitizer::toString() {
187 std::ostringstream ss;
188 ss << "SparseLengthsSumInputSanitizer[";
189 ss << "tableHeight=" << tableHeight_;
190 ss << ", indices=";
191 if (indicesPH_) {
192 ss << indicesPH_->getName().str();
193 }
194 ss << ", weigths=";
195 if (weightsPH_) {
196 ss << weightsPH_->getName().str();
197 }
198 ss << ", lengths=";
199 if (lengthsPH_) {
200 ss << lengthsPH_->getName().str();
201 }
202 ss << "]";
203 return ss.str();
204}
205
206//
207// EmbeddingBag input sanitization
208//
209EmbeddingBagInputSanitizer::EmbeddingBagInputSanitizer(size_t tableHeight,
210 Placeholder *indicesPH,
211 Placeholder *weightsPH,
212 Placeholder *offsetsPH)
213 : tableHeight_{tableHeight}, indicesPH_{indicesPH}, weightsPH_{weightsPH},
214 offsetsPH_{offsetsPH} {}
215
216Error EmbeddingBagInputSanitizer::sanitize(
217 const PlaceholderBindings &bindings) {
218 auto *indices = bindings.get(indicesPH_);
219
220 // Either a constant or some node internal to the function, skip
221 if (indices == nullptr) {
222 return Error::success();
223 }
224
225 size_t indicesLen = indices->getRealNumElements();
226
227 if (weightsPH_) {
228 auto *weights = bindings.get(weightsPH_);
229 // If this is a weigthed one and the placeholder is real (not a constant
230 // or internal to the function, then sanitize
231 if (weights != nullptr) {
232 size_t weightsLen = weights->getRealNumElements();
233 RETURN_ERR_IF_NOT(
234 indicesLen == weightsLen,
235 strFormat("EBB weights sanitization failed on %s: number of indices "
236 "%lu is not equal to number of weights %lu",
237 weightsPH_->getName().str().c_str(), indicesLen,
238 weightsLen));
239 }
240 }
241
242 // Sanitize indices
243 if (indices->getElementType() == ElemKind::Int64ITy) {
244 RETURN_IF_ERR(
245 sanitizeIndices<int64_t>(indices, tableHeight_, indicesPH_->getName()));
246 } else if (indices->getElementType() == ElemKind::Int32ITy) {
247 RETURN_IF_ERR(
248 sanitizeIndices<int32_t>(indices, tableHeight_, indicesPH_->getName()));
249 } else {
250 return MAKE_ERR(strFormat(
251 "EBB indices sanitization failed on tensor %s: unsupported "
252 "element type %s",
253 indicesPH_->getName().str().c_str(),
254 Type::getElementName(indices->getElementType()).str().c_str()));
255 }
256
257 // Sanitize offsets
258 auto *offsets = bindings.get(offsetsPH_);
259
260 // Either a constant or some node internal to the function, skip
261 if (offsets == nullptr) {
262 return Error::success();
263 }
264
265 if (offsets->getElementType() == ElemKind::Int32ITy) {
266 RETURN_IF_ERR(
267 sanitizeOffsets<int32_t>(offsets, indicesLen, offsetsPH_->getName()));
268 } else if (offsets->getElementType() == ElemKind::Int64ITy) {
269 RETURN_IF_ERR(
270 sanitizeOffsets<int64_t>(offsets, indicesLen, offsetsPH_->getName()));
271 } else {
272 return MAKE_ERR(strFormat(
273 "EBB offsets sanitization failed on tensor %s: unsupported "
274 "element type %s",
275 offsetsPH_->getName().str().c_str(),
276 Type::getElementName(offsets->getElementType()).str().c_str()));
277 }
278
279 return Error::success();
280}
281
282std::string EmbeddingBagInputSanitizer::toString() {
283 std::ostringstream ss;
284 ss << "EmbeddingBagInputSanitizer[";
285 ss << "tableHeight=" << tableHeight_;
286 ss << ", indices=";
287 if (indicesPH_) {
288 ss << indicesPH_->getName().str();
289 }
290 ss << ", weigths=";
291 if (weightsPH_) {
292 ss << weightsPH_->getName().str();
293 }
294 ss << ", offsets=";
295 if (offsetsPH_) {
296 ss << offsetsPH_->getName().str();
297 }
298 ss << "]";
299 return ss.str();
300}
301
302//
303// Public utility functions
304//
305std::vector<InputSanitizerPtr> getInputSanitizers(const Function &function) {
306 std::vector<InputSanitizerPtr> result;
307
308 for (const auto &node : function.getNodes()) {
309 if (auto *SLS =
310 llvm::dyn_cast<FusedRowwiseQuantizedSparseLengthsWeightedSumNode>(
311 &node)) {
312 VLOG(1) << SLS->getIndices() << " " << SLS->getWeights() << " "
313 << SLS->getLengths() << " " << SLS->getData().dims()[0];
314 result.push_back(std::make_shared<SparseLengthsSumInputSanitizer>(
315 SLS->getData().dims()[0],
316 llvm::dyn_cast<Placeholder>(SLS->getIndices()),
317 llvm::dyn_cast<Placeholder>(SLS->getWeights()),
318 llvm::dyn_cast<Placeholder>(SLS->getLengths())));
319 } else if (auto *SLS =
320 llvm::dyn_cast<FusedRowwiseQuantizedSparseLengthsSumNode>(
321 &node)) {
322 VLOG(1) << SLS->getIndices() << " " << SLS->getLengths() << " "
323 << SLS->getData().dims()[0];
324 result.push_back(std::make_shared<SparseLengthsSumInputSanitizer>(
325 SLS->getData().dims()[0],
326 llvm::dyn_cast<Placeholder>(SLS->getIndices()),
327 /* weights */ nullptr,
328 llvm::dyn_cast<Placeholder>(SLS->getLengths())));
329 } else if (auto *SLS = llvm::dyn_cast<SparseLengthsSumNode>(&node)) {
330 VLOG(1) << SLS->getIndices() << " " << SLS->getLengths() << " "
331 << SLS->getData().dims()[0];
332 result.push_back(std::make_shared<SparseLengthsSumInputSanitizer>(
333 SLS->getData().dims()[0],
334 llvm::dyn_cast<Placeholder>(SLS->getIndices()),
335 /* weights */ nullptr,
336 llvm::dyn_cast<Placeholder>(SLS->getLengths())));
337 } else if (auto *SLS =
338 llvm::dyn_cast<SparseLengthsWeightedSumNode>(&node)) {
339 VLOG(1) << SLS->getIndices() << " " << SLS->getWeights() << " "
340 << SLS->getLengths() << " " << SLS->getData().dims()[0];
341 result.push_back(std::make_shared<SparseLengthsSumInputSanitizer>(
342 SLS->getData().dims()[0],
343 llvm::dyn_cast<Placeholder>(SLS->getIndices()),
344 llvm::dyn_cast<Placeholder>(SLS->getWeights()),
345 llvm::dyn_cast<Placeholder>(SLS->getLengths())));
346 } else if (auto *EBB = llvm::dyn_cast<EmbeddingBagNode>(&node)) {
347 VLOG(1) << EBB->getIndices() << " " << EBB->getOffsets() << " "
348 << EBB->getData().dims()[0];
349 result.push_back(std::make_shared<EmbeddingBagInputSanitizer>(
350 EBB->getData().dims()[0],
351 llvm::dyn_cast<Placeholder>(EBB->getIndices()),
352 /* weights */ nullptr,
353 llvm::dyn_cast<Placeholder>(EBB->getOffsets())));
354 } else if (auto *EBB =
355 llvm::dyn_cast<EmbeddingBagByteRowwiseOffsetsNode>(&node)) {
356 VLOG(1) << EBB->getIndices() << " " << EBB->getWeights() << " "
357 << EBB->getOffsets() << " " << EBB->getData().dims()[0];
358 result.push_back(std::make_shared<EmbeddingBagInputSanitizer>(
359 EBB->getData().dims()[0],
360 llvm::dyn_cast<Placeholder>(EBB->getIndices()),
361 llvm::dyn_cast<Placeholder>(EBB->getWeights()),
362 llvm::dyn_cast<Placeholder>(EBB->getOffsets())));
363 }
364 }
365
366 return result;
367}
368
369Error sanitizeInputs(const std::vector<InputSanitizerPtr> &sanitizers,
370 const PlaceholderBindings &bindings) {
371 if (flags::SanitizeInputsPercent == 0 ||
372 folly::Random::rand32() % 100 > flags::SanitizeInputsPercent) {
373 return Error::success();
374 }
375
376 for (auto &sanitizer : sanitizers) {
377 RETURN_IF_ERR(sanitizer->sanitize(bindings));
378 }
379
380 return Error::success();
381}
382
383} // namespace runtime
384} // namespace glow
385