1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#ifndef GLOW_BACKENDS_COMPILEDFUNCTION_H
17#define GLOW_BACKENDS_COMPILEDFUNCTION_H
18
19#include "glow/Backend/BackendUtils.h"
20#include "glow/Backend/BlockStreamBase.h"
21#include "glow/ExecutionContext/ExecutionContext.h"
22#include "glow/Graph/Nodes.h"
23#include "glow/Support/Error.h"
24
25namespace glow {
26
27class PlaceholderBindings;
28/// Interface for executing a compiled function.
29class CompiledFunction {
30public:
31 /// Default Ctor.
32 CompiledFunction() = delete;
33
34 /// Ctor that accepts runtimeBundle.
35 CompiledFunction(runtime::RuntimeBundle &&bundle);
36
37 /// Dtor.
38 virtual ~CompiledFunction();
39 /// Execute the network and allocate Placeholder memory with given
40 /// \p bindings providing mapping between Placeholder and populated tensor.
41 /// \returns an Error if an error ocurred during execution.
42 virtual Error execute(ExecutionContext *context) = 0;
43
44 /// Getter for the runtimeBundle.
45 runtime::RuntimeBundle &getRuntimeBundle() { return runtimeBundle_; }
46
47 /// Collects constants for runtime.
48 virtual void collectConstants(const Module *){};
49
50 /// Setter for TraceEvent lookup. Note: does not enable tracing automatically.
51 void setTraceInfo(TraceInfo &&info) { traceInfo_ = std::move(info); }
52
53 /// Getter for the TraceEvent lookup.
54 TraceInfo &getTraceInfo() { return traceInfo_; }
55 const TraceInfo &getTraceInfo() const { return traceInfo_; }
56
57 /// Read trace events out of this func and write them into /p bindings
58 virtual void translateTraceEvents(ExecutionContext *bindings) const {}
59
60 /// \returns the backend name used to compile this function.
61 virtual std::string getCompileBackendName() const = 0;
62
63 /// Once the compiledFunction is done being added to devices calling this
64 /// method will free any resources needed to load the network on the device
65 /// but not needed for running on the device.
66 virtual void freeCompilationResources(){};
67
68 /// \returns a JSON representation of the result of compilation. Structure of
69 /// the JSON is dependent on the backend.
70 virtual const std::string toJSON() const { return ""; }
71
72 /// Dumps a JSON representation of the result of compilation to the specified
73 /// path \p fname.
74 void dumpJSON(llvm::StringRef fname) const;
75
76 /// Return the ptr of serialized string of this compiled function.
77 /// Serialization is dependent on the backend.
78 /// If backend does not support serialization, return null.
79 /// Specifically, serialize() will take the ownership of BlockStream, as
80 /// unique_ptr is used.
81 virtual std::unique_ptr<BlockStreamBase> serialize() { return nullptr; }
82
83 /// Overwrite this compiled function through input \p serializedData.
84 /// Deserialization is dependent on the backend.
85 /// Return true if backend support deserialization, and deserialization is
86 /// successed.
87 virtual Error deserialize(const std::vector<char> &serializedData) {
88 return Error::success();
89 }
90
91protected:
92 /// Contains symbol offsets and allocation sizes.
93 runtime::RuntimeBundle runtimeBundle_;
94
95 /// Information regarding runtime trace instrumentation present in this
96 /// function.
97 TraceInfo traceInfo_;
98};
99} // end namespace glow
100
101#endif // GLOW_BACKENDS_COMPILEDFUNCTION_H
102