1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#ifndef GLOW_RUNTIME_DEFERREDWEIGHTLOADER_H
17#define GLOW_RUNTIME_DEFERREDWEIGHTLOADER_H
18
19#include "glog/logging.h"
20
21#include "glow/Base/Tensor.h"
22#include "glow/Support/Error.h"
23#include "glow/Support/Register.h"
24#include "glow/Support/Support.h"
25
26namespace glow {
27namespace runtime {
28
29/// A base class for deferred weight loaders. This allows for large weights to
30/// be skipped during compilation and loaded after compilation one at a time.
31class DeferredWeightLoader {
32public:
33 /// Loads the next weight, returns an Error indicating success/failure. Frees
34 /// any resources used by the current deferred weight.
35 virtual Error loadNextWeight() = 0;
36
37 /// Accepts a void * \p loaderObject with is passed in from the interface
38 /// library, e.g. deferredBlobReader in onnxifi.
39 virtual Error setSrc(void *loaderObject) = 0;
40
41 /// Accepts a map from string to Type \p info. This is used by the loader when
42 /// converted the loaded weight into a glow Tensor.
43 virtual void setTypeInfo(std::map<std::string, Type> info) = 0;
44
45 /// \returns a reference to \ref typeInfo_, a map from string to Type info.
46 std::map<std::string, glow::Type> &getTypeInfo() { return typeInfo_; }
47
48 /// Gets the name of the currently loaded weight.
49 virtual std::string getName() = 0;
50
51 /// Gets the Tensor for the current weight.
52 virtual Tensor *getTensor() = 0;
53
54 virtual ~DeferredWeightLoader() = default;
55
56protected:
57 std::map<std::string, glow::Type> typeInfo_;
58};
59
60class DeferredWeightLoaderRegistry final {
61public:
62 void registerLoader(DeferredWeightLoader *loader);
63 DeferredWeightLoader *getLoader();
64
65private:
66 DeferredWeightLoader *loader_{nullptr};
67};
68
69/// Global singleton.
70DeferredWeightLoaderRegistry *DeferredLoader();
71
72} // namespace runtime
73} // namespace glow
74
75#endif // GLOW_RUNTIME_DEFERREDWEIGHTLOADER_H
76