1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#ifndef GLOW_RUNTIME_INPUTSANITIZER_H
17#define GLOW_RUNTIME_INPUTSANITIZER_H
18
19#include <memory>
20#include <vector>
21
22#include "glow/Graph/Graph.h"
23#include "glow/Graph/Nodes.h"
24#include "glow/Graph/PlaceholderBindings.h"
25#include "glow/Support/Error.h"
26
27namespace glow {
28namespace runtime {
29
30/*
31 * Base abstract class for input sanitizers.
32 * Each operator type can have its own specialization.
33 */
34class InputSanitizer {
35public:
36 virtual ~InputSanitizer() = default;
37 virtual Error sanitize(const PlaceholderBindings &bindings) = 0;
38 virtual std::string toString() = 0;
39};
40
41using InputSanitizerPtr = std::shared_ptr<InputSanitizer>;
42
43class SparseLengthsSumInputSanitizer : public InputSanitizer {
44public:
45 SparseLengthsSumInputSanitizer(const size_t tableHeight,
46 Placeholder *indicesPH, Placeholder *weightsPH,
47 Placeholder *lengthsPH);
48
49 Error sanitize(const PlaceholderBindings &bindings) override;
50
51 std::string toString() override;
52
53private:
54 size_t tableHeight_{0};
55 Placeholder *indicesPH_{nullptr};
56 Placeholder *weightsPH_{nullptr};
57 Placeholder *lengthsPH_{nullptr};
58};
59
60class EmbeddingBagInputSanitizer : public InputSanitizer {
61public:
62 EmbeddingBagInputSanitizer(size_t tableHeight, Placeholder *indicesPH,
63 Placeholder *weightsPH, Placeholder *offsetsPH);
64
65 Error sanitize(const PlaceholderBindings &bindings) override;
66
67 std::string toString() override;
68
69private:
70 size_t tableHeight_{0};
71 Placeholder *indicesPH_{nullptr};
72 Placeholder *weightsPH_{nullptr};
73 Placeholder *offsetsPH_{nullptr};
74};
75
76//
77// Public utility functions
78//
79std::vector<InputSanitizerPtr> getInputSanitizers(const Function &function);
80Error sanitizeInputs(const std::vector<InputSanitizerPtr> &sanitizers,
81 const PlaceholderBindings &bindings);
82
83} // namespace runtime
84} // namespace glow
85
86#endif // GLOW_RUNTIME_INPUTSANITIZER_H
87