1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | #ifndef GLOW_RUNTIME_INPUTSANITIZER_H |
17 | #define GLOW_RUNTIME_INPUTSANITIZER_H |
18 | |
19 | #include <memory> |
20 | #include <vector> |
21 | |
22 | #include "glow/Graph/Graph.h" |
23 | #include "glow/Graph/Nodes.h" |
24 | #include "glow/Graph/PlaceholderBindings.h" |
25 | #include "glow/Support/Error.h" |
26 | |
27 | namespace glow { |
28 | namespace runtime { |
29 | |
30 | /* |
31 | * Base abstract class for input sanitizers. |
32 | * Each operator type can have its own specialization. |
33 | */ |
34 | class InputSanitizer { |
35 | public: |
36 | virtual ~InputSanitizer() = default; |
37 | virtual Error sanitize(const PlaceholderBindings &bindings) = 0; |
38 | virtual std::string toString() = 0; |
39 | }; |
40 | |
41 | using InputSanitizerPtr = std::shared_ptr<InputSanitizer>; |
42 | |
43 | class SparseLengthsSumInputSanitizer : public InputSanitizer { |
44 | public: |
45 | SparseLengthsSumInputSanitizer(const size_t tableHeight, |
46 | Placeholder *indicesPH, Placeholder *weightsPH, |
47 | Placeholder *lengthsPH); |
48 | |
49 | Error sanitize(const PlaceholderBindings &bindings) override; |
50 | |
51 | std::string toString() override; |
52 | |
53 | private: |
54 | size_t tableHeight_{0}; |
55 | Placeholder *indicesPH_{nullptr}; |
56 | Placeholder *weightsPH_{nullptr}; |
57 | Placeholder *lengthsPH_{nullptr}; |
58 | }; |
59 | |
60 | class EmbeddingBagInputSanitizer : public InputSanitizer { |
61 | public: |
62 | EmbeddingBagInputSanitizer(size_t tableHeight, Placeholder *indicesPH, |
63 | Placeholder *weightsPH, Placeholder *offsetsPH); |
64 | |
65 | Error sanitize(const PlaceholderBindings &bindings) override; |
66 | |
67 | std::string toString() override; |
68 | |
69 | private: |
70 | size_t tableHeight_{0}; |
71 | Placeholder *indicesPH_{nullptr}; |
72 | Placeholder *weightsPH_{nullptr}; |
73 | Placeholder *offsetsPH_{nullptr}; |
74 | }; |
75 | |
76 | // |
77 | // Public utility functions |
78 | // |
79 | std::vector<InputSanitizerPtr> getInputSanitizers(const Function &function); |
80 | Error sanitizeInputs(const std::vector<InputSanitizerPtr> &sanitizers, |
81 | const PlaceholderBindings &bindings); |
82 | |
83 | } // namespace runtime |
84 | } // namespace glow |
85 | |
86 | #endif // GLOW_RUNTIME_INPUTSANITIZER_H |
87 | |