1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_INTERNAL_H_
17#define TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_INTERNAL_H_
18
19#include <vector>
20
21#include "tensorflow/c/c_api.h"
22#include "tensorflow/c/conversion_macros.h"
23#include "tensorflow/c/eager/abstract_context.h"
24#include "tensorflow/c/eager/abstract_operation.h"
25#include "tensorflow/c/eager/abstract_tensor_handle.h"
26#include "tensorflow/c/eager/c_api_unified_experimental.h"
27#include "tensorflow/c/tf_datatype.h"
28#include "tensorflow/c/tf_status.h"
29#include "tensorflow/core/framework/tensor_shape.h"
30#include "tensorflow/core/platform/casts.h"
31#include "tensorflow/core/platform/types.h"
32
33namespace tensorflow {
34
35// Represents the results of the execution of an operation.
36struct OutputList {
37 std::vector<AbstractTensorHandle*> outputs;
38 int expected_num_outputs = -1;
39};
40
41namespace tracing {
42
43// =============================================================================
44// Implementation detail for the unified execution APIs for Eager and tracing
45// backends (graph/MLIR).
46//
47// This defines a set of abstract classes that are intended to provide the
48// functionality of the opaque C types exposed in the public APIs defined in the
49// `c_api_unified_experimental.h` header.
50// =============================================================================
51
52// Represents either a MlirTensor or a GraphTensor.
53// This base class does not expose any public methods other than to distinguish
54// which subclass it actually is. The user is responsible to use the right
55// type of AbstractTensor in their context (do not pass an MlirTensor to a
56// GraphContext and vice-versa).
57class TracingTensorHandle : public AbstractTensorHandle {
58 protected:
59 explicit TracingTensorHandle(AbstractTensorHandleKind kind)
60 : AbstractTensorHandle(kind) {}
61
62 public:
63 // For LLVM style RTTI.
64 static bool classof(const AbstractTensorHandle* ptr) {
65 return ptr->getKind() == kGraph || ptr->getKind() == kMlir;
66 }
67};
68
69// An abstract operation describes an operation by its type, name, and
70// attributes. It can be "executed" by the context with some input tensors.
71// It is allowed to reusing the same abstract operation for multiple execution
72// on a given context, with the same or different input tensors.
73class TracingOperation : public AbstractOperation {
74 protected:
75 explicit TracingOperation(AbstractOperationKind kind)
76 : AbstractOperation(kind) {}
77
78 public:
79 // Sets the name of the operation: this is an optional identifier that is
80 // not intended to carry semantics and preserved/propagated without
81 // guarantees.
82 virtual Status SetOpName(const char* op_name) = 0;
83
84 // For LLVM style RTTI.
85 static bool classof(const AbstractOperation* ptr) {
86 return ptr->getKind() == kGraph || ptr->getKind() == kMlir;
87 }
88};
89
90namespace internal {
91struct TracingOperationDeleter {
92 void operator()(TracingOperation* p) const {
93 if (p != nullptr) {
94 p->Release();
95 }
96 }
97};
98} // namespace internal
99
100using TracingOperationPtr =
101 std::unique_ptr<TracingOperation, internal::TracingOperationDeleter>;
102
103// This holds the context for the execution: dispatching operations either to an
104// MLIR implementation or to a graph implementation.
105class TracingContext : public AbstractContext {
106 protected:
107 explicit TracingContext(AbstractContextKind kind) : AbstractContext(kind) {}
108
109 public:
110 // Add a function parameter and return the corresponding tensor.
111 virtual Status AddParameter(DataType dtype, const PartialTensorShape& shape,
112 TracingTensorHandle**) = 0;
113
114 // Finalize this context and make a function out of it. The context is in a
115 // invalid state after this call and must be destroyed.
116 virtual Status Finalize(OutputList* outputs, AbstractFunction**) = 0;
117
118 // For LLVM style RTTI.
119 static bool classof(const AbstractContext* ptr) {
120 return ptr->getKind() == kGraph || ptr->getKind() == kMlir;
121 }
122};
123
124typedef TracingContext* (*FactoryFunction)(const char* fn_name, TF_Status*);
125Status SetDefaultTracingEngine(const char* name);
126void RegisterTracingEngineFactory(const ::tensorflow::string& name,
127 FactoryFunction factory);
128} // namespace tracing
129
130DEFINE_CONVERSION_FUNCTIONS(AbstractContext, TF_ExecutionContext)
131DEFINE_CONVERSION_FUNCTIONS(AbstractTensorHandle, TF_AbstractTensor)
132DEFINE_CONVERSION_FUNCTIONS(AbstractFunction, TF_AbstractFunction)
133DEFINE_CONVERSION_FUNCTIONS(AbstractOperation, TF_AbstractOp)
134DEFINE_CONVERSION_FUNCTIONS(OutputList, TF_OutputList)
135} // namespace tensorflow
136
137#endif // TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_INTERNAL_H_
138