1/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_FUNCTION_RUNTIME_CLIENT_H_
17#define TENSORFLOW_CORE_FUNCTION_RUNTIME_CLIENT_H_
18
19#include <vector>
20
21#include "absl/types/span.h"
22#include "tensorflow/c/eager/abstract_tensor_handle.h"
23#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
24#include "tensorflow/core/common_runtime/eager/context.h"
25#include "tensorflow/core/framework/function.pb.h"
26#include "tensorflow/core/platform/status.h"
27#include "tensorflow/core/platform/statusor.h"
28#include "tensorflow/core/platform/stringpiece.h"
29
30namespace tensorflow {
31namespace core {
32namespace function {
33
34// TODO(mdan): Get rid of this once pybind can depend on MLIR headers.
35// This empty struct serves to hide a pointer to an actual MLIR FuncOp object.
36struct OpaqueTfgGraphFuncOp;
37
38// This is the current global context managed by the Python API. For historical
39// reasons, the Python runtime controls this context and all other clients must
40// use it. See tensorflow/python/eager/pywrap_tfe.h and
41// tensorflow/python/eager/context.py.
42//
43// This must always be called after the Python eager context was initialized.
44//
45// If the Python runtime isn't involved, or when writing code that exclusively
46// relies on functions defined in this namespace, users are encouraged to
47// maintain their own EagerContext or use GlobalEagerContext.
48EagerContext& GlobalPythonEagerContext();
49
50// This global context is available for testing and to be shared among various
51// APIs.
52EagerContext& GlobalEagerContext();
53
54using ReturnValues = std::vector<ImmediateTensorHandlePtr>;
55
56// A public API for manipulating and executing functions in a TensorFlow
57// runtime.
58class Runtime {
59 public:
60 explicit Runtime(EagerContext& eager_ctx) : eager_ctx_(eager_ctx) {}
61
62 StatusOr<FunctionDef> GetFunctionProto(StringPiece name);
63
64 // TODO(mdan): Enforce creation or rename to SetFunction.
65 Status CreateFunction(const FunctionDef& fdef);
66 // TODO(mdan): Change to mlir::tfg::GraphFuncOp once pybind can depend on it.
67 Status CreateFunction(OpaqueTfgGraphFuncOp* fop);
68 // Applies a MLIR pipeline to an existing function.
69 // The pipeline may rename the function. If it does so, the old function
70 // remains unchanged. If the new name specifies an existing function, it will
71 // be overwritten.
72 Status TransformFunction(StringPiece name, StringPiece pipeline_name);
73
74 StatusOr<ReturnValues> CallFunction(
75 StringPiece name, absl::Span<AbstractTensorHandle* const> args);
76
77 private:
78 EagerContext& eager_ctx_;
79};
80
81} // namespace function
82} // namespace core
83} // namespace tensorflow
84
85#endif // TENSORFLOW_CORE_FUNCTION_RUNTIME_CLIENT_H_
86