1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_FRAMEWORK_OP_KERNEL_H_ |
17 | #define TENSORFLOW_CORE_FRAMEWORK_OP_KERNEL_H_ |
18 | |
19 | #include <functional> |
20 | #include <unordered_set> |
21 | #include <utility> |
22 | #include <vector> |
23 | |
24 | #include "absl/time/time.h" |
25 | #include "absl/types/optional.h" |
26 | #include "absl/types/span.h" |
27 | #include "tensorflow/core/framework/allocator.h" |
28 | #include "tensorflow/core/framework/cancellation.h" |
29 | #include "tensorflow/core/framework/control_flow.h" |
30 | #include "tensorflow/core/framework/device_base.h" |
31 | #include "tensorflow/core/framework/graph.pb.h" |
32 | #include "tensorflow/core/framework/kernel_def.pb.h" |
33 | #include "tensorflow/core/framework/kernel_def_builder.h" |
34 | #include "tensorflow/core/framework/node_def.pb.h" |
35 | #include "tensorflow/core/framework/node_def_util.h" |
36 | #include "tensorflow/core/framework/node_properties.h" |
37 | #include "tensorflow/core/framework/op.h" // TODO(b/62899350): Remove |
38 | #include "tensorflow/core/framework/op_requires.h" |
39 | #include "tensorflow/core/framework/registration/registration.h" |
40 | #include "tensorflow/core/framework/rendezvous.h" |
41 | #include "tensorflow/core/framework/session_state.h" |
42 | #include "tensorflow/core/framework/tensor.h" |
43 | #include "tensorflow/core/framework/tensor_shape.h" |
44 | #include "tensorflow/core/framework/tensor_shape.pb.h" // TODO(b/62899350): Remove |
45 | #include "tensorflow/core/framework/tracking_allocator.h" |
46 | #include "tensorflow/core/framework/types.h" |
47 | #include "tensorflow/core/framework/types.pb.h" |
48 | #include "tensorflow/core/lib/core/errors.h" |
49 | #include "tensorflow/core/lib/core/status.h" |
50 | #include "tensorflow/core/lib/gtl/array_slice.h" |
51 | #include "tensorflow/core/lib/gtl/manual_constructor.h" |
52 | #include "tensorflow/core/platform/env.h" |
53 | #include "tensorflow/core/platform/logging.h" |
54 | #include "tensorflow/core/platform/macros.h" |
55 | #include "tensorflow/core/platform/mutex.h" |
56 | #include "tensorflow/core/platform/profile_utils/cpu_utils.h" |
57 | #include "tensorflow/core/platform/thread_annotations.h" |
58 | #include "tensorflow/core/platform/types.h" |
59 | #include "tensorflow/core/protobuf/config.pb.h" |
60 | #include "tensorflow/core/util/managed_stack_trace.h" |
61 | |
62 | namespace Eigen { |
63 | struct ThreadPoolDevice; |
64 | struct GpuDevice; |
65 | } // end namespace Eigen |
66 | |
67 | namespace tensorflow { |
68 | |
69 | namespace checkpoint { |
70 | class TensorSliceReaderCacheWrapper; |
71 | } // namespace checkpoint |
72 | |
73 | class AsyncOpKernel; |
74 | class CallFrameInterface; |
75 | class DeviceMgr; |
76 | class FunctionLibraryRuntime; |
77 | class OpKernelConstruction; // declared below |
78 | class OpKernelContext; // declared below, |
79 | class OpRegistryInterface; |
80 | class ResourceMgr; |
81 | class ScopedStepContainer; |
82 | class CollectiveExecutor; |
83 | class StepStatsCollectorInterface; |
84 | class CoordinationServiceAgent; |
85 | |
86 | // A label that is added to kernels that are JIT compiled. These labels will be |
87 | // removed before kernels are looked up, so they can be used without specifying |
88 | // the label. This label is a temporary measure to allow JIT kernels to be |
89 | // disabled if needed. |
90 | extern const char* kJitKernelLabel; |
91 | extern const char* kDisableJitKernelsEnvVar; |
92 | |
93 | class OpKernel { |
94 | public: |
95 | // OpKernel won't be instantiated by the scheduler, so you may perform |
96 | // expensive initialization in the descendant's constructor. |
97 | explicit OpKernel(OpKernelConstruction* context); |
98 | |
99 | // Specialized constructor that allows a kernel implementation to mark itself |
100 | // as a "deferred" op. If true, the executor will provide access to the |
101 | // `OpKernelContext::inc_num_deferred_ops_function()` and |
102 | // `OpKernelContext::dec_num_deferred_ops_function()` methods at run-time. |
103 | OpKernel(OpKernelConstruction* context, bool is_deferred); |
104 | |
105 | // Specialized constructor that enables the descendant to provide a custom |
106 | // `NodeDef` value. For example, this constructor can be used to provide a |
107 | // stripped-down `NodeDef` that does not contain the full set of attrs (such |
108 | // as tensor values) if the descendant stores them in a different form. |
109 | OpKernel(OpKernelConstruction* context, NodeDef&& custom_def, |
110 | bool is_deferred); |
111 | |
112 | virtual ~OpKernel(); |
113 | |
114 | // An OpKernel's computation can be either synchronous or |
115 | // asynchronous. All OpKernel Compute() methods must be thread-safe as they |
116 | // may be called concurrently (e.g. by multiple executions of the same graph |
117 | // concurrently). |
118 | // |
119 | // Most OpKernels should compute synchronously. They should |
120 | // subclass OpKernel and override the Compute() method and have it |
121 | // return after completing the supplied work. |
122 | // |
123 | // A synchronous OpKernel *MUST NOT* block the calling thread on a |
124 | // synchronization mechanism (condition variable, Notification, etc.) that |
125 | // will be unblocked by the execution of another OpKernel. Execution may |
126 | // deadlock in that case, because the executor may use a bounded number of |
127 | // threads. |
128 | // |
129 | // If an OpKernel must block on the execution of another OpKernel (e.g. a |
130 | // RecvOp, or a DequeueOp), the implementation *MUST* subclass AsyncOpKernel, |
131 | // and override `AsyncOpKernel::ComputeAsync()`. In addition, because the |
132 | // unblocking kernel may never run (due to an error or cancellation), in most |
133 | // cases the AsyncOpKernel should implement cancellation support via |
134 | // `ctx->cancellation_manager()`. |
135 | // |
136 | // In both cases, implementations of Compute() and ComputeAsync() |
137 | // get inputs and write outputs through the given OpKernelContext |
138 | // and returns a status via context->SetStatus(). They must be |
139 | // thread-safe. |
140 | |
141 | // Synchronous compute. |
142 | // |
143 | // "context" is guaranteed to be alive until Compute() returns. |
144 | virtual void Compute(OpKernelContext* context) = 0; |
145 | |
146 | // Returns nullptr iff this op kernel is synchronous. |
147 | virtual AsyncOpKernel* AsAsync() { return nullptr; } |
148 | |
149 | // Returns true iff this op kernel is considered "expensive". The |
150 | // runtime may use this flag to optimize graph execution for example |
151 | // to "inline" inexpensive kernels. |
152 | virtual bool IsExpensive() { return expensive_; } |
153 | |
154 | // Returns a pointer to the tensor stored inside constant ops. |
155 | virtual const Tensor* const_tensor() const { return nullptr; } |
156 | |
157 | // Accessors. |
158 | const NodeDef& def() const { return props_->node_def; } |
159 | const std::string& name() const { return props_->node_def.name(); } |
160 | absl::string_view name_view() const { return name_view_; } |
161 | const std::string& type_string() const { return props_->node_def.op(); } |
162 | absl::string_view type_string_view() const { return type_string_view_; } |
163 | const std::string& requested_input(int i) const { |
164 | return props_->node_def.input(i); |
165 | } |
166 | const std::string& requested_device() const { |
167 | return props_->node_def.device(); |
168 | } |
169 | |
170 | int num_inputs() const { return props_->input_types.size(); } |
171 | DataType input_type(int i) const { return props_->input_types[i]; } |
172 | const DataTypeVector& input_types() const { return props_->input_types; } |
173 | const MemoryTypeVector& input_memory_types() const { |
174 | return input_memory_types_; |
175 | } |
176 | |
177 | int num_outputs() const { return props_->output_types.size(); } |
178 | DataType output_type(int o) const { return props_->output_types[o]; } |
179 | const DataTypeVector& output_types() const { return props_->output_types; } |
180 | const MemoryTypeVector& output_memory_types() const { |
181 | return output_memory_types_; |
182 | } |
183 | |
184 | Status InputRange(StringPiece input_name, int* start, int* stop) const; |
185 | Status OutputRange(StringPiece output_name, int* start, int* stop) const; |
186 | |
187 | // Returns `true` if and only if this kernel uses deferred execution. |
188 | bool is_deferred() const { return is_deferred_; } |
189 | |
190 | // Returns a trace string for current computation, op name/type and input |
191 | // tensor shape/dtype are encoded for profiler cost analysis. Most OpKernel |
192 | // should use the default implementation. |
193 | virtual std::string TraceString(const OpKernelContext& ctx, |
194 | bool verbose) const; |
195 | |
196 | protected: |
197 | std::string ShapeTraceString(const OpKernelContext& ctx) const; |
198 | |
199 | private: |
200 | const std::shared_ptr<const NodeProperties> props_; |
201 | const MemoryTypeVector input_memory_types_; |
202 | const MemoryTypeVector output_memory_types_; |
203 | NameRangeMap input_name_map_; |
204 | NameRangeMap output_name_map_; |
205 | const absl::string_view name_view_; |
206 | const absl::string_view type_string_view_; |
207 | const int graph_def_version_; |
208 | const bool is_deferred_; |
209 | bool expensive_; |
210 | |
211 | TF_DISALLOW_COPY_AND_ASSIGN(OpKernel); |
212 | }; |
213 | |
214 | class AsyncOpKernel : public OpKernel { |
215 | public: |
216 | using OpKernel::OpKernel; // Lift OpKernel constructors. |
217 | |
218 | // Asynchronous compute. |
219 | // |
220 | // Implementations of ComputeAsync() must ensure that `done` is (eventually) |
221 | // called exactly once to signal the completion of the computation. The |
222 | // implementation of ComputeAsync() must not block on the execution of another |
223 | // OpKernel. `done` may be called by the current thread, or by another thread. |
224 | // `context` is guaranteed to stay alive until the `done` callback starts. |
225 | // |
226 | // Since it is possible that the unblocking kernel may never run (due to an |
227 | // error or cancellation), in most cases the AsyncOpKernel should implement |
228 | // cancellation support via `context->cancellation_manager()`. |
229 | // |
230 | // WARNING: As soon as the `done` callback starts, `context` and `this` may be |
231 | // deleted. No code depending on these objects should execute after the call |
232 | // to `done`. |
233 | typedef std::function<void()> DoneCallback; |
234 | virtual void ComputeAsync(OpKernelContext* context, DoneCallback done) = 0; |
235 | |
236 | AsyncOpKernel* AsAsync() override { return this; } |
237 | |
238 | void Compute(OpKernelContext* context) override; |
239 | }; |
240 | |
241 | class OpKernelConstruction { |
242 | public: |
243 | OpKernelConstruction(DeviceType device_type, DeviceBase* device, |
244 | Allocator* allocator, FunctionLibraryRuntime* flib, |
245 | ResourceMgr* resource_mgr, |
246 | const std::shared_ptr<const NodeProperties>& props, |
247 | const MemoryTypeSlice& input_memory_types, |
248 | const MemoryTypeSlice& output_memory_types, |
249 | int graph_def_version, Status* status); |
250 | |
251 | Env* env() const { return device_->env(); } |
252 | |
253 | // Allocation of tensors during kernel construction: |
254 | // |
255 | // It is legal to temporarily allocate scratch tensor storage during |
256 | // Op kernel construction. Scratch tensors should be allocated using |
257 | // allocate_temp below. Some kernels need to keep tensors in between |
258 | // invocations. If such a Tensor is allocated during kernel |
259 | // construction this also must be done using allocate_temp, and the |
260 | // Op may only store the returned Tensor object. |
261 | |
262 | // Allocates a temporary Tensor of the specified type and shape. The |
263 | // Tensor must not be used after kernel construction is |
264 | // complete. See comment above. |
265 | Status allocate_temp(DataType type, const TensorShape& shape, |
266 | Tensor* out_temp); |
267 | Status allocate_temp(DataType type, const TensorShape& shape, |
268 | Tensor* out_temp, AllocatorAttributes allocator_attr); |
269 | |
270 | // User-supplied configuration of this operation. |
271 | const NodeDef& def() const { return props_->node_def; } |
272 | |
273 | // For inspecting the inputs to this operation. |
274 | int num_inputs() const { return props_->input_types.size(); } |
275 | DataType input_type(int i) const { return props_->input_types[i]; } |
276 | const DataTypeSlice& input_types() const { return props_->input_types_slice; } |
277 | const MemoryTypeSlice& input_memory_types() const { |
278 | return input_memory_types_; |
279 | } |
280 | |
281 | // For inspecting the outputs expected from this operation. |
282 | int num_outputs() const { return props_->output_types.size(); } |
283 | DataType output_type(int i) const { return props_->output_types[i]; } |
284 | const DataTypeSlice& output_types() const { |
285 | return props_->output_types_slice; |
286 | } |
287 | const MemoryTypeSlice& output_memory_types() const { |
288 | return output_memory_types_; |
289 | } |
290 | |
291 | // If expected_inputs == inputs() and expected_outputs == output_types(), |
292 | // returns OK, else returns INVALID_ARGUMENT with an error message. |
293 | // Recommended for Ops with dynamic signatures. |
294 | Status MatchSignature(const DataTypeSlice expected_inputs, |
295 | const DataTypeSlice expected_outputs); |
296 | |
297 | // For recording configuration errors during construction. |
298 | void SetStatus(const Status& status); |
299 | const Status& status() const { return *status_; } |
300 | |
301 | // Look up the attr with name attr_name and set *value to its value. If no |
302 | // attr with attr_name is found in def(), or the attr does not have |
303 | // a matching type, a non-ok status will be returned. |
304 | template <class T> |
305 | Status GetAttr(StringPiece attr_name, T* value) const TF_ATTRIBUTE_NOINLINE; |
306 | |
307 | // Return true if the attr_name is defined in def(). |
308 | bool HasAttr(StringPiece attr_name) const; |
309 | |
310 | // Return the device type. |
311 | const DeviceType& device_type() const { return device_type_; } |
312 | |
313 | // If not nullptr, the kernel can instantiate functions defined in |
314 | // the library. E.g., |
315 | // CHECK_NOTNULL(function_library())->Instantiate("Foo", ...). |
316 | FunctionLibraryRuntime* function_library() const { return flib_; } |
317 | |
318 | // Shared resources accessible to this kernel. |
319 | ResourceMgr* resource_manager() const { return resource_mgr_; } |
320 | |
321 | // The GraphDef version whose behavior we should follow. |
322 | int graph_def_version() const { return graph_def_version_; } |
323 | |
324 | // Helper routines for the OP_REQUIRES macros |
325 | void CtxFailure(const Status& s); |
326 | void CtxFailureWithWarning(const Status& s); |
327 | void CtxFailure(const char* file, int line, const Status& s); |
328 | void CtxFailureWithWarning(const char* file, int line, const Status& s); |
329 | |
330 | // Unrecommended functions: these are functions that have some |
331 | // current uses but are not recommended for use, and may go away at |
332 | // some future major version release. |
333 | |
334 | // May be used, e.g., to get GPU handles, etc. |
335 | // |
336 | // Currently only used to call MakeTensorFromProto() for |
337 | // implementing ConstantOp for every device. See comments |
338 | // on Device::MakeTensorFromProto for longer-term replacement |
339 | // ideas. |
340 | DeviceBase* device() const { return device_; } |
341 | |
342 | private: |
343 | const DeviceType device_type_; |
344 | DeviceBase* const device_; |
345 | Allocator* allocator_; |
346 | FunctionLibraryRuntime* flib_; |
347 | ResourceMgr* const resource_mgr_; |
348 | std::shared_ptr<const NodeProperties> props_; |
349 | MemoryTypeSlice input_memory_types_; |
350 | MemoryTypeSlice output_memory_types_; |
351 | const int graph_def_version_; |
352 | Status* status_; |
353 | |
354 | // Allow access from OpKernel ctor. |
355 | friend class OpKernel; |
356 | |
357 | TF_DISALLOW_COPY_AND_ASSIGN(OpKernelConstruction); |
358 | }; |
359 | |
360 | // TODO(mrry): Consider converting to a random_access_iterator, and upgrading |
361 | // tensorflow::gtl::iterator_range to make the below container classes |
362 | // unnecessary. |
363 | template <typename ListType, typename ElementType> |
364 | class OpArgIterator { |
365 | public: |
366 | using iterator_category = std::forward_iterator_tag; |
367 | using value_type = ElementType; |
368 | using pointer = ElementType*; |
369 | using const_pointer = const ElementType*; |
370 | using reference = ElementType&; |
371 | using const_reference = const ElementType&; |
372 | using difference_type = ptrdiff_t; |
373 | |
374 | OpArgIterator(const ListType* list, int i) : list_(list), i_(i) {} |
375 | |
376 | bool operator==(const OpArgIterator& rhs) { |
377 | DCHECK(list_ == rhs.list_); |
378 | return i_ == rhs.i_; |
379 | } |
380 | |
381 | bool operator!=(const OpArgIterator& rhs) { |
382 | DCHECK(list_ == rhs.list_); |
383 | return i_ != rhs.i_; |
384 | } |
385 | |
386 | OpArgIterator operator++() { // prefix ++it |
387 | ++i_; |
388 | return *this; |
389 | } |
390 | |
391 | OpArgIterator operator++(int) { // postfix it++ |
392 | OpArgIterator old_value = *this; |
393 | ++i_; |
394 | return old_value; |
395 | } |
396 | |
397 | reference operator*() { return (*list_)[i_]; } |
398 | pointer operator->() { return &(*list_)[i_]; } |
399 | |
400 | const_reference operator*() const { return (*list_)[i_]; } |
401 | const_pointer operator->() const { return &(*list_)[i_]; } |
402 | |
403 | private: |
404 | const ListType* const list_; |
405 | int i_; |
406 | }; |
407 | |
408 | // Utility class for representing a list of immutable input tensors |
409 | // that are passed to the op as a single named argument. |
410 | class OpInputList { |
411 | public: |
412 | typedef OpArgIterator<OpInputList, const Tensor> Iterator; |
413 | OpInputList() : ctx_(nullptr), start_(0), stop_(0) {} |
414 | OpInputList(OpKernelContext* ctx, int start, int stop) |
415 | : ctx_(ctx), start_(start), stop_(stop) {} |
416 | OpInputList& operator=(const OpInputList& other) = default; |
417 | const Tensor& operator[](int i) const; |
418 | int size() const { return stop_ - start_; } |
419 | Iterator begin() const { return Iterator(this, 0); } |
420 | Iterator end() const { return Iterator(this, size()); } |
421 | |
422 | private: |
423 | OpKernelContext* ctx_; // not owned |
424 | int start_; |
425 | int stop_; |
426 | }; |
427 | |
428 | // Utility class for representing a list of mutable ("ref") input tensors |
429 | // that are passed to the op as a single named argument. |
430 | class OpMutableInputList { |
431 | public: |
432 | typedef OpArgIterator<OpMutableInputList, Tensor*> Iterator; |
433 | OpMutableInputList(OpKernelContext* ctx, int start, int stop) |
434 | : ctx_(ctx), start_(start), stop_(stop) {} |
435 | OpMutableInputList() : ctx_(nullptr), start_(0), stop_(0) {} |
436 | OpMutableInputList& operator=(const OpMutableInputList& other) = default; |
437 | Tensor at(int i, bool lock_held); |
438 | mutex* ref_mutex(int i); |
439 | int size() const { return stop_ - start_; } |
440 | Iterator begin() const { return Iterator(this, 0); } |
441 | Iterator end() const { return Iterator(this, size()); } |
442 | |
443 | private: |
444 | OpKernelContext* ctx_; // not owned |
445 | int start_; |
446 | int stop_; |
447 | }; |
448 | |
449 | // Utility class for representing a list of output tensors that are |
450 | // grouped as a single named output. |
451 | class OpOutputList { |
452 | public: |
453 | typedef OpArgIterator<OpOutputList, const Tensor*> Iterator; |
454 | OpOutputList() : ctx_(nullptr), start_(0), stop_(0) {} |
455 | OpOutputList(OpKernelContext* ctx, int start, int stop) |
456 | : ctx_(ctx), start_(start), stop_(stop) {} |
457 | OpOutputList& operator=(const OpOutputList& other) = default; |
458 | Tensor* operator[](int i); |
459 | bool required(int i) const; |
460 | DataType expected_output_dtype(int i) const; |
461 | Status allocate(int i, const TensorShape& shape, Tensor** output); |
462 | void set(int i, const Tensor& tensor); |
463 | void set(int i, Tensor&& tensor); |
464 | void set_ref(int i, mutex* mu, Tensor* tensor_for_ref); |
465 | int size() const { return stop_ - start_; } |
466 | Iterator begin() const { return Iterator(this, 0); } |
467 | Iterator end() const { return Iterator(this, size()); } |
468 | |
469 | private: |
470 | OpKernelContext* ctx_; // not owned |
471 | int start_; |
472 | int stop_; |
473 | }; |
474 | |
475 | // Holds a tensor or tensor reference. For tensor references, we need |
476 | // a mutex to prevent concurrent access to the tensor. |
477 | struct TensorValue { |
478 | TensorValue() : mutex_if_ref(nullptr), tensor(nullptr) {} |
479 | explicit TensorValue(Tensor* t) : mutex_if_ref(nullptr), tensor(t) {} |
480 | TensorValue(mutex* mu, Tensor* t) : mutex_if_ref(mu), tensor(t) {} |
481 | Tensor* operator->() const { return tensor; } |
482 | bool is_ref() const { return mutex_if_ref != nullptr; } |
483 | |
484 | // Return the dtype of the Tensor. For references, return the underlying type. |
485 | DataType dtype() const { |
486 | if (is_ref()) { |
487 | return MakeRefType(tensor->dtype()); |
488 | } else { |
489 | return tensor->dtype(); |
490 | } |
491 | } |
492 | |
493 | // Return the dtype of the Tensor. For references, return the underlying type. |
494 | // This variation on the dtype() acquires the lock for references. |
495 | // |
496 | // TODO(b/133843385): Disallow dtype modifications |
497 | DataType dtype_safe() const { |
498 | if (is_ref()) { |
499 | tf_shared_lock ml(*mutex_if_ref); |
500 | return MakeRefType(tensor->dtype()); |
501 | } else { |
502 | return tensor->dtype(); |
503 | } |
504 | } |
505 | |
506 | mutex* mutex_if_ref; // nullptr if not a ref, != nullptr if a ref |
507 | Tensor* tensor; |
508 | }; |
509 | |
510 | // Used to store partitioned graphs from function-calling ops. |
511 | struct GraphCollector { |
512 | mutex mu; |
513 | std::vector<GraphDef> partitioned_graphs TF_GUARDED_BY(mu); |
514 | GraphDef raw_graph TF_GUARDED_BY(mu); |
515 | GraphDef optimized_graph TF_GUARDED_BY(mu); |
516 | |
517 | bool dirty TF_GUARDED_BY(mu); |
518 | |
519 | GraphCollector() : dirty(false) {} |
520 | |
521 | void CollectRawGraph(const GraphDef& graph) { |
522 | mutex_lock ml(mu); |
523 | raw_graph.MergeFrom(graph); |
524 | dirty = true; |
525 | } |
526 | |
527 | void CollectOptimizedGraph(const GraphDef& graph) { |
528 | mutex_lock ml(mu); |
529 | optimized_graph.MergeFrom(graph); |
530 | dirty = true; |
531 | } |
532 | |
533 | void CollectPartitionedGraph(const GraphDef& graph) { |
534 | mutex_lock ml(mu); |
535 | partitioned_graphs.push_back(graph); |
536 | dirty = true; |
537 | } |
538 | |
539 | void ClearGraphs() TF_EXCLUSIVE_LOCKS_REQUIRED(mu) { |
540 | raw_graph.Clear(); |
541 | optimized_graph.Clear(); |
542 | partitioned_graphs.clear(); |
543 | dirty = false; |
544 | } |
545 | |
546 | bool HasUpdatedGraphs() { |
547 | mutex_lock ml(mu); |
548 | return dirty; |
549 | } |
550 | }; |
551 | |
552 | class OpKernelContext { |
553 | public: |
554 | // The first element of a WrappedAllocator is a "base" Allocator and |
555 | // the second element is that Allocator wrapped by a |
556 | // TrackingAllocator |
557 | typedef std::pair<Allocator*, TrackingAllocator*> WrappedAllocator; |
558 | |
559 | // TODO(zhifengc): Do some cleanup of Params. |
560 | // The Params struct is passed in to initialize an OpKernelContext, |
561 | // and must outlive the OpKernelContext. |
562 | struct Params { |
563 | ~Params() { delete eigen_gpu_device; } |
564 | |
565 | // The step being executed. |
566 | int64_t step_id = 0; |
567 | |
568 | // Timestamp for the start of graph execution. Used for latency metrics. |
569 | int64_t start_time_usecs = 0; |
570 | |
571 | // The deadline for the session to complete by. Empty if unspecified. |
572 | absl::optional<absl::Time> deadline; |
573 | |
574 | // The op kernel being computed. |
575 | OpKernel* op_kernel = nullptr; |
576 | |
577 | // The device on which the kernel is running. |
578 | DeviceBase* device = nullptr; |
579 | |
580 | // The Eigen GPU device wrapper, which may include a per-op |
581 | // wrapped allocator. The concrete type of this object depends on |
582 | // the type of this->device, so eigen_gpu_device can't be an |
583 | // inline member and must be heap allocated. However, we don't |
584 | // want to allocate a new eigen_gpu_device for every Op that is |
585 | // executed. Instead this member is allocated on first use using |
586 | // ensure_eigen_gpu_device, and then if the Params structure is |
587 | // re-used for subsequent Ops, the eigen_gpu_device is |
588 | // ReInitialized in the OpKernelContext constructor. Unlike the |
589 | // other pointers in Params, this one is owned by Params. |
590 | PerOpGpuDevice* eigen_gpu_device = nullptr; |
591 | |
592 | inline void ensure_eigen_gpu_device() { |
593 | DCHECK(device); |
594 | if (nullptr == eigen_gpu_device) { |
595 | // Surprisingly, MakeGpuDevice will return nullptr if the |
596 | // device is not a GPU device. This is ok, since those devices |
597 | // will never use eigen_gpu_device. It seems better to have |
598 | // ensure_eigen_gpu_device fall through and regenerate the |
599 | // nullptr every time an OpKernelContext is instantiated, than |
600 | // to do an unnecessary allocation of a dummy eigen GPU |
601 | // device for CPU device Ops. |
602 | eigen_gpu_device = device->MakeGpuDevice(); |
603 | } |
604 | } |
605 | |
606 | bool track_allocations = false; |
607 | bool log_memory = false; |
608 | |
609 | // Array indexed by output number for this node |
610 | const AllocatorAttributes* output_attr_array = nullptr; |
611 | |
612 | // Shared resources accessible by this op kernel invocation. |
613 | ResourceMgr* resource_manager = nullptr; |
614 | |
615 | // Per-step resources accessible by this op kernel invocation should be |
616 | // stored in this container.. |
617 | ScopedStepContainer* step_container = nullptr; |
618 | |
619 | // Mechanism used by this op kernel invocation to communicate with |
620 | // computations running on other devices. |
621 | RendezvousInterface* rendezvous = nullptr; |
622 | |
623 | // Mechanism for executing a collective op that needs to coordinate |
624 | // with parallel instances running on other devices. |
625 | CollectiveExecutor* collective_executor = nullptr; |
626 | |
627 | // The session state for this op. |
628 | SessionState* session_state = nullptr; |
629 | |
630 | // Unique session identifier. Can be empty. |
631 | std::string session_handle; |
632 | |
633 | // Metadata about the session. Can be nullptr. |
634 | const SessionMetadata* session_metadata = nullptr; |
635 | |
636 | // The tensor store for this op. |
637 | TensorStore* tensor_store = nullptr; |
638 | |
639 | // Mechanism used by this op kernel invocation to register a callback |
640 | // for its cancellation. |
641 | CancellationManager* cancellation_manager = nullptr; |
642 | |
643 | // Inputs to this op kernel. |
644 | absl::Span<const TensorValue> inputs; |
645 | bool is_input_dead = false; |
646 | |
647 | absl::Span<const AllocatorAttributes> input_alloc_attrs; |
648 | |
649 | // Device context. |
650 | DeviceContext* op_device_context = nullptr; |
651 | |
652 | // Control-flow op supports. |
653 | FrameAndIter frame_iter; |
654 | |
655 | // Function call supports. |
656 | CallFrameInterface* call_frame = nullptr; |
657 | FunctionLibraryRuntime* function_library = nullptr; |
658 | std::function<void(std::function<void()>)>* runner = nullptr; |
659 | StepStatsCollectorInterface* stats_collector = nullptr; |
660 | GraphCollector* graph_collector = nullptr; |
661 | bool run_all_kernels_inline = false; |
662 | const std::string* executor_type = nullptr; |
663 | |
664 | // TensorSliceReaderCache support. |
665 | checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache = nullptr; |
666 | |
667 | // Support for forwarding reservations (used by ScopedAllocator). |
668 | static constexpr int kNeverForward = -2; |
669 | static constexpr int kNoReservation = -1; |
670 | // Values in [0,...) represent reservations for the indexed output. |
671 | const int* forward_from_array = nullptr; |
672 | |
673 | // For tracking actively running deferred ops. |
674 | std::function<void()> inc_num_deferred_ops_function; |
675 | std::function<void()> dec_num_deferred_ops_function; |
676 | |
677 | absl::optional<ManagedStackTrace> stack_trace = {}; |
678 | |
679 | // For implementing `OpKernelContext::output_required()`. If null, all |
680 | // outputs are required. |
681 | bool* outputs_required_array = nullptr; |
682 | |
683 | // For access to distributed coordination service. |
684 | CoordinationServiceAgent* coordination_service_agent = nullptr; |
685 | }; |
686 | |
687 | // params must outlive the OpKernelContext. |
688 | explicit OpKernelContext(Params* params); |
689 | OpKernelContext(Params* params, int num_outputs); |
690 | ~OpKernelContext(); |
691 | |
692 | Env* env() const { return params_->device->env(); } |
693 | |
694 | int64_t step_id() const { return params_->step_id; } |
695 | |
696 | int64_t start_time_usecs() const { return params_->start_time_usecs; } |
697 | |
698 | // The deadline for the session to complete by. Empty if unspecified in |
699 | // RunOptions. |
700 | absl::optional<absl::Time> deadline() const { return params_->deadline; } |
701 | |
702 | const OpKernel& op_kernel() const { return *params_->op_kernel; } |
703 | |
704 | // Stack trace of where the op was defined (if defined in eager mode). |
705 | const absl::optional<ManagedStackTrace>& stack_trace() const { |
706 | return params_->stack_trace; |
707 | } |
708 | |
709 | // Input/output signature. |
710 | |
711 | int num_inputs() const { return params_->inputs.size(); } |
712 | DataType input_dtype(int index) const; |
713 | Status input_dtype(StringPiece name, DataType* dtype) const; |
714 | MemoryType input_memory_type(int index) const; |
715 | |
716 | int num_outputs() const { return outputs_.size(); } |
717 | DataType expected_output_dtype(int index) const; |
718 | MemoryType output_memory_type(int index) const; |
719 | |
720 | // Input |
721 | |
722 | // Returns an immutable input tensor. May only be used for non-Ref |
723 | // inputs. For Ref inputs use mutable_input below. |
724 | // REQUIRES: !IsRefType(input_dtype(index)) |
725 | // TODO(mrry): Convert this to return Status. |
726 | const Tensor& input(int index) const; |
727 | |
728 | // Returns the named immutable input tensor in "tensor", as defined |
729 | // in the OpDef. May only be used for non-Ref inputs. For Ref inputs |
730 | // use mutable_input below. |
731 | // REQUIRES: !IsRefType(input_dtype(index)) |
732 | // REQUIRES: the named input must not be a list. |
733 | Status input(StringPiece name, const Tensor** tensor); |
734 | |
735 | // Returns the named list-valued immutable input in "list", as |
736 | // defined in the OpDef. If the named output is not list-valued, |
737 | // returns a one-element list. May only be used for non-Ref |
738 | // inputs. For Ref inputs use mutable_input below. |
739 | // REQUIRES: !IsRefType(input_dtype(index)) |
740 | Status input_list(StringPiece name, OpInputList* list); |
741 | |
742 | // For mutable inputs, use the following together to make sure there |
743 | // is no concurrent access to mutable_input(), e.g.: |
744 | // { |
745 | // Tensor& t = context->mutable_input(index); |
746 | // mutex_lock lock(*context->input_ref_mutex(index)); |
747 | // // modify the values in t |
748 | // } |
749 | // REQUIRES: IsRefType(input_dtype(index)) |
750 | Status input_ref_mutex(StringPiece name, mutex** out_mutex); |
751 | |
752 | // Returns a mutable input tensor. Must be used to access Ref |
753 | // inputs. REQUIRES: IsRefType(input_dtype(index)). The caller may |
754 | // modify the values stored in the Tensor buffer, and modifications |
755 | // will be visible to other Ops reading the same ref tensor. If |
756 | // !lock_held the input mutex will be acquired before returning the |
757 | // Tensor. |
758 | // TODO(mrry): Convert this to return Status. |
759 | Tensor mutable_input(int index, bool lock_held); |
760 | |
761 | // Returns the named mutable input tensor in "tensor", as defined in |
762 | // the OpDef. Must be used to access Ref inputs. The values stored |
763 | // in the Tensor buffer may be modified, and modifications will be |
764 | // visible to other Ops reading the same ref tensor. If !lock_held |
765 | // the input mutex will be acquired before returning the Tensor. |
766 | // REQUIRES: the named input must not be a list. |
767 | // REQUIRES: the named input must be a ref tensor. |
768 | Status mutable_input(StringPiece name, Tensor* tensor, bool lock_held); |
769 | |
770 | // Returns the named list-valued mutable input in "list", as defined |
771 | // in the OpDef. If the named input is not list-valued, returns a |
772 | // one-element list. Must be used to access Ref inputs. The values |
773 | // stored in the Tensor buffer may be modified, and modifications |
774 | // will be visible to other Ops reading the same ref tensor. |
775 | // REQUIRES: the named input must be a ref tensor. |
776 | Status mutable_input_list(StringPiece name, OpMutableInputList* list); |
777 | |
778 | // Replace the corresponding Ref Input to use the storage buffer |
779 | // used by tensor. If !lock_held the input mutex will be acquired |
780 | // before returning the Tensor. |
781 | // REQUIRES: IsRefType(input_dtype(index)). |
782 | void replace_ref_input(int index, const Tensor& tensor, bool lock_held); |
783 | |
784 | // Replace the corresponding named Ref Input to use the storage |
785 | // buffer used by tensor. If !lock_held the input mutex will be |
786 | // acquired before returning the Tensor. |
787 | // REQUIRES: IsRefType(input_dtype(index)). |
788 | Status replace_ref_input(StringPiece name, const Tensor& tensor, |
789 | bool lock_held); |
790 | |
791 | // Deletes the Tensor object used as the Ref Input at |
792 | // input_index. This is not usually necessary and should be used |
793 | // with caution. If !lock_held the input mutex will be acquired |
794 | // before returning the Tensor. |
795 | // REQUIRES: IsRefType(input_dtype(input_index)). |
796 | void delete_ref_input(int input_index, bool lock_held); |
797 | |
798 | // Return true if there is input at the given index. An operator has no |
799 | // input at index if its tensor is null. This is primarily used by the |
800 | // merge operator. |
801 | // TODO(mrry): Convert this to return Status. |
802 | bool has_input(int index) const; |
803 | |
804 | // Returns true if all inputs are the same shape, otherwise sets the |
805 | // status to a non-OK value and returns false. |
806 | // Usage: if (!context->ValidateInputsAreSameShape(this)) return; |
807 | bool ValidateInputsAreSameShape(OpKernel* op); |
808 | |
809 | // If non-null, kernels should populate with any partition subgraphs created. |
810 | GraphCollector* graph_collector() { return params_->graph_collector; } |
811 | |
812 | // If True, hint that all kernels in functions called by this kernel, should |
813 | // be treated as "inexpensive", and hence executed on the scheduling thread. |
814 | bool run_all_kernels_inline() const { |
815 | return params_->run_all_kernels_inline; |
816 | } |
817 | |
818 | // Returns the registered name for the executor type that is executing the |
819 | // current kernel. If empty, the default executor is used. |
820 | const std::string& executor_type() const; |
821 | |
822 | // Input to output forwarding. |
823 | |
824 | // Set the output Ref Tensor at output_index to be an alias of the |
825 | // input Ref Tensor at input_index. |
826 | // REQUIRES: IsRefType(input_dtype(input_index)). |
827 | // REQUIRES: IsRefType(output_dtype(output_index)). |
828 | void forward_ref_input_to_ref_output(int input_index, int output_index); |
829 | |
830 | // Returns true when an alias to input[input_index], reshaped to output_shape, |
831 | // which is safe to use for in-place computation was written to *output. |
832 | // Returns false if input[input_index] has a refcount greater than one, or if |
833 | // its type does not match the expected output type of output[output_index], |
834 | // or the number of elements in input[input_index] does not equal the number |
835 | // of elements in output_shape. |
836 | bool forward_input_to_output_with_shape(int input_index, int output_index, |
837 | const TensorShape& output_shape, |
838 | Tensor** output) TF_MUST_USE_RESULT; |
839 | Status forward_input_to_output_with_shape(StringPiece input_name, |
840 | StringPiece output_name, |
841 | const TensorShape& output_shape, |
842 | Tensor** output) TF_MUST_USE_RESULT; |
843 | |
844 | // Returns a pointer to a Tensor aliasing the underlying buffer backing |
845 | // input[input_index] iff |
846 | // * input[input_index] is not a ref, |
847 | // * the data type, shape, memory type, and allocator attributes of |
848 | // input[input_index] are compatible with those given in dtype, shape, |
849 | // memory_type, and attr, |
850 | // * refcount on the underlying buffer is one. |
851 | // * Either there is no forwarding reservation for either input_index |
852 | // or output_index or the specified input is reserved for the specified |
853 | // output. More precisely: |
854 | // |
855 | // These cases mean neither input nor output has a reservation: |
856 | // forward_from_array = nullptr |
857 | // OR (input_index is not in forward_from_array AND |
858 | // (output_index == kNoReservation OR |
859 | // forward_from_array[output_index] == kNoReservation)) |
860 | // |
861 | // This case means that input_index is reserved for output_index: |
862 | // forward_from_array[output_index] == input_index |
863 | // |
864 | // This case means the output is reserved to always be allocated, |
865 | // never assigned a forwarded input: |
866 | // forward_from_array[output_index] == kNeverForward |
867 | // |
868 | // Otherwise returns nullptr. |
869 | // NOTE: For Cuda kernels that read inputs using the __ldg() intrinsic, |
870 | // forwarding is only safe if there are no reads via __ldg() after writes |
871 | // to the same address. |
872 | std::unique_ptr<Tensor> forward_input( |
873 | int input_index, int output_index, DataType output_dtype, |
874 | const TensorShape& output_shape, MemoryType output_memory_type, |
875 | const AllocatorAttributes& output_attr) TF_MUST_USE_RESULT; |
876 | |
877 | // Tries to forward one of the inputs given in input_indices to |
878 | // output[output_index]. If none of the given inputs can be forwarded, calls |
879 | // allocate_output() to allocate a new output buffer. The index of the |
880 | // forwarded input will be assign to output argument forwarded_input (if it's |
881 | // not nullptr). If no inputs are forwarded, forwarded_input will be assigned |
882 | // -1. |
883 | Status forward_input_or_allocate_output( |
884 | gtl::ArraySlice<int> candidate_input_indices, int output_index, |
885 | const TensorShape& output_shape, Tensor** output, |
886 | int* forwarded_input = nullptr) TF_MUST_USE_RESULT; |
887 | Status forward_input_or_allocate_output( |
888 | gtl::ArraySlice<StringPiece> candidate_input_names, |
889 | StringPiece output_name, const TensorShape& output_shape, |
890 | Tensor** output) TF_MUST_USE_RESULT; |
891 | |
892 | // Tries to reuse one of the inputs given in input_indices as a temporary. |
893 | // If none of the given inputs can be forwarded, calls |
894 | // allocate_temp() to allocate a new temporary buffer. |
895 | Status forward_input_or_allocate_temp( |
896 | gtl::ArraySlice<int> candidate_input_indices, DataType type, |
897 | const TensorShape& shape, const AllocatorAttributes& allocator_attr, |
898 | Tensor* out_temp) TF_MUST_USE_RESULT; |
899 | |
900 | Status forward_input_or_allocate_temp( |
901 | gtl::ArraySlice<int> candidate_input_indices, DataType type, |
902 | const TensorShape& shape, Tensor* out_temp) TF_MUST_USE_RESULT { |
903 | return forward_input_or_allocate_temp(candidate_input_indices, type, shape, |
904 | AllocatorAttributes(), out_temp); |
905 | } |
906 | |
907 | // Output |
908 | |
909 | // Returns the named list-valued output in "list", as defined in the OpDef. |
910 | // If the named output is not list-valued, returns a one-element list. |
911 | Status output_list(StringPiece name, OpOutputList* list); |
912 | |
913 | // If output_required(index) returns true, the OpKernel's Compute() method |
914 | // should call allocate_output(index, ...), set_output(index, ...), |
915 | // set_output_ref(index, ...), or set the status to a non-ok value. |
916 | // If it returns false, it may output, but is not required to do so. |
917 | bool output_required(int index) const { |
918 | return !params_->outputs_required_array || |
919 | params_->outputs_required_array[index]; |
920 | } |
921 | |
922 | // If output_expects_forwarding returns true, the OpKernel's Compute() method |
923 | // should not allocate the output with allocate_output but instead needs to |
924 | // use forward_input. |
925 | bool output_expects_forwarding(int index) const { |
926 | return params_->forward_from_array != nullptr && |
927 | params_->forward_from_array[index] >= 0; |
928 | } |
929 | |
930 | // Allocation of tensors during kernel execution inside the Compute |
931 | // method: |
932 | // |
933 | // There are two methods to allocate Tensors when an Op kernel |
934 | // executes. |
935 | // |
936 | // 1) allocate_output. This should be used to allocate any tensor |
937 | // that is going to be used as an output from the Op at the end of |
938 | // the current execution. The caller indicates which output the |
939 | // Tensor will be assigned to, and the call returns the |
940 | // newly-allocated Tensor. The Tensor can subsequently be assigned |
941 | // to during kernel execution, and will be used as the designated |
942 | // output when the kernel execution completes. |
943 | // |
944 | // 2) allocate_temp. This should be used to allocate any scratch |
945 | // storage that is needed while the kernel is executing, and will |
946 | // not be retained by the Op. |
947 | // |
948 | // In some cases a Tensor needs to be used as an output even though |
949 | // it was previously allocated elsewhere. The Tensor may have been |
950 | // passed as an input, or stored in a Tensor during a |
951 | // previous kernel execution, or allocated earlier in the kernel |
952 | // execution at a time when it was not known which output it would |
953 | // be assigned to. In this case the kernel can use set_output or |
954 | // set_output_ref to indicate that the tensor should be used as the |
955 | // designated output. It is legal to use any previously-allocated |
956 | // Tensor as an argument to set_output or set_output_ref, including |
957 | // Tensors allocated via allocate_temp. There may be a performance |
958 | // penalty to using a Tensor that was not allocated using |
959 | // allocate_output. This is because allocate_output uses the |
960 | // AllocatorAttributes stored in output_attr_array for the |
961 | // designated output. In some cases, using the wrong attributes may |
962 | // cause an extra copy of the Tensor's buffer. |
963 | |
964 | // Allocates output for the specified output index with shape. |
965 | // OpKernelContext retains ownership of the returned pointer. See |
966 | // comment above. |
967 | // |
968 | // If memory allocation fails, returns an error status. |
969 | // |
970 | // REQUIRES: !IsRefType(expected_output_dtype(index)) |
971 | Status allocate_output(int index, const TensorShape& shape, |
972 | Tensor** tensor) TF_MUST_USE_RESULT; |
973 | Status allocate_output(StringPiece name, const TensorShape& shape, |
974 | Tensor** tensor) TF_MUST_USE_RESULT; |
975 | // The following methods use the supplied attributes instead of |
976 | // those in output_attr_array. The caller is responsible for |
977 | // ensuring that the attributes are "compatible" with the |
978 | // output_attr_array, e.g. the tensor is allocated on the correct |
979 | // device. See comment above. |
980 | Status allocate_output(int index, const TensorShape& shape, Tensor** tensor, |
981 | AllocatorAttributes attr) TF_MUST_USE_RESULT; |
982 | Status allocate_output(StringPiece name, const TensorShape& shape, |
983 | Tensor** tensor, |
984 | AllocatorAttributes attr) TF_MUST_USE_RESULT; |
985 | |
986 | // Allocates a temporary Tensor of the specified type and |
987 | // shape. Devices such as GPUs that enqueue Ops for lazy execution |
988 | // may retain references to the temporary tensors after the Op's |
989 | // Compute method has run. See comment above. |
990 | Status allocate_temp(DataType type, const TensorShape& shape, |
991 | Tensor* out_temp, AllocatorAttributes allocator_attr, |
992 | const AllocationAttributes& allocation_attr); |
993 | Status allocate_temp(DataType type, const TensorShape& shape, |
994 | Tensor* out_temp, AllocatorAttributes allocator_attr); |
995 | Status allocate_temp(DataType type, const TensorShape& shape, |
996 | Tensor* out_temp); |
997 | |
998 | // Copies a tensor (allocated by the caller) to the specified output |
999 | // index. REQUIRES: !IsRefType(expected_output_dtype(index)) |
1000 | // REQUIRES: 'tensor' must have the same MemoryType as |
1001 | // output_memory_types[index]. See comment above. |
1002 | Status set_output(StringPiece name, const Tensor& tensor); |
1003 | Status set_output(StringPiece name, Tensor&& tensor); |
1004 | void set_output(int index, const Tensor& tensor); |
1005 | void set_output(int index, Tensor&& tensor); |
1006 | |
1007 | // To output a reference. Caller retains ownership of mu and tensor_for_ref, |
1008 | // and they must outlive all uses within the step. See comment above. |
1009 | // REQUIRES: IsRefType(expected_output_dtype(index)) |
1010 | Status set_output_ref(StringPiece name, mutex* mu, Tensor* tensor_for_ref); |
1011 | |
1012 | // Returns nullptr if allocate_output() or set_output() have not been called. |
1013 | Status mutable_output(StringPiece name, Tensor** tensor); |
1014 | |
1015 | // Return the DeviceContext that should be used for this Op. |
1016 | // |
1017 | // If using the templated function, the type must be a subclass |
1018 | // of DeviceContext. |
1019 | // |
1020 | // Returns nullptr if the device did not provide one. |
1021 | template <typename T> |
1022 | T* op_device_context(); |
1023 | DeviceContext* op_device_context() { |
1024 | DeviceContext* ret = params_->op_device_context; |
1025 | if (ret == nullptr) { |
1026 | auto* dev_info = device()->tensorflow_accelerator_device_info(); |
1027 | if (dev_info) ret = dev_info->default_context; |
1028 | } |
1029 | return ret; |
1030 | } |
1031 | |
1032 | AllocatorAttributes input_alloc_attr(int index) const { |
1033 | if (params_->input_alloc_attrs.empty()) { |
1034 | return AllocatorAttributes(); |
1035 | } else { |
1036 | DCHECK_GE(index, 0); |
1037 | DCHECK_LT(index, params_->input_alloc_attrs.size()); |
1038 | return params_->input_alloc_attrs[index]; |
1039 | } |
1040 | } |
1041 | |
1042 | AllocatorAttributes output_alloc_attr(int index) const { |
1043 | return params_->output_attr_array[index]; |
1044 | } |
1045 | |
1046 | gtl::InlinedVector<WrappedAllocator, 4> ConsumeWrappedAllocators() { |
1047 | gtl::InlinedVector<WrappedAllocator, 4> retrieved; |
1048 | if (tracking_state_) { |
1049 | mutex_lock lock(tracking_state_->mu); |
1050 | retrieved.swap(tracking_state_->wrapped_allocators); |
1051 | } |
1052 | return retrieved; |
1053 | } |
1054 | |
1055 | // Communication. |
1056 | // |
1057 | // An op kernel communicates with outside environment through |
1058 | // Rendezvous Send() and Recv(). |
1059 | RendezvousInterface* rendezvous() const { return params_->rendezvous; } |
1060 | |
1061 | CollectiveExecutor* collective_executor() const { |
1062 | return params_->collective_executor; |
1063 | } |
1064 | |
1065 | // An op kernel can access the session state it belongs to. |
1066 | SessionState* session_state() const { return params_->session_state; } |
1067 | |
1068 | // Unique identifier of the session it belongs to. Can be empty. |
1069 | std::string session_handle() const { return params_->session_handle; } |
1070 | |
1071 | // Metadata about the session. Can be nullptr. |
1072 | const SessionMetadata* session_metadata() const { |
1073 | return params_->session_metadata; |
1074 | } |
1075 | |
1076 | // An op kernel can access the tensor store of the run it belongs to. |
1077 | TensorStore* tensor_store() const { return params_->tensor_store; } |
1078 | |
1079 | // Function call support. |
1080 | // |
1081 | // If this kernel invocation is within a function execution, |
1082 | // call_frame() returns the call frame for the function call. |
1083 | CallFrameInterface* call_frame() const { return params_->call_frame; } |
1084 | |
1085 | // If not nullptr, the kernel invoke functions defined in the |
1086 | // library. E.g., CHECK_NOTNULL(function_library())->Run("Foo", ...). |
1087 | FunctionLibraryRuntime* function_library() const { |
1088 | return params_->function_library; |
1089 | } |
1090 | |
1091 | std::function<void(std::function<void()>)>* runner() const { |
1092 | return params_->runner; |
1093 | } |
1094 | StepStatsCollectorInterface* stats_collector() const { |
1095 | return params_->stats_collector; |
1096 | } |
1097 | |
1098 | // Shared resources accessible to this kernel. |
1099 | ResourceMgr* resource_manager() const { return params_->resource_manager; } |
1100 | |
1101 | checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache() const { |
1102 | return params_->slice_reader_cache; |
1103 | } |
1104 | |
1105 | // Execution. |
1106 | // |
1107 | // OpKernels can use these eigen devices to carry out their |
1108 | // numerical computation. |
1109 | const Eigen::ThreadPoolDevice& eigen_cpu_device() const { |
1110 | return *device()->eigen_cpu_device(); |
1111 | } |
1112 | const Eigen::GpuDevice& eigen_gpu_device() const { |
1113 | return params_->eigen_gpu_device->device(); |
1114 | } |
1115 | template <typename EigenDeviceType> |
1116 | const EigenDeviceType& eigen_device() const; |
1117 | |
1118 | // Error handling. |
1119 | |
1120 | // If expected_inputs == inputs() and expected_outputs == output_types(), |
1121 | // returns OK, else returns INVALID_ARGUMENT with an error message. |
1122 | // Recommended for Ops with dynamic signatures, where validation can only |
1123 | // be performed at runtime. |
1124 | Status MatchSignature(const DataTypeSlice expected_inputs, |
1125 | const DataTypeSlice expected_outputs); |
1126 | |
1127 | // An OpKernel should call SetStatus() if Compute() encounters an |
1128 | // error. |
1129 | void SetStatus(const Status& status); |
1130 | const Status& status() const { return status_; } |
1131 | |
1132 | // Cancellation. |
1133 | // |
1134 | // EXPERIMENTAL. See the implementation in tensorflow::FIFOQueue for an |
1135 | // example of how to use this API. |
1136 | CancellationManager* cancellation_manager() const { |
1137 | return params_->cancellation_manager; |
1138 | } |
1139 | |
1140 | // Other accessors. |
1141 | |
1142 | // For control flow. |
1143 | FrameAndIter frame_iter() const { return params_->frame_iter; } |
1144 | bool is_input_dead() const { return params_->is_input_dead; } |
1145 | |
1146 | // May be used, e.g., to get GPU handles, etc. |
1147 | // TODO(tucker): Add example usage. |
1148 | DeviceBase* device() const { return params_->device; } |
1149 | |
1150 | // Per-step container for use by white-listed internal ops. |
1151 | ScopedStepContainer* step_container() const { |
1152 | return params_->step_container; |
1153 | } |
1154 | |
1155 | // Access to distributed coordination service. |
1156 | CoordinationServiceAgent* coordination_service_agent() const { |
1157 | return params_->coordination_service_agent; |
1158 | } |
1159 | |
1160 | // Helper routines for the OP_REQUIRES macros |
1161 | void CtxFailure(const Status& s); |
1162 | void CtxFailureWithWarning(const Status& s); |
1163 | void CtxFailure(const char* file, int line, const Status& s); |
1164 | void CtxFailureWithWarning(const char* file, int line, const Status& s); |
1165 | |
1166 | // Unrecommended functions: these are functions that have some |
1167 | // current uses but are not recommended for use, and may go away at |
1168 | // some future major version release. |
1169 | // |
1170 | // The following functions all have versions that return Status |
1171 | // to capture error conditions, and are strongly preferred. |
1172 | Tensor* mutable_output(int index); |
1173 | mutex* input_ref_mutex(int index); |
1174 | void set_output_ref(int index, mutex* mu, Tensor* tensor_for_ref); |
1175 | TensorValue release_output(int index); |
1176 | |
1177 | bool track_allocations() const { return params_->track_allocations; } |
1178 | |
1179 | // Records temp memory allocation. Tensor object is recorded to identify the |
1180 | // case where temp memory is used as output memory. |
1181 | void record_temp_memory_allocation(int64_t size, const Tensor& t) |
1182 | TF_LOCKS_EXCLUDED(tracking_state_->stats_mu); |
1183 | |
1184 | // Returns recorded size of temporary memory; |
1185 | int64_t temp_memory_allocated() const |
1186 | TF_LOCKS_EXCLUDED(tracking_state_->stats_mu); |
1187 | |
1188 | // Records persistent memory allocation, size can be negative indicating |
1189 | // deallocation. |
1190 | void record_persistent_memory_allocation(int64_t size, int64_t alloc_id = -1) |
1191 | TF_LOCKS_EXCLUDED(tracking_state_->stats_mu); |
1192 | |
1193 | // Returns recorded size and ids of persistent memory. |
1194 | int64_t persistent_memory_allocated() const |
1195 | TF_LOCKS_EXCLUDED(tracking_state_->stats_mu); |
1196 | |
1197 | std::vector<int64_t> persistent_alloc_ids() const |
1198 | TF_LOCKS_EXCLUDED(tracking_state_->stats_mu); |
1199 | |
1200 | // Resets counters for temp and persistent memory and recorded ids. |
1201 | void clear_recorded_memory() TF_LOCKS_EXCLUDED(tracking_state_->stats_mu); |
1202 | |
1203 | bool input_is_ref(int index) const; |
1204 | |
1205 | void set_record_memory_consumption(bool v); |
1206 | |
1207 | // Used by OpKernel implementations to track actively running deferred ops. |
1208 | // |
1209 | // A deferred op is one whose Compute method returns (or whose ComputeAsync |
1210 | // method invokes the callback) when work is scheduled onto a device. At that |
1211 | // point, we don't know when the work will actually complete (or if it has |
1212 | // already completed) on the device. These functions allow the executor to |
1213 | // track the status of deferred ops and act accordingly. |
1214 | // |
1215 | // Deferred OpKernel implementations must use these methods to get two |
1216 | // functions. It then must call these two functions in pairs, before and after |
1217 | // device execution, respectively. |
1218 | TF_MUST_USE_RESULT std::function<void()> inc_num_deferred_ops_function() { |
1219 | DCHECK(params_->op_kernel->is_deferred()); |
1220 | return params_->inc_num_deferred_ops_function |
1221 | ? params_->inc_num_deferred_ops_function |
1222 | : []() {}; |
1223 | } |
1224 | TF_MUST_USE_RESULT std::function<void()> dec_num_deferred_ops_function() { |
1225 | DCHECK(params_->op_kernel->is_deferred()); |
1226 | return params_->dec_num_deferred_ops_function |
1227 | ? params_->dec_num_deferred_ops_function |
1228 | : []() {}; |
1229 | } |
1230 | |
1231 | Allocator* get_allocator(AllocatorAttributes attr); |
1232 | |
1233 | private: |
1234 | bool record_memory_consumption_ = false; |
1235 | |
1236 | // Internal common method used when allocating tensor memory |
1237 | Status allocate_tensor(DataType type, const TensorShape& shape, |
1238 | Tensor* out_tensor, |
1239 | AllocatorAttributes allocator_attr) { |
1240 | return allocate_tensor(type, shape, out_tensor, allocator_attr, |
1241 | AllocationAttributes()); |
1242 | } |
1243 | |
1244 | Status allocate_tensor(DataType type, const TensorShape& shape, |
1245 | Tensor* out_tensor, AllocatorAttributes allocator_attr, |
1246 | const AllocationAttributes& allocation_attr); |
1247 | |
1248 | // Helpers for `set_output()`. |
1249 | |
1250 | // Returns `true` if the tensor was copied into an allocated output. |
1251 | bool maybe_set_output_by_allocate_and_copy(int index, const Tensor& tensor); |
1252 | |
1253 | void maybe_track_allocations_for_set_output(const Tensor& tensor); |
1254 | |
1255 | Status get_input_index(StringPiece name, int* out_index) const; |
1256 | Status get_output_index(StringPiece name, int* out_index) const; |
1257 | |
1258 | // Initialize the allocated_scope_ids_ set the first time this method is |
1259 | // called. |
1260 | void maybe_initialize_scope_id_set(); |
1261 | |
1262 | Status status_; |
1263 | friend class CollectiveExecutor; // for access to params_ |
1264 | Params* params_; // not owned |
1265 | gtl::InlinedVector<TensorValue, 4> outputs_; |
1266 | |
1267 | // Keep track of calls to ScopedAllocator. |
1268 | // TODO(ayushd): change to absl::flat_hash_set. |
1269 | std::unique_ptr<std::unordered_set<int32>> allocated_scope_ids_; |
1270 | |
1271 | // The following data members are only used when allocation tracking is |
1272 | // enabled, memory consumption is being recorded, or tensor access is being |
1273 | // recorded. |
1274 | struct TrackingState { |
1275 | mutable mutex mu; |
1276 | gtl::InlinedVector<WrappedAllocator, 4> wrapped_allocators |
1277 | TF_GUARDED_BY(mu); |
1278 | |
1279 | mutable mutex stats_mu; |
1280 | int64_t temp_memory_allocated TF_GUARDED_BY(stats_mu) = 0; |
1281 | |
1282 | int64_t persistent_memory_allocated TF_GUARDED_BY(stats_mu) = 0; |
1283 | gtl::InlinedVector<std::pair<const void*, int64_t>, 2> |
1284 | temp_tensor_buffer_and_size TF_GUARDED_BY(stats_mu); |
1285 | gtl::InlinedVector<int64_t, 2> persistent_alloc_ids TF_GUARDED_BY(stats_mu); |
1286 | }; |
1287 | std::unique_ptr<TrackingState> tracking_state_; |
1288 | |
1289 | // For access to `params_->op_kernel`. |
1290 | friend void CheckNotInComputeAsync(OpKernelContext* ctx, |
1291 | const char* correct_macro_name); |
1292 | |
1293 | TF_DISALLOW_COPY_AND_ASSIGN(OpKernelContext); |
1294 | }; |
1295 | |
1296 | template <> |
1297 | const Eigen::ThreadPoolDevice& OpKernelContext::eigen_device() const; |
1298 | |
1299 | template <> |
1300 | const Eigen::GpuDevice& OpKernelContext::eigen_device() const; |
1301 | |
1302 | // Register your OpKernel by specifying the Op's name, the device the |
1303 | // kernel runs on, any type attr constraints for this kernel, any |
1304 | // host-memory args, and the class to instantiate. Examples: |
1305 | // |
1306 | // // A kernel that supports all types. |
1307 | // REGISTER_KERNEL_BUILDER(Name("Save").Device(DEVICE_CPU), SaveOp); |
1308 | // |
1309 | // // The following are equivalent ways of specifying that the kernel only |
1310 | // // works if the "T" type attr is set to DT_FLOAT. |
1311 | // REGISTER_KERNEL_BUILDER( |
1312 | // Name("Sub").Device(DEVICE_CPU).TypeConstraint<float>("T"), |
1313 | // SubOp<float>); |
1314 | // // (You would then repeat this for every type supported by "Sub".) |
1315 | // |
1316 | // // This form allows you to specify a list of types as the constraint. |
1317 | // REGISTER_KERNEL_BUILDER(Name("Sub") |
1318 | // .Device(DEVICE_CPU) |
1319 | // .TypeConstraint("T", {DT_FLOAT}), |
1320 | // SubOp<float>); |
1321 | // |
1322 | // // A kernel that expects one of the input tensors in host memory. |
1323 | // REGISTER_KERNEL_BUILDER( |
1324 | // Name("Reshape").Device(DEVICE_GPU).HostMemory("shape"), ReshapeOp); |
1325 | // |
1326 | // // A kernel that works on any device. Kernels using DEVICE_DEFAULT |
1327 | // // must aways run on host and all inputs and outputs must use `HostMemory`. |
1328 | // // Kernels for data management, control-flow primitives or working with |
1329 | // // tensor shapes for various devices (including `PluggableDevices`) are |
1330 | // // typical uses. |
1331 | // REGISTER_KERNEL_BUILDER( |
1332 | // Name("TensorListLength").Device(DEVICE_DEFAULT).HostMemory("length"), |
1333 | // TensorListLength); |
1334 | // |
1335 | // See kernel_def_builder for details. |
1336 | |
1337 | // Instantiate an OpKernel that has been registered. Returns nullptr |
1338 | // if no operation for that type of device / input signature combination |
1339 | // (and a NOT_FOUND *status), or there is an error in construction (and |
1340 | // an INVALID_ARGUMENT *status). Otherwise, the caller takes ownership |
1341 | // of the returned pointer. |
1342 | // EXPECTED USAGE: unique_ptr<OpKernel> op = CreateOpKernel(...); |
1343 | // REQUIRES: def has all attrs specified (e.g. using AddDefaultsToNodeDef()). |
1344 | std::unique_ptr<OpKernel> CreateOpKernel(DeviceType device_type, |
1345 | DeviceBase* device, |
1346 | Allocator* allocator, |
1347 | const NodeDef& node_def, |
1348 | int graph_def_version, Status* status); |
1349 | |
1350 | std::unique_ptr<OpKernel> CreateOpKernel( |
1351 | DeviceType device_type, DeviceBase* device, Allocator* allocator, |
1352 | const std::shared_ptr<const NodeProperties>& props, int graph_def_version, |
1353 | Status* status); |
1354 | |
1355 | Status CreateOpKernel(DeviceType device_type, DeviceBase* device, |
1356 | Allocator* allocator, FunctionLibraryRuntime* flib, |
1357 | const std::shared_ptr<const NodeProperties>& props, |
1358 | int graph_def_version, OpKernel** kernel); |
1359 | |
1360 | Status CreateOpKernel(DeviceType device_type, DeviceBase* device, |
1361 | Allocator* allocator, FunctionLibraryRuntime* flib, |
1362 | ResourceMgr* resource_mgr, |
1363 | const std::shared_ptr<const NodeProperties>& props, |
1364 | int graph_def_version, OpKernel** kernel); |
1365 | |
1366 | // Returns into 'device_types' the subset of prioritized_types that this |
1367 | // binary has registered for the given NodeDef. |
1368 | // |
1369 | // REQUIRES: * 'device_types' is not nullptr. |
1370 | // * def has all attrs specified (e.g. using AddDefaultsToNodeDef()). |
1371 | Status SupportedDeviceTypesForNode( |
1372 | const std::vector<DeviceType>& prioritized_types, const NodeDef& def, |
1373 | PrioritizedDeviceTypeVector* device_types, |
1374 | const DeviceNameUtils::ParsedName* local_address_spec = nullptr); |
1375 | |
1376 | // Returns a message with a description of the kernels registered for op |
1377 | // `op_name`. |
1378 | std::string KernelsRegisteredForOp(StringPiece op_name); |
1379 | |
1380 | // Call once after Op registration has completed. |
1381 | Status ValidateKernelRegistrations(const OpRegistryInterface& op_registry); |
1382 | |
1383 | // ----------------------------------------------------------------------------- |
1384 | // OpKernel registration implementation follows, please ignore. |
1385 | |
1386 | // Allow the REGISTER_KERNEL_BUILDER(Name("op_name").Device(...)...) syntax. |
1387 | namespace register_kernel { |
1388 | |
1389 | class Name : public KernelDefBuilder { |
1390 | public: |
1391 | explicit Name(const char* op); |
1392 | }; |
1393 | |
1394 | } // namespace register_kernel |
1395 | |
1396 | // Kernel registration appears as: |
1397 | // REGISTER_KERNEL_BUILDER(Name("OpName").Device(DEVICE_CPU)..., OpImpl) |
1398 | // We'd like to have "OpName" as a constant-expression, without requiring that |
1399 | // of the overall KernelDefBuilder expression (beginning with the |
1400 | // register_kernel::Name constructor above). |
1401 | // |
1402 | // So, we pull the "OpName" part to a separate macro-level argument. This |
1403 | // involves treating Name("OpName") as a macro call, via token-pasting (e.g. |
1404 | // M_## => M_Name("OpName")), and having it expand to '"OpName", |
1405 | // Name("OpName")' which is then usable as two arguments. |
1406 | #define (name_str) \ |
1407 | name_str, ::tensorflow::register_kernel::Name(name_str) |
1408 | #define (m, ...) m(__VA_ARGS__) |
1409 | #define (m, kernel_builder, ...) \ |
1410 | TF_EXTRACT_KERNEL_NAME_IMPL(m, TF_EXTRACT_KERNEL_NAME_##kernel_builder, \ |
1411 | __VA_ARGS__) |
1412 | |
1413 | // REGISTER_KERNEL_BUILDER_IMPL_2, with a unique 'ctr' as the first argument. |
1414 | // TODO(dodgen): There are some uses of this macro inside functions, where |
1415 | // kernel_builder refers to (non-const) locals (they should be fixed). To |
1416 | // accommodate those, kernel_builder.Build() appears as an argument to an |
1417 | // immediately-called lambda (not in the lambda itself). |
1418 | #define REGISTER_KERNEL_BUILDER_IMPL_3(ctr, op_name, kernel_builder_expr, \ |
1419 | is_system_kernel, ...) \ |
1420 | static ::tensorflow::InitOnStartupMarker const register_kernel_##ctr \ |
1421 | TF_ATTRIBUTE_UNUSED = \ |
1422 | TF_INIT_ON_STARTUP_IF(is_system_kernel || \ |
1423 | (SHOULD_REGISTER_OP_KERNEL(#__VA_ARGS__) && \ |
1424 | SHOULD_REGISTER_OP(op_name))) \ |
1425 | << ([](::tensorflow::KernelDef const* kernel_def) { \ |
1426 | ::tensorflow::kernel_factory::OpKernelRegistrar registrar( \ |
1427 | kernel_def, #__VA_ARGS__, \ |
1428 | [](::tensorflow::OpKernelConstruction* context) \ |
1429 | -> ::tensorflow::OpKernel* { \ |
1430 | return new __VA_ARGS__(context); \ |
1431 | }); \ |
1432 | (void)registrar; \ |
1433 | return ::tensorflow::InitOnStartupMarker{}; \ |
1434 | })(kernel_builder_expr.Build()); |
1435 | |
1436 | // REGISTER_KERNEL_BUILDER_IMPL, but with kernel_builder split to op_name, |
1437 | // kernel_builder_expr. |
1438 | #define REGISTER_KERNEL_BUILDER_IMPL_2(op_name, kernel_builder_expr, \ |
1439 | is_system_kernel, ...) \ |
1440 | TF_NEW_ID_FOR_INIT(REGISTER_KERNEL_BUILDER_IMPL_3, op_name, \ |
1441 | kernel_builder_expr, is_system_kernel, __VA_ARGS__) |
1442 | |
1443 | // REGISTER_KERNEL_BUILDER, but with is_system_kernel bound. |
1444 | #define REGISTER_KERNEL_BUILDER_IMPL(kernel_builder, is_system_kernel, ...) \ |
1445 | TF_EXTRACT_KERNEL_NAME(REGISTER_KERNEL_BUILDER_IMPL_2, kernel_builder, \ |
1446 | is_system_kernel, __VA_ARGS__) |
1447 | |
1448 | #define REGISTER_KERNEL_BUILDER(kernel_builder, ...) \ |
1449 | TF_ATTRIBUTE_ANNOTATE("tf:kernel") \ |
1450 | REGISTER_KERNEL_BUILDER_IMPL(kernel_builder, false, __VA_ARGS__) |
1451 | |
1452 | // The `REGISTER_SYSTEM_KERNEL_BUILDER()` macro acts as |
1453 | // `REGISTER_KERNEL_BUILDER()` except that the kernel is registered |
1454 | // unconditionally even when selective registration is used. |
1455 | #define REGISTER_SYSTEM_KERNEL_BUILDER(kernel_builder, ...) \ |
1456 | TF_ATTRIBUTE_ANNOTATE("tf:kernel") \ |
1457 | TF_ATTRIBUTE_ANNOTATE("tf:kernel:system") \ |
1458 | REGISTER_KERNEL_BUILDER_IMPL(kernel_builder, true, __VA_ARGS__) |
1459 | |
1460 | // Checks whether a given kernel is registered on device_type. |
1461 | bool KernelDefAvailable(const DeviceType& device_type, const NodeDef& node_def); |
1462 | |
1463 | // If node of node_name, experimental_debug_info, node_op, node_device and |
1464 | // node_attrs has a corresponding kernel registered on device_type, returns OK |
1465 | // and fill in the kernel def and kernel_class_name. <def> and |
1466 | // <kernel_class_name> may be null. |
1467 | Status FindKernelDef( |
1468 | const DeviceType& device_type, StringPiece node_name, |
1469 | bool has_experimental_debug_info, |
1470 | const NodeDef_ExperimentalDebugInfo& experimental_debug_info, |
1471 | StringPiece node_op, StringPiece node_device, AttrSlice node_attrs, |
1472 | const KernelDef** def, std::string* kernel_class_name); |
1473 | |
1474 | // If node_def has a corresponding kernel registered on device_type, |
1475 | // returns OK and fill in the kernel def and kernel_class_name. <def> and |
1476 | // <kernel_class_name> may be null. |
1477 | Status FindKernelDef(const DeviceType& device_type, const NodeDef& node_def, |
1478 | const KernelDef** def, std::string* kernel_class_name); |
1479 | |
1480 | // Writes a list of all registered kernels to LOG(INFO), to help users debug |
1481 | // missing kernel errors. |
1482 | void LogAllRegisteredKernels(); |
1483 | |
1484 | // Gets a list of all registered kernels. |
1485 | KernelList GetAllRegisteredKernels(); |
1486 | |
1487 | // Gets a list of all registered kernels for which predicate returns true |
1488 | KernelList GetFilteredRegisteredKernels( |
1489 | const std::function<bool(const KernelDef&)>& predicate); |
1490 | |
1491 | // Gets a list of all registered kernels for a given op |
1492 | KernelList GetRegisteredKernelsForOp(StringPiece op_name); |
1493 | |
1494 | namespace kernel_factory { |
1495 | |
1496 | // OpKernelFactory is responsible for creating OpKernels when TensorFlow needs |
1497 | // them. You register factories with the TensorFlow core by constructing an |
1498 | // OpKernelRegistrar and passing the factory as a constructor parameter. |
1499 | class OpKernelFactory { |
1500 | public: |
1501 | virtual OpKernel* Create(OpKernelConstruction* context) = 0; |
1502 | virtual ~OpKernelFactory() = default; |
1503 | }; |
1504 | |
1505 | class OpKernelRegistrar { |
1506 | public: |
1507 | // Registers the given kernel factory with TensorFlow. TF will call the |
1508 | // factory Create() method when it determines that a kernel matching the given |
1509 | // KernelDef is required. |
1510 | OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name, |
1511 | std::unique_ptr<OpKernelFactory> factory) |
1512 | TF_ATTRIBUTE_NOINLINE { |
1513 | InitInternal(kernel_def, kernel_class_name, std::move(factory)); |
1514 | } |
1515 | |
1516 | // Registers the given factory function with TensorFlow. This is equivalent |
1517 | // to registering a factory whose Create function invokes `create_fn`. |
1518 | OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name, |
1519 | OpKernel* (*create_fn)(OpKernelConstruction*)) |
1520 | TF_ATTRIBUTE_NOINLINE { |
1521 | InitInternal(kernel_def, kernel_class_name, |
1522 | absl::make_unique<PtrOpKernelFactory>(create_fn)); |
1523 | } |
1524 | |
1525 | private: |
1526 | struct PtrOpKernelFactory : public OpKernelFactory { |
1527 | explicit PtrOpKernelFactory(OpKernel* (*create_func)(OpKernelConstruction*)) |
1528 | : create_func_(create_func) {} |
1529 | |
1530 | OpKernel* Create(OpKernelConstruction* context) override; |
1531 | |
1532 | OpKernel* (*create_func_)(OpKernelConstruction*); |
1533 | }; |
1534 | |
1535 | void InitInternal(const KernelDef* kernel_def, StringPiece kernel_class_name, |
1536 | std::unique_ptr<OpKernelFactory> factory); |
1537 | }; |
1538 | |
1539 | } // namespace kernel_factory |
1540 | |
1541 | // ----------------------------------------------------------------------------- |
1542 | // Template and inline method implementations, please ignore |
1543 | |
1544 | template <class T> |
1545 | Status OpKernelConstruction::GetAttr(StringPiece attr_name, T* value) const { |
1546 | return GetNodeAttr(def(), attr_name, value); |
1547 | } |
1548 | |
1549 | inline DataType OpKernelContext::input_dtype(int index) const { |
1550 | DCHECK_GE(index, 0); |
1551 | DCHECK_LT(index, num_inputs()); |
1552 | const TensorValue& value(params_->inputs[index]); |
1553 | return value.dtype(); |
1554 | } |
1555 | |
1556 | inline MemoryType OpKernelContext::input_memory_type(int index) const { |
1557 | DCHECK_GE(index, 0); |
1558 | DCHECK_LT(index, num_inputs()); |
1559 | return op_kernel().input_memory_types()[index]; |
1560 | } |
1561 | |
1562 | inline DataType OpKernelContext::expected_output_dtype(int index) const { |
1563 | DCHECK_GE(index, 0); |
1564 | DCHECK_LT(index, num_outputs()); |
1565 | return params_->op_kernel->output_type(index); |
1566 | } |
1567 | |
1568 | inline MemoryType OpKernelContext::output_memory_type(int index) const { |
1569 | DCHECK_GE(index, 0); |
1570 | DCHECK_LT(index, num_outputs()); |
1571 | return op_kernel().output_memory_types()[index]; |
1572 | } |
1573 | |
1574 | inline bool OpKernelContext::input_is_ref(int index) const { |
1575 | const TensorValue& value(params_->inputs[index]); |
1576 | return value.is_ref(); |
1577 | } |
1578 | |
1579 | // no input if tensor == nullptr. |
1580 | inline bool OpKernelContext::has_input(int index) const { |
1581 | DCHECK_GE(index, 0); |
1582 | DCHECK_LT(index, num_inputs()); |
1583 | return params_->inputs[index].tensor != nullptr; |
1584 | } |
1585 | |
1586 | inline mutex* OpKernelContext::input_ref_mutex(int index) { |
1587 | DCHECK_GE(index, 0); |
1588 | DCHECK_LT(index, num_inputs()); |
1589 | DCHECK(input_is_ref(index)); |
1590 | return params_->inputs[index].mutex_if_ref; |
1591 | } |
1592 | |
1593 | inline Tensor* OpKernelContext::mutable_output(int index) { |
1594 | DCHECK_GE(index, 0); |
1595 | DCHECK_LT(index, num_outputs()); |
1596 | return outputs_[index].tensor; |
1597 | } |
1598 | |
1599 | inline TensorValue OpKernelContext::release_output(int index) { |
1600 | DCHECK_GE(index, 0); |
1601 | DCHECK_LT(index, num_outputs()); |
1602 | TensorValue value = outputs_[index]; |
1603 | outputs_[index] = TensorValue(); |
1604 | return value; |
1605 | } |
1606 | |
1607 | template <typename T> |
1608 | T* OpKernelContext::op_device_context() { |
1609 | static_assert(std::is_base_of<DeviceContext, T>::value, |
1610 | "T is not a subclass of DeviceContext" ); |
1611 | return static_cast<T*>(op_device_context()); |
1612 | } |
1613 | |
1614 | inline const Tensor& OpInputList::operator[](int i) const { |
1615 | DCHECK_GE(i, 0); |
1616 | DCHECK_LT(i, stop_ - start_); |
1617 | return ctx_->input(start_ + i); |
1618 | } |
1619 | |
1620 | inline mutex* OpMutableInputList::ref_mutex(int i) { |
1621 | DCHECK_GE(i, 0); |
1622 | DCHECK_LT(i, stop_ - start_); |
1623 | return ctx_->input_ref_mutex(start_ + i); |
1624 | } |
1625 | |
1626 | inline Tensor OpMutableInputList::at(int i, bool lock_held) { |
1627 | DCHECK_GE(i, 0); |
1628 | DCHECK_LT(i, stop_ - start_); |
1629 | return ctx_->mutable_input(start_ + i, lock_held); |
1630 | } |
1631 | |
1632 | inline Tensor* OpOutputList::operator[](int i) { |
1633 | DCHECK_GE(i, 0); |
1634 | DCHECK_LT(i, stop_ - start_); |
1635 | return ctx_->mutable_output(start_ + i); |
1636 | } |
1637 | |
1638 | inline bool OpOutputList::required(int i) const { |
1639 | DCHECK_GE(i, 0); |
1640 | DCHECK_LT(i, stop_ - start_); |
1641 | return ctx_->output_required(start_ + i); |
1642 | } |
1643 | |
1644 | inline DataType OpOutputList::expected_output_dtype(int i) const { |
1645 | DCHECK_GE(i, 0); |
1646 | DCHECK_LT(i, stop_ - start_); |
1647 | return ctx_->expected_output_dtype(start_ + i); |
1648 | } |
1649 | |
1650 | inline Status OpOutputList::allocate(int i, const TensorShape& shape, |
1651 | Tensor** output) { |
1652 | DCHECK_GE(i, 0); |
1653 | DCHECK_LT(i, stop_ - start_); |
1654 | return ctx_->allocate_output(start_ + i, shape, output); |
1655 | } |
1656 | |
1657 | inline void OpOutputList::set(int i, const Tensor& tensor) { |
1658 | DCHECK_GE(i, 0); |
1659 | DCHECK_LT(i, stop_ - start_); |
1660 | ctx_->set_output(start_ + i, tensor); |
1661 | } |
1662 | |
1663 | inline void OpOutputList::set(int i, Tensor&& tensor) { |
1664 | DCHECK_GE(i, 0); |
1665 | DCHECK_LT(i, stop_ - start_); |
1666 | ctx_->set_output(start_ + i, std::move(tensor)); |
1667 | } |
1668 | |
1669 | inline void OpOutputList::set_ref(int i, mutex* mu, Tensor* tensor_for_ref) { |
1670 | DCHECK_GE(i, 0); |
1671 | DCHECK_LT(i, stop_ - start_); |
1672 | ctx_->set_output_ref(i, mu, tensor_for_ref); |
1673 | } |
1674 | |
1675 | // Generate a fatal error if OP_REQUIRES or OP_REQUIRES_OK are used in |
1676 | // AsyncOpKernel implementations. If these macros are used and the condition |
1677 | // does not hold, the `done` callback will never be called and the system will |
1678 | // deadlock, so a crash failure is preferable. Since the OP_REQUIRES[_OK] macros |
1679 | // are legal to use in AsyncOpKernel constructors, we use overload resolution |
1680 | // to distinguish between OpKernelConstruction* and OpKernelContext* context |
1681 | // types. |
1682 | class XlaOpKernelContext; |
1683 | inline void CheckNotInComputeAsync(XlaOpKernelContext*, const char*) {} |
1684 | inline void CheckNotInComputeAsync(OpKernelConstruction*, const char*) {} |
1685 | void CheckNotInComputeAsync(OpKernelContext* ctx, |
1686 | const char* correct_macro_name); |
1687 | |
1688 | } // namespace tensorflow |
1689 | |
1690 | #endif // TENSORFLOW_CORE_FRAMEWORK_OP_KERNEL_H_ |
1691 | |