1 | /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_KERNELS_BATCH_KERNELS_H_ |
17 | #define TENSORFLOW_CORE_KERNELS_BATCH_KERNELS_H_ |
18 | |
19 | #include "absl/types/optional.h" |
20 | #include "tensorflow/core/framework/function.h" |
21 | #include "tensorflow/core/framework/op_kernel.h" |
22 | #include "tensorflow/core/platform/mutex.h" |
23 | #include "tensorflow/core/platform/status.h" |
24 | |
25 | namespace tensorflow { |
26 | |
27 | // Per-model inflight batches parameters. |
28 | ABSL_CONST_INIT extern const int64_t kMinInflightBatches; |
29 | ABSL_CONST_INIT extern const int64_t kInitialInflightBatches; |
30 | ABSL_CONST_INIT extern const int64_t kBatchesToAverageOver; |
31 | ABSL_CONST_INIT extern const int64_t kMaxInflightBatches; |
32 | |
33 | namespace internal { |
34 | class BatchFunctionKernelTestAccess; |
35 | } |
36 | |
37 | // `BatchFunctionKernel` is the implementation of op `BatchFunction`. |
38 | // |
39 | // `BatchFunctionKernel` will batch (tensor) inputs by concatenating them |
40 | // along the 0-th dimension, schedule a user-defined computation, and then |
41 | // splits the returned tensors as batch output. |
42 | // |
43 | // In particular, an instance of `BatchFunctionKernel` creates or re-uses a |
44 | // a batch scheduler instance based on op attributes, pre-processes and enqueues |
45 | // concatenated inputs to the scheduler which invokes user-defined function, |
46 | // and then splits function output as op output. |
47 | // |
48 | // User defined function is named by attribute `f` and defined in the graph. |
49 | class BatchFunctionKernel : public AsyncOpKernel { |
50 | public: |
51 | explicit BatchFunctionKernel(OpKernelConstruction* c); |
52 | |
53 | bool IsExpensive() override; |
54 | |
55 | void ComputeAsync(OpKernelContext* c, DoneCallback done) final; |
56 | |
57 | private: |
58 | friend class internal::BatchFunctionKernelTestAccess; |
59 | |
60 | // Validates 'allowed_batch_sizes_'. The entries must increase monotonically. |
61 | // If large batch split is not enabled, the last one must equal |
62 | // `max_batch_size_`. otherwise the last element must be smaller than or equal |
63 | // to `max_batch_size_`. |
64 | Status ValidateAllowedBatchSizes() const; |
65 | |
66 | // Creates the function handle if it isn't initialized yet; and re-use it |
67 | // afterwards. |
68 | Status GetOrCreateFunctionHandle(OpKernelContext* c, |
69 | FunctionLibraryRuntime::Handle* handle); |
70 | |
71 | // Instantiate the user-defined function and emits `handle`. |
72 | Status InstantiateFunction(OpKernelContext* c, |
73 | FunctionLibraryRuntime::Handle* handle) const; |
74 | |
75 | // Initialize vars by reading from op-kernel-construction. |
76 | // Vars |
77 | // - enable_adaptive_batch_threads_ |
78 | // true if value of attribute `kEnableAdaptiveSchedulerAttr` is true, or |
79 | // if `num_batch_threads` is not positive. |
80 | // - adaptive_batch_scheduler_options_ |
81 | // Read from corresponding attributes as long as they are set. |
82 | void SetAdaptiveBatchSchedulerOptions(OpKernelConstruction* c, |
83 | int32_t num_batch_threads); |
84 | string container_; |
85 | string shared_name_; |
86 | string batcher_queue_; |
87 | int32 num_batch_threads_; |
88 | int32 max_batch_size_; |
89 | int32 batch_timeout_micros_; |
90 | int32 max_enqueued_batches_; |
91 | std::vector<int32> allowed_batch_sizes_; |
92 | NameAttrList func_; |
93 | absl::optional<FunctionLibraryRuntime::Handle> fhandle_ TF_GUARDED_BY(mu_); |
94 | FunctionLibraryRuntime* flib_; |
95 | bool enable_large_batch_splitting_; |
96 | bool has_attribute_enable_large_batch_splitting_; |
97 | bool enable_adaptive_batch_threads_ = false; |
98 | |
99 | mutex mu_; |
100 | |
101 | // Parameters for adaptive batch scheduler only. |
102 | // Note 'num_batch_threads_' above is shared by two implementations of batch |
103 | // scheduler. |
104 | struct AdaptiveBatchSchedulerOptions { |
105 | int32 min_in_flight_batches_limit = kMinInflightBatches; |
106 | int32 initial_in_flight_batches_limit = kInitialInflightBatches; |
107 | int32 max_in_flight_batches_limit = kMaxInflightBatches; |
108 | int32 batches_to_average_over = kBatchesToAverageOver; |
109 | }; |
110 | absl::optional<AdaptiveBatchSchedulerOptions> |
111 | adaptive_batch_scheduler_options_ = absl::nullopt; |
112 | }; |
113 | |
114 | } // namespace tensorflow |
115 | |
116 | #endif // TENSORFLOW_CORE_KERNELS_BATCH_KERNELS_H_ |
117 | |