1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | #ifndef GLOW_RUNTIME_DEFERREDWEIGHTLOADER_H |
17 | #define GLOW_RUNTIME_DEFERREDWEIGHTLOADER_H |
18 | |
19 | #include "glog/logging.h" |
20 | |
21 | #include "glow/Base/Tensor.h" |
22 | #include "glow/Support/Error.h" |
23 | #include "glow/Support/Register.h" |
24 | #include "glow/Support/Support.h" |
25 | |
26 | namespace glow { |
27 | namespace runtime { |
28 | |
29 | /// A base class for deferred weight loaders. This allows for large weights to |
30 | /// be skipped during compilation and loaded after compilation one at a time. |
31 | class DeferredWeightLoader { |
32 | public: |
33 | /// Loads the next weight, returns an Error indicating success/failure. Frees |
34 | /// any resources used by the current deferred weight. |
35 | virtual Error loadNextWeight() = 0; |
36 | |
37 | /// Accepts a void * \p loaderObject with is passed in from the interface |
38 | /// library, e.g. deferredBlobReader in onnxifi. |
39 | virtual Error setSrc(void *loaderObject) = 0; |
40 | |
41 | /// Accepts a map from string to Type \p info. This is used by the loader when |
42 | /// converted the loaded weight into a glow Tensor. |
43 | virtual void setTypeInfo(std::map<std::string, Type> info) = 0; |
44 | |
45 | /// \returns a reference to \ref typeInfo_, a map from string to Type info. |
46 | std::map<std::string, glow::Type> &getTypeInfo() { return typeInfo_; } |
47 | |
48 | /// Gets the name of the currently loaded weight. |
49 | virtual std::string getName() = 0; |
50 | |
51 | /// Gets the Tensor for the current weight. |
52 | virtual Tensor *getTensor() = 0; |
53 | |
54 | virtual ~DeferredWeightLoader() = default; |
55 | |
56 | protected: |
57 | std::map<std::string, glow::Type> typeInfo_; |
58 | }; |
59 | |
60 | class DeferredWeightLoaderRegistry final { |
61 | public: |
62 | void registerLoader(DeferredWeightLoader *loader); |
63 | DeferredWeightLoader *getLoader(); |
64 | |
65 | private: |
66 | DeferredWeightLoader *loader_{nullptr}; |
67 | }; |
68 | |
69 | /// Global singleton. |
70 | DeferredWeightLoaderRegistry *DeferredLoader(); |
71 | |
72 | } // namespace runtime |
73 | } // namespace glow |
74 | |
75 | #endif // GLOW_RUNTIME_DEFERREDWEIGHTLOADER_H |
76 | |