1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/kernels/batch_kernels.h" |
17 | |
18 | #include "absl/strings/str_cat.h" |
19 | #include "tensorflow/core/common_runtime/device_mgr.h" |
20 | #include "tensorflow/core/framework/device.h" |
21 | #include "tensorflow/core/framework/function.h" |
22 | #include "tensorflow/core/framework/op_kernel.h" |
23 | #include "tensorflow/core/framework/op_requires.h" |
24 | #include "tensorflow/core/framework/resource_mgr.h" |
25 | #include "tensorflow/core/framework/tensor.h" |
26 | #include "tensorflow/core/framework/tensor_shape.h" |
27 | #include "tensorflow/core/framework/tensor_util.h" |
28 | #include "tensorflow/core/framework/types.h" |
29 | #include "tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h" |
30 | #include "tensorflow/core/kernels/batching_util/batch_resource_base.h" |
31 | #include "tensorflow/core/kernels/batching_util/bounded_executor.h" |
32 | #include "tensorflow/core/kernels/batching_util/concat_split_util.h" |
33 | #include "tensorflow/core/kernels/batching_util/periodic_function.h" |
34 | #include "tensorflow/core/kernels/ops_util.h" |
35 | #include "tensorflow/core/lib/monitoring/gauge.h" |
36 | #include "tensorflow/core/lib/random/random.h" |
37 | #include "tensorflow/core/platform/errors.h" |
38 | #include "tensorflow/core/platform/logging.h" |
39 | #include "tensorflow/core/platform/macros.h" |
40 | #include "tensorflow/core/platform/numbers.h" |
41 | #include "tensorflow/core/platform/threadpool.h" |
42 | |
43 | namespace tensorflow { |
44 | namespace { |
45 | // Op attributes. |
46 | constexpr char kEnableAdaptiveSchedulerAttr[] = "_enable_adaptive_scheduler" ; |
47 | constexpr char kMinInflightBatchesAttr[] = "_min_inflight_batches" ; |
48 | constexpr char kInitialInflightBatchesAttr[] = "_initial_inflight_batches" ; |
49 | constexpr char kMaxInflightBatchesAttr[] = "_max_inflight_batches" ; |
50 | constexpr char kBatchesToAverageOverAttr[] = "_batches_to_average_over" ; |
51 | |
52 | // Default thread count in the per-process batching thread pool. |
53 | constexpr int64_t kBatchThreadPoolSize = 128; |
54 | } // namespace |
55 | |
56 | // Per-model inflight batches parameters. |
57 | const int64_t kMinInflightBatches = 16; |
58 | const int64_t kInitialInflightBatches = 16; |
59 | const int64_t kBatchesToAverageOver = 10; |
60 | const int64_t kMaxInflightBatches = 64; |
61 | |
62 | auto* batch_op_split_usage = monitoring::Gauge<string, 1>::New( |
63 | "/tensorflow/serving/batching/enable_large_batch_splitting" , |
64 | "Tracks the usage of attribute `enable_large_batch_splitting` for " |
65 | "BatchFunction kernel in a saved model." , |
66 | "model_name" ); |
67 | |
68 | void RecordBatchSplitUsage( |
69 | absl::optional<bool> maybe_enable_large_batch_splitting, |
70 | const string& model_name) { |
71 | if (maybe_enable_large_batch_splitting.has_value()) { |
72 | if (maybe_enable_large_batch_splitting.value()) { |
73 | batch_op_split_usage->GetCell(model_name)->Set("true" ); |
74 | } else { |
75 | batch_op_split_usage->GetCell(model_name)->Set("false" ); |
76 | } |
77 | } else { |
78 | batch_op_split_usage->GetCell(model_name)->Set("unset" ); |
79 | } |
80 | } |
81 | |
82 | void RecordBatchParamNumBatchThreads(int64_t num_batch_threads, |
83 | const string& model_name) { |
84 | static auto* cell = monitoring::Gauge<int64_t, 1>::New( |
85 | "/tensorflow/serving/batching/num_batch_threads" , |
86 | "Tracks the number of batch threads of a model." , "model_name" ); |
87 | cell->GetCell(model_name)->Set(num_batch_threads); |
88 | } |
89 | |
90 | const string& GetModelName(OpKernelContext* ctx) { |
91 | static string* kModelNameUnset = new string("model_name_unset" ); |
92 | if (!ctx->session_metadata()) return *kModelNameUnset; |
93 | if (ctx->session_metadata()->name().empty()) return *kModelNameUnset; |
94 | return ctx->session_metadata()->name(); |
95 | } |
96 | |
97 | using ::tensorflow::concat_split_util::Concat; |
98 | using ::tensorflow::concat_split_util::Split; |
99 | |
100 | int32 NumBatchThreadsFromEnvironmentWithDefault(int default_num_batch_threads) { |
101 | int32_t num; |
102 | const char* val = std::getenv("TF_NUM_BATCH_THREADS" ); |
103 | |
104 | return (val && strings::safe_strto32(val, &num)) ? num |
105 | : default_num_batch_threads; |
106 | } |
107 | |
108 | static thread::ThreadPool* GetOrCreateBatchThreadsPool() { |
109 | static thread::ThreadPool* shared_thread_pool = [&]() -> thread::ThreadPool* { |
110 | serving::BoundedExecutor::Options options; |
111 | |
112 | options.num_threads = |
113 | NumBatchThreadsFromEnvironmentWithDefault(kBatchThreadPoolSize); |
114 | |
115 | options.thread_name = std::string("adaptive_batch_threads" ); |
116 | |
117 | auto status_or_executor = serving::BoundedExecutor::Create(options); |
118 | if (!status_or_executor.ok()) { |
119 | LOG(WARNING) << "Failed to create a batch threads pool with error " |
120 | << status_or_executor.status(); |
121 | return nullptr; |
122 | } |
123 | static serving::BoundedExecutor* executor = |
124 | status_or_executor.value().release(); |
125 | return new thread::ThreadPool(executor); |
126 | }(); |
127 | return shared_thread_pool; |
128 | } |
129 | |
130 | // A class encapsulating the state and logic for batching tensors. |
131 | class BatchResource : public serving::BatchResourceBase { |
132 | public: |
133 | static Status Create(int32_t num_batch_threads, |
134 | int32_t max_execution_batch_size, |
135 | int32_t batch_timeout_micros, |
136 | int32_t max_enqueued_batches, |
137 | const std::vector<int32>& allowed_batch_sizes, |
138 | FunctionLibraryRuntime::Handle fhandle, |
139 | FunctionLibraryRuntime* flib, |
140 | bool enable_large_batch_splitting, |
141 | std::unique_ptr<BatchResource>* resource) { |
142 | BatcherT::Options batcher_options; |
143 | batcher_options.num_batch_threads = num_batch_threads; |
144 | std::shared_ptr<BatcherT> batcher; |
145 | TF_RETURN_IF_ERROR(BatcherT::Create(batcher_options, &batcher)); |
146 | |
147 | resource->reset(new BatchResource( |
148 | fhandle, flib, std::move(batcher), |
149 | GetBatcherQueueOptions(num_batch_threads, max_execution_batch_size, |
150 | batch_timeout_micros, max_enqueued_batches, |
151 | allowed_batch_sizes, |
152 | enable_large_batch_splitting), |
153 | allowed_batch_sizes)); |
154 | return OkStatus(); |
155 | } |
156 | |
157 | static Status Create( |
158 | AdaptiveBatcherT::Options adaptive_shared_batch_scheduler_options, |
159 | int32_t max_batch_size, int32_t batch_timeout_micros, |
160 | int32_t max_enqueued_batches, |
161 | const std::vector<int32>& allowed_batch_sizes, |
162 | FunctionLibraryRuntime::Handle fhandle, FunctionLibraryRuntime* flib, |
163 | std::unique_ptr<BatchResource>* resource) { |
164 | std::shared_ptr<AdaptiveBatcherT> batcher; |
165 | TF_RETURN_IF_ERROR(AdaptiveBatcherT::Create( |
166 | adaptive_shared_batch_scheduler_options, &batcher)); |
167 | |
168 | resource->reset(new BatchResource( |
169 | fhandle, flib, std::move(batcher), |
170 | GetAdaptiveBatcherQueueOptions( |
171 | max_batch_size, batch_timeout_micros, max_enqueued_batches, |
172 | true /* enable large batch split */, allowed_batch_sizes), |
173 | allowed_batch_sizes)); |
174 | return OkStatus(); |
175 | } |
176 | |
177 | string DebugString() const final { return "BatchResource" ; } |
178 | |
179 | private: |
180 | BatchResource(FunctionLibraryRuntime::Handle fhandle, |
181 | FunctionLibraryRuntime* flib, std::shared_ptr<BatcherT> batcher, |
182 | const BatcherT::QueueOptions& batcher_queue_options, |
183 | std::vector<int32> allowed_batch_sizes) |
184 | : BatchResourceBase( |
185 | /*has_process_batch_function=*/fhandle != kInvalidHandle, |
186 | std::move(batcher), batcher_queue_options, |
187 | std::move(allowed_batch_sizes)), |
188 | fhandle_(fhandle), |
189 | flib_(flib) {} |
190 | |
191 | BatchResource(FunctionLibraryRuntime::Handle fhandle, |
192 | FunctionLibraryRuntime* flib, |
193 | std::shared_ptr<AdaptiveBatcherT> batcher, |
194 | const AdaptiveBatcherT::QueueOptions& batcher_queue_options, |
195 | std::vector<int32> allowed_batch_sizes) |
196 | : BatchResourceBase( |
197 | /*has_process_batch_function=*/fhandle != kInvalidHandle, |
198 | std::move(batcher), batcher_queue_options, |
199 | std::move(allowed_batch_sizes)), |
200 | fhandle_(fhandle), |
201 | flib_(flib) {} |
202 | |
203 | void ProcessFuncBatchImpl( |
204 | const BatchTask& last_task, absl::Span<const Tensor> inputs, |
205 | std::vector<Tensor>* combined_outputs, |
206 | std::function<void(const Status&)> done) const override { |
207 | auto* last_task_context = last_task.context; |
208 | FunctionLibraryRuntime::Options opts; |
209 | opts.step_container = last_task_context->step_container(); |
210 | opts.cancellation_manager = last_task_context->cancellation_manager(); |
211 | opts.collective_executor = last_task_context->collective_executor(); |
212 | opts.stats_collector = last_task_context->stats_collector(); |
213 | opts.runner = last_task_context->runner(); |
214 | opts.run_all_kernels_inline = last_task_context->run_all_kernels_inline(); |
215 | // We do not set 'opts.rendezvous', since if the function is run multiple |
216 | // times in parallel with the same rendezvous, a _Send node from one run |
217 | // might be matched with a _Recv node of a different run. Not setting the |
218 | // rendezvous causes a new rendezvous to be used for each run. |
219 | Notification done_notif; |
220 | |
221 | flib_->Run(opts, fhandle_, inputs, combined_outputs, |
222 | [&](const Status& run_status) { |
223 | done(run_status); |
224 | done_notif.Notify(); |
225 | }); |
226 | // By waiting for the notification we are ensuring that this thread isn't |
227 | // used for processing other batches, which gives the batches time to |
228 | // coalesce upstream. So overall the number of batches going through the |
229 | // devices goes down, improving latency and throughput in most cases. |
230 | done_notif.WaitForNotification(); |
231 | } |
232 | |
233 | FunctionLibraryRuntime::Handle fhandle_; |
234 | FunctionLibraryRuntime* flib_; |
235 | }; |
236 | |
237 | BatchFunctionKernel::BatchFunctionKernel(OpKernelConstruction* c) |
238 | : AsyncOpKernel(c) { |
239 | OP_REQUIRES_OK(c, c->GetAttr("container" , &container_)); |
240 | OP_REQUIRES_OK(c, c->GetAttr("shared_name" , &shared_name_)); |
241 | OP_REQUIRES_OK(c, c->GetAttr("batching_queue" , &batcher_queue_)); |
242 | OP_REQUIRES_OK(c, c->GetAttr("num_batch_threads" , &num_batch_threads_)); |
243 | OP_REQUIRES_OK(c, c->GetAttr("max_batch_size" , &max_batch_size_)); |
244 | OP_REQUIRES_OK(c, c->GetAttr("batch_timeout_micros" , &batch_timeout_micros_)); |
245 | OP_REQUIRES_OK(c, c->GetAttr("max_enqueued_batches" , &max_enqueued_batches_)); |
246 | OP_REQUIRES_OK(c, c->GetAttr("allowed_batch_sizes" , &allowed_batch_sizes_)); |
247 | |
248 | OP_REQUIRES_OK(c, c->GetAttr("f" , &func_)); |
249 | flib_ = c->function_library(); |
250 | |
251 | if (c->HasAttr("enable_large_batch_splitting" )) { |
252 | OP_REQUIRES_OK(c, c->GetAttr("enable_large_batch_splitting" , |
253 | &enable_large_batch_splitting_)); |
254 | has_attribute_enable_large_batch_splitting_ = true; |
255 | } else { |
256 | enable_large_batch_splitting_ = false; |
257 | has_attribute_enable_large_batch_splitting_ = false; |
258 | } |
259 | |
260 | // Helper function `SetAdaptiveBatchSchedulerOptions` calls |
261 | // `OP_REQUIRES_OK`, which exits the current function upon error. |
262 | // So validate status of `op-kernel-construction`. |
263 | SetAdaptiveBatchSchedulerOptions(c, num_batch_threads_); |
264 | if (!c->status().ok()) { |
265 | return; |
266 | } |
267 | |
268 | if (enable_adaptive_batch_threads_) { |
269 | // One scheduler instance contains a couple of queue instances, |
270 | // `batcher_queue_` is the key to find queue for this batch-op in the |
271 | // graph. |
272 | // Use `shared_name_` and name() as prefix for `batcher_queue_`. |
273 | // Note name() is unique per session (from session metadata). |
274 | batcher_queue_ = name() + "/" + shared_name_ + batcher_queue_; |
275 | } |
276 | |
277 | if (shared_name_.empty()) { |
278 | // If shared_name is not supplied, use name instead (prevent collisions by |
279 | // default). |
280 | shared_name_ = name(); |
281 | } |
282 | |
283 | OP_REQUIRES_OK(c, ValidateAllowedBatchSizes()); |
284 | } |
285 | |
286 | bool BatchFunctionKernel::IsExpensive() { return false; } |
287 | |
288 | void BatchFunctionKernel::ComputeAsync(OpKernelContext* c, DoneCallback done) { |
289 | RecordBatchSplitUsage(has_attribute_enable_large_batch_splitting_ |
290 | ? absl::make_optional(enable_large_batch_splitting_) |
291 | : absl::nullopt, |
292 | GetModelName(c)); |
293 | // TODO(b/173255290): Add num_batch_threads_ parameter to TFRT batch kernel. |
294 | RecordBatchParamNumBatchThreads(num_batch_threads_, GetModelName(c)); |
295 | |
296 | std::function<Status(BatchResource**)> creator; |
297 | |
298 | FunctionLibraryRuntime::Handle handle; |
299 | OP_REQUIRES_OK_ASYNC(c, GetOrCreateFunctionHandle(c, &handle), done); |
300 | |
301 | if (adaptive_batch_scheduler_options_ != absl::nullopt) { |
302 | creator = [this, handle](BatchResource** r) { |
303 | serving::AdaptiveSharedBatchScheduler< |
304 | serving::BatchResourceBase::BatchTask>::Options |
305 | adaptive_shared_batch_scheduler_options; |
306 | adaptive_shared_batch_scheduler_options.thread_pool_name = |
307 | "adaptive_batch_threads" ; |
308 | adaptive_shared_batch_scheduler_options.num_batch_threads = |
309 | adaptive_batch_scheduler_options_->max_in_flight_batches_limit; |
310 | adaptive_shared_batch_scheduler_options.thread_pool = |
311 | GetOrCreateBatchThreadsPool(); |
312 | // adaptive_shared_batch_scheduler_options.full_batch_scheduling_boost_micros |
313 | // is 0 (default value) intentionally, so tasks are scheduled in a FIFO |
314 | // way. |
315 | // Two rationales to use default value (zero) for |
316 | // `full_batch_scheduling_boost_micros` |
317 | // 1) In this way, tasks scheduling policy is FIFO. Compared with round |
318 | // robin (what shared batch scheduler does), FIFO ensures that model |
319 | // with low QPS (i.e., models enqueue fewer tasks in the shared queue) |
320 | // will be processed timely. |
321 | // 2) If set, `full_batch_scheduling_boost_micros` should be of order |
322 | // the batch processing latency (which varies on a model basis). |
323 | // If a non-zero value is not set properly, it harms tail latency. |
324 | adaptive_shared_batch_scheduler_options.min_in_flight_batches_limit = |
325 | adaptive_batch_scheduler_options_->min_in_flight_batches_limit; |
326 | adaptive_shared_batch_scheduler_options.initial_in_flight_batches_limit = |
327 | adaptive_batch_scheduler_options_->initial_in_flight_batches_limit; |
328 | adaptive_shared_batch_scheduler_options.batches_to_average_over = |
329 | adaptive_batch_scheduler_options_->batches_to_average_over; |
330 | adaptive_shared_batch_scheduler_options.fifo_scheduling = true; |
331 | std::unique_ptr<BatchResource> new_resource; |
332 | TF_RETURN_IF_ERROR(BatchResource::Create( |
333 | adaptive_shared_batch_scheduler_options, max_batch_size_, |
334 | batch_timeout_micros_, max_enqueued_batches_, allowed_batch_sizes_, |
335 | handle, flib_, &new_resource)); |
336 | *r = new_resource.release(); |
337 | return OkStatus(); |
338 | }; |
339 | } else { |
340 | creator = [this, handle](BatchResource** r) { |
341 | std::unique_ptr<BatchResource> new_resource; |
342 | TF_RETURN_IF_ERROR(BatchResource::Create( |
343 | num_batch_threads_, max_batch_size_, batch_timeout_micros_, |
344 | max_enqueued_batches_, allowed_batch_sizes_, handle, flib_, |
345 | enable_large_batch_splitting_, &new_resource)); |
346 | *r = new_resource.release(); |
347 | return OkStatus(); |
348 | }; |
349 | } |
350 | |
351 | BatchResource* br; |
352 | OP_REQUIRES_OK_ASYNC(c, |
353 | c->resource_manager()->LookupOrCreate( |
354 | container_, shared_name_, &br, creator), |
355 | done); |
356 | const Status status = |
357 | br->RegisterInput(random::New64(), c, batcher_queue_, done); |
358 | br->Unref(); |
359 | OP_REQUIRES_OK_ASYNC(c, status, done); |
360 | // Assume br calls done, so nothing to do here. |
361 | } |
362 | |
363 | Status BatchFunctionKernel::InstantiateFunction( |
364 | OpKernelContext* c, FunctionLibraryRuntime::Handle* handle) const { |
365 | // TODO(b/173748062): Merge this instantiation logic with PartitionedCall. |
366 | if (!flib_) { |
367 | return errors::Internal("No function library" ); |
368 | } |
369 | |
370 | FunctionLibraryRuntime::InstantiateOptions opts; |
371 | opts.target = flib_->device() == nullptr ? "" : flib_->device()->name(); |
372 | opts.is_multi_device_function = true; |
373 | const ConfigProto* config = flib_->config_proto(); |
374 | if (config) { |
375 | opts.config_proto = *config; |
376 | } |
377 | |
378 | Device* cpu_device; |
379 | TF_RETURN_IF_ERROR(flib_->device_mgr()->LookupDevice("CPU:0" , &cpu_device)); |
380 | |
381 | const FunctionDef* fdef = |
382 | flib_->GetFunctionLibraryDefinition()->Find(func_.name()); |
383 | if (!fdef) { |
384 | return errors::NotFound("Failed to find definition for function \"" , |
385 | func_.name(), "\"" ); |
386 | } |
387 | OpInputList in_tensors; |
388 | TF_RETURN_IF_ERROR(c->input_list("in_tensors" , &in_tensors)); |
389 | for (int i = 0; i < in_tensors.size(); i++) { |
390 | if (in_tensors[i].dtype() == DT_RESOURCE) { |
391 | return errors::InvalidArgument( |
392 | "BatchFunction cannot take resource inputs but input " , i, |
393 | " is a resource." ); |
394 | } else { |
395 | // Currently, inputs are on CPU since they are concatenated on CPU |
396 | opts.input_devices.push_back(cpu_device->name()); |
397 | } |
398 | } |
399 | OpInputList captured_tensors; |
400 | TF_RETURN_IF_ERROR(c->input_list("captured_tensors" , &captured_tensors)); |
401 | for (const Tensor& t : captured_tensors) { |
402 | if (t.dtype() == DT_RESOURCE) { |
403 | const ResourceHandle& rhandle = t.flat<ResourceHandle>()(0); |
404 | opts.input_devices.push_back(rhandle.device()); |
405 | } else { |
406 | opts.input_devices.push_back(cpu_device->name()); |
407 | } |
408 | } |
409 | const OpDef& signature = fdef->signature(); |
410 | for (int i = 0; i < signature.output_arg_size(); i++) { |
411 | // Currently, outputs must be on CPU since they are split on CPU. |
412 | opts.output_devices.push_back(cpu_device->name()); |
413 | } |
414 | if (opts.input_devices.size() != signature.input_arg_size()) { |
415 | return errors::InvalidArgument( |
416 | "Function takes " , signature.input_arg_size(), " argument(s) but " , |
417 | opts.input_devices.size(), " argument(s) were passed" ); |
418 | } |
419 | return flib_->Instantiate(func_.name(), AttrSlice(&func_.attr()), opts, |
420 | handle); |
421 | } |
422 | |
423 | Status BatchFunctionKernel::GetOrCreateFunctionHandle( |
424 | OpKernelContext* c, FunctionLibraryRuntime::Handle* handle) { |
425 | mutex_lock ml(mu_); |
426 | if (!fhandle_) { |
427 | TF_RETURN_IF_ERROR(InstantiateFunction(c, handle)); |
428 | fhandle_ = *handle; |
429 | } else { |
430 | *handle = fhandle_.value(); |
431 | } |
432 | return OkStatus(); |
433 | } |
434 | |
435 | // Validates 'allowed_batch_sizes_'. The entries must increase monotonically. |
436 | // If large batch split is not enabled, the last one must equal |
437 | // `max_batch_size_`. otherwise the last element must be smaller than or equal |
438 | // to `max_batch_size_`. |
439 | Status BatchFunctionKernel::ValidateAllowedBatchSizes() const { |
440 | if (allowed_batch_sizes_.empty()) { |
441 | return OkStatus(); |
442 | } |
443 | int32_t last_size = 0; |
444 | for (size_t i = 0; i < allowed_batch_sizes_.size(); ++i) { |
445 | const int32_t size = allowed_batch_sizes_.at(i); |
446 | if (i > 0 && size <= last_size) { |
447 | return errors::InvalidArgument( |
448 | "allowed_batch_sizes entries must be monotonically increasing" ); |
449 | } |
450 | |
451 | if ((!enable_large_batch_splitting_) && |
452 | (i == allowed_batch_sizes_.size() - 1) && (size != max_batch_size_)) { |
453 | return errors::InvalidArgument( |
454 | "final entry in allowed_batch_sizes must equal max_batch_size when " |
455 | "enable_large_batch_splitting is False" ); |
456 | } |
457 | |
458 | last_size = size; |
459 | } |
460 | return OkStatus(); |
461 | } |
462 | |
463 | // Initialize vars by reading from op-kernel-construction. |
464 | // Vars |
465 | // - enable_adaptive_batch_threads_ |
466 | // true if value of attribute `kEnableAdaptiveSchedulerAttr` is true, or |
467 | // if `num_batch_threads` is not positive. |
468 | // - adaptive_batch_scheduler_options_ |
469 | // Read from corresponding attributes as long as they are set. |
470 | void BatchFunctionKernel::SetAdaptiveBatchSchedulerOptions( |
471 | OpKernelConstruction* c, int32_t num_batch_threads) { |
472 | if (c->HasAttr(kEnableAdaptiveSchedulerAttr)) { |
473 | OP_REQUIRES_OK(c, c->GetAttr(kEnableAdaptiveSchedulerAttr, |
474 | &enable_adaptive_batch_threads_)); |
475 | } |
476 | |
477 | if (num_batch_threads <= 0) { |
478 | enable_adaptive_batch_threads_ = true; |
479 | } |
480 | |
481 | if (!enable_adaptive_batch_threads_) { |
482 | // adaptive_batch_scheduler_options_ is nullopt. |
483 | return; |
484 | } |
485 | |
486 | // adaptive_batch_scheduler_options_ is not nullopt |
487 | AdaptiveBatchSchedulerOptions options; |
488 | |
489 | if (c->HasAttr(kBatchesToAverageOverAttr)) { |
490 | OP_REQUIRES_OK(c, c->GetAttr(kBatchesToAverageOverAttr, |
491 | &options.batches_to_average_over)); |
492 | } |
493 | |
494 | if (c->HasAttr(kMinInflightBatchesAttr)) { |
495 | OP_REQUIRES_OK(c, c->GetAttr(kMinInflightBatchesAttr, |
496 | &options.min_in_flight_batches_limit)); |
497 | } |
498 | |
499 | if (c->HasAttr(kInitialInflightBatchesAttr)) { |
500 | OP_REQUIRES_OK(c, c->GetAttr(kInitialInflightBatchesAttr, |
501 | &options.initial_in_flight_batches_limit)); |
502 | } |
503 | |
504 | if (c->HasAttr(kMaxInflightBatchesAttr)) { |
505 | OP_REQUIRES_OK(c, c->GetAttr(kMaxInflightBatchesAttr, |
506 | &options.max_in_flight_batches_limit)); |
507 | } |
508 | |
509 | // At this point, the batch kernel is configured to use adaptive scheduling. |
510 | // To validate or return error at kernel construction time, invokes |
511 | // `GetOrCreateBatchThreadsPool` and validates returned `thread_pool` is |
512 | // valid. |
513 | // Note`GetOrCreateBatchThreadsPool` creates the thread pool once and |
514 | // re-uses the thread-pool instance afterwards. |
515 | thread::ThreadPool* thread_pool = GetOrCreateBatchThreadsPool(); |
516 | OP_REQUIRES( |
517 | c, thread_pool != nullptr, |
518 | errors::FailedPrecondition("Failed to create batch threads pool" )); |
519 | |
520 | adaptive_batch_scheduler_options_ = options; |
521 | } |
522 | REGISTER_KERNEL_BUILDER(Name("BatchFunction" ).Device(DEVICE_CPU), |
523 | BatchFunctionKernel); |
524 | // Currently all inputs and outputs are on the host. |
525 | // TODO(b/173748277): Accept inputs/outputs on the device. |
526 | REGISTER_KERNEL_BUILDER(Name("BatchFunction" ) |
527 | .Device(DEVICE_GPU) |
528 | .HostMemory("in_tensors" ) |
529 | .HostMemory("captured_tensors" ) |
530 | .HostMemory("out_tensors" ), |
531 | BatchFunctionKernel); |
532 | REGISTER_KERNEL_BUILDER(Name("BatchFunction" ) |
533 | .Device(DEVICE_DEFAULT) |
534 | .HostMemory("in_tensors" ) |
535 | .HostMemory("captured_tensors" ) |
536 | .HostMemory("out_tensors" ), |
537 | BatchFunctionKernel); |
538 | |
539 | class BatchKernel : public AsyncOpKernel { |
540 | public: |
541 | explicit BatchKernel(OpKernelConstruction* c) : AsyncOpKernel(c) { |
542 | OP_REQUIRES_OK(c, c->GetAttr("container" , &container_)); |
543 | OP_REQUIRES_OK(c, c->GetAttr("shared_name" , &shared_name_)); |
544 | // If shared_name is not supplied, use name instead (prevent collisions by |
545 | // default). |
546 | if (shared_name_.empty()) { |
547 | shared_name_ = name(); |
548 | } |
549 | OP_REQUIRES_OK(c, c->GetAttr("batching_queue" , &batcher_queue_)); |
550 | OP_REQUIRES_OK(c, c->GetAttr("num_batch_threads" , &num_batch_threads_)); |
551 | OP_REQUIRES_OK(c, c->GetAttr("max_batch_size" , &max_batch_size_)); |
552 | OP_REQUIRES_OK(c, |
553 | c->GetAttr("batch_timeout_micros" , &batch_timeout_micros_)); |
554 | OP_REQUIRES_OK(c, |
555 | c->GetAttr("max_enqueued_batches" , &max_enqueued_batches_)); |
556 | OP_REQUIRES_OK(c, c->GetAttr("allowed_batch_sizes" , &allowed_batch_sizes_)); |
557 | OP_REQUIRES_OK(c, ValidateAllowedBatchSizes()); |
558 | } |
559 | |
560 | void ComputeAsync(OpKernelContext* c, DoneCallback done) final { |
561 | BatchResource* br; |
562 | std::function<Status(BatchResource**)> creator = [this](BatchResource** r) { |
563 | std::unique_ptr<BatchResource> new_resource; |
564 | TF_RETURN_IF_ERROR(BatchResource::Create( |
565 | num_batch_threads_, max_batch_size_, batch_timeout_micros_, |
566 | max_enqueued_batches_, allowed_batch_sizes_, kInvalidHandle, |
567 | /*flib=*/nullptr, false, &new_resource)); |
568 | *r = new_resource.release(); |
569 | return OkStatus(); |
570 | }; |
571 | OP_REQUIRES_OK_ASYNC(c, |
572 | c->resource_manager()->LookupOrCreate( |
573 | container_, shared_name_, &br, creator), |
574 | done); |
575 | const Status status = |
576 | br->RegisterInput(random::New64(), c, batcher_queue_, done); |
577 | br->Unref(); |
578 | OP_REQUIRES_OK_ASYNC(c, status, done); |
579 | // Assume br calls done, so nothing to do here. |
580 | } |
581 | |
582 | // Validates 'allowed_batch_sizes_'. The entries must increase |
583 | // monotonically, and the last one must equal 'max_batch_size_'. |
584 | Status ValidateAllowedBatchSizes() const { |
585 | if (allowed_batch_sizes_.empty()) { |
586 | return OkStatus(); |
587 | } |
588 | int32_t last_size = 0; |
589 | for (size_t i = 0; i < allowed_batch_sizes_.size(); ++i) { |
590 | const int32_t size = allowed_batch_sizes_.at(i); |
591 | if (i > 0 && size <= last_size) { |
592 | return errors::InvalidArgument( |
593 | "allowed_batch_sizes entries must be monotonically increasing" ); |
594 | } |
595 | if (i == allowed_batch_sizes_.size() - 1 && size != max_batch_size_) { |
596 | return errors::InvalidArgument( |
597 | "final entry in allowed_batch_sizes must equal max_batch_size" ); |
598 | } |
599 | last_size = size; |
600 | } |
601 | return OkStatus(); |
602 | } |
603 | |
604 | private: |
605 | string container_; |
606 | string shared_name_; |
607 | string batcher_queue_; |
608 | int32 num_batch_threads_; |
609 | int32 max_batch_size_; |
610 | int32 batch_timeout_micros_; |
611 | int32 max_enqueued_batches_; |
612 | std::vector<int32> allowed_batch_sizes_; |
613 | }; |
614 | |
615 | REGISTER_KERNEL_BUILDER(Name("Batch" ).Device(DEVICE_CPU), BatchKernel); |
616 | |
617 | // A class encapsulating the state and logic for unbatching tensors. |
618 | // |
619 | // UnbatchResource keeps two data structures indexed by batch-key: one which has |
620 | // the continuations for all concurrent kernels which are waiting for tensors |
621 | // and another which has tensors which are waiting for their corresponding |
622 | // kernels to run. Whenever a kernel runs, we either grab its tensor if it's |
623 | // waiting already, or we insert it in the queue and then look at its tensor to |
624 | // see if it can be used to dispatch any stored continuations. |
625 | class UnbatchResource : public ResourceBase { |
626 | public: |
627 | explicit UnbatchResource(int32_t timeout_micros) |
628 | : timeout_micros_(timeout_micros), |
629 | timeout_enforcer_(new serving::PeriodicFunction( |
630 | [this] { EnforceTimeout(); }, 1000 /* 1 ms */)) {} |
631 | |
632 | ~UnbatchResource() override { |
633 | // Tear down 'timeout_enforcer_' first, since it accesses other state in |
634 | // this class. |
635 | timeout_enforcer_ = nullptr; |
636 | } |
637 | |
638 | string DebugString() const final { return "UnbatchResource" ; } |
639 | |
640 | Status Compute(OpKernelContext* context, AsyncOpKernel::DoneCallback done) { |
641 | const Tensor& data_t = context->input(0); |
642 | const Tensor& batch_index_t = context->input(1); |
643 | |
644 | if (batch_index_t.shape().dim_size(0) > data_t.shape().dim_size(0)) { |
645 | return errors::InvalidArgument( |
646 | "Wrong shape for index tensor. Expected 0th dimension size to be no " |
647 | "greater than " , |
648 | data_t.shape().dim_size(0), |
649 | "; Got: " , batch_index_t.shape().dim_size(0), "." ); |
650 | } |
651 | if (batch_index_t.shape().dim_size(1) != 3) { |
652 | return errors::InvalidArgument( |
653 | "Wrong shape for index tensor. Expected 1st dimension size to be 3 ; " |
654 | "Got: " , |
655 | batch_index_t.shape().dim_size(1), "." ); |
656 | } |
657 | |
658 | if (!TensorShapeUtils::IsScalar(context->input(2).shape())) { |
659 | return errors::InvalidArgument( |
660 | "Input id should be scalar; " |
661 | "Got: " , |
662 | context->input(2).DebugString(), "." ); |
663 | } |
664 | const int64_t batch_key = context->input(2).scalar<int64_t>()(); |
665 | const bool nonempty_input = batch_index_t.dim_size(0) > 0; |
666 | |
667 | // If we have a non-empty tensor, slice it up. |
668 | // (It is important to do this outside of the critical section below.) |
669 | // The following variables are populated iff 'nonempty_input==true'. |
670 | std::vector<int64_t> sizes; |
671 | std::vector<int64_t> batch_keys; |
672 | std::vector<Tensor> split_inputs; |
673 | if (nonempty_input) { |
674 | auto batch_indices = |
675 | batch_index_t.shaped<int64_t, 2>({batch_index_t.dim_size(0), 3}); |
676 | for (int i = 0; i < batch_index_t.dim_size(0); ++i) { |
677 | sizes.push_back(batch_indices(i, 2) - batch_indices(i, 1)); |
678 | batch_keys.push_back(batch_indices(i, 0)); |
679 | } |
680 | |
681 | TF_RETURN_IF_ERROR(Split(context, data_t, sizes, &split_inputs)); |
682 | } |
683 | |
684 | // Critical section. |
685 | std::vector<AsyncOpKernel::DoneCallback> done_callbacks_to_call; |
686 | Status status = [&]() -> Status { |
687 | mutex_lock ml(mu_); |
688 | |
689 | // Check to see whether the tensor we want is already ready. |
690 | auto tensor_it = waiting_tensors_.find(batch_key); |
691 | if (tensor_it != waiting_tensors_.end()) { |
692 | context->set_output(0, tensor_it->second.tensor); |
693 | waiting_tensors_.erase(tensor_it); |
694 | done_callbacks_to_call.push_back(done); |
695 | return OkStatus(); |
696 | } |
697 | |
698 | const uint64 deadline_micros = |
699 | Env::Default()->NowMicros() + timeout_micros_; |
700 | |
701 | // Add ourselves to the waitlist for tensors. |
702 | if (!waiting_callbacks_ |
703 | .emplace(batch_key, |
704 | WaitingCallback{deadline_micros, context, done}) |
705 | .second) { |
706 | return errors::AlreadyExists( |
707 | "Multiple session runs with the same batch key." ); |
708 | } |
709 | |
710 | // If we have a non-empty tensor, finish the waitlisted runs, |
711 | // and store any remaining pieces. |
712 | if (nonempty_input) { |
713 | for (size_t i = 0; i < batch_keys.size(); ++i) { |
714 | auto runs_it = waiting_callbacks_.find(batch_keys[i]); |
715 | if (runs_it != waiting_callbacks_.end()) { |
716 | runs_it->second.context->set_output(0, split_inputs[i]); |
717 | done_callbacks_to_call.push_back(runs_it->second.done); |
718 | waiting_callbacks_.erase(runs_it); |
719 | } else { |
720 | // Note: the deadline here is in case we are arriving late and the |
721 | // kernel that should rendezvous with this tensor has already waited |
722 | // and timed out. |
723 | if (!waiting_tensors_ |
724 | .emplace(batch_keys[i], |
725 | WaitingTensor{deadline_micros, split_inputs[i]}) |
726 | .second) { |
727 | return errors::AlreadyExists( |
728 | "Multiple tensors returned for same batch key." ); |
729 | } |
730 | } |
731 | } |
732 | } |
733 | |
734 | return OkStatus(); |
735 | }(); |
736 | |
737 | for (const AsyncOpKernel::DoneCallback& done_callback : |
738 | done_callbacks_to_call) { |
739 | done_callback(); |
740 | } |
741 | |
742 | return status; |
743 | } |
744 | |
745 | private: |
746 | // Evicts waiting tensors and callbacks that have exceeded their deadline. |
747 | void EnforceTimeout() { |
748 | const uint64 now = Env::Default()->NowMicros(); |
749 | std::vector<WaitingCallback> evicted_callbacks; |
750 | |
751 | { |
752 | mutex_lock ml(mu_); |
753 | |
754 | for (auto it = waiting_tensors_.begin(); it != waiting_tensors_.end();) { |
755 | const WaitingTensor& waiting_tensor = it->second; |
756 | if (waiting_tensor.deadline_micros < now) { |
757 | it = waiting_tensors_.erase(it); |
758 | } else { |
759 | ++it; |
760 | } |
761 | } |
762 | |
763 | for (auto it = waiting_callbacks_.begin(); |
764 | it != waiting_callbacks_.end();) { |
765 | const WaitingCallback& waiting_callback = it->second; |
766 | if (waiting_callback.deadline_micros < now) { |
767 | evicted_callbacks.push_back(waiting_callback); |
768 | it = waiting_callbacks_.erase(it); |
769 | } else { |
770 | ++it; |
771 | } |
772 | } |
773 | } |
774 | |
775 | for (const WaitingCallback& evicted_callback : evicted_callbacks) { |
776 | evicted_callback.context->CtxFailureWithWarning(errors::DeadlineExceeded( |
777 | "Batched data did not arrive within timeout window." )); |
778 | evicted_callback.done(); |
779 | } |
780 | } |
781 | |
782 | struct WaitingTensor { |
783 | uint64 deadline_micros; |
784 | Tensor tensor; |
785 | }; |
786 | |
787 | struct WaitingCallback { |
788 | uint64 deadline_micros; |
789 | OpKernelContext* context; |
790 | AsyncOpKernel::DoneCallback done; |
791 | }; |
792 | |
793 | const int32 timeout_micros_; |
794 | |
795 | mutex mu_; |
796 | |
797 | // Maps keyed by BatchKey of tensors waiting for callbacks and callbacks |
798 | // waiting for tensors. |
799 | std::unordered_map<int64_t, WaitingTensor> waiting_tensors_ |
800 | TF_GUARDED_BY(mu_); |
801 | std::unordered_map<int64_t, WaitingCallback> waiting_callbacks_ |
802 | TF_GUARDED_BY(mu_); |
803 | |
804 | // A thread that evicts waiting tensors and callbacks that have exceeded their |
805 | // deadline. |
806 | std::unique_ptr<serving::PeriodicFunction> timeout_enforcer_; |
807 | }; |
808 | |
809 | class UnbatchKernel : public AsyncOpKernel { |
810 | public: |
811 | explicit UnbatchKernel(OpKernelConstruction* c) : AsyncOpKernel(c) { |
812 | OP_REQUIRES_OK(c, c->GetAttr("container" , &container_)); |
813 | OP_REQUIRES_OK(c, c->GetAttr("shared_name" , &shared_name_)); |
814 | // If shared_name is not supplied, use name instead (prevent collisions by |
815 | // default). |
816 | if (shared_name_.empty()) { |
817 | shared_name_ = name(); |
818 | } |
819 | OP_REQUIRES_OK(c, c->GetAttr("timeout_micros" , &timeout_micros_)); |
820 | } |
821 | |
822 | void ComputeAsync(OpKernelContext* c, DoneCallback done) final { |
823 | UnbatchResource* ubr; |
824 | std::function<Status(UnbatchResource**)> creator = |
825 | [this](UnbatchResource** r) { |
826 | *r = new UnbatchResource(timeout_micros_); |
827 | return OkStatus(); |
828 | }; |
829 | OP_REQUIRES_OK_ASYNC(c, |
830 | c->resource_manager()->LookupOrCreate( |
831 | container_, shared_name_, &ubr, creator), |
832 | done); |
833 | auto status = ubr->Compute(c, done); |
834 | ubr->Unref(); |
835 | OP_REQUIRES_OK_ASYNC(c, status, done); |
836 | // Assume ubr calls done, so nothing to do here. |
837 | } |
838 | |
839 | private: |
840 | string container_; |
841 | string shared_name_; |
842 | int32 timeout_micros_; |
843 | }; |
844 | REGISTER_KERNEL_BUILDER(Name("Unbatch" ).Device(DEVICE_CPU), UnbatchKernel); |
845 | |
846 | // A class encapsulating the state and logic for batching tensors |
847 | // deterministically for the gradient of unbatch. |
848 | class UnbatchGradResource : public ResourceBase { |
849 | public: |
850 | UnbatchGradResource() {} |
851 | |
852 | string DebugString() const final { return "UnbatchGradResource" ; } |
853 | |
854 | // Flushes the information for one batch, given its context and done |
855 | // callback. Clears all information about it from the available_tensors_. |
856 | Status OutputBatch(OpKernelContext* context, |
857 | const AsyncOpKernel::DoneCallback& done) |
858 | TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
859 | const Tensor& batch_index_t = context->input(1); |
860 | auto batch_index = |
861 | batch_index_t.shaped<int64_t, 2>({batch_index_t.dim_size(0), 3}); |
862 | std::vector<Tensor> tensors; |
863 | for (int i = 0; i < batch_index_t.dim_size(0); ++i) { |
864 | auto available_it = available_tensors_.find(batch_index(i, 0)); |
865 | if (available_it == available_tensors_.end()) { |
866 | return errors::Internal("bad bookkeeping of available tensors." ); |
867 | } |
868 | tensors.push_back(available_it->second); |
869 | available_tensors_.erase(available_it); |
870 | } |
871 | |
872 | const DataType type = tensors[0].dtype(); |
873 | Tensor concatenated_tensor; |
874 | switch (type) { |
875 | #define CASE(type) \ |
876 | case DataTypeToEnum<type>::value: \ |
877 | TF_RETURN_IF_ERROR(Concat<type>(context, tensors, &concatenated_tensor)); \ |
878 | context->set_output(0, concatenated_tensor); \ |
879 | break; |
880 | TF_CALL_ALL_TYPES(CASE); |
881 | #undef CASE |
882 | default: |
883 | return errors::InvalidArgument("Unsupported data type: " , type); |
884 | } |
885 | done(); |
886 | return OkStatus(); |
887 | } |
888 | |
889 | // Ingests data from one invocation of the op. |
890 | Status Compute(OpKernelContext* context, |
891 | const AsyncOpKernel::DoneCallback& done) { |
892 | const Tensor& data_t = context->input(0); |
893 | const Tensor& batch_index_t = context->input(1); |
894 | const Tensor& grad_t = context->input(2); |
895 | const Tensor& batch_key_t = context->input(3); |
896 | |
897 | mutex_lock ml(mu_); |
898 | if (batch_key_t.NumElements() != 1) { |
899 | return errors::InvalidArgument("Expected `id` to be scalar. Received " , |
900 | batch_key_t.DebugString()); |
901 | } |
902 | |
903 | const int64_t batch_key = context->input(3).scalar<int64_t>()(); |
904 | // Mark our tensor as available. |
905 | if (!available_tensors_.emplace(batch_key, grad_t).second) { |
906 | return errors::InvalidArgument("Two runs with the same batch key." ); |
907 | } |
908 | |
909 | // Check whether we have a valid input tensor and, if so, create its |
910 | // dispatch logic. |
911 | if (data_t.NumElements() > 0) { |
912 | if (batch_index_t.NumElements() == 0) { |
913 | return errors::InvalidArgument( |
914 | "batch_index is empty while the tensor isn't." ); |
915 | } |
916 | std::unordered_set<int64_t> missing_tensors; |
917 | if (batch_index_t.NumElements() != batch_index_t.dim_size(0) * 3) { |
918 | return errors::InvalidArgument( |
919 | "batch_index should contain " , batch_index_t.dim_size(0) * 3, |
920 | " elements. Received " , batch_index_t.NumElements()); |
921 | } |
922 | const auto batch_index = |
923 | batch_index_t.shaped<int64_t, 2>({batch_index_t.dim_size(0), 3}); |
924 | for (int i = 0; i < batch_index_t.dim_size(0); ++i) { |
925 | const int64_t batch_key = batch_index(i, 0); |
926 | if (available_tensors_.find(batch_key) == available_tensors_.end()) { |
927 | missing_tensors.emplace(batch_key); |
928 | } |
929 | } |
930 | if (missing_tensors.empty()) { |
931 | return OutputBatch(context, done); |
932 | } |
933 | if (!available_batches_ |
934 | .emplace(batch_key, Batch{missing_tensors, context, done}) |
935 | .second) { |
936 | return errors::InvalidArgument( |
937 | "Batch key with valid batch used twice." ); |
938 | } |
939 | for (const int64_t i : missing_tensors) { |
940 | if (!desired_tensor_to_batch_map_.emplace(i, batch_key).second) { |
941 | return errors::InvalidArgument( |
942 | "Missing tensor wanted by more than one batch." ); |
943 | } |
944 | } |
945 | } else { |
946 | // If we don't have a valid input tensor we can output an empty tensor and |
947 | // call our done closure. |
948 | TensorShape output_shape(grad_t.shape()); |
949 | output_shape.set_dim(0, 0); |
950 | Tensor* output = nullptr; |
951 | TF_RETURN_IF_ERROR(context->allocate_output(0, output_shape, &output)); |
952 | done(); |
953 | } |
954 | |
955 | // Search to see whether our tensor is desired by any existing batch. |
956 | auto desire_it = desired_tensor_to_batch_map_.find(batch_key); |
957 | if (desire_it != desired_tensor_to_batch_map_.end()) { |
958 | // Mark our tensor as no longer missing. |
959 | auto batch_it = available_batches_.find(desire_it->second); |
960 | desired_tensor_to_batch_map_.erase(desire_it); |
961 | if (batch_it == available_batches_.end()) { |
962 | return errors::InvalidArgument("Batch no longer exists." ); |
963 | } |
964 | batch_it->second.missing_tensors.erase(batch_key); |
965 | // If all tensors are available we should concatenate them and dispatch |
966 | // the batch. |
967 | if (batch_it->second.missing_tensors.empty()) { |
968 | TF_RETURN_IF_ERROR( |
969 | OutputBatch(batch_it->second.context, batch_it->second.done)); |
970 | available_batches_.erase(batch_it); |
971 | } |
972 | } |
973 | return OkStatus(); |
974 | } |
975 | |
976 | private: |
977 | mutex mu_; |
978 | |
979 | // Represents a still-incomplete batch of tensors. When all tensors become |
980 | // available they will be concatenated in the right order and sent through the |
981 | // context. |
982 | struct Batch { |
983 | // Batch keys for tensors which are still missing from this batch. When this |
984 | // is empty the Tensors can be concatenated and forwarded. |
985 | std::unordered_set<int64_t> missing_tensors; |
986 | |
987 | // Context and callback for the session responsible for finishing this |
988 | // batch. |
989 | OpKernelContext* context; |
990 | AsyncOpKernel::DoneCallback done; |
991 | }; |
992 | |
993 | // Map from batch key of the session which will output the batched gradients |
994 | // to still-incomplete batches. |
995 | std::unordered_map<int64_t, Batch> available_batches_; |
996 | |
997 | // Map from batch key to tensors which are waiting for their batches to be |
998 | // available. |
999 | std::unordered_map<int64_t, Tensor> available_tensors_; |
1000 | |
1001 | // Map from batch key of a tensor which is not yet available to the batch key |
1002 | // of the batch to which it belongs. |
1003 | std::unordered_map<int64_t, int64_t> desired_tensor_to_batch_map_; |
1004 | }; |
1005 | |
1006 | class UnbatchGradKernel : public AsyncOpKernel { |
1007 | public: |
1008 | explicit UnbatchGradKernel(OpKernelConstruction* c) : AsyncOpKernel(c) { |
1009 | OP_REQUIRES_OK(c, c->GetAttr("container" , &container_)); |
1010 | OP_REQUIRES_OK(c, c->GetAttr("shared_name" , &shared_name_)); |
1011 | // If shared_name is not supplied, use name instead (prevent collisions by |
1012 | // default). |
1013 | if (shared_name_.empty()) { |
1014 | shared_name_ = name(); |
1015 | } |
1016 | } |
1017 | |
1018 | void ComputeAsync(OpKernelContext* c, DoneCallback done) final { |
1019 | UnbatchGradResource* ubr; |
1020 | std::function<Status(UnbatchGradResource**)> creator = |
1021 | [](UnbatchGradResource** r) { |
1022 | *r = new UnbatchGradResource(); |
1023 | return OkStatus(); |
1024 | }; |
1025 | OP_REQUIRES_OK_ASYNC(c, |
1026 | c->resource_manager()->LookupOrCreate( |
1027 | container_, shared_name_, &ubr, creator), |
1028 | done); |
1029 | Status status = ubr->Compute(c, done); |
1030 | ubr->Unref(); |
1031 | OP_REQUIRES_OK_ASYNC(c, status, done); |
1032 | // Assume ubr calls done, so nothing to do here. |
1033 | } |
1034 | |
1035 | private: |
1036 | string container_; |
1037 | string shared_name_; |
1038 | }; |
1039 | REGISTER_KERNEL_BUILDER(Name("UnbatchGrad" ).Device(DEVICE_CPU), |
1040 | UnbatchGradKernel); |
1041 | |
1042 | } // namespace tensorflow |
1043 | |