1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#ifndef TENSORFLOW_C_EAGER_ABSTRACT_OPERATION_H_
16#define TENSORFLOW_C_EAGER_ABSTRACT_OPERATION_H_
17
18#include <memory>
19
20#include "absl/types/span.h"
21#include "tensorflow/c/eager/abstract_tensor_handle.h"
22#include "tensorflow/c/tensor_interface.h"
23#include "tensorflow/core/framework/tensor_shape.h"
24#include "tensorflow/core/framework/types.pb.h"
25#include "tensorflow/core/platform/status.h"
26
27namespace tensorflow {
28
29// Abstract interface to an operation.
30// This interface allows building and executing an operation in either
31// tracing or immediate execution mode.
32class AbstractOperation {
33 protected:
34 enum AbstractOperationKind {
35 kGraph,
36 kMlir,
37 kEager,
38 kTfrt,
39 kTape,
40 kOpHandler
41 };
42 explicit AbstractOperation(AbstractOperationKind kind) : kind_(kind) {}
43 virtual ~AbstractOperation() {}
44
45 public:
46 AbstractOperationKind getKind() const { return kind_; }
47
48 // Release any underlying resources, including the interface object.
49 //
50 // WARNING: The destructor of this class is marked as protected to disallow
51 // clients from directly destroying this object since it may manage it's own
52 // lifetime through ref counting. Thus this must be allocated on the heap and
53 // clients MUST call Release() in order to destroy an instance of this class.
54 virtual void Release() = 0;
55
56 virtual Status Reset(const char* op, const char* raw_device_name) = 0;
57
58 virtual const string& Name() const = 0;
59
60 // Returns the operation's device name.
61 //
62 // The value returned may be different from the one set by SetDeviceName, but
63 // it will be compatible with it: the name will be updated by device placement
64 // logic to refer to the specific device chosen.
65 //
66 // Example: If one calls `op->SetDeviceName("/device:GPU")`, the value
67 // returned by DeviceName should be "/device:GPU:*" until a particular GPU is
68 // chosen for the operation by the device placement logic in the
69 // executor. After that, the value returned by DeviceName will be a full
70 // device name such as "/job:localhost/replica:0/task:0/device:GPU:1".
71 virtual const string& DeviceName() const = 0;
72
73 // Sets the operation device name.
74 //
75 // The given `name` must be parseable by DeviceNameUtils::ParseFullName, and
76 // the result will be used as a constraint for device placement. See the
77 // documentation for DeviceName for more details.
78 //
79 // The value will override the previous value - that is, no "merging" of
80 // existing and given constraints will be performed.
81 virtual Status SetDeviceName(const char* name) = 0;
82
83 virtual Status AddInput(AbstractTensorHandle* input) = 0;
84 virtual Status AddInputList(
85 absl::Span<AbstractTensorHandle* const> inputs) = 0;
86 virtual Status Execute(absl::Span<AbstractTensorHandle*> retvals,
87 int* num_retvals) = 0;
88
89 virtual Status SetAttrString(const char* attr_name, const char* data,
90 size_t length) = 0;
91 virtual Status SetAttrInt(const char* attr_name, int64_t value) = 0;
92 virtual Status SetAttrFloat(const char* attr_name, float value) = 0;
93 virtual Status SetAttrBool(const char* attr_name, bool value) = 0;
94 virtual Status SetAttrType(const char* attr_name, DataType value) = 0;
95 virtual Status SetAttrShape(const char* attr_name, const int64_t* dims,
96 const int num_dims) = 0;
97 virtual Status SetAttrShape(const char* attr_name,
98 const PartialTensorShape shape);
99 virtual Status SetAttrFunction(const char* attr_name,
100 const AbstractOperation* value) = 0;
101 virtual Status SetAttrFunctionName(const char* attr_name, const char* value,
102 size_t length) = 0;
103 virtual Status SetAttrTensor(const char* attr_name,
104 AbstractTensorInterface* tensor) = 0;
105 virtual Status SetAttrStringList(const char* attr_name,
106 const void* const* values,
107 const size_t* lengths, int num_values) = 0;
108 virtual Status SetAttrStringList(const char* attr_name,
109 absl::Span<string const> values);
110 virtual Status SetAttrFloatList(const char* attr_name, const float* values,
111 int num_values) = 0;
112 virtual Status SetAttrIntList(const char* attr_name, const int64_t* values,
113 int num_values) = 0;
114 virtual Status SetAttrTypeList(const char* attr_name, const DataType* values,
115 int num_values) = 0;
116 virtual Status SetAttrBoolList(const char* attr_name,
117 const unsigned char* values,
118 int num_values) = 0;
119 virtual Status SetAttrShapeList(const char* attr_name, const int64_t** dims,
120 const int* num_dims, int num_values) = 0;
121 virtual Status SetAttrFunctionList(
122 const char* attr_name, absl::Span<const AbstractOperation*> values) = 0;
123
124 private:
125 const AbstractOperationKind kind_;
126};
127
128// TODO(b/193656009): Defining these in a cc file causes linker errors with
129// fastbuild.
130inline Status AbstractOperation::SetAttrShape(const char* attr_name,
131 const PartialTensorShape shape) {
132 return SetAttrShape(attr_name, shape.dim_sizes().data(), shape.dims());
133}
134
135inline Status AbstractOperation::SetAttrStringList(
136 const char* attr_name, absl::Span<string const> values) {
137 std::vector<const char*> raw_strs;
138 std::vector<size_t> lengths;
139 raw_strs.reserve(values.size());
140 lengths.reserve(values.size());
141 for (const auto& s : values) {
142 raw_strs.emplace_back(s.data());
143 lengths.emplace_back(s.size());
144 }
145 return SetAttrStringList(attr_name,
146 reinterpret_cast<const void**>(raw_strs.data()),
147 lengths.data(), values.size());
148}
149
150namespace internal {
151struct AbstractOperationDeleter {
152 void operator()(AbstractOperation* p) const {
153 if (p != nullptr) {
154 p->Release();
155 }
156 }
157};
158} // namespace internal
159
160using AbstractOperationPtr =
161 std::unique_ptr<AbstractOperation, internal::AbstractOperationDeleter>;
162
163} // namespace tensorflow
164
165#endif // TENSORFLOW_C_EAGER_ABSTRACT_OPERATION_H_
166