1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_OPERATION_H_ |
16 | #define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_OPERATION_H_ |
17 | |
18 | #include "absl/container/inlined_vector.h" |
19 | #include "absl/types/optional.h" |
20 | #include "absl/types/span.h" |
21 | #include "absl/types/variant.h" |
22 | #include "tensorflow/c/eager/abstract_tensor_handle.h" |
23 | #include "tensorflow/c/eager/immediate_execution_operation.h" |
24 | #include "tensorflow/core/common_runtime/eager/attr_builder.h" |
25 | #include "tensorflow/core/common_runtime/eager/context.h" |
26 | #include "tensorflow/core/common_runtime/eager/eager_executor.h" |
27 | #include "tensorflow/core/common_runtime/eager/kernel_and_device.h" |
28 | #include "tensorflow/core/common_runtime/eager/tensor_handle.h" |
29 | #include "tensorflow/core/framework/cancellation.h" |
30 | #include "tensorflow/core/framework/device_attributes.pb.h" |
31 | #include "tensorflow/core/framework/op_def.pb.h" |
32 | #include "tensorflow/core/util/device_name_utils.h" |
33 | #include "tensorflow/core/util/managed_stack_trace.h" |
34 | |
35 | namespace tensorflow { |
36 | |
37 | class EagerOperation : public ImmediateExecutionOperation { |
38 | public: |
39 | explicit EagerOperation(tensorflow::EagerContext* ctx) |
40 | : ImmediateExecutionOperation(kEager), ctx_(*ctx), is_function_(false) {} |
41 | ~EagerOperation() override { |
42 | for (ImmediateExecutionTensorHandle* h : inputs_) { |
43 | h->Unref(); |
44 | } |
45 | } |
46 | |
47 | void Release() override { delete this; } |
48 | |
49 | void Clear() override; |
50 | Status Reset(const char* op, const char* raw_device_name) override { |
51 | return Reset(op, raw_device_name, false, nullptr); |
52 | } |
53 | |
54 | const string& Name() const override { return attrs_.op_name(); } |
55 | |
56 | const string& DeviceName() const override { return device_name_; } |
57 | |
58 | ImmediateExecutionContext* GetContext() const override { return &ctx_; } |
59 | |
60 | const DeviceNameUtils::ParsedName& GetDeviceParsedName() const { |
61 | return device_parsed_name_; |
62 | } |
63 | |
64 | // Replaces the previous device name with the given one (see |
65 | // AbstractOperation::SetDeviceName for more details). |
66 | // |
67 | // This also resets the internal device pointer, unless the given name refers |
68 | // to a known custom device, in which case the internal device pointer is |
69 | // updated to that device. |
70 | Status SetDeviceName(const char* name) override; |
71 | |
72 | void SetDevice(VariantDevice device) { |
73 | device_ = device; |
74 | device_name_ = absl::visit( |
75 | [](auto* device) { return device == nullptr ? "" : device->name(); }, |
76 | device); |
77 | DeviceNameUtils::ParseFullName(device_name_, &device_parsed_name_); |
78 | // TODO(b/154133594): Due to intricacies of external logic, we can not |
79 | // set this do device_name_ as it would be natural, because we need the |
80 | // next call to SetDeviceName to reset the device pointer. |
81 | last_set_device_name_ = "\177" ; // DEL (an invalid value) |
82 | } |
83 | |
84 | Status SetAttrValue(const char* attr_name, const AttrValue& value); |
85 | |
86 | Status AddInput(AbstractTensorHandle* input) override; |
87 | Status AddInputList(absl::Span<AbstractTensorHandle* const> inputs) override; |
88 | Status SetInput(size_t index, ImmediateExecutionTensorHandle* input) override; |
89 | absl::Span<ImmediateExecutionTensorHandle* const> GetInputs() const override; |
90 | bool HasCustomDeviceInput() const override { |
91 | return custom_device_tensor_handles_count_ > 0; |
92 | } |
93 | Status Execute(absl::Span<AbstractTensorHandle*> retvals, |
94 | int* num_retvals) override; |
95 | const tensorflow::OpDef* OpDef() const override { return op_def_; }; |
96 | |
97 | Status SetAttrString(const char* attr_name, const char* data, |
98 | size_t length) override; |
99 | Status SetAttrInt(const char* attr_name, int64_t value) override; |
100 | Status SetAttrFloat(const char* attr_name, float value) override; |
101 | Status SetAttrBool(const char* attr_name, bool value) override; |
102 | Status SetAttrType(const char* attr_name, DataType value) override; |
103 | Status SetAttrShape(const char* attr_name, const int64_t* dims, |
104 | const int num_dims) override; |
105 | Status SetAttrFunction(const char* attr_name, |
106 | const AbstractOperation* value) override; |
107 | Status SetAttrFunctionName(const char* attr_name, const char* data, |
108 | size_t length) override; |
109 | Status SetAttrTensor(const char* attr_name, |
110 | AbstractTensorInterface* tensor) override; |
111 | Status SetAttrStringList(const char* attr_name, const void* const* values, |
112 | const size_t* lengths, int num_values) override; |
113 | Status SetAttrFloatList(const char* attr_name, const float* values, |
114 | int num_values) override; |
115 | Status SetAttrIntList(const char* attr_name, const int64_t* values, |
116 | int num_values) override; |
117 | Status SetAttrTypeList(const char* attr_name, const DataType* values, |
118 | int num_values) override; |
119 | Status SetAttrBoolList(const char* attr_name, const unsigned char* values, |
120 | int num_values) override; |
121 | Status SetAttrShapeList(const char* attr_name, const int64_t** dims, |
122 | const int* num_dims, int num_values) override; |
123 | Status SetAttrFunctionList( |
124 | const char* attr_name, |
125 | absl::Span<const AbstractOperation*> values) override; |
126 | |
127 | Status InputLength(const char* input_name, int* length) override; |
128 | Status OutputLength(const char* output_name, int* length) override; |
129 | |
130 | const AbstractOpAttrs* GetOpAttrs() const override; |
131 | void AddAttrs(const AbstractOpAttrs* op_attrs) override; |
132 | |
133 | void SetStackTrace(ManagedStackTrace stack_trace) override { |
134 | stack_trace_ = stack_trace; |
135 | } |
136 | |
137 | absl::optional<ManagedStackTrace> GetStackTrace() override { |
138 | return stack_trace_; |
139 | } |
140 | |
141 | Status Reset(const char* op, const char* device_name, bool remote, |
142 | EagerExecutor* executor, |
143 | const absl::optional<EagerFunctionParams> remote_func_params = |
144 | absl::nullopt); |
145 | |
146 | bool is_function() const { return is_function_; } |
147 | bool colocation_exempt() const { return colocation_exempt_; } |
148 | |
149 | tensorflow::EagerContext& EagerContext() const { return ctx_; } |
150 | |
151 | AttrBuilder* MutableAttrs() { return &attrs_; } |
152 | const AttrBuilder& Attrs() const { return attrs_; } |
153 | |
154 | // TensorHandleInputs and MutableTensorHandleInputs first check that all |
155 | // inputs are TensorHandles, i.e. that there are no custom device inputs. They |
156 | // return a bad status otherwise. |
157 | Status TensorHandleInputs( |
158 | const absl::InlinedVector<TensorHandle*, 4>** inputs) const; |
159 | Status MutableTensorHandleInputs( |
160 | absl::InlinedVector<TensorHandle*, 4>** inputs); |
161 | |
162 | const absl::InlinedVector<ImmediateExecutionTensorHandle*, 4>& Inputs() |
163 | const { |
164 | return inputs_; |
165 | } |
166 | |
167 | void UpdateInput(int i, TensorHandle* h); |
168 | |
169 | // Like TensorHandles, EagerOperations may be placed either on a virtual |
170 | // CustomDevice or on a physical Device. |
171 | VariantDevice Device() const { return device_; } |
172 | |
173 | // Indicates whether the op is assigned to a device that is local to the |
174 | // current host. |
175 | bool IsLocal() const; |
176 | |
177 | CancellationManager* GetCancellationManager() const { |
178 | return cancellation_manager_; |
179 | } |
180 | void SetCancellationManager( |
181 | CancellationManager* cancellation_manager) override { |
182 | cancellation_manager_ = cancellation_manager; |
183 | } |
184 | |
185 | // Assign step_id value only if op has valid step id. |
186 | // When eager_func_params.has_value() returns true, we can directly overwrite |
187 | // its step id according to Op's step id (if not default value). However, when |
188 | // eager_func_params.has_value() returns false, we need to first create a new |
189 | // EagerFuncParams object for it before assigning step_id; otherwise, |
190 | // directly assigning step_id in this case leaves eager_func_params to be |
191 | // in a weird state where: |
192 | // (1) eager_func_params.has_value() returns false, but |
193 | // (2) eager_func_params->step_id.has_value() returns true. |
194 | void SetStepId(int64_t step_id) override { |
195 | assert(is_function()); |
196 | if (step_id != EagerContext::kGlobalRendezvousId) { |
197 | if (eager_func_params_.has_value()) { |
198 | eager_func_params_->step_id = step_id; |
199 | } else { |
200 | eager_func_params_ = EagerFunctionParams{ |
201 | kInvalidOpId, /*is_component_function=*/false, step_id}; |
202 | } |
203 | } else { |
204 | LOG(WARNING) << "SetStepId() should not receive a gloabl rendezvous id." ; |
205 | } |
206 | } |
207 | |
208 | EagerExecutor& Executor() { return *executor_; } |
209 | |
210 | string DebugString() const; |
211 | |
212 | const absl::optional<EagerFunctionParams>& eager_func_params() const { |
213 | return eager_func_params_; |
214 | } |
215 | |
216 | // Op name recorded for memory debugging purpose. |
217 | const char* op_name() const { return op_name_; } |
218 | |
219 | // For LLVM style RTTI. |
220 | static bool classof(const AbstractOperation* ptr) { |
221 | return ptr->getKind() == kEager; |
222 | } |
223 | |
224 | private: |
225 | void AddTensorHandle(ImmediateExecutionTensorHandle* h); |
226 | |
227 | const tensorflow::OpDef* GetOpDef(Status* status); |
228 | |
229 | void ClearInferenceState() { |
230 | op_def_ = nullptr; |
231 | inference_arg_idx_ = 0; |
232 | inference_attrs_.clear_no_resize(); |
233 | } |
234 | |
235 | Status MaybeInferSingleInputAttrs(ImmediateExecutionTensorHandle* handle); |
236 | Status InferInputListAttrs(int num_inputs); |
237 | |
238 | void InferSingleTypeInputListAttrs(const OpDef::ArgDef& input_def, |
239 | const DataType dtype, int num_inputs); |
240 | void InferMixedTypeInputListAttrs(const OpDef::ArgDef& input_def, |
241 | const std::vector<DataType>& dtypes); |
242 | |
243 | tensorflow::EagerContext& ctx_; |
244 | const char* op_name_ = nullptr; |
245 | AttrBuilder attrs_; |
246 | const AttrTypeMap* attr_types_; |
247 | |
248 | // The number of custom device TensorHandle inputs. These inputs need to be |
249 | // processed by CustomDeviceOpHandler first. |
250 | int custom_device_tensor_handles_count_ = 0; |
251 | absl::InlinedVector<ImmediateExecutionTensorHandle*, 4> inputs_; |
252 | |
253 | // The last device name given to SetDeviceName. |
254 | // This is used to avoid having to re-process the same device in repeated |
255 | // calls to SetDeviceName. |
256 | string last_set_device_name_; |
257 | |
258 | // The operation's device name. |
259 | // This contains the named passed to SetDeviceName until device_ is set, |
260 | // at which point it contains the device_ name. |
261 | string device_name_; |
262 | |
263 | // The parsed device name. |
264 | // This will always contain the result of |
265 | // DeviceNameUtils::ParseFullName(device_name_). |
266 | DeviceNameUtils::ParsedName device_parsed_name_; |
267 | |
268 | // The operation's device. |
269 | // This is set by the execution device placement logic, and should conform |
270 | // with the contents of device_name_. Once it is set, the device_name_ is |
271 | // updated accordingly. |
272 | VariantDevice device_; |
273 | |
274 | absl::optional<ManagedStackTrace> stack_trace_; |
275 | bool is_function_; // Conceptually const, but can't be because of Reset |
276 | bool colocation_exempt_; |
277 | CancellationManager* cancellation_manager_ = nullptr; // Not owned. |
278 | EagerExecutor* executor_; // Not owned. |
279 | |
280 | absl::optional<EagerFunctionParams> eager_func_params_; |
281 | |
282 | // Inference information |
283 | const tensorflow::OpDef* op_def_; // op definition from protobuf |
284 | int inference_arg_idx_; // arg definition index for the next input to be |
285 | // added |
286 | gtl::FlatSet<std::string> inference_attrs_; // attributes inferred so far |
287 | }; |
288 | |
289 | inline void EagerOperation::UpdateInput(int i, TensorHandle* h) { |
290 | ImmediateExecutionTensorHandle** slot = &inputs_[i]; |
291 | ImmediateExecutionTensorHandle* existing = *slot; |
292 | if (existing != h) { |
293 | h->Ref(); |
294 | existing->Unref(); |
295 | *slot = h; // Update inputs_[i] to h |
296 | } |
297 | } |
298 | |
299 | inline EagerOperation* OperationFromInterface( |
300 | ImmediateExecutionOperation* operation) { |
301 | return down_cast<EagerOperation*>(operation); |
302 | } |
303 | |
304 | inline const EagerOperation* OperationFromInterface( |
305 | const ImmediateExecutionOperation* operation) { |
306 | return down_cast<const EagerOperation*>(operation); |
307 | } |
308 | |
309 | } // namespace tensorflow |
310 | |
311 | #endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_OPERATION_H_ |
312 | |