1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/kernels/batch_kernels.h"
17
18#include "absl/strings/str_cat.h"
19#include "tensorflow/core/common_runtime/device_mgr.h"
20#include "tensorflow/core/framework/device.h"
21#include "tensorflow/core/framework/function.h"
22#include "tensorflow/core/framework/op_kernel.h"
23#include "tensorflow/core/framework/op_requires.h"
24#include "tensorflow/core/framework/resource_mgr.h"
25#include "tensorflow/core/framework/tensor.h"
26#include "tensorflow/core/framework/tensor_shape.h"
27#include "tensorflow/core/framework/tensor_util.h"
28#include "tensorflow/core/framework/types.h"
29#include "tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h"
30#include "tensorflow/core/kernels/batching_util/batch_resource_base.h"
31#include "tensorflow/core/kernels/batching_util/bounded_executor.h"
32#include "tensorflow/core/kernels/batching_util/concat_split_util.h"
33#include "tensorflow/core/kernels/batching_util/periodic_function.h"
34#include "tensorflow/core/kernels/ops_util.h"
35#include "tensorflow/core/lib/monitoring/gauge.h"
36#include "tensorflow/core/lib/random/random.h"
37#include "tensorflow/core/platform/errors.h"
38#include "tensorflow/core/platform/logging.h"
39#include "tensorflow/core/platform/macros.h"
40#include "tensorflow/core/platform/numbers.h"
41#include "tensorflow/core/platform/threadpool.h"
42
43namespace tensorflow {
44namespace {
45// Op attributes.
46constexpr char kEnableAdaptiveSchedulerAttr[] = "_enable_adaptive_scheduler";
47constexpr char kMinInflightBatchesAttr[] = "_min_inflight_batches";
48constexpr char kInitialInflightBatchesAttr[] = "_initial_inflight_batches";
49constexpr char kMaxInflightBatchesAttr[] = "_max_inflight_batches";
50constexpr char kBatchesToAverageOverAttr[] = "_batches_to_average_over";
51
52// Default thread count in the per-process batching thread pool.
53constexpr int64_t kBatchThreadPoolSize = 128;
54} // namespace
55
56// Per-model inflight batches parameters.
57const int64_t kMinInflightBatches = 16;
58const int64_t kInitialInflightBatches = 16;
59const int64_t kBatchesToAverageOver = 10;
60const int64_t kMaxInflightBatches = 64;
61
62auto* batch_op_split_usage = monitoring::Gauge<string, 1>::New(
63 "/tensorflow/serving/batching/enable_large_batch_splitting",
64 "Tracks the usage of attribute `enable_large_batch_splitting` for "
65 "BatchFunction kernel in a saved model.",
66 "model_name");
67
68void RecordBatchSplitUsage(
69 absl::optional<bool> maybe_enable_large_batch_splitting,
70 const string& model_name) {
71 if (maybe_enable_large_batch_splitting.has_value()) {
72 if (maybe_enable_large_batch_splitting.value()) {
73 batch_op_split_usage->GetCell(model_name)->Set("true");
74 } else {
75 batch_op_split_usage->GetCell(model_name)->Set("false");
76 }
77 } else {
78 batch_op_split_usage->GetCell(model_name)->Set("unset");
79 }
80}
81
82void RecordBatchParamNumBatchThreads(int64_t num_batch_threads,
83 const string& model_name) {
84 static auto* cell = monitoring::Gauge<int64_t, 1>::New(
85 "/tensorflow/serving/batching/num_batch_threads",
86 "Tracks the number of batch threads of a model.", "model_name");
87 cell->GetCell(model_name)->Set(num_batch_threads);
88}
89
90const string& GetModelName(OpKernelContext* ctx) {
91 static string* kModelNameUnset = new string("model_name_unset");
92 if (!ctx->session_metadata()) return *kModelNameUnset;
93 if (ctx->session_metadata()->name().empty()) return *kModelNameUnset;
94 return ctx->session_metadata()->name();
95}
96
97using ::tensorflow::concat_split_util::Concat;
98using ::tensorflow::concat_split_util::Split;
99
100int32 NumBatchThreadsFromEnvironmentWithDefault(int default_num_batch_threads) {
101 int32_t num;
102 const char* val = std::getenv("TF_NUM_BATCH_THREADS");
103
104 return (val && strings::safe_strto32(val, &num)) ? num
105 : default_num_batch_threads;
106}
107
108static thread::ThreadPool* GetOrCreateBatchThreadsPool() {
109 static thread::ThreadPool* shared_thread_pool = [&]() -> thread::ThreadPool* {
110 serving::BoundedExecutor::Options options;
111
112 options.num_threads =
113 NumBatchThreadsFromEnvironmentWithDefault(kBatchThreadPoolSize);
114
115 options.thread_name = std::string("adaptive_batch_threads");
116
117 auto status_or_executor = serving::BoundedExecutor::Create(options);
118 if (!status_or_executor.ok()) {
119 LOG(WARNING) << "Failed to create a batch threads pool with error "
120 << status_or_executor.status();
121 return nullptr;
122 }
123 static serving::BoundedExecutor* executor =
124 status_or_executor.value().release();
125 return new thread::ThreadPool(executor);
126 }();
127 return shared_thread_pool;
128}
129
130// A class encapsulating the state and logic for batching tensors.
131class BatchResource : public serving::BatchResourceBase {
132 public:
133 static Status Create(int32_t num_batch_threads,
134 int32_t max_execution_batch_size,
135 int32_t batch_timeout_micros,
136 int32_t max_enqueued_batches,
137 const std::vector<int32>& allowed_batch_sizes,
138 FunctionLibraryRuntime::Handle fhandle,
139 FunctionLibraryRuntime* flib,
140 bool enable_large_batch_splitting,
141 std::unique_ptr<BatchResource>* resource) {
142 BatcherT::Options batcher_options;
143 batcher_options.num_batch_threads = num_batch_threads;
144 std::shared_ptr<BatcherT> batcher;
145 TF_RETURN_IF_ERROR(BatcherT::Create(batcher_options, &batcher));
146
147 resource->reset(new BatchResource(
148 fhandle, flib, std::move(batcher),
149 GetBatcherQueueOptions(num_batch_threads, max_execution_batch_size,
150 batch_timeout_micros, max_enqueued_batches,
151 allowed_batch_sizes,
152 enable_large_batch_splitting),
153 allowed_batch_sizes));
154 return OkStatus();
155 }
156
157 static Status Create(
158 AdaptiveBatcherT::Options adaptive_shared_batch_scheduler_options,
159 int32_t max_batch_size, int32_t batch_timeout_micros,
160 int32_t max_enqueued_batches,
161 const std::vector<int32>& allowed_batch_sizes,
162 FunctionLibraryRuntime::Handle fhandle, FunctionLibraryRuntime* flib,
163 std::unique_ptr<BatchResource>* resource) {
164 std::shared_ptr<AdaptiveBatcherT> batcher;
165 TF_RETURN_IF_ERROR(AdaptiveBatcherT::Create(
166 adaptive_shared_batch_scheduler_options, &batcher));
167
168 resource->reset(new BatchResource(
169 fhandle, flib, std::move(batcher),
170 GetAdaptiveBatcherQueueOptions(
171 max_batch_size, batch_timeout_micros, max_enqueued_batches,
172 true /* enable large batch split */, allowed_batch_sizes),
173 allowed_batch_sizes));
174 return OkStatus();
175 }
176
177 string DebugString() const final { return "BatchResource"; }
178
179 private:
180 BatchResource(FunctionLibraryRuntime::Handle fhandle,
181 FunctionLibraryRuntime* flib, std::shared_ptr<BatcherT> batcher,
182 const BatcherT::QueueOptions& batcher_queue_options,
183 std::vector<int32> allowed_batch_sizes)
184 : BatchResourceBase(
185 /*has_process_batch_function=*/fhandle != kInvalidHandle,
186 std::move(batcher), batcher_queue_options,
187 std::move(allowed_batch_sizes)),
188 fhandle_(fhandle),
189 flib_(flib) {}
190
191 BatchResource(FunctionLibraryRuntime::Handle fhandle,
192 FunctionLibraryRuntime* flib,
193 std::shared_ptr<AdaptiveBatcherT> batcher,
194 const AdaptiveBatcherT::QueueOptions& batcher_queue_options,
195 std::vector<int32> allowed_batch_sizes)
196 : BatchResourceBase(
197 /*has_process_batch_function=*/fhandle != kInvalidHandle,
198 std::move(batcher), batcher_queue_options,
199 std::move(allowed_batch_sizes)),
200 fhandle_(fhandle),
201 flib_(flib) {}
202
203 void ProcessFuncBatchImpl(
204 const BatchTask& last_task, absl::Span<const Tensor> inputs,
205 std::vector<Tensor>* combined_outputs,
206 std::function<void(const Status&)> done) const override {
207 auto* last_task_context = last_task.context;
208 FunctionLibraryRuntime::Options opts;
209 opts.step_container = last_task_context->step_container();
210 opts.cancellation_manager = last_task_context->cancellation_manager();
211 opts.collective_executor = last_task_context->collective_executor();
212 opts.stats_collector = last_task_context->stats_collector();
213 opts.runner = last_task_context->runner();
214 opts.run_all_kernels_inline = last_task_context->run_all_kernels_inline();
215 // We do not set 'opts.rendezvous', since if the function is run multiple
216 // times in parallel with the same rendezvous, a _Send node from one run
217 // might be matched with a _Recv node of a different run. Not setting the
218 // rendezvous causes a new rendezvous to be used for each run.
219 Notification done_notif;
220
221 flib_->Run(opts, fhandle_, inputs, combined_outputs,
222 [&](const Status& run_status) {
223 done(run_status);
224 done_notif.Notify();
225 });
226 // By waiting for the notification we are ensuring that this thread isn't
227 // used for processing other batches, which gives the batches time to
228 // coalesce upstream. So overall the number of batches going through the
229 // devices goes down, improving latency and throughput in most cases.
230 done_notif.WaitForNotification();
231 }
232
233 FunctionLibraryRuntime::Handle fhandle_;
234 FunctionLibraryRuntime* flib_;
235};
236
237BatchFunctionKernel::BatchFunctionKernel(OpKernelConstruction* c)
238 : AsyncOpKernel(c) {
239 OP_REQUIRES_OK(c, c->GetAttr("container", &container_));
240 OP_REQUIRES_OK(c, c->GetAttr("shared_name", &shared_name_));
241 OP_REQUIRES_OK(c, c->GetAttr("batching_queue", &batcher_queue_));
242 OP_REQUIRES_OK(c, c->GetAttr("num_batch_threads", &num_batch_threads_));
243 OP_REQUIRES_OK(c, c->GetAttr("max_batch_size", &max_batch_size_));
244 OP_REQUIRES_OK(c, c->GetAttr("batch_timeout_micros", &batch_timeout_micros_));
245 OP_REQUIRES_OK(c, c->GetAttr("max_enqueued_batches", &max_enqueued_batches_));
246 OP_REQUIRES_OK(c, c->GetAttr("allowed_batch_sizes", &allowed_batch_sizes_));
247
248 OP_REQUIRES_OK(c, c->GetAttr("f", &func_));
249 flib_ = c->function_library();
250
251 if (c->HasAttr("enable_large_batch_splitting")) {
252 OP_REQUIRES_OK(c, c->GetAttr("enable_large_batch_splitting",
253 &enable_large_batch_splitting_));
254 has_attribute_enable_large_batch_splitting_ = true;
255 } else {
256 enable_large_batch_splitting_ = false;
257 has_attribute_enable_large_batch_splitting_ = false;
258 }
259
260 // Helper function `SetAdaptiveBatchSchedulerOptions` calls
261 // `OP_REQUIRES_OK`, which exits the current function upon error.
262 // So validate status of `op-kernel-construction`.
263 SetAdaptiveBatchSchedulerOptions(c, num_batch_threads_);
264 if (!c->status().ok()) {
265 return;
266 }
267
268 if (enable_adaptive_batch_threads_) {
269 // One scheduler instance contains a couple of queue instances,
270 // `batcher_queue_` is the key to find queue for this batch-op in the
271 // graph.
272 // Use `shared_name_` and name() as prefix for `batcher_queue_`.
273 // Note name() is unique per session (from session metadata).
274 batcher_queue_ = name() + "/" + shared_name_ + batcher_queue_;
275 }
276
277 if (shared_name_.empty()) {
278 // If shared_name is not supplied, use name instead (prevent collisions by
279 // default).
280 shared_name_ = name();
281 }
282
283 OP_REQUIRES_OK(c, ValidateAllowedBatchSizes());
284}
285
286bool BatchFunctionKernel::IsExpensive() { return false; }
287
288void BatchFunctionKernel::ComputeAsync(OpKernelContext* c, DoneCallback done) {
289 RecordBatchSplitUsage(has_attribute_enable_large_batch_splitting_
290 ? absl::make_optional(enable_large_batch_splitting_)
291 : absl::nullopt,
292 GetModelName(c));
293 // TODO(b/173255290): Add num_batch_threads_ parameter to TFRT batch kernel.
294 RecordBatchParamNumBatchThreads(num_batch_threads_, GetModelName(c));
295
296 std::function<Status(BatchResource**)> creator;
297
298 FunctionLibraryRuntime::Handle handle;
299 OP_REQUIRES_OK_ASYNC(c, GetOrCreateFunctionHandle(c, &handle), done);
300
301 if (adaptive_batch_scheduler_options_ != absl::nullopt) {
302 creator = [this, handle](BatchResource** r) {
303 serving::AdaptiveSharedBatchScheduler<
304 serving::BatchResourceBase::BatchTask>::Options
305 adaptive_shared_batch_scheduler_options;
306 adaptive_shared_batch_scheduler_options.thread_pool_name =
307 "adaptive_batch_threads";
308 adaptive_shared_batch_scheduler_options.num_batch_threads =
309 adaptive_batch_scheduler_options_->max_in_flight_batches_limit;
310 adaptive_shared_batch_scheduler_options.thread_pool =
311 GetOrCreateBatchThreadsPool();
312 // adaptive_shared_batch_scheduler_options.full_batch_scheduling_boost_micros
313 // is 0 (default value) intentionally, so tasks are scheduled in a FIFO
314 // way.
315 // Two rationales to use default value (zero) for
316 // `full_batch_scheduling_boost_micros`
317 // 1) In this way, tasks scheduling policy is FIFO. Compared with round
318 // robin (what shared batch scheduler does), FIFO ensures that model
319 // with low QPS (i.e., models enqueue fewer tasks in the shared queue)
320 // will be processed timely.
321 // 2) If set, `full_batch_scheduling_boost_micros` should be of order
322 // the batch processing latency (which varies on a model basis).
323 // If a non-zero value is not set properly, it harms tail latency.
324 adaptive_shared_batch_scheduler_options.min_in_flight_batches_limit =
325 adaptive_batch_scheduler_options_->min_in_flight_batches_limit;
326 adaptive_shared_batch_scheduler_options.initial_in_flight_batches_limit =
327 adaptive_batch_scheduler_options_->initial_in_flight_batches_limit;
328 adaptive_shared_batch_scheduler_options.batches_to_average_over =
329 adaptive_batch_scheduler_options_->batches_to_average_over;
330 adaptive_shared_batch_scheduler_options.fifo_scheduling = true;
331 std::unique_ptr<BatchResource> new_resource;
332 TF_RETURN_IF_ERROR(BatchResource::Create(
333 adaptive_shared_batch_scheduler_options, max_batch_size_,
334 batch_timeout_micros_, max_enqueued_batches_, allowed_batch_sizes_,
335 handle, flib_, &new_resource));
336 *r = new_resource.release();
337 return OkStatus();
338 };
339 } else {
340 creator = [this, handle](BatchResource** r) {
341 std::unique_ptr<BatchResource> new_resource;
342 TF_RETURN_IF_ERROR(BatchResource::Create(
343 num_batch_threads_, max_batch_size_, batch_timeout_micros_,
344 max_enqueued_batches_, allowed_batch_sizes_, handle, flib_,
345 enable_large_batch_splitting_, &new_resource));
346 *r = new_resource.release();
347 return OkStatus();
348 };
349 }
350
351 BatchResource* br;
352 OP_REQUIRES_OK_ASYNC(c,
353 c->resource_manager()->LookupOrCreate(
354 container_, shared_name_, &br, creator),
355 done);
356 const Status status =
357 br->RegisterInput(random::New64(), c, batcher_queue_, done);
358 br->Unref();
359 OP_REQUIRES_OK_ASYNC(c, status, done);
360 // Assume br calls done, so nothing to do here.
361}
362
363Status BatchFunctionKernel::InstantiateFunction(
364 OpKernelContext* c, FunctionLibraryRuntime::Handle* handle) const {
365 // TODO(b/173748062): Merge this instantiation logic with PartitionedCall.
366 if (!flib_) {
367 return errors::Internal("No function library");
368 }
369
370 FunctionLibraryRuntime::InstantiateOptions opts;
371 opts.target = flib_->device() == nullptr ? "" : flib_->device()->name();
372 opts.is_multi_device_function = true;
373 const ConfigProto* config = flib_->config_proto();
374 if (config) {
375 opts.config_proto = *config;
376 }
377
378 Device* cpu_device;
379 TF_RETURN_IF_ERROR(flib_->device_mgr()->LookupDevice("CPU:0", &cpu_device));
380
381 const FunctionDef* fdef =
382 flib_->GetFunctionLibraryDefinition()->Find(func_.name());
383 if (!fdef) {
384 return errors::NotFound("Failed to find definition for function \"",
385 func_.name(), "\"");
386 }
387 OpInputList in_tensors;
388 TF_RETURN_IF_ERROR(c->input_list("in_tensors", &in_tensors));
389 for (int i = 0; i < in_tensors.size(); i++) {
390 if (in_tensors[i].dtype() == DT_RESOURCE) {
391 return errors::InvalidArgument(
392 "BatchFunction cannot take resource inputs but input ", i,
393 " is a resource.");
394 } else {
395 // Currently, inputs are on CPU since they are concatenated on CPU
396 opts.input_devices.push_back(cpu_device->name());
397 }
398 }
399 OpInputList captured_tensors;
400 TF_RETURN_IF_ERROR(c->input_list("captured_tensors", &captured_tensors));
401 for (const Tensor& t : captured_tensors) {
402 if (t.dtype() == DT_RESOURCE) {
403 const ResourceHandle& rhandle = t.flat<ResourceHandle>()(0);
404 opts.input_devices.push_back(rhandle.device());
405 } else {
406 opts.input_devices.push_back(cpu_device->name());
407 }
408 }
409 const OpDef& signature = fdef->signature();
410 for (int i = 0; i < signature.output_arg_size(); i++) {
411 // Currently, outputs must be on CPU since they are split on CPU.
412 opts.output_devices.push_back(cpu_device->name());
413 }
414 if (opts.input_devices.size() != signature.input_arg_size()) {
415 return errors::InvalidArgument(
416 "Function takes ", signature.input_arg_size(), " argument(s) but ",
417 opts.input_devices.size(), " argument(s) were passed");
418 }
419 return flib_->Instantiate(func_.name(), AttrSlice(&func_.attr()), opts,
420 handle);
421}
422
423Status BatchFunctionKernel::GetOrCreateFunctionHandle(
424 OpKernelContext* c, FunctionLibraryRuntime::Handle* handle) {
425 mutex_lock ml(mu_);
426 if (!fhandle_) {
427 TF_RETURN_IF_ERROR(InstantiateFunction(c, handle));
428 fhandle_ = *handle;
429 } else {
430 *handle = fhandle_.value();
431 }
432 return OkStatus();
433}
434
435// Validates 'allowed_batch_sizes_'. The entries must increase monotonically.
436// If large batch split is not enabled, the last one must equal
437// `max_batch_size_`. otherwise the last element must be smaller than or equal
438// to `max_batch_size_`.
439Status BatchFunctionKernel::ValidateAllowedBatchSizes() const {
440 if (allowed_batch_sizes_.empty()) {
441 return OkStatus();
442 }
443 int32_t last_size = 0;
444 for (size_t i = 0; i < allowed_batch_sizes_.size(); ++i) {
445 const int32_t size = allowed_batch_sizes_.at(i);
446 if (i > 0 && size <= last_size) {
447 return errors::InvalidArgument(
448 "allowed_batch_sizes entries must be monotonically increasing");
449 }
450
451 if ((!enable_large_batch_splitting_) &&
452 (i == allowed_batch_sizes_.size() - 1) && (size != max_batch_size_)) {
453 return errors::InvalidArgument(
454 "final entry in allowed_batch_sizes must equal max_batch_size when "
455 "enable_large_batch_splitting is False");
456 }
457
458 last_size = size;
459 }
460 return OkStatus();
461}
462
463// Initialize vars by reading from op-kernel-construction.
464// Vars
465// - enable_adaptive_batch_threads_
466// true if value of attribute `kEnableAdaptiveSchedulerAttr` is true, or
467// if `num_batch_threads` is not positive.
468// - adaptive_batch_scheduler_options_
469// Read from corresponding attributes as long as they are set.
470void BatchFunctionKernel::SetAdaptiveBatchSchedulerOptions(
471 OpKernelConstruction* c, int32_t num_batch_threads) {
472 if (c->HasAttr(kEnableAdaptiveSchedulerAttr)) {
473 OP_REQUIRES_OK(c, c->GetAttr(kEnableAdaptiveSchedulerAttr,
474 &enable_adaptive_batch_threads_));
475 }
476
477 if (num_batch_threads <= 0) {
478 enable_adaptive_batch_threads_ = true;
479 }
480
481 if (!enable_adaptive_batch_threads_) {
482 // adaptive_batch_scheduler_options_ is nullopt.
483 return;
484 }
485
486 // adaptive_batch_scheduler_options_ is not nullopt
487 AdaptiveBatchSchedulerOptions options;
488
489 if (c->HasAttr(kBatchesToAverageOverAttr)) {
490 OP_REQUIRES_OK(c, c->GetAttr(kBatchesToAverageOverAttr,
491 &options.batches_to_average_over));
492 }
493
494 if (c->HasAttr(kMinInflightBatchesAttr)) {
495 OP_REQUIRES_OK(c, c->GetAttr(kMinInflightBatchesAttr,
496 &options.min_in_flight_batches_limit));
497 }
498
499 if (c->HasAttr(kInitialInflightBatchesAttr)) {
500 OP_REQUIRES_OK(c, c->GetAttr(kInitialInflightBatchesAttr,
501 &options.initial_in_flight_batches_limit));
502 }
503
504 if (c->HasAttr(kMaxInflightBatchesAttr)) {
505 OP_REQUIRES_OK(c, c->GetAttr(kMaxInflightBatchesAttr,
506 &options.max_in_flight_batches_limit));
507 }
508
509 // At this point, the batch kernel is configured to use adaptive scheduling.
510 // To validate or return error at kernel construction time, invokes
511 // `GetOrCreateBatchThreadsPool` and validates returned `thread_pool` is
512 // valid.
513 // Note`GetOrCreateBatchThreadsPool` creates the thread pool once and
514 // re-uses the thread-pool instance afterwards.
515 thread::ThreadPool* thread_pool = GetOrCreateBatchThreadsPool();
516 OP_REQUIRES(
517 c, thread_pool != nullptr,
518 errors::FailedPrecondition("Failed to create batch threads pool"));
519
520 adaptive_batch_scheduler_options_ = options;
521}
522REGISTER_KERNEL_BUILDER(Name("BatchFunction").Device(DEVICE_CPU),
523 BatchFunctionKernel);
524// Currently all inputs and outputs are on the host.
525// TODO(b/173748277): Accept inputs/outputs on the device.
526REGISTER_KERNEL_BUILDER(Name("BatchFunction")
527 .Device(DEVICE_GPU)
528 .HostMemory("in_tensors")
529 .HostMemory("captured_tensors")
530 .HostMemory("out_tensors"),
531 BatchFunctionKernel);
532REGISTER_KERNEL_BUILDER(Name("BatchFunction")
533 .Device(DEVICE_DEFAULT)
534 .HostMemory("in_tensors")
535 .HostMemory("captured_tensors")
536 .HostMemory("out_tensors"),
537 BatchFunctionKernel);
538
539class BatchKernel : public AsyncOpKernel {
540 public:
541 explicit BatchKernel(OpKernelConstruction* c) : AsyncOpKernel(c) {
542 OP_REQUIRES_OK(c, c->GetAttr("container", &container_));
543 OP_REQUIRES_OK(c, c->GetAttr("shared_name", &shared_name_));
544 // If shared_name is not supplied, use name instead (prevent collisions by
545 // default).
546 if (shared_name_.empty()) {
547 shared_name_ = name();
548 }
549 OP_REQUIRES_OK(c, c->GetAttr("batching_queue", &batcher_queue_));
550 OP_REQUIRES_OK(c, c->GetAttr("num_batch_threads", &num_batch_threads_));
551 OP_REQUIRES_OK(c, c->GetAttr("max_batch_size", &max_batch_size_));
552 OP_REQUIRES_OK(c,
553 c->GetAttr("batch_timeout_micros", &batch_timeout_micros_));
554 OP_REQUIRES_OK(c,
555 c->GetAttr("max_enqueued_batches", &max_enqueued_batches_));
556 OP_REQUIRES_OK(c, c->GetAttr("allowed_batch_sizes", &allowed_batch_sizes_));
557 OP_REQUIRES_OK(c, ValidateAllowedBatchSizes());
558 }
559
560 void ComputeAsync(OpKernelContext* c, DoneCallback done) final {
561 BatchResource* br;
562 std::function<Status(BatchResource**)> creator = [this](BatchResource** r) {
563 std::unique_ptr<BatchResource> new_resource;
564 TF_RETURN_IF_ERROR(BatchResource::Create(
565 num_batch_threads_, max_batch_size_, batch_timeout_micros_,
566 max_enqueued_batches_, allowed_batch_sizes_, kInvalidHandle,
567 /*flib=*/nullptr, false, &new_resource));
568 *r = new_resource.release();
569 return OkStatus();
570 };
571 OP_REQUIRES_OK_ASYNC(c,
572 c->resource_manager()->LookupOrCreate(
573 container_, shared_name_, &br, creator),
574 done);
575 const Status status =
576 br->RegisterInput(random::New64(), c, batcher_queue_, done);
577 br->Unref();
578 OP_REQUIRES_OK_ASYNC(c, status, done);
579 // Assume br calls done, so nothing to do here.
580 }
581
582 // Validates 'allowed_batch_sizes_'. The entries must increase
583 // monotonically, and the last one must equal 'max_batch_size_'.
584 Status ValidateAllowedBatchSizes() const {
585 if (allowed_batch_sizes_.empty()) {
586 return OkStatus();
587 }
588 int32_t last_size = 0;
589 for (size_t i = 0; i < allowed_batch_sizes_.size(); ++i) {
590 const int32_t size = allowed_batch_sizes_.at(i);
591 if (i > 0 && size <= last_size) {
592 return errors::InvalidArgument(
593 "allowed_batch_sizes entries must be monotonically increasing");
594 }
595 if (i == allowed_batch_sizes_.size() - 1 && size != max_batch_size_) {
596 return errors::InvalidArgument(
597 "final entry in allowed_batch_sizes must equal max_batch_size");
598 }
599 last_size = size;
600 }
601 return OkStatus();
602 }
603
604 private:
605 string container_;
606 string shared_name_;
607 string batcher_queue_;
608 int32 num_batch_threads_;
609 int32 max_batch_size_;
610 int32 batch_timeout_micros_;
611 int32 max_enqueued_batches_;
612 std::vector<int32> allowed_batch_sizes_;
613};
614
615REGISTER_KERNEL_BUILDER(Name("Batch").Device(DEVICE_CPU), BatchKernel);
616
617// A class encapsulating the state and logic for unbatching tensors.
618//
619// UnbatchResource keeps two data structures indexed by batch-key: one which has
620// the continuations for all concurrent kernels which are waiting for tensors
621// and another which has tensors which are waiting for their corresponding
622// kernels to run. Whenever a kernel runs, we either grab its tensor if it's
623// waiting already, or we insert it in the queue and then look at its tensor to
624// see if it can be used to dispatch any stored continuations.
625class UnbatchResource : public ResourceBase {
626 public:
627 explicit UnbatchResource(int32_t timeout_micros)
628 : timeout_micros_(timeout_micros),
629 timeout_enforcer_(new serving::PeriodicFunction(
630 [this] { EnforceTimeout(); }, 1000 /* 1 ms */)) {}
631
632 ~UnbatchResource() override {
633 // Tear down 'timeout_enforcer_' first, since it accesses other state in
634 // this class.
635 timeout_enforcer_ = nullptr;
636 }
637
638 string DebugString() const final { return "UnbatchResource"; }
639
640 Status Compute(OpKernelContext* context, AsyncOpKernel::DoneCallback done) {
641 const Tensor& data_t = context->input(0);
642 const Tensor& batch_index_t = context->input(1);
643
644 if (batch_index_t.shape().dim_size(0) > data_t.shape().dim_size(0)) {
645 return errors::InvalidArgument(
646 "Wrong shape for index tensor. Expected 0th dimension size to be no "
647 "greater than ",
648 data_t.shape().dim_size(0),
649 "; Got: ", batch_index_t.shape().dim_size(0), ".");
650 }
651 if (batch_index_t.shape().dim_size(1) != 3) {
652 return errors::InvalidArgument(
653 "Wrong shape for index tensor. Expected 1st dimension size to be 3 ; "
654 "Got: ",
655 batch_index_t.shape().dim_size(1), ".");
656 }
657
658 if (!TensorShapeUtils::IsScalar(context->input(2).shape())) {
659 return errors::InvalidArgument(
660 "Input id should be scalar; "
661 "Got: ",
662 context->input(2).DebugString(), ".");
663 }
664 const int64_t batch_key = context->input(2).scalar<int64_t>()();
665 const bool nonempty_input = batch_index_t.dim_size(0) > 0;
666
667 // If we have a non-empty tensor, slice it up.
668 // (It is important to do this outside of the critical section below.)
669 // The following variables are populated iff 'nonempty_input==true'.
670 std::vector<int64_t> sizes;
671 std::vector<int64_t> batch_keys;
672 std::vector<Tensor> split_inputs;
673 if (nonempty_input) {
674 auto batch_indices =
675 batch_index_t.shaped<int64_t, 2>({batch_index_t.dim_size(0), 3});
676 for (int i = 0; i < batch_index_t.dim_size(0); ++i) {
677 sizes.push_back(batch_indices(i, 2) - batch_indices(i, 1));
678 batch_keys.push_back(batch_indices(i, 0));
679 }
680
681 TF_RETURN_IF_ERROR(Split(context, data_t, sizes, &split_inputs));
682 }
683
684 // Critical section.
685 std::vector<AsyncOpKernel::DoneCallback> done_callbacks_to_call;
686 Status status = [&]() -> Status {
687 mutex_lock ml(mu_);
688
689 // Check to see whether the tensor we want is already ready.
690 auto tensor_it = waiting_tensors_.find(batch_key);
691 if (tensor_it != waiting_tensors_.end()) {
692 context->set_output(0, tensor_it->second.tensor);
693 waiting_tensors_.erase(tensor_it);
694 done_callbacks_to_call.push_back(done);
695 return OkStatus();
696 }
697
698 const uint64 deadline_micros =
699 Env::Default()->NowMicros() + timeout_micros_;
700
701 // Add ourselves to the waitlist for tensors.
702 if (!waiting_callbacks_
703 .emplace(batch_key,
704 WaitingCallback{deadline_micros, context, done})
705 .second) {
706 return errors::AlreadyExists(
707 "Multiple session runs with the same batch key.");
708 }
709
710 // If we have a non-empty tensor, finish the waitlisted runs,
711 // and store any remaining pieces.
712 if (nonempty_input) {
713 for (size_t i = 0; i < batch_keys.size(); ++i) {
714 auto runs_it = waiting_callbacks_.find(batch_keys[i]);
715 if (runs_it != waiting_callbacks_.end()) {
716 runs_it->second.context->set_output(0, split_inputs[i]);
717 done_callbacks_to_call.push_back(runs_it->second.done);
718 waiting_callbacks_.erase(runs_it);
719 } else {
720 // Note: the deadline here is in case we are arriving late and the
721 // kernel that should rendezvous with this tensor has already waited
722 // and timed out.
723 if (!waiting_tensors_
724 .emplace(batch_keys[i],
725 WaitingTensor{deadline_micros, split_inputs[i]})
726 .second) {
727 return errors::AlreadyExists(
728 "Multiple tensors returned for same batch key.");
729 }
730 }
731 }
732 }
733
734 return OkStatus();
735 }();
736
737 for (const AsyncOpKernel::DoneCallback& done_callback :
738 done_callbacks_to_call) {
739 done_callback();
740 }
741
742 return status;
743 }
744
745 private:
746 // Evicts waiting tensors and callbacks that have exceeded their deadline.
747 void EnforceTimeout() {
748 const uint64 now = Env::Default()->NowMicros();
749 std::vector<WaitingCallback> evicted_callbacks;
750
751 {
752 mutex_lock ml(mu_);
753
754 for (auto it = waiting_tensors_.begin(); it != waiting_tensors_.end();) {
755 const WaitingTensor& waiting_tensor = it->second;
756 if (waiting_tensor.deadline_micros < now) {
757 it = waiting_tensors_.erase(it);
758 } else {
759 ++it;
760 }
761 }
762
763 for (auto it = waiting_callbacks_.begin();
764 it != waiting_callbacks_.end();) {
765 const WaitingCallback& waiting_callback = it->second;
766 if (waiting_callback.deadline_micros < now) {
767 evicted_callbacks.push_back(waiting_callback);
768 it = waiting_callbacks_.erase(it);
769 } else {
770 ++it;
771 }
772 }
773 }
774
775 for (const WaitingCallback& evicted_callback : evicted_callbacks) {
776 evicted_callback.context->CtxFailureWithWarning(errors::DeadlineExceeded(
777 "Batched data did not arrive within timeout window."));
778 evicted_callback.done();
779 }
780 }
781
782 struct WaitingTensor {
783 uint64 deadline_micros;
784 Tensor tensor;
785 };
786
787 struct WaitingCallback {
788 uint64 deadline_micros;
789 OpKernelContext* context;
790 AsyncOpKernel::DoneCallback done;
791 };
792
793 const int32 timeout_micros_;
794
795 mutex mu_;
796
797 // Maps keyed by BatchKey of tensors waiting for callbacks and callbacks
798 // waiting for tensors.
799 std::unordered_map<int64_t, WaitingTensor> waiting_tensors_
800 TF_GUARDED_BY(mu_);
801 std::unordered_map<int64_t, WaitingCallback> waiting_callbacks_
802 TF_GUARDED_BY(mu_);
803
804 // A thread that evicts waiting tensors and callbacks that have exceeded their
805 // deadline.
806 std::unique_ptr<serving::PeriodicFunction> timeout_enforcer_;
807};
808
809class UnbatchKernel : public AsyncOpKernel {
810 public:
811 explicit UnbatchKernel(OpKernelConstruction* c) : AsyncOpKernel(c) {
812 OP_REQUIRES_OK(c, c->GetAttr("container", &container_));
813 OP_REQUIRES_OK(c, c->GetAttr("shared_name", &shared_name_));
814 // If shared_name is not supplied, use name instead (prevent collisions by
815 // default).
816 if (shared_name_.empty()) {
817 shared_name_ = name();
818 }
819 OP_REQUIRES_OK(c, c->GetAttr("timeout_micros", &timeout_micros_));
820 }
821
822 void ComputeAsync(OpKernelContext* c, DoneCallback done) final {
823 UnbatchResource* ubr;
824 std::function<Status(UnbatchResource**)> creator =
825 [this](UnbatchResource** r) {
826 *r = new UnbatchResource(timeout_micros_);
827 return OkStatus();
828 };
829 OP_REQUIRES_OK_ASYNC(c,
830 c->resource_manager()->LookupOrCreate(
831 container_, shared_name_, &ubr, creator),
832 done);
833 auto status = ubr->Compute(c, done);
834 ubr->Unref();
835 OP_REQUIRES_OK_ASYNC(c, status, done);
836 // Assume ubr calls done, so nothing to do here.
837 }
838
839 private:
840 string container_;
841 string shared_name_;
842 int32 timeout_micros_;
843};
844REGISTER_KERNEL_BUILDER(Name("Unbatch").Device(DEVICE_CPU), UnbatchKernel);
845
846// A class encapsulating the state and logic for batching tensors
847// deterministically for the gradient of unbatch.
848class UnbatchGradResource : public ResourceBase {
849 public:
850 UnbatchGradResource() {}
851
852 string DebugString() const final { return "UnbatchGradResource"; }
853
854 // Flushes the information for one batch, given its context and done
855 // callback. Clears all information about it from the available_tensors_.
856 Status OutputBatch(OpKernelContext* context,
857 const AsyncOpKernel::DoneCallback& done)
858 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
859 const Tensor& batch_index_t = context->input(1);
860 auto batch_index =
861 batch_index_t.shaped<int64_t, 2>({batch_index_t.dim_size(0), 3});
862 std::vector<Tensor> tensors;
863 for (int i = 0; i < batch_index_t.dim_size(0); ++i) {
864 auto available_it = available_tensors_.find(batch_index(i, 0));
865 if (available_it == available_tensors_.end()) {
866 return errors::Internal("bad bookkeeping of available tensors.");
867 }
868 tensors.push_back(available_it->second);
869 available_tensors_.erase(available_it);
870 }
871
872 const DataType type = tensors[0].dtype();
873 Tensor concatenated_tensor;
874 switch (type) {
875#define CASE(type) \
876 case DataTypeToEnum<type>::value: \
877 TF_RETURN_IF_ERROR(Concat<type>(context, tensors, &concatenated_tensor)); \
878 context->set_output(0, concatenated_tensor); \
879 break;
880 TF_CALL_ALL_TYPES(CASE);
881#undef CASE
882 default:
883 return errors::InvalidArgument("Unsupported data type: ", type);
884 }
885 done();
886 return OkStatus();
887 }
888
889 // Ingests data from one invocation of the op.
890 Status Compute(OpKernelContext* context,
891 const AsyncOpKernel::DoneCallback& done) {
892 const Tensor& data_t = context->input(0);
893 const Tensor& batch_index_t = context->input(1);
894 const Tensor& grad_t = context->input(2);
895 const Tensor& batch_key_t = context->input(3);
896
897 mutex_lock ml(mu_);
898 if (batch_key_t.NumElements() != 1) {
899 return errors::InvalidArgument("Expected `id` to be scalar. Received ",
900 batch_key_t.DebugString());
901 }
902
903 const int64_t batch_key = context->input(3).scalar<int64_t>()();
904 // Mark our tensor as available.
905 if (!available_tensors_.emplace(batch_key, grad_t).second) {
906 return errors::InvalidArgument("Two runs with the same batch key.");
907 }
908
909 // Check whether we have a valid input tensor and, if so, create its
910 // dispatch logic.
911 if (data_t.NumElements() > 0) {
912 if (batch_index_t.NumElements() == 0) {
913 return errors::InvalidArgument(
914 "batch_index is empty while the tensor isn't.");
915 }
916 std::unordered_set<int64_t> missing_tensors;
917 if (batch_index_t.NumElements() != batch_index_t.dim_size(0) * 3) {
918 return errors::InvalidArgument(
919 "batch_index should contain ", batch_index_t.dim_size(0) * 3,
920 " elements. Received ", batch_index_t.NumElements());
921 }
922 const auto batch_index =
923 batch_index_t.shaped<int64_t, 2>({batch_index_t.dim_size(0), 3});
924 for (int i = 0; i < batch_index_t.dim_size(0); ++i) {
925 const int64_t batch_key = batch_index(i, 0);
926 if (available_tensors_.find(batch_key) == available_tensors_.end()) {
927 missing_tensors.emplace(batch_key);
928 }
929 }
930 if (missing_tensors.empty()) {
931 return OutputBatch(context, done);
932 }
933 if (!available_batches_
934 .emplace(batch_key, Batch{missing_tensors, context, done})
935 .second) {
936 return errors::InvalidArgument(
937 "Batch key with valid batch used twice.");
938 }
939 for (const int64_t i : missing_tensors) {
940 if (!desired_tensor_to_batch_map_.emplace(i, batch_key).second) {
941 return errors::InvalidArgument(
942 "Missing tensor wanted by more than one batch.");
943 }
944 }
945 } else {
946 // If we don't have a valid input tensor we can output an empty tensor and
947 // call our done closure.
948 TensorShape output_shape(grad_t.shape());
949 output_shape.set_dim(0, 0);
950 Tensor* output = nullptr;
951 TF_RETURN_IF_ERROR(context->allocate_output(0, output_shape, &output));
952 done();
953 }
954
955 // Search to see whether our tensor is desired by any existing batch.
956 auto desire_it = desired_tensor_to_batch_map_.find(batch_key);
957 if (desire_it != desired_tensor_to_batch_map_.end()) {
958 // Mark our tensor as no longer missing.
959 auto batch_it = available_batches_.find(desire_it->second);
960 desired_tensor_to_batch_map_.erase(desire_it);
961 if (batch_it == available_batches_.end()) {
962 return errors::InvalidArgument("Batch no longer exists.");
963 }
964 batch_it->second.missing_tensors.erase(batch_key);
965 // If all tensors are available we should concatenate them and dispatch
966 // the batch.
967 if (batch_it->second.missing_tensors.empty()) {
968 TF_RETURN_IF_ERROR(
969 OutputBatch(batch_it->second.context, batch_it->second.done));
970 available_batches_.erase(batch_it);
971 }
972 }
973 return OkStatus();
974 }
975
976 private:
977 mutex mu_;
978
979 // Represents a still-incomplete batch of tensors. When all tensors become
980 // available they will be concatenated in the right order and sent through the
981 // context.
982 struct Batch {
983 // Batch keys for tensors which are still missing from this batch. When this
984 // is empty the Tensors can be concatenated and forwarded.
985 std::unordered_set<int64_t> missing_tensors;
986
987 // Context and callback for the session responsible for finishing this
988 // batch.
989 OpKernelContext* context;
990 AsyncOpKernel::DoneCallback done;
991 };
992
993 // Map from batch key of the session which will output the batched gradients
994 // to still-incomplete batches.
995 std::unordered_map<int64_t, Batch> available_batches_;
996
997 // Map from batch key to tensors which are waiting for their batches to be
998 // available.
999 std::unordered_map<int64_t, Tensor> available_tensors_;
1000
1001 // Map from batch key of a tensor which is not yet available to the batch key
1002 // of the batch to which it belongs.
1003 std::unordered_map<int64_t, int64_t> desired_tensor_to_batch_map_;
1004};
1005
1006class UnbatchGradKernel : public AsyncOpKernel {
1007 public:
1008 explicit UnbatchGradKernel(OpKernelConstruction* c) : AsyncOpKernel(c) {
1009 OP_REQUIRES_OK(c, c->GetAttr("container", &container_));
1010 OP_REQUIRES_OK(c, c->GetAttr("shared_name", &shared_name_));
1011 // If shared_name is not supplied, use name instead (prevent collisions by
1012 // default).
1013 if (shared_name_.empty()) {
1014 shared_name_ = name();
1015 }
1016 }
1017
1018 void ComputeAsync(OpKernelContext* c, DoneCallback done) final {
1019 UnbatchGradResource* ubr;
1020 std::function<Status(UnbatchGradResource**)> creator =
1021 [](UnbatchGradResource** r) {
1022 *r = new UnbatchGradResource();
1023 return OkStatus();
1024 };
1025 OP_REQUIRES_OK_ASYNC(c,
1026 c->resource_manager()->LookupOrCreate(
1027 container_, shared_name_, &ubr, creator),
1028 done);
1029 Status status = ubr->Compute(c, done);
1030 ubr->Unref();
1031 OP_REQUIRES_OK_ASYNC(c, status, done);
1032 // Assume ubr calls done, so nothing to do here.
1033 }
1034
1035 private:
1036 string container_;
1037 string shared_name_;
1038};
1039REGISTER_KERNEL_BUILDER(Name("UnbatchGrad").Device(DEVICE_CPU),
1040 UnbatchGradKernel);
1041
1042} // namespace tensorflow
1043