1 | /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_FUNCTION_RUNTIME_CLIENT_H_ |
17 | #define TENSORFLOW_CORE_FUNCTION_RUNTIME_CLIENT_H_ |
18 | |
19 | #include <vector> |
20 | |
21 | #include "absl/types/span.h" |
22 | #include "tensorflow/c/eager/abstract_tensor_handle.h" |
23 | #include "tensorflow/c/eager/immediate_execution_tensor_handle.h" |
24 | #include "tensorflow/core/common_runtime/eager/context.h" |
25 | #include "tensorflow/core/framework/function.pb.h" |
26 | #include "tensorflow/core/platform/status.h" |
27 | #include "tensorflow/core/platform/statusor.h" |
28 | #include "tensorflow/core/platform/stringpiece.h" |
29 | |
30 | namespace tensorflow { |
31 | namespace core { |
32 | namespace function { |
33 | |
34 | // TODO(mdan): Get rid of this once pybind can depend on MLIR headers. |
35 | // This empty struct serves to hide a pointer to an actual MLIR FuncOp object. |
36 | struct OpaqueTfgGraphFuncOp; |
37 | |
38 | // This is the current global context managed by the Python API. For historical |
39 | // reasons, the Python runtime controls this context and all other clients must |
40 | // use it. See tensorflow/python/eager/pywrap_tfe.h and |
41 | // tensorflow/python/eager/context.py. |
42 | // |
43 | // This must always be called after the Python eager context was initialized. |
44 | // |
45 | // If the Python runtime isn't involved, or when writing code that exclusively |
46 | // relies on functions defined in this namespace, users are encouraged to |
47 | // maintain their own EagerContext or use GlobalEagerContext. |
48 | EagerContext& GlobalPythonEagerContext(); |
49 | |
50 | // This global context is available for testing and to be shared among various |
51 | // APIs. |
52 | EagerContext& GlobalEagerContext(); |
53 | |
54 | using ReturnValues = std::vector<ImmediateTensorHandlePtr>; |
55 | |
56 | // A public API for manipulating and executing functions in a TensorFlow |
57 | // runtime. |
58 | class Runtime { |
59 | public: |
60 | explicit Runtime(EagerContext& eager_ctx) : eager_ctx_(eager_ctx) {} |
61 | |
62 | StatusOr<FunctionDef> GetFunctionProto(StringPiece name); |
63 | |
64 | // TODO(mdan): Enforce creation or rename to SetFunction. |
65 | Status CreateFunction(const FunctionDef& fdef); |
66 | // TODO(mdan): Change to mlir::tfg::GraphFuncOp once pybind can depend on it. |
67 | Status CreateFunction(OpaqueTfgGraphFuncOp* fop); |
68 | // Applies a MLIR pipeline to an existing function. |
69 | // The pipeline may rename the function. If it does so, the old function |
70 | // remains unchanged. If the new name specifies an existing function, it will |
71 | // be overwritten. |
72 | Status TransformFunction(StringPiece name, StringPiece pipeline_name); |
73 | |
74 | StatusOr<ReturnValues> CallFunction( |
75 | StringPiece name, absl::Span<AbstractTensorHandle* const> args); |
76 | |
77 | private: |
78 | EagerContext& eager_ctx_; |
79 | }; |
80 | |
81 | } // namespace function |
82 | } // namespace core |
83 | } // namespace tensorflow |
84 | |
85 | #endif // TENSORFLOW_CORE_FUNCTION_RUNTIME_CLIENT_H_ |
86 | |