1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_INTERNAL_H_ |
17 | #define TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_INTERNAL_H_ |
18 | |
19 | #include <vector> |
20 | |
21 | #include "tensorflow/c/c_api.h" |
22 | #include "tensorflow/c/conversion_macros.h" |
23 | #include "tensorflow/c/eager/abstract_context.h" |
24 | #include "tensorflow/c/eager/abstract_operation.h" |
25 | #include "tensorflow/c/eager/abstract_tensor_handle.h" |
26 | #include "tensorflow/c/eager/c_api_unified_experimental.h" |
27 | #include "tensorflow/c/tf_datatype.h" |
28 | #include "tensorflow/c/tf_status.h" |
29 | #include "tensorflow/core/framework/tensor_shape.h" |
30 | #include "tensorflow/core/platform/casts.h" |
31 | #include "tensorflow/core/platform/types.h" |
32 | |
33 | namespace tensorflow { |
34 | |
35 | // Represents the results of the execution of an operation. |
36 | struct OutputList { |
37 | std::vector<AbstractTensorHandle*> outputs; |
38 | int expected_num_outputs = -1; |
39 | }; |
40 | |
41 | namespace tracing { |
42 | |
43 | // ============================================================================= |
44 | // Implementation detail for the unified execution APIs for Eager and tracing |
45 | // backends (graph/MLIR). |
46 | // |
47 | // This defines a set of abstract classes that are intended to provide the |
48 | // functionality of the opaque C types exposed in the public APIs defined in the |
49 | // `c_api_unified_experimental.h` header. |
50 | // ============================================================================= |
51 | |
52 | // Represents either a MlirTensor or a GraphTensor. |
53 | // This base class does not expose any public methods other than to distinguish |
54 | // which subclass it actually is. The user is responsible to use the right |
55 | // type of AbstractTensor in their context (do not pass an MlirTensor to a |
56 | // GraphContext and vice-versa). |
57 | class TracingTensorHandle : public AbstractTensorHandle { |
58 | protected: |
59 | explicit TracingTensorHandle(AbstractTensorHandleKind kind) |
60 | : AbstractTensorHandle(kind) {} |
61 | |
62 | public: |
63 | // For LLVM style RTTI. |
64 | static bool classof(const AbstractTensorHandle* ptr) { |
65 | return ptr->getKind() == kGraph || ptr->getKind() == kMlir; |
66 | } |
67 | }; |
68 | |
69 | // An abstract operation describes an operation by its type, name, and |
70 | // attributes. It can be "executed" by the context with some input tensors. |
71 | // It is allowed to reusing the same abstract operation for multiple execution |
72 | // on a given context, with the same or different input tensors. |
73 | class TracingOperation : public AbstractOperation { |
74 | protected: |
75 | explicit TracingOperation(AbstractOperationKind kind) |
76 | : AbstractOperation(kind) {} |
77 | |
78 | public: |
79 | // Sets the name of the operation: this is an optional identifier that is |
80 | // not intended to carry semantics and preserved/propagated without |
81 | // guarantees. |
82 | virtual Status SetOpName(const char* op_name) = 0; |
83 | |
84 | // For LLVM style RTTI. |
85 | static bool classof(const AbstractOperation* ptr) { |
86 | return ptr->getKind() == kGraph || ptr->getKind() == kMlir; |
87 | } |
88 | }; |
89 | |
90 | namespace internal { |
91 | struct TracingOperationDeleter { |
92 | void operator()(TracingOperation* p) const { |
93 | if (p != nullptr) { |
94 | p->Release(); |
95 | } |
96 | } |
97 | }; |
98 | } // namespace internal |
99 | |
100 | using TracingOperationPtr = |
101 | std::unique_ptr<TracingOperation, internal::TracingOperationDeleter>; |
102 | |
103 | // This holds the context for the execution: dispatching operations either to an |
104 | // MLIR implementation or to a graph implementation. |
105 | class TracingContext : public AbstractContext { |
106 | protected: |
107 | explicit TracingContext(AbstractContextKind kind) : AbstractContext(kind) {} |
108 | |
109 | public: |
110 | // Add a function parameter and return the corresponding tensor. |
111 | virtual Status AddParameter(DataType dtype, const PartialTensorShape& shape, |
112 | TracingTensorHandle**) = 0; |
113 | |
114 | // Finalize this context and make a function out of it. The context is in a |
115 | // invalid state after this call and must be destroyed. |
116 | virtual Status Finalize(OutputList* outputs, AbstractFunction**) = 0; |
117 | |
118 | // For LLVM style RTTI. |
119 | static bool classof(const AbstractContext* ptr) { |
120 | return ptr->getKind() == kGraph || ptr->getKind() == kMlir; |
121 | } |
122 | }; |
123 | |
124 | typedef TracingContext* (*FactoryFunction)(const char* fn_name, TF_Status*); |
125 | Status SetDefaultTracingEngine(const char* name); |
126 | void RegisterTracingEngineFactory(const ::tensorflow::string& name, |
127 | FactoryFunction factory); |
128 | } // namespace tracing |
129 | |
130 | DEFINE_CONVERSION_FUNCTIONS(AbstractContext, TF_ExecutionContext) |
131 | DEFINE_CONVERSION_FUNCTIONS(AbstractTensorHandle, TF_AbstractTensor) |
132 | DEFINE_CONVERSION_FUNCTIONS(AbstractFunction, TF_AbstractFunction) |
133 | DEFINE_CONVERSION_FUNCTIONS(AbstractOperation, TF_AbstractOp) |
134 | DEFINE_CONVERSION_FUNCTIONS(OutputList, TF_OutputList) |
135 | } // namespace tensorflow |
136 | |
137 | #endif // TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_INTERNAL_H_ |
138 | |