1/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_KERNELS_BATCH_KERNELS_H_
17#define TENSORFLOW_CORE_KERNELS_BATCH_KERNELS_H_
18
19#include "absl/types/optional.h"
20#include "tensorflow/core/framework/function.h"
21#include "tensorflow/core/framework/op_kernel.h"
22#include "tensorflow/core/platform/mutex.h"
23#include "tensorflow/core/platform/status.h"
24
25namespace tensorflow {
26
27// Per-model inflight batches parameters.
28ABSL_CONST_INIT extern const int64_t kMinInflightBatches;
29ABSL_CONST_INIT extern const int64_t kInitialInflightBatches;
30ABSL_CONST_INIT extern const int64_t kBatchesToAverageOver;
31ABSL_CONST_INIT extern const int64_t kMaxInflightBatches;
32
33namespace internal {
34class BatchFunctionKernelTestAccess;
35}
36
37// `BatchFunctionKernel` is the implementation of op `BatchFunction`.
38//
39// `BatchFunctionKernel` will batch (tensor) inputs by concatenating them
40// along the 0-th dimension, schedule a user-defined computation, and then
41// splits the returned tensors as batch output.
42//
43// In particular, an instance of `BatchFunctionKernel` creates or re-uses a
44// a batch scheduler instance based on op attributes, pre-processes and enqueues
45// concatenated inputs to the scheduler which invokes user-defined function,
46// and then splits function output as op output.
47//
48// User defined function is named by attribute `f` and defined in the graph.
49class BatchFunctionKernel : public AsyncOpKernel {
50 public:
51 explicit BatchFunctionKernel(OpKernelConstruction* c);
52
53 bool IsExpensive() override;
54
55 void ComputeAsync(OpKernelContext* c, DoneCallback done) final;
56
57 private:
58 friend class internal::BatchFunctionKernelTestAccess;
59
60 // Validates 'allowed_batch_sizes_'. The entries must increase monotonically.
61 // If large batch split is not enabled, the last one must equal
62 // `max_batch_size_`. otherwise the last element must be smaller than or equal
63 // to `max_batch_size_`.
64 Status ValidateAllowedBatchSizes() const;
65
66 // Creates the function handle if it isn't initialized yet; and re-use it
67 // afterwards.
68 Status GetOrCreateFunctionHandle(OpKernelContext* c,
69 FunctionLibraryRuntime::Handle* handle);
70
71 // Instantiate the user-defined function and emits `handle`.
72 Status InstantiateFunction(OpKernelContext* c,
73 FunctionLibraryRuntime::Handle* handle) const;
74
75 // Initialize vars by reading from op-kernel-construction.
76 // Vars
77 // - enable_adaptive_batch_threads_
78 // true if value of attribute `kEnableAdaptiveSchedulerAttr` is true, or
79 // if `num_batch_threads` is not positive.
80 // - adaptive_batch_scheduler_options_
81 // Read from corresponding attributes as long as they are set.
82 void SetAdaptiveBatchSchedulerOptions(OpKernelConstruction* c,
83 int32_t num_batch_threads);
84 string container_;
85 string shared_name_;
86 string batcher_queue_;
87 int32 num_batch_threads_;
88 int32 max_batch_size_;
89 int32 batch_timeout_micros_;
90 int32 max_enqueued_batches_;
91 std::vector<int32> allowed_batch_sizes_;
92 NameAttrList func_;
93 absl::optional<FunctionLibraryRuntime::Handle> fhandle_ TF_GUARDED_BY(mu_);
94 FunctionLibraryRuntime* flib_;
95 bool enable_large_batch_splitting_;
96 bool has_attribute_enable_large_batch_splitting_;
97 bool enable_adaptive_batch_threads_ = false;
98
99 mutex mu_;
100
101 // Parameters for adaptive batch scheduler only.
102 // Note 'num_batch_threads_' above is shared by two implementations of batch
103 // scheduler.
104 struct AdaptiveBatchSchedulerOptions {
105 int32 min_in_flight_batches_limit = kMinInflightBatches;
106 int32 initial_in_flight_batches_limit = kInitialInflightBatches;
107 int32 max_in_flight_batches_limit = kMaxInflightBatches;
108 int32 batches_to_average_over = kBatchesToAverageOver;
109 };
110 absl::optional<AdaptiveBatchSchedulerOptions>
111 adaptive_batch_scheduler_options_ = absl::nullopt;
112};
113
114} // namespace tensorflow
115
116#endif // TENSORFLOW_CORE_KERNELS_BATCH_KERNELS_H_
117