1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_OPERATION_H_
16#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_OPERATION_H_
17
18#include "absl/container/inlined_vector.h"
19#include "absl/types/optional.h"
20#include "absl/types/span.h"
21#include "absl/types/variant.h"
22#include "tensorflow/c/eager/abstract_tensor_handle.h"
23#include "tensorflow/c/eager/immediate_execution_operation.h"
24#include "tensorflow/core/common_runtime/eager/attr_builder.h"
25#include "tensorflow/core/common_runtime/eager/context.h"
26#include "tensorflow/core/common_runtime/eager/eager_executor.h"
27#include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
28#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
29#include "tensorflow/core/framework/cancellation.h"
30#include "tensorflow/core/framework/device_attributes.pb.h"
31#include "tensorflow/core/framework/op_def.pb.h"
32#include "tensorflow/core/util/device_name_utils.h"
33#include "tensorflow/core/util/managed_stack_trace.h"
34
35namespace tensorflow {
36
37class EagerOperation : public ImmediateExecutionOperation {
38 public:
39 explicit EagerOperation(tensorflow::EagerContext* ctx)
40 : ImmediateExecutionOperation(kEager), ctx_(*ctx), is_function_(false) {}
41 ~EagerOperation() override {
42 for (ImmediateExecutionTensorHandle* h : inputs_) {
43 h->Unref();
44 }
45 }
46
47 void Release() override { delete this; }
48
49 void Clear() override;
50 Status Reset(const char* op, const char* raw_device_name) override {
51 return Reset(op, raw_device_name, false, nullptr);
52 }
53
54 const string& Name() const override { return attrs_.op_name(); }
55
56 const string& DeviceName() const override { return device_name_; }
57
58 ImmediateExecutionContext* GetContext() const override { return &ctx_; }
59
60 const DeviceNameUtils::ParsedName& GetDeviceParsedName() const {
61 return device_parsed_name_;
62 }
63
64 // Replaces the previous device name with the given one (see
65 // AbstractOperation::SetDeviceName for more details).
66 //
67 // This also resets the internal device pointer, unless the given name refers
68 // to a known custom device, in which case the internal device pointer is
69 // updated to that device.
70 Status SetDeviceName(const char* name) override;
71
72 void SetDevice(VariantDevice device) {
73 device_ = device;
74 device_name_ = absl::visit(
75 [](auto* device) { return device == nullptr ? "" : device->name(); },
76 device);
77 DeviceNameUtils::ParseFullName(device_name_, &device_parsed_name_);
78 // TODO(b/154133594): Due to intricacies of external logic, we can not
79 // set this do device_name_ as it would be natural, because we need the
80 // next call to SetDeviceName to reset the device pointer.
81 last_set_device_name_ = "\177"; // DEL (an invalid value)
82 }
83
84 Status SetAttrValue(const char* attr_name, const AttrValue& value);
85
86 Status AddInput(AbstractTensorHandle* input) override;
87 Status AddInputList(absl::Span<AbstractTensorHandle* const> inputs) override;
88 Status SetInput(size_t index, ImmediateExecutionTensorHandle* input) override;
89 absl::Span<ImmediateExecutionTensorHandle* const> GetInputs() const override;
90 bool HasCustomDeviceInput() const override {
91 return custom_device_tensor_handles_count_ > 0;
92 }
93 Status Execute(absl::Span<AbstractTensorHandle*> retvals,
94 int* num_retvals) override;
95 const tensorflow::OpDef* OpDef() const override { return op_def_; };
96
97 Status SetAttrString(const char* attr_name, const char* data,
98 size_t length) override;
99 Status SetAttrInt(const char* attr_name, int64_t value) override;
100 Status SetAttrFloat(const char* attr_name, float value) override;
101 Status SetAttrBool(const char* attr_name, bool value) override;
102 Status SetAttrType(const char* attr_name, DataType value) override;
103 Status SetAttrShape(const char* attr_name, const int64_t* dims,
104 const int num_dims) override;
105 Status SetAttrFunction(const char* attr_name,
106 const AbstractOperation* value) override;
107 Status SetAttrFunctionName(const char* attr_name, const char* data,
108 size_t length) override;
109 Status SetAttrTensor(const char* attr_name,
110 AbstractTensorInterface* tensor) override;
111 Status SetAttrStringList(const char* attr_name, const void* const* values,
112 const size_t* lengths, int num_values) override;
113 Status SetAttrFloatList(const char* attr_name, const float* values,
114 int num_values) override;
115 Status SetAttrIntList(const char* attr_name, const int64_t* values,
116 int num_values) override;
117 Status SetAttrTypeList(const char* attr_name, const DataType* values,
118 int num_values) override;
119 Status SetAttrBoolList(const char* attr_name, const unsigned char* values,
120 int num_values) override;
121 Status SetAttrShapeList(const char* attr_name, const int64_t** dims,
122 const int* num_dims, int num_values) override;
123 Status SetAttrFunctionList(
124 const char* attr_name,
125 absl::Span<const AbstractOperation*> values) override;
126
127 Status InputLength(const char* input_name, int* length) override;
128 Status OutputLength(const char* output_name, int* length) override;
129
130 const AbstractOpAttrs* GetOpAttrs() const override;
131 void AddAttrs(const AbstractOpAttrs* op_attrs) override;
132
133 void SetStackTrace(ManagedStackTrace stack_trace) override {
134 stack_trace_ = stack_trace;
135 }
136
137 absl::optional<ManagedStackTrace> GetStackTrace() override {
138 return stack_trace_;
139 }
140
141 Status Reset(const char* op, const char* device_name, bool remote,
142 EagerExecutor* executor,
143 const absl::optional<EagerFunctionParams> remote_func_params =
144 absl::nullopt);
145
146 bool is_function() const { return is_function_; }
147 bool colocation_exempt() const { return colocation_exempt_; }
148
149 tensorflow::EagerContext& EagerContext() const { return ctx_; }
150
151 AttrBuilder* MutableAttrs() { return &attrs_; }
152 const AttrBuilder& Attrs() const { return attrs_; }
153
154 // TensorHandleInputs and MutableTensorHandleInputs first check that all
155 // inputs are TensorHandles, i.e. that there are no custom device inputs. They
156 // return a bad status otherwise.
157 Status TensorHandleInputs(
158 const absl::InlinedVector<TensorHandle*, 4>** inputs) const;
159 Status MutableTensorHandleInputs(
160 absl::InlinedVector<TensorHandle*, 4>** inputs);
161
162 const absl::InlinedVector<ImmediateExecutionTensorHandle*, 4>& Inputs()
163 const {
164 return inputs_;
165 }
166
167 void UpdateInput(int i, TensorHandle* h);
168
169 // Like TensorHandles, EagerOperations may be placed either on a virtual
170 // CustomDevice or on a physical Device.
171 VariantDevice Device() const { return device_; }
172
173 // Indicates whether the op is assigned to a device that is local to the
174 // current host.
175 bool IsLocal() const;
176
177 CancellationManager* GetCancellationManager() const {
178 return cancellation_manager_;
179 }
180 void SetCancellationManager(
181 CancellationManager* cancellation_manager) override {
182 cancellation_manager_ = cancellation_manager;
183 }
184
185 // Assign step_id value only if op has valid step id.
186 // When eager_func_params.has_value() returns true, we can directly overwrite
187 // its step id according to Op's step id (if not default value). However, when
188 // eager_func_params.has_value() returns false, we need to first create a new
189 // EagerFuncParams object for it before assigning step_id; otherwise,
190 // directly assigning step_id in this case leaves eager_func_params to be
191 // in a weird state where:
192 // (1) eager_func_params.has_value() returns false, but
193 // (2) eager_func_params->step_id.has_value() returns true.
194 void SetStepId(int64_t step_id) override {
195 assert(is_function());
196 if (step_id != EagerContext::kGlobalRendezvousId) {
197 if (eager_func_params_.has_value()) {
198 eager_func_params_->step_id = step_id;
199 } else {
200 eager_func_params_ = EagerFunctionParams{
201 kInvalidOpId, /*is_component_function=*/false, step_id};
202 }
203 } else {
204 LOG(WARNING) << "SetStepId() should not receive a gloabl rendezvous id.";
205 }
206 }
207
208 EagerExecutor& Executor() { return *executor_; }
209
210 string DebugString() const;
211
212 const absl::optional<EagerFunctionParams>& eager_func_params() const {
213 return eager_func_params_;
214 }
215
216 // Op name recorded for memory debugging purpose.
217 const char* op_name() const { return op_name_; }
218
219 // For LLVM style RTTI.
220 static bool classof(const AbstractOperation* ptr) {
221 return ptr->getKind() == kEager;
222 }
223
224 private:
225 void AddTensorHandle(ImmediateExecutionTensorHandle* h);
226
227 const tensorflow::OpDef* GetOpDef(Status* status);
228
229 void ClearInferenceState() {
230 op_def_ = nullptr;
231 inference_arg_idx_ = 0;
232 inference_attrs_.clear_no_resize();
233 }
234
235 Status MaybeInferSingleInputAttrs(ImmediateExecutionTensorHandle* handle);
236 Status InferInputListAttrs(int num_inputs);
237
238 void InferSingleTypeInputListAttrs(const OpDef::ArgDef& input_def,
239 const DataType dtype, int num_inputs);
240 void InferMixedTypeInputListAttrs(const OpDef::ArgDef& input_def,
241 const std::vector<DataType>& dtypes);
242
243 tensorflow::EagerContext& ctx_;
244 const char* op_name_ = nullptr;
245 AttrBuilder attrs_;
246 const AttrTypeMap* attr_types_;
247
248 // The number of custom device TensorHandle inputs. These inputs need to be
249 // processed by CustomDeviceOpHandler first.
250 int custom_device_tensor_handles_count_ = 0;
251 absl::InlinedVector<ImmediateExecutionTensorHandle*, 4> inputs_;
252
253 // The last device name given to SetDeviceName.
254 // This is used to avoid having to re-process the same device in repeated
255 // calls to SetDeviceName.
256 string last_set_device_name_;
257
258 // The operation's device name.
259 // This contains the named passed to SetDeviceName until device_ is set,
260 // at which point it contains the device_ name.
261 string device_name_;
262
263 // The parsed device name.
264 // This will always contain the result of
265 // DeviceNameUtils::ParseFullName(device_name_).
266 DeviceNameUtils::ParsedName device_parsed_name_;
267
268 // The operation's device.
269 // This is set by the execution device placement logic, and should conform
270 // with the contents of device_name_. Once it is set, the device_name_ is
271 // updated accordingly.
272 VariantDevice device_;
273
274 absl::optional<ManagedStackTrace> stack_trace_;
275 bool is_function_; // Conceptually const, but can't be because of Reset
276 bool colocation_exempt_;
277 CancellationManager* cancellation_manager_ = nullptr; // Not owned.
278 EagerExecutor* executor_; // Not owned.
279
280 absl::optional<EagerFunctionParams> eager_func_params_;
281
282 // Inference information
283 const tensorflow::OpDef* op_def_; // op definition from protobuf
284 int inference_arg_idx_; // arg definition index for the next input to be
285 // added
286 gtl::FlatSet<std::string> inference_attrs_; // attributes inferred so far
287};
288
289inline void EagerOperation::UpdateInput(int i, TensorHandle* h) {
290 ImmediateExecutionTensorHandle** slot = &inputs_[i];
291 ImmediateExecutionTensorHandle* existing = *slot;
292 if (existing != h) {
293 h->Ref();
294 existing->Unref();
295 *slot = h; // Update inputs_[i] to h
296 }
297}
298
299inline EagerOperation* OperationFromInterface(
300 ImmediateExecutionOperation* operation) {
301 return down_cast<EagerOperation*>(operation);
302}
303
304inline const EagerOperation* OperationFromInterface(
305 const ImmediateExecutionOperation* operation) {
306 return down_cast<const EagerOperation*>(operation);
307}
308
309} // namespace tensorflow
310
311#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_OPERATION_H_
312