1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/op_kernel.h" |
17 | |
18 | #include <cstdlib> |
19 | #include <cstring> |
20 | #include <mutex> // NOLINT |
21 | #include <string> |
22 | #include <unordered_map> |
23 | #include <utility> |
24 | #include <vector> |
25 | |
26 | #include "absl/base/call_once.h" |
27 | #include "absl/container/flat_hash_set.h" |
28 | #include "absl/strings/match.h" |
29 | #include "tensorflow/core/framework/allocation_description.pb.h" |
30 | #include "tensorflow/core/framework/attr_value.pb.h" |
31 | #include "tensorflow/core/framework/attr_value_util.h" |
32 | #include "tensorflow/core/framework/device_attributes.pb.h" |
33 | #include "tensorflow/core/framework/device_factory.h" |
34 | #include "tensorflow/core/framework/graph.pb.h" |
35 | #include "tensorflow/core/framework/kernel_def.pb.h" |
36 | #include "tensorflow/core/framework/kernel_def_util.h" |
37 | #include "tensorflow/core/framework/log_memory.h" |
38 | #include "tensorflow/core/framework/memory_types.h" |
39 | #include "tensorflow/core/framework/node_def.pb.h" |
40 | #include "tensorflow/core/framework/node_def_util.h" |
41 | #include "tensorflow/core/framework/node_properties.h" |
42 | #include "tensorflow/core/framework/op_def_util.h" |
43 | #include "tensorflow/core/framework/tensor_reference.h" |
44 | #include "tensorflow/core/framework/types.h" |
45 | #include "tensorflow/core/lib/core/errors.h" |
46 | #include "tensorflow/core/lib/core/notification.h" |
47 | #include "tensorflow/core/lib/core/stringpiece.h" |
48 | #include "tensorflow/core/lib/gtl/map_util.h" |
49 | #include "tensorflow/core/lib/io/path.h" |
50 | #include "tensorflow/core/lib/strings/str_util.h" |
51 | #include "tensorflow/core/lib/strings/strcat.h" |
52 | #include "tensorflow/core/platform/cpu_info.h" |
53 | #include "tensorflow/core/platform/env.h" |
54 | #include "tensorflow/core/platform/logging.h" |
55 | #include "tensorflow/core/platform/mutex.h" |
56 | #include "tensorflow/core/platform/platform_strings.h" |
57 | #include "tensorflow/core/platform/types.h" |
58 | #include "tensorflow/core/profiler/lib/scoped_memory_debug_annotation.h" |
59 | #include "tensorflow/core/profiler/lib/traceme.h" |
60 | #include "tensorflow/core/util/ptr_util.h" |
61 | |
62 | namespace tensorflow { |
63 | |
64 | const char* kJitKernelLabel = "JITCompiledKernel" ; |
65 | const char* kDisableJitKernelsEnvVar = "TF_DISABLE_JIT_KERNELS" ; |
66 | |
67 | namespace { |
68 | |
69 | Status MatchSignatureHelper(const DataTypeSlice expected_inputs, |
70 | const DataTypeSlice expected_outputs, |
71 | const DataTypeSlice inputs, |
72 | const DataTypeSlice outputs) { |
73 | bool signature_mismatch = false; |
74 | |
75 | if (inputs.size() != expected_inputs.size()) signature_mismatch = true; |
76 | for (size_t i = 0; !signature_mismatch && i < inputs.size(); ++i) { |
77 | if (!TypesCompatible(expected_inputs[i], inputs[i])) { |
78 | signature_mismatch = true; |
79 | } |
80 | } |
81 | |
82 | if (outputs.size() != expected_outputs.size()) signature_mismatch = true; |
83 | for (size_t i = 0; !signature_mismatch && i < outputs.size(); ++i) { |
84 | if (!TypesCompatible(expected_outputs[i], outputs[i])) { |
85 | signature_mismatch = true; |
86 | } |
87 | } |
88 | |
89 | if (signature_mismatch) { |
90 | return errors::InvalidArgument( |
91 | "Signature mismatch, have: " , DataTypeSliceString(inputs), "->" , |
92 | DataTypeSliceString(outputs), |
93 | " expected: " , DataTypeSliceString(expected_inputs), "->" , |
94 | DataTypeSliceString(expected_outputs)); |
95 | } |
96 | return OkStatus(); |
97 | } |
98 | |
99 | const absl::flat_hash_set<std::string>* GetOpNodeDefsToLogFromEnv() { |
100 | auto* result = new absl::flat_hash_set<std::string>; |
101 | const char* env = getenv("TF_DEBUG_OPS_TO_LOG_NODEDEFS" ); |
102 | if (!env) { |
103 | return result; |
104 | } |
105 | |
106 | std::vector<absl::string_view> ops = absl::StrSplit(env, ','); |
107 | LOG(INFO) << "Will log NodeDefs from the following ops: " ; |
108 | for (absl::string_view op : ops) { |
109 | result->insert(std::string(op)); |
110 | LOG(INFO) << " |" << op << "|" ; |
111 | } |
112 | |
113 | return result; |
114 | } |
115 | |
116 | // Returns true if the NodeDef for the OpKernel should be logged. The |
117 | // envionrmental variable TF_DEBUG_OPS_TO_LOG_NODEDEFS can be set to a |
118 | // comma-separated list of op types. The NodeDef for each is printed, which is |
119 | // useful for debugging purposes. |
120 | bool ShouldLogNodeDef(OpKernel* op_kernel) { |
121 | static const absl::flat_hash_set<std::string>& ops_to_log_nodedefs = |
122 | *GetOpNodeDefsToLogFromEnv(); |
123 | return ops_to_log_nodedefs.count(op_kernel->type_string()); |
124 | } |
125 | |
126 | } // namespace |
127 | |
128 | // OpKernel ------------------------------------------------------------------ |
129 | |
130 | OpKernel::OpKernel(OpKernelConstruction* context) : OpKernel(context, false) {} |
131 | |
132 | OpKernel::OpKernel(OpKernelConstruction* context, bool is_deferred) |
133 | : props_(context->props_), |
134 | input_memory_types_(context->input_memory_types().begin(), |
135 | context->input_memory_types().end()), |
136 | output_memory_types_(context->output_memory_types().begin(), |
137 | context->output_memory_types().end()), |
138 | input_name_map_(context->num_inputs()), |
139 | output_name_map_(context->num_outputs()), |
140 | name_view_(props_->node_def.name()), |
141 | type_string_view_(props_->node_def.op()), |
142 | graph_def_version_(context->graph_def_version()), |
143 | is_deferred_(is_deferred) { |
144 | OP_REQUIRES_OK(context, |
145 | NameRangesForNode(props_->node_def, *props_->op_def, |
146 | &input_name_map_, &output_name_map_)); |
147 | OP_REQUIRES_OK(context, CheckOpDeprecation(*props_->op_def, |
148 | context->graph_def_version())); |
149 | |
150 | // Kernels executing on GPU tie very few resources on the CPU where the |
151 | // scheduler runs: we consider them as inexpensive. |
152 | expensive_ = context->device_type() != DeviceType(DEVICE_GPU) && |
153 | !DeviceFactory::IsPluggableDevice( |
154 | DeviceTypeString(context->device_type())); |
155 | |
156 | if (ShouldLogNodeDef(this)) { |
157 | LOG(INFO) << "NodeDef for " << name() << ":\n" << def().ShortDebugString(); |
158 | } |
159 | } |
160 | |
161 | OpKernel::OpKernel(OpKernelConstruction* context, NodeDef&& custom_def, |
162 | bool is_deferred) |
163 | : props_(std::make_shared<const NodeProperties>( |
164 | context->props_->op_def, std::move(custom_def), |
165 | context->props_->input_types, context->props_->output_types)), |
166 | input_memory_types_(context->input_memory_types().begin(), |
167 | context->input_memory_types().end()), |
168 | output_memory_types_(context->output_memory_types().begin(), |
169 | context->output_memory_types().end()), |
170 | input_name_map_(context->num_inputs()), |
171 | output_name_map_(context->num_outputs()), |
172 | name_view_(props_->node_def.name()), |
173 | type_string_view_(props_->node_def.op()), |
174 | graph_def_version_(context->graph_def_version()), |
175 | is_deferred_(is_deferred) { |
176 | OP_REQUIRES_OK(context, |
177 | NameRangesForNode(props_->node_def, *props_->op_def, |
178 | &input_name_map_, &output_name_map_)); |
179 | OP_REQUIRES_OK(context, CheckOpDeprecation(*props_->op_def, |
180 | context->graph_def_version())); |
181 | |
182 | // Kernels executing on GPU tie very few resources on the CPU where the |
183 | // scheduler runs: we consider them as inexpensive. |
184 | expensive_ = context->device_type() != DeviceType(DEVICE_GPU) && |
185 | !DeviceFactory::IsPluggableDevice( |
186 | DeviceTypeString(context->device_type())); |
187 | } |
188 | |
189 | OpKernel::~OpKernel() {} |
190 | |
191 | Status OpKernel::InputRange(StringPiece input_name, int* start, |
192 | int* stop) const { |
193 | const auto result = input_name_map_.find(input_name); |
194 | if (result == input_name_map_.end()) { |
195 | return errors::InvalidArgument("Unknown input name: " , input_name); |
196 | } else { |
197 | *start = result->second.first; |
198 | *stop = result->second.second; |
199 | return OkStatus(); |
200 | } |
201 | } |
202 | |
203 | Status OpKernel::OutputRange(StringPiece output_name, int* start, |
204 | int* stop) const { |
205 | const auto result = output_name_map_.find(output_name); |
206 | if (result == output_name_map_.end()) { |
207 | return errors::InvalidArgument("Unknown output name: " , output_name); |
208 | } else { |
209 | *start = result->second.first; |
210 | *stop = result->second.second; |
211 | return OkStatus(); |
212 | } |
213 | } |
214 | |
215 | string OpKernel::ShapeTraceString(const OpKernelContext& ctx) const { |
216 | int num_inputs = ctx.num_inputs(); |
217 | if (num_inputs == 0) return "" ; |
218 | std::vector<string> tensor_shapes; |
219 | tensor_shapes.reserve(num_inputs); |
220 | for (int i = 0; i < num_inputs; i++) { |
221 | if (!ctx.has_input(i)) { |
222 | tensor_shapes.emplace_back(); // Placeholder |
223 | continue; |
224 | } |
225 | DataType input_dtype = ctx.input_dtype(i); |
226 | if (input_dtype == DataType::DT_RESOURCE || |
227 | input_dtype == DataType::DT_VARIANT || IsRefType(input_dtype)) { |
228 | tensor_shapes.emplace_back(); // Placeholder |
229 | continue; |
230 | } |
231 | tensor_shapes.emplace_back(strings::StrCat( |
232 | DataTypeString(input_dtype), ctx.input(i).shape().DebugString())); |
233 | } |
234 | return strings::StrCat("(" , absl::StrJoin(tensor_shapes, ";" ), ")" ); |
235 | } |
236 | |
237 | string OpKernel::TraceString(const OpKernelContext& ctx, bool verbose) const { |
238 | string trace_string = profiler::TraceMeOp(name_view(), type_string_view()); |
239 | if (verbose) { |
240 | string shape = ShapeTraceString(ctx); |
241 | if (!shape.empty()) { |
242 | trace_string = |
243 | profiler::TraceMeEncode(std::move(trace_string), {{"shape" , shape}}); |
244 | } |
245 | } |
246 | return trace_string; |
247 | } |
248 | |
249 | void AsyncOpKernel::Compute(OpKernelContext* context) { |
250 | Notification n; |
251 | ComputeAsync(context, [&n]() { n.Notify(); }); |
252 | n.WaitForNotification(); |
253 | } |
254 | |
255 | // OpKernelConstruction ------------------------------------------------------ |
256 | |
257 | OpKernelConstruction::OpKernelConstruction( |
258 | DeviceType device_type, DeviceBase* device, Allocator* allocator, |
259 | FunctionLibraryRuntime* flib, ResourceMgr* resource_mgr, |
260 | const std::shared_ptr<const NodeProperties>& props, |
261 | const MemoryTypeSlice& input_memory_types, |
262 | const MemoryTypeSlice& output_memory_types, int graph_def_version, |
263 | Status* status) |
264 | : device_type_(std::move(device_type)), |
265 | device_(device), |
266 | allocator_(allocator), |
267 | flib_(flib), |
268 | resource_mgr_(resource_mgr), |
269 | props_(props), |
270 | input_memory_types_(input_memory_types), |
271 | output_memory_types_(output_memory_types), |
272 | graph_def_version_(graph_def_version), |
273 | status_(status) {} |
274 | |
275 | bool OpKernelConstruction::HasAttr(StringPiece attr_name) const { |
276 | return HasNodeAttr(def(), attr_name); |
277 | } |
278 | |
279 | void OpKernelConstruction::SetStatus(const Status& status) { |
280 | status_->Update(status); |
281 | } |
282 | |
283 | Status OpKernelConstruction::MatchSignature( |
284 | const DataTypeSlice expected_inputs, const DataTypeSlice expected_outputs) { |
285 | return MatchSignatureHelper(expected_inputs, expected_outputs, |
286 | props_->input_types, props_->output_types); |
287 | } |
288 | |
289 | Status OpKernelConstruction::allocate_temp(DataType type, |
290 | const TensorShape& shape, |
291 | Tensor* out_temp) { |
292 | AllocationAttributes attr; |
293 | attr.allocation_will_be_logged = true; |
294 | Tensor new_temp(allocator_, type, shape, attr); |
295 | |
296 | if (!new_temp.IsInitialized()) { |
297 | return errors::ResourceExhausted( |
298 | "OOM when allocating temporary tensor with shape" , shape.DebugString()); |
299 | } |
300 | if (LogMemory::IsEnabled()) { |
301 | LogMemory::RecordTensorAllocation( |
302 | def().name(), LogMemory::OP_KERNEL_CONSTRUCTION_STEP_ID, new_temp); |
303 | } |
304 | *out_temp = new_temp; |
305 | return OkStatus(); |
306 | } |
307 | |
308 | Status OpKernelConstruction::allocate_temp(DataType type, |
309 | const TensorShape& shape, |
310 | Tensor* out_temp, |
311 | AllocatorAttributes allocator_attr) { |
312 | if (allocator_attr.scope_id != 0) { |
313 | return errors::InvalidArgument( |
314 | "ScopedAllocator cannot be used via OpKernelConstruction." ); |
315 | } |
316 | Allocator* a = device_->GetAllocator(allocator_attr); |
317 | AllocationAttributes attr; |
318 | attr.allocation_will_be_logged = true; |
319 | Tensor new_temp(a, type, shape, attr); |
320 | |
321 | if (!new_temp.IsInitialized()) { |
322 | return errors::ResourceExhausted( |
323 | "OOM when allocating temporary tensor with shape" , shape.DebugString()); |
324 | } |
325 | if (LogMemory::IsEnabled()) { |
326 | LogMemory::RecordTensorAllocation( |
327 | def().name(), LogMemory::OP_KERNEL_CONSTRUCTION_STEP_ID, new_temp); |
328 | } |
329 | *out_temp = new_temp; |
330 | return OkStatus(); |
331 | } |
332 | |
333 | // OpKernelContext ----------------------------------------------------------- |
334 | |
335 | const int OpKernelContext::Params::kNeverForward; |
336 | const int OpKernelContext::Params::kNoReservation; |
337 | |
338 | OpKernelContext::OpKernelContext(Params* params) |
339 | : OpKernelContext( |
340 | params, static_cast<int>(params->op_kernel->output_types().size())) {} |
341 | |
342 | OpKernelContext::OpKernelContext(Params* params, int num_outputs) |
343 | : params_(params), outputs_(num_outputs) { |
344 | if (params_->track_allocations) { |
345 | tracking_state_ = absl::make_unique<TrackingState>(); |
346 | } |
347 | |
348 | params_->ensure_eigen_gpu_device(); |
349 | if (params_->eigen_gpu_device != nullptr) { |
350 | Allocator* eigen_gpu_allocator = get_allocator(AllocatorAttributes()); |
351 | Status s = params_->device->ReinitializeGpuDevice( |
352 | this, params_->eigen_gpu_device, params_->op_device_context, |
353 | eigen_gpu_allocator); |
354 | if (!s.ok()) { |
355 | SetStatus(s); |
356 | } |
357 | } |
358 | } |
359 | |
360 | OpKernelContext::~OpKernelContext() { |
361 | for (TensorValue& value : outputs_) { |
362 | if (!value.is_ref()) { |
363 | delete value.tensor; |
364 | } |
365 | } |
366 | if (params_->track_allocations && |
367 | !tracking_state_->wrapped_allocators.empty()) { |
368 | LOG(WARNING) << "OpKernelContext is tracking allocations but they are not " |
369 | << "being consumed by the StepStatsCollector." ; |
370 | for (auto& wrapped_allocator : tracking_state_->wrapped_allocators) { |
371 | wrapped_allocator.second->GetRecordsAndUnRef(); |
372 | } |
373 | } |
374 | } |
375 | |
376 | Allocator* OpKernelContext::get_allocator(AllocatorAttributes attr) { |
377 | Allocator* allocator = nullptr; |
378 | if (TF_PREDICT_FALSE(attr.scope_id > 0)) { |
379 | allocator = params_->device->GetScopedAllocator(attr, step_id()); |
380 | CHECK(allocator); |
381 | } else { |
382 | allocator = params_->device->GetAllocator(attr); |
383 | } |
384 | if (TF_PREDICT_FALSE(track_allocations())) { |
385 | DCHECK(tracking_state_); |
386 | mutex_lock lock(tracking_state_->mu); |
387 | for (const auto& wrapped : tracking_state_->wrapped_allocators) { |
388 | if (wrapped.first == allocator) { |
389 | return wrapped.second; |
390 | } |
391 | } |
392 | TrackingAllocator* wrapped_allocator = |
393 | new TrackingAllocator(allocator, params_->track_allocations); |
394 | tracking_state_->wrapped_allocators.push_back( |
395 | std::make_pair(allocator, wrapped_allocator)); |
396 | return wrapped_allocator; |
397 | } else { |
398 | return allocator; |
399 | } |
400 | } |
401 | |
402 | void OpKernelContext::SetStatus(const Status& status) { |
403 | status_.Update(status); |
404 | } |
405 | |
406 | Status OpKernelContext::input(StringPiece name, const Tensor** tensor) { |
407 | int index; |
408 | TF_RETURN_IF_ERROR(get_input_index(name, &index)); |
409 | if (input_is_ref(index)) { |
410 | return errors::InvalidArgument("OpKernel used ref input name '" , name, |
411 | "' when non-ref input was expected" ); |
412 | } |
413 | *tensor = params_->inputs[index].tensor; |
414 | return OkStatus(); |
415 | } |
416 | |
417 | Status OpKernelContext::input_dtype(StringPiece name, DataType* dtype) const { |
418 | int index; |
419 | TF_RETURN_IF_ERROR(get_input_index(name, &index)); |
420 | const TensorValue& value(params_->inputs[index]); |
421 | *dtype = value.dtype(); |
422 | return OkStatus(); |
423 | } |
424 | |
425 | Status OpKernelContext::input_ref_mutex(StringPiece name, mutex** out_mutex) { |
426 | int index; |
427 | TF_RETURN_IF_ERROR(get_input_index(name, &index)); |
428 | *out_mutex = input_ref_mutex(index); |
429 | return OkStatus(); |
430 | } |
431 | |
432 | const Tensor& OpKernelContext::input(int index) const { |
433 | CHECK_GE(index, 0); |
434 | CHECK_LT(index, num_inputs()) << " name: " << op_kernel().name(); |
435 | CHECK(!input_is_ref(index)); |
436 | const Tensor& tensor = *params_->inputs[index].tensor; |
437 | return tensor; |
438 | } |
439 | |
440 | Tensor OpKernelContext::mutable_input(int index, bool lock_held) { |
441 | CHECK_GE(index, 0); |
442 | CHECK_LT(index, num_inputs()); |
443 | CHECK(input_is_ref(index)); |
444 | // return a copy of the Ref acquired while holding the mutex |
445 | if (lock_held) { |
446 | Tensor& tensor = *params_->inputs[index].tensor; |
447 | return tensor; |
448 | } else { |
449 | tf_shared_lock l(*input_ref_mutex(index)); |
450 | Tensor& tensor = *params_->inputs[index].tensor; |
451 | return tensor; |
452 | } |
453 | } |
454 | |
455 | void OpKernelContext::replace_ref_input(int index, const Tensor& tensor, |
456 | bool lock_held) { |
457 | CHECK_GE(index, 0); |
458 | CHECK_LT(index, num_inputs()); |
459 | CHECK(input_is_ref(index)); |
460 | // should only modify the tensor while holding the mutex |
461 | if (lock_held) { |
462 | *params_->inputs[index].tensor = tensor; |
463 | } else { |
464 | mutex_lock l(*input_ref_mutex(index)); |
465 | *params_->inputs[index].tensor = tensor; |
466 | } |
467 | } |
468 | |
469 | void OpKernelContext::forward_ref_input_to_ref_output(int input_index, |
470 | int output_index) { |
471 | CHECK_GE(input_index, 0); |
472 | CHECK_LT(input_index, num_inputs()); |
473 | CHECK(input_is_ref(input_index)); |
474 | set_output_ref(output_index, params_->inputs[input_index].mutex_if_ref, |
475 | params_->inputs[input_index].tensor); |
476 | } |
477 | |
478 | bool OpKernelContext::forward_input_to_output_with_shape( |
479 | int input_index, int output_index, const TensorShape& output_shape, |
480 | Tensor** output) { |
481 | const auto output_attr = params_->output_attr_array == nullptr |
482 | ? AllocatorAttributes() |
483 | : output_alloc_attr(output_index); |
484 | std::unique_ptr<Tensor> new_tensor = forward_input( |
485 | input_index, output_index, expected_output_dtype(output_index), |
486 | output_shape, output_memory_type(output_index), output_attr); |
487 | if (new_tensor != nullptr) { |
488 | // Transfer ownership to the output slot in OpKernelContext. |
489 | outputs_[output_index] = TensorValue(new_tensor.release()); |
490 | *output = outputs_[output_index].tensor; |
491 | return true; |
492 | } else { |
493 | return false; |
494 | } |
495 | } |
496 | |
497 | Status OpKernelContext::forward_input_to_output_with_shape( |
498 | StringPiece input_name, StringPiece output_name, |
499 | const TensorShape& output_shape, Tensor** output) { |
500 | int input_index, output_index; |
501 | TF_RETURN_IF_ERROR(get_input_index(input_name, &input_index)); |
502 | TF_RETURN_IF_ERROR(get_output_index(output_name, &output_index)); |
503 | if (!forward_input_to_output_with_shape(input_index, output_index, |
504 | output_shape, output)) { |
505 | return errors::FailedPrecondition("OpKernel could not forward input '" , |
506 | input_name, "' to output '" , output_name); |
507 | } |
508 | return OkStatus(); |
509 | } |
510 | |
511 | std::unique_ptr<Tensor> OpKernelContext::forward_input( |
512 | int input_index, int output_index, DataType output_dtype, |
513 | const TensorShape& output_shape, MemoryType output_memory_type, |
514 | const AllocatorAttributes& output_attr) { |
515 | CHECK_GE(input_index, 0); |
516 | CHECK_LT(input_index, num_inputs()); |
517 | const TensorValue& input = params_->inputs[input_index]; |
518 | // Check whether at graph construction time this output was marked |
519 | // either for no forwarding or with a reservation for this input. |
520 | // If it's reserved for this input we'll skip the refcount and |
521 | // AllocatorAttribute checks. |
522 | // TODO(tucker): Maybe we should skip all of the checks? |
523 | bool never_forward = |
524 | (params_->forward_from_array != nullptr && output_index >= 0 && |
525 | params_->forward_from_array[output_index] == Params::kNeverForward); |
526 | if (never_forward) return nullptr; |
527 | bool forward_expected = |
528 | (params_->forward_from_array != nullptr && output_index >= 0 && |
529 | params_->forward_from_array[output_index] == input_index); |
530 | if (!forward_expected && params_->forward_from_array != nullptr) { |
531 | // Check for possibly conflicting forward. |
532 | for (int i = 0; i < num_outputs(); ++i) { |
533 | if (params_->forward_from_array[i] == input_index) { |
534 | // This input is reserved for output i. |
535 | return nullptr; |
536 | } |
537 | } |
538 | } |
539 | // Check that input tensor exists and is not a ref. |
540 | if (input.tensor == nullptr || input.is_ref()) { |
541 | CHECK(!forward_expected); |
542 | return nullptr; |
543 | } |
544 | // Check that input type matches. |
545 | if (input_dtype(input_index) != output_dtype) { |
546 | CHECK(!forward_expected); |
547 | return nullptr; |
548 | } |
549 | // Check that the input and output sizes are compatible. |
550 | if (input.tensor->shape().num_elements() != output_shape.num_elements()) { |
551 | CHECK(!forward_expected); |
552 | return nullptr; |
553 | } |
554 | // Check that input and output memory types match, i.e. |
555 | // that they either both live in host or both live in device memory. |
556 | if (input_memory_type(input_index) != output_memory_type) { |
557 | CHECK(!forward_expected); |
558 | return nullptr; |
559 | } |
560 | if (!forward_expected) { |
561 | if (!input->RefCountIsOne()) { |
562 | return nullptr; |
563 | } |
564 | // Check that output allocator attributes are not more restrictive than |
565 | // input allocator attributes. |
566 | const auto input_attr = params_->input_alloc_attrs.empty() |
567 | ? AllocatorAttributes() |
568 | : input_alloc_attr(input_index); |
569 | if (!output_attr.IsEqualOrLessRestrictiveThan(input_attr)) { |
570 | return nullptr; |
571 | } |
572 | } |
573 | |
574 | auto output_tensor = MakeUnique<Tensor>(); |
575 | CHECK(output_tensor->CopyFrom(*input.tensor, output_shape)); |
576 | return output_tensor; |
577 | } |
578 | |
579 | Status OpKernelContext::forward_input_or_allocate_temp( |
580 | gtl::ArraySlice<int> candidate_input_indices, DataType type, |
581 | const TensorShape& shape, const AllocatorAttributes& allocator_attr, |
582 | Tensor* out_temp) { |
583 | for (int input_index : candidate_input_indices) { |
584 | std::unique_ptr<Tensor> new_tensor = |
585 | forward_input(input_index, Params::kNoReservation /*output_index*/, |
586 | type, shape, DEVICE_MEMORY, allocator_attr); |
587 | if (new_tensor != nullptr) { |
588 | *out_temp = std::move(*new_tensor); |
589 | return OkStatus(); |
590 | } |
591 | } |
592 | return allocate_temp(type, shape, out_temp, allocator_attr); |
593 | } |
594 | |
595 | Status OpKernelContext::forward_input_or_allocate_output( |
596 | gtl::ArraySlice<int> candidate_input_indices, int output_index, |
597 | const TensorShape& output_shape, Tensor** output, int* forwarded_input) { |
598 | for (int input_index : candidate_input_indices) { |
599 | if (forward_input_to_output_with_shape(input_index, output_index, |
600 | output_shape, output)) { |
601 | if (forwarded_input != nullptr) { |
602 | *forwarded_input = input_index; |
603 | } |
604 | return OkStatus(); |
605 | } |
606 | } |
607 | if (forwarded_input != nullptr) { |
608 | *forwarded_input = -1; |
609 | } |
610 | return allocate_output(output_index, output_shape, output); |
611 | } |
612 | |
613 | Status OpKernelContext::forward_input_or_allocate_output( |
614 | gtl::ArraySlice<StringPiece> candidate_input_names, StringPiece output_name, |
615 | const TensorShape& output_shape, Tensor** output) { |
616 | for (const StringPiece& input_name : candidate_input_names) { |
617 | if (forward_input_to_output_with_shape(input_name, output_name, |
618 | output_shape, output) |
619 | .ok()) { |
620 | return OkStatus(); |
621 | } |
622 | } |
623 | return allocate_output(output_name, output_shape, output); |
624 | } |
625 | |
626 | void OpKernelContext::delete_ref_input(int index, bool lock_held) { |
627 | CHECK_GE(index, 0); |
628 | CHECK_LT(index, num_inputs()); |
629 | CHECK(input_is_ref(index)); |
630 | // should only modify the tensor while holding the mutex |
631 | if (lock_held) { |
632 | delete params_->inputs[index].tensor; |
633 | } else { |
634 | mutex_lock l(*input_ref_mutex(index)); |
635 | delete params_->inputs[index].tensor; |
636 | } |
637 | } |
638 | |
639 | Status OpKernelContext::mutable_input(StringPiece name, Tensor* tensor, |
640 | bool lock_held) { |
641 | int index; |
642 | TF_RETURN_IF_ERROR(get_input_index(name, &index)); |
643 | if (!input_is_ref(index)) { |
644 | return errors::InvalidArgument("OpKernel used non-ref input name '" , name, |
645 | "' when ref input was expected" ); |
646 | } |
647 | // return a copy of the Ref acquired while holding the mutex |
648 | if (lock_held) { |
649 | *tensor = *params_->inputs[index].tensor; |
650 | } else { |
651 | tf_shared_lock l(*input_ref_mutex(index)); |
652 | *tensor = *params_->inputs[index].tensor; |
653 | } |
654 | return OkStatus(); |
655 | } |
656 | |
657 | Status OpKernelContext::replace_ref_input(StringPiece name, |
658 | const Tensor& tensor, |
659 | bool lock_held) { |
660 | int index; |
661 | TF_RETURN_IF_ERROR(get_input_index(name, &index)); |
662 | if (!input_is_ref(index)) { |
663 | return errors::InvalidArgument("OpKernel used immutable input name '" , name, |
664 | "' when ref input was expected" ); |
665 | } |
666 | replace_ref_input(index, tensor, lock_held); |
667 | return OkStatus(); |
668 | } |
669 | |
670 | Status OpKernelContext::input_list(StringPiece name, OpInputList* list) { |
671 | int start, stop; |
672 | TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop)); |
673 | *list = OpInputList(this, start, stop); |
674 | return OkStatus(); |
675 | } |
676 | |
677 | Status OpKernelContext::mutable_input_list(StringPiece name, |
678 | OpMutableInputList* list) { |
679 | int start, stop; |
680 | TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop)); |
681 | *list = OpMutableInputList(this, start, stop); |
682 | return OkStatus(); |
683 | } |
684 | |
685 | Status OpKernelContext::output_list(StringPiece name, OpOutputList* list) { |
686 | int start, stop; |
687 | TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop)); |
688 | *list = OpOutputList(this, start, stop); |
689 | return OkStatus(); |
690 | } |
691 | |
692 | void OpKernelContext::maybe_initialize_scope_id_set() { |
693 | if (allocated_scope_ids_ == nullptr) { |
694 | allocated_scope_ids_ = absl::make_unique<std::unordered_set<int32>>(); |
695 | } |
696 | } |
697 | |
698 | Status OpKernelContext::allocate_output(int index, const TensorShape& shape, |
699 | Tensor** tensor) { |
700 | if (index < 0) { |
701 | return errors::Internal("allocate_output with bad index=" , index, |
702 | " kernel=" , params_->op_kernel->name()); |
703 | } |
704 | if (index >= num_outputs()) { |
705 | return errors::Internal("allocate_output with bad index=" , index, |
706 | " num_outputs=" , num_outputs(), |
707 | " kernel=" , params_->op_kernel->name()); |
708 | } |
709 | bool forward_expected = |
710 | (params_->forward_from_array != nullptr && index >= 0 && |
711 | params_->forward_from_array[index] >= 0); |
712 | if (forward_expected) { |
713 | return errors::Internal( |
714 | "Explicit allocate_output call where input forwarding required. Try " |
715 | "turning off the ScopedAllocator optimizer." ); |
716 | } |
717 | AllocatorAttributes attr = output_alloc_attr(index); |
718 | return allocate_output(index, shape, tensor, attr); |
719 | } |
720 | |
721 | Status OpKernelContext::allocate_output(StringPiece name, |
722 | const TensorShape& shape, |
723 | Tensor** tensor) { |
724 | int start, stop; |
725 | TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop)); |
726 | if (stop != start + 1) { |
727 | return errors::InvalidArgument("OpKernel used list-valued output name '" , |
728 | name, |
729 | "' when single-valued output was " |
730 | "expected" ); |
731 | } |
732 | return allocate_output(start, shape, tensor); |
733 | } |
734 | |
735 | Status OpKernelContext::allocate_output(StringPiece name, |
736 | const TensorShape& shape, |
737 | Tensor** tensor, |
738 | AllocatorAttributes attr) { |
739 | int start, stop; |
740 | TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop)); |
741 | if (stop != start + 1) { |
742 | return errors::InvalidArgument("OpKernel used list-valued output name '" , |
743 | name, |
744 | "' when single-valued output was " |
745 | "expected" ); |
746 | } |
747 | return allocate_output(start, shape, tensor, attr); |
748 | } |
749 | |
750 | Status OpKernelContext::allocate_tensor( |
751 | DataType type, const TensorShape& shape, Tensor* out_tensor, |
752 | AllocatorAttributes attr, const AllocationAttributes& allocation_attr) { |
753 | Allocator* a = get_allocator(attr); |
754 | Tensor new_tensor( |
755 | a, type, shape, |
756 | AllocationAttributes( |
757 | /*retry_on_failure=*/allocation_attr.retry_on_failure, |
758 | /*allocation_will_be_logged=*/true, allocation_attr.freed_by_func)); |
759 | |
760 | if (!new_tensor.IsInitialized()) { |
761 | return errors::ResourceExhausted( |
762 | "OOM when allocating tensor with shape" , shape.DebugString(), |
763 | " and type " , DataTypeString(type), " on " , params_->device->name(), |
764 | " by allocator " , a->Name()); |
765 | } |
766 | if (params_->log_memory) { |
767 | LogMemory::RecordTensorAllocation(params_->op_kernel->name(), |
768 | params_->step_id, new_tensor); |
769 | } |
770 | *out_tensor = std::move(new_tensor); |
771 | return OkStatus(); |
772 | } |
773 | |
774 | Status OpKernelContext::allocate_output(int index, const TensorShape& shape, |
775 | Tensor** output, |
776 | AllocatorAttributes attr) { |
777 | if (index < 0) { |
778 | return errors::Internal("allocate_output with bad index=" , index, |
779 | " kernel=" , params_->op_kernel->name()); |
780 | } |
781 | if (index >= num_outputs()) { |
782 | return errors::Internal("allocate_output with bad index=" , index, |
783 | " num_outputs=" , outputs_.size(), |
784 | " kernel=" , params_->op_kernel->name()); |
785 | } |
786 | const DataType type = params_->op_kernel->output_type(index); |
787 | if (IsRefType(type)) { |
788 | return errors::Internal("allocate_output with ref type. index=" , index, |
789 | " type=" , type, |
790 | " kernel=" , params_->op_kernel->name()); |
791 | } |
792 | if (mutable_output(index) != nullptr) { |
793 | return errors::Internal("allocate_output on same index multiple times." , |
794 | " index = " , index, |
795 | " mutable_output(index) = " , mutable_output(index), |
796 | " kernel=" , params_->op_kernel->name()); |
797 | } |
798 | if (attr.scope_id > 0) { |
799 | maybe_initialize_scope_id_set(); |
800 | if (!allocated_scope_ids_->insert(attr.scope_id).second) { |
801 | return errors::Internal( |
802 | "OpKernel " , params_->op_kernel->name(), |
803 | " called allocate_output at index " , index, " with scope_id " , |
804 | attr.scope_id, |
805 | " more than once. Try turning off the ScopedAllocator optimizer." ); |
806 | } |
807 | } |
808 | profiler::ScopedMemoryDebugAnnotation op_annotation( |
809 | op_kernel().name_view().data(), step_id(), "output" , type, |
810 | [&shape]() { return shape.DebugString(); }); |
811 | auto output_tensor = MakeUnique<Tensor>(); |
812 | Status s = allocate_tensor(type, shape, output_tensor.get(), attr); |
813 | if (s.ok()) { |
814 | outputs_[index] = TensorValue(output_tensor.release()); |
815 | *output = outputs_[index].tensor; |
816 | } |
817 | return s; |
818 | } |
819 | |
820 | Status OpKernelContext::allocate_temp( |
821 | DataType type, const TensorShape& shape, Tensor* out_temp, |
822 | AllocatorAttributes allocator_attr, |
823 | const AllocationAttributes& allocation_attr) { |
824 | if (allocator_attr.scope_id > 0) { |
825 | // We do not allow ScopedAllocator calls from allocate_temp. |
826 | // Here we clear the scope_id and return a temporary buffer. |
827 | // This is because it is legal for a kernel to call allocate_temp |
828 | // and then set_output with the temp tensor. |
829 | // |
830 | // We achieve memory correctness by forcing an allocation in set_output and |
831 | // copying over the tensor from the temp buffer. Kernels which would like |
832 | // to avoid this performance penalty should switch to calling |
833 | // allocate_output. |
834 | VLOG(2) << "Warning: OpKernel " << params_->op_kernel->name() |
835 | << " called allocate_temp with scope_id " << allocator_attr.scope_id |
836 | << ". Switch to allocate_output to avoid performance penalty." ; |
837 | allocator_attr.scope_id = -1; |
838 | } |
839 | profiler::ScopedMemoryDebugAnnotation op_annotation( |
840 | op_kernel().name_view().data(), step_id(), "temp" , type, |
841 | [&shape]() { return shape.DebugString(); }); |
842 | Status s = |
843 | allocate_tensor(type, shape, out_temp, allocator_attr, allocation_attr); |
844 | if (track_allocations() && s.ok() && out_temp->TotalBytes() > 0) { |
845 | Allocator* a = get_allocator(allocator_attr); |
846 | if (a->TracksAllocationSizes()) { |
847 | int64_t alloc_size = a->AllocatedSize(out_temp->tensor_data().data()); |
848 | record_temp_memory_allocation(alloc_size, *out_temp); |
849 | } |
850 | } else if (record_memory_consumption_) { |
851 | DCHECK(tracking_state_); |
852 | mutex_lock l(tracking_state_->stats_mu); |
853 | tracking_state_->temp_memory_allocated += out_temp->TotalBytes(); |
854 | } |
855 | return s; |
856 | } |
857 | |
858 | Status OpKernelContext::allocate_temp(DataType type, const TensorShape& shape, |
859 | Tensor* out_temp, |
860 | AllocatorAttributes allocator_attr) { |
861 | return allocate_temp(type, shape, out_temp, allocator_attr, |
862 | AllocationAttributes()); |
863 | } |
864 | |
865 | Status OpKernelContext::allocate_temp(DataType type, const TensorShape& shape, |
866 | Tensor* out_temp) { |
867 | return allocate_temp(type, shape, out_temp, AllocatorAttributes()); |
868 | } |
869 | |
870 | Status OpKernelContext::get_input_index(StringPiece name, |
871 | int* out_index) const { |
872 | int start, stop; |
873 | TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop)); |
874 | if (stop != start + 1) { |
875 | return errors::InvalidArgument("OpKernel used list-valued input name '" , |
876 | name, |
877 | "' when single-valued input was " |
878 | "expected" ); |
879 | } |
880 | *out_index = start; |
881 | return OkStatus(); |
882 | } |
883 | |
884 | Status OpKernelContext::get_output_index(StringPiece name, |
885 | int* out_index) const { |
886 | int start, stop; |
887 | TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop)); |
888 | if (stop != start + 1) { |
889 | return errors::InvalidArgument("OpKernel used list-valued output name '" , |
890 | name, |
891 | "' when single-valued output was " |
892 | "expected" ); |
893 | } |
894 | *out_index = start; |
895 | return OkStatus(); |
896 | } |
897 | |
898 | Status OpKernelContext::set_output(StringPiece name, const Tensor& tensor) { |
899 | int index; |
900 | TF_RETURN_IF_ERROR(get_output_index(name, &index)); |
901 | set_output(index, tensor); |
902 | return OkStatus(); |
903 | } |
904 | |
905 | Status OpKernelContext::set_output(StringPiece name, Tensor&& tensor) { |
906 | int index; |
907 | TF_RETURN_IF_ERROR(get_output_index(name, &index)); |
908 | set_output(index, std::move(tensor)); |
909 | return OkStatus(); |
910 | } |
911 | |
912 | bool OpKernelContext::maybe_set_output_by_allocate_and_copy( |
913 | int index, const Tensor& tensor) { |
914 | bool allocate_and_copy = false; |
915 | const bool never_forward = |
916 | (params_->forward_from_array != nullptr && |
917 | params_->forward_from_array[index] == Params::kNeverForward); |
918 | if (TF_PREDICT_FALSE(never_forward)) { |
919 | maybe_initialize_scope_id_set(); |
920 | if (allocated_scope_ids_->find(output_alloc_attr(index).scope_id) == |
921 | allocated_scope_ids_->end()) { |
922 | allocate_and_copy = true; |
923 | } else { |
924 | // The output at `index` must have been previously allocated via a call to |
925 | // `allocate_output(index, ...)`. That call would ensure that we return |
926 | // the correct slice of the ScopedAllocated buffer, so we do not |
927 | // re-allocate and copy here. |
928 | LOG(WARNING) |
929 | << "OpKernel " << params_->op_kernel->name() |
930 | << " called both allocate_output and set_output with scope_id " |
931 | << output_alloc_attr(index).scope_id; |
932 | } |
933 | } |
934 | |
935 | if (TF_PREDICT_FALSE(allocate_and_copy)) { |
936 | // This output was marked to not be forwarded either during graph |
937 | // construction or grappler passes. Force an allocation and copy input to |
938 | // output. |
939 | VLOG(1) << "OpKernelContext set_output index " << index << " tensor " |
940 | << tensor.DebugString() << " never_forward " << never_forward |
941 | << " params_->forward_from_array[index] " |
942 | << params_->forward_from_array[index] << " alloc_attr.scope_id " |
943 | << output_alloc_attr(index).scope_id; |
944 | profiler::ScopedMemoryDebugAnnotation op_annotation( |
945 | op_kernel().name_view().data(), step_id(), "output" , tensor.dtype(), |
946 | [&tensor]() { return tensor.shape().DebugString(); }); |
947 | auto new_tensor = MakeUnique<Tensor>(); |
948 | Status s = allocate_tensor(tensor.dtype(), tensor.shape(), new_tensor.get(), |
949 | output_alloc_attr(index)); |
950 | TF_CHECK_OK(s); |
951 | device()->CopyTensorInSameDevice(&tensor, new_tensor.get(), |
952 | op_device_context(), [](const Status&) {}); |
953 | outputs_[index] = TensorValue(new_tensor.release()); |
954 | } |
955 | return allocate_and_copy; |
956 | } |
957 | |
958 | void OpKernelContext::maybe_track_allocations_for_set_output( |
959 | const Tensor& tensor) { |
960 | if (TF_PREDICT_FALSE(track_allocations()) && tensor.TotalBytes() > 0) { |
961 | DCHECK(tracking_state_); |
962 | mutex_lock l(tracking_state_->stats_mu); |
963 | const auto it = std::find_if( |
964 | tracking_state_->temp_tensor_buffer_and_size.begin(), |
965 | tracking_state_->temp_tensor_buffer_and_size.end(), |
966 | [&tensor](const std::pair<const void*, int64>& e) { |
967 | return e.first == |
968 | static_cast<const void*>(tensor.tensor_data().data()); |
969 | }); |
970 | if (it != tracking_state_->temp_tensor_buffer_and_size.end()) { |
971 | tracking_state_->temp_memory_allocated -= it->second; |
972 | tracking_state_->temp_tensor_buffer_and_size.erase(it); |
973 | } |
974 | } |
975 | } |
976 | |
977 | void OpKernelContext::set_output(int index, const Tensor& tensor) { |
978 | CHECK_GE(index, 0); |
979 | CHECK_LT(index, outputs_.size()); |
980 | const DataType type = params_->op_kernel->output_type(index); |
981 | CHECK(!IsRefType(type)); |
982 | CHECK_EQ(outputs_[index].tensor, nullptr); |
983 | if (TF_PREDICT_TRUE(!maybe_set_output_by_allocate_and_copy(index, tensor))) { |
984 | // Input can be forwarded to output; incref on `tensor` and set output at |
985 | // `index` to this tensor. |
986 | outputs_[index] = TensorValue(new Tensor(tensor)); |
987 | maybe_track_allocations_for_set_output(*outputs_[index].tensor); |
988 | } |
989 | } |
990 | |
991 | void OpKernelContext::set_output(int index, Tensor&& tensor) { |
992 | CHECK_GE(index, 0); |
993 | CHECK_LT(index, outputs_.size()); |
994 | const DataType type = params_->op_kernel->output_type(index); |
995 | CHECK(!IsRefType(type)); |
996 | CHECK_EQ(outputs_[index].tensor, nullptr); |
997 | if (TF_PREDICT_TRUE(!maybe_set_output_by_allocate_and_copy(index, tensor))) { |
998 | // Input can be forwarded to output; set output at `index` to this tensor. |
999 | outputs_[index] = TensorValue(new Tensor(std::move(tensor))); |
1000 | maybe_track_allocations_for_set_output(*outputs_[index].tensor); |
1001 | } |
1002 | } |
1003 | |
1004 | void OpKernelContext::set_output_ref(int index, mutex* mu, |
1005 | Tensor* tensor_for_ref) { |
1006 | CHECK_GE(index, 0); |
1007 | CHECK_LT(index, outputs_.size()); |
1008 | CHECK(IsRefType(params_->op_kernel->output_type(index))); |
1009 | outputs_[index] = TensorValue(mu, tensor_for_ref); |
1010 | } |
1011 | |
1012 | Status OpKernelContext::set_output_ref(StringPiece name, mutex* mu, |
1013 | Tensor* tensor_for_ref) { |
1014 | int index; |
1015 | TF_RETURN_IF_ERROR(get_output_index(name, &index)); |
1016 | set_output_ref(index, mu, tensor_for_ref); |
1017 | return OkStatus(); |
1018 | } |
1019 | |
1020 | Status OpKernelContext::mutable_output(StringPiece name, Tensor** tensor) { |
1021 | int index; |
1022 | TF_RETURN_IF_ERROR(get_output_index(name, &index)); |
1023 | *tensor = mutable_output(index); |
1024 | return OkStatus(); |
1025 | } |
1026 | |
1027 | bool OpKernelContext::ValidateInputsAreSameShape(OpKernel* op) { |
1028 | const auto& inputs = params_->inputs; |
1029 | for (size_t i = 1; i < inputs.size(); ++i) { |
1030 | if (!inputs[0]->IsSameSize(*(inputs[i].tensor))) { |
1031 | SetStatus(errors::InvalidArgument( |
1032 | "Inputs to operation " , op->name(), " of type " , op->type_string(), |
1033 | " must have the same size and shape. Input 0: " , |
1034 | inputs[0]->shape().DebugString(), " != input " , i, ": " , |
1035 | inputs[i]->shape().DebugString())); |
1036 | return false; |
1037 | } |
1038 | } |
1039 | return true; |
1040 | } |
1041 | |
1042 | Status OpKernelContext::MatchSignature(const DataTypeSlice expected_inputs, |
1043 | const DataTypeSlice expected_outputs) { |
1044 | DataTypeVector inputs; |
1045 | for (const TensorValue& t : params_->inputs) { |
1046 | inputs.push_back(t.dtype()); |
1047 | } |
1048 | DataTypeVector outputs = params_->op_kernel->output_types(); |
1049 | return MatchSignatureHelper(expected_inputs, expected_outputs, inputs, |
1050 | outputs); |
1051 | } |
1052 | |
1053 | void OpKernelContext::record_temp_memory_allocation(int64_t size, |
1054 | const Tensor& t) { |
1055 | if (tracking_state_) { |
1056 | mutex_lock l(tracking_state_->stats_mu); |
1057 | tracking_state_->temp_memory_allocated += size; |
1058 | tracking_state_->temp_tensor_buffer_and_size.emplace_back( |
1059 | static_cast<const void*>(t.tensor_data().data()), size); |
1060 | } |
1061 | } |
1062 | |
1063 | int64_t OpKernelContext::temp_memory_allocated() const { |
1064 | if (tracking_state_) { |
1065 | mutex_lock l(tracking_state_->stats_mu); |
1066 | return tracking_state_->temp_memory_allocated; |
1067 | } else { |
1068 | return 0; |
1069 | } |
1070 | } |
1071 | |
1072 | void OpKernelContext::record_persistent_memory_allocation(int64_t size, |
1073 | int64_t alloc_id) { |
1074 | if (tracking_state_) { |
1075 | mutex_lock l(tracking_state_->stats_mu); |
1076 | tracking_state_->persistent_memory_allocated += size; |
1077 | if (alloc_id >= 0) { |
1078 | tracking_state_->persistent_alloc_ids.push_back(alloc_id); |
1079 | } |
1080 | } |
1081 | } |
1082 | |
1083 | int64_t OpKernelContext::persistent_memory_allocated() const { |
1084 | if (tracking_state_) { |
1085 | mutex_lock l(tracking_state_->stats_mu); |
1086 | return tracking_state_->persistent_memory_allocated; |
1087 | } else { |
1088 | return 0; |
1089 | } |
1090 | } |
1091 | |
1092 | std::vector<int64_t> OpKernelContext::persistent_alloc_ids() const { |
1093 | if (tracking_state_) { |
1094 | mutex_lock l(tracking_state_->stats_mu); |
1095 | return std::vector<int64_t>(tracking_state_->persistent_alloc_ids.begin(), |
1096 | tracking_state_->persistent_alloc_ids.end()); |
1097 | } else { |
1098 | return std::vector<int64_t>(); |
1099 | } |
1100 | } |
1101 | |
1102 | void OpKernelContext::clear_recorded_memory() { |
1103 | if (tracking_state_) { |
1104 | mutex_lock l(tracking_state_->stats_mu); |
1105 | tracking_state_->temp_memory_allocated = 0; |
1106 | tracking_state_->persistent_memory_allocated = 0; |
1107 | tracking_state_->temp_tensor_buffer_and_size.clear(); |
1108 | tracking_state_->persistent_alloc_ids.clear(); |
1109 | } |
1110 | } |
1111 | |
1112 | void OpKernelContext::set_record_memory_consumption(bool v) { |
1113 | record_memory_consumption_ = v; |
1114 | if (v && !tracking_state_) { |
1115 | tracking_state_ = absl::make_unique<TrackingState>(); |
1116 | } |
1117 | } |
1118 | |
1119 | const string& OpKernelContext::executor_type() const { |
1120 | if (params_->executor_type) { |
1121 | return *params_->executor_type; |
1122 | } else { |
1123 | static const string& kEmptyString = *new string("" ); |
1124 | return kEmptyString; |
1125 | } |
1126 | } |
1127 | |
1128 | // OpKernel registration ------------------------------------------------------ |
1129 | |
1130 | struct KernelRegistration { |
1131 | KernelRegistration(const KernelDef& d, StringPiece c, |
1132 | std::unique_ptr<kernel_factory::OpKernelFactory> f) |
1133 | : def(d), kernel_class_name(c), factory(std::move(f)) {} |
1134 | |
1135 | const KernelDef def; |
1136 | const string kernel_class_name; |
1137 | std::unique_ptr<kernel_factory::OpKernelFactory> factory; |
1138 | }; |
1139 | |
1140 | // This maps from 'op_type' + DeviceType to the set of KernelDefs and |
1141 | // factory functions for instantiating the OpKernel that matches the |
1142 | // KernelDef. |
1143 | struct KernelRegistry { |
1144 | mutex mu; |
1145 | std::unordered_multimap<string, KernelRegistration> registry |
1146 | TF_GUARDED_BY(mu); |
1147 | }; |
1148 | |
1149 | #if defined(_WIN32) |
1150 | static const char kKernelLibPattern[] = "libtfkernel*.dll" ; |
1151 | #elif defined(__APPLE__) |
1152 | static const char kKernelLibPattern[] = "libtfkernel*.dylib" ; |
1153 | #else |
1154 | static const char kKernelLibPattern[] = "libtfkernel*.so" ; |
1155 | #endif |
1156 | |
1157 | #define FEATURE(x) \ |
1158 | { x, #x } |
1159 | |
1160 | // Returns Status::OK if the dynamic library at the given path is safe to |
1161 | // load with some level of confidence. |
1162 | static Status IsProbablySafeToLoad(const string& path) { |
1163 | // A map of platform string to required CPU feature. |
1164 | using port::CPUFeature; |
1165 | static const auto* feature_map = |
1166 | new std::map<string, std::pair<CPUFeature, string>>{ |
1167 | {"__AVX512VL__=1" , FEATURE(CPUFeature::AVX512VL)}, |
1168 | }; |
1169 | |
1170 | std::vector<std::string> platform_strings; |
1171 | int result = GetPlatformStrings(path, &platform_strings); |
1172 | if (result) { |
1173 | return Status(error::Code::UNKNOWN, strerror(result)); |
1174 | } |
1175 | if (platform_strings.empty()) { |
1176 | return Status(error::Code::FAILED_PRECONDITION, |
1177 | "Didn't find any platform strings" ); |
1178 | } |
1179 | std::vector<std::string> missing_features; |
1180 | for (const auto& platform_string : platform_strings) { |
1181 | const auto& entry = feature_map->find(platform_string); |
1182 | if (entry != feature_map->end() && |
1183 | !port::TestCPUFeature(entry->second.first)) { |
1184 | missing_features.emplace_back(entry->second.second); |
1185 | } |
1186 | } |
1187 | if (!missing_features.empty()) { |
1188 | string errmsg = "Missing CPU features: " ; |
1189 | errmsg.append(absl::StrJoin(missing_features, ", " )); |
1190 | return errors::FailedPrecondition(errmsg); |
1191 | } |
1192 | return OkStatus(); |
1193 | } |
1194 | |
1195 | void LoadDynamicKernelsInternal() { |
1196 | Env* env = Env::Default(); |
1197 | |
1198 | // Override to allow loading unsafe packages for development. |
1199 | // DO NOT USE UNLESS YOU KNOW WHAT ABI ISSUES YOU CAN ENCOUNTER. |
1200 | char* _abi_check_env_var = getenv("TF_REALLY_LOAD_UNSAFE_PACKAGES" ); |
1201 | bool override_abi_check = false; |
1202 | if (_abi_check_env_var != nullptr) { |
1203 | override_abi_check = strcmp(_abi_check_env_var, "1" ) == 0; |
1204 | } |
1205 | |
1206 | string bazel_kernel_dir = |
1207 | io::JoinPath(env->GetRunfilesDir(), "tensorflow" , "core" , "kernels" ); |
1208 | std::vector<string> files; |
1209 | Status s_kernel_dir = env->GetChildren(bazel_kernel_dir, &files); |
1210 | if (s_kernel_dir.ok()) { |
1211 | string dll_spec = io::JoinPath(bazel_kernel_dir, kKernelLibPattern); |
1212 | for (const auto& file : files) { |
1213 | string fullpath = io::JoinPath(bazel_kernel_dir, file); |
1214 | if (env->MatchPath(fullpath, dll_spec)) { |
1215 | Status s = IsProbablySafeToLoad(fullpath); |
1216 | if (!s.ok() && override_abi_check) { |
1217 | LOG(WARNING) << "Loading UNSAFE library " << fullpath |
1218 | << " because ABI check override is set: " |
1219 | << s.error_message(); |
1220 | } |
1221 | if (s.ok() || override_abi_check) { |
1222 | // TODO(gunan): Store the handles to the opened files. |
1223 | void* unused_filehandle; |
1224 | TF_CHECK_OK( |
1225 | env->LoadDynamicLibrary(fullpath.c_str(), &unused_filehandle)); |
1226 | } else { |
1227 | LOG(WARNING) << "Not loading plugin library " << fullpath << ": " |
1228 | << s.error_message(); |
1229 | } |
1230 | } |
1231 | } |
1232 | } |
1233 | } |
1234 | |
1235 | // Mechanism for loading existing kernel libraries. |
1236 | void LoadDynamicKernels() { |
1237 | // TODO(gunan): As more features are available, add intelligent kernel |
1238 | // selection, and dropping unsuitable kernel logic here. |
1239 | static absl::once_flag dll_loader_flag; |
1240 | absl::call_once(dll_loader_flag, LoadDynamicKernelsInternal); |
1241 | } |
1242 | |
1243 | static string Key(StringPiece op_type, const DeviceType& device_type, |
1244 | StringPiece label) { |
1245 | return strings::StrCat(op_type, ":" , DeviceTypeString(device_type), ":" , |
1246 | label); |
1247 | } |
1248 | |
1249 | // Provide a way for users to disable JIT kernels for a transitional period. |
1250 | // Until this is removed, this function also removes the JIT label that is added |
1251 | // to JIT kernels during the static registration, to allow them to be found |
1252 | // during lookup as normal kernels. |
1253 | void SetupOrDisableJit(KernelRegistry* registry) { |
1254 | std::unordered_multimap<string, KernelRegistration> jit_kernels; |
1255 | bool remove_jit_kernels = absl::StrContains( |
1256 | absl::NullSafeStringView(getenv(kDisableJitKernelsEnvVar)), "1" ); |
1257 | |
1258 | mutex_lock l(registry->mu); |
1259 | std::unordered_multimap<string, KernelRegistration>& all_kernels = |
1260 | registry->registry; |
1261 | auto it = all_kernels.begin(); |
1262 | while (it != all_kernels.end()) { |
1263 | if (absl::StrContains(it->second.def.label(), kJitKernelLabel)) { |
1264 | // Remove all kernels that have the jit label. They will be added back |
1265 | // without the label if they are not to be disabled. |
1266 | KernelDef def_without_label = it->second.def; |
1267 | def_without_label.set_label("" ); |
1268 | |
1269 | if (!remove_jit_kernels) { |
1270 | jit_kernels.emplace( |
1271 | Key(def_without_label.op(), |
1272 | DeviceType(def_without_label.device_type()), |
1273 | def_without_label.label()), |
1274 | KernelRegistration(def_without_label, it->second.kernel_class_name, |
1275 | std::move(it->second.factory))); |
1276 | } |
1277 | |
1278 | it = all_kernels.erase(it); |
1279 | } else { |
1280 | it++; |
1281 | } |
1282 | } |
1283 | |
1284 | // Add back kernels if they are not disabled. This new key-value pair have all |
1285 | // references to the label removed. |
1286 | for (auto& jit_kernel : jit_kernels) { |
1287 | all_kernels.insert(std::move(jit_kernel)); |
1288 | } |
1289 | } |
1290 | |
1291 | namespace register_kernel { |
1292 | |
1293 | // Defined out of line to save code space |
1294 | Name::Name(const char* op) : KernelDefBuilder(op) {} |
1295 | |
1296 | } // namespace register_kernel |
1297 | |
1298 | void* GlobalKernelRegistry() { |
1299 | static KernelRegistry* global_kernel_registry = []() { |
1300 | KernelRegistry* registry = new KernelRegistry; |
1301 | OpRegistry::Global()->RegisterValidator(ValidateKernelRegistrations); |
1302 | return registry; |
1303 | }(); |
1304 | return global_kernel_registry; |
1305 | } |
1306 | |
1307 | static KernelRegistry* GlobalKernelRegistryTyped() { |
1308 | #ifdef AUTOLOAD_DYNAMIC_KERNELS |
1309 | LoadDynamicKernels(); |
1310 | #endif // AUTOLOAD_DYNAMIC_KERNELS |
1311 | auto* registry = reinterpret_cast<KernelRegistry*>(GlobalKernelRegistry()); |
1312 | // Update or disable JIT kernels based on user configuration. This is a |
1313 | // temporary fallback as part of the initial release of JIT kernels. |
1314 | static absl::once_flag setup_or_disable_jit; |
1315 | absl::call_once(setup_or_disable_jit, SetupOrDisableJit, registry); |
1316 | return registry; |
1317 | } |
1318 | |
1319 | namespace kernel_factory { |
1320 | |
1321 | void OpKernelRegistrar::InitInternal(const KernelDef* kernel_def, |
1322 | StringPiece kernel_class_name, |
1323 | std::unique_ptr<OpKernelFactory> factory) { |
1324 | const string key = |
1325 | Key(kernel_def->op(), DeviceType(kernel_def->device_type()), |
1326 | kernel_def->label()); |
1327 | |
1328 | // To avoid calling LoadDynamicKernels DO NOT CALL GlobalKernelRegistryTyped |
1329 | // here. |
1330 | // InitInternal gets called by static initializers, so it ends up executing |
1331 | // before main. This causes LoadKernelLibraries function to get called |
1332 | // before some file libraries can initialize, which in turn crashes the |
1333 | // program flakily. Until we get rid of static initializers in kernel |
1334 | // registration mechanism, we have this workaround here. |
1335 | auto global_registry = |
1336 | reinterpret_cast<KernelRegistry*>(GlobalKernelRegistry()); |
1337 | mutex_lock l(global_registry->mu); |
1338 | global_registry->registry.emplace( |
1339 | key, |
1340 | KernelRegistration(*kernel_def, kernel_class_name, std::move(factory))); |
1341 | delete kernel_def; |
1342 | } |
1343 | |
1344 | OpKernel* OpKernelRegistrar::PtrOpKernelFactory::Create( |
1345 | OpKernelConstruction* context) { |
1346 | return (*create_func_)(context); |
1347 | } |
1348 | |
1349 | } // namespace kernel_factory |
1350 | |
1351 | namespace { |
1352 | |
1353 | // Label defaults to empty if not found in NodeDef. |
1354 | const string& GetKernelLabelAttr(const AttrSlice& node_attrs) { |
1355 | static const string& kKernelAttr = *new string("_kernel" ); |
1356 | static const string& kEmptyString = *new string("" ); |
1357 | |
1358 | // NOTE: We inline the implementation of `GetNodeAttrString()` here in order |
1359 | // to use the `AttrSlice::FindByString()` overload, which does a more |
1360 | // efficient map lookup (instead of a linear scan) when the attribute name is |
1361 | // already a `const string&`. |
1362 | const AttrValue* attr_value = node_attrs.FindByString(kKernelAttr); |
1363 | if (attr_value == nullptr || attr_value->value_case() != AttrValue::kS) |
1364 | return kEmptyString; |
1365 | else |
1366 | return attr_value->s(); |
1367 | } |
1368 | |
1369 | // TODO(irving): Replace with const Node& version below. |
1370 | Status FindKernelRegistration( |
1371 | const DeviceType& device_type, StringPiece node_name, |
1372 | bool has_experimental_debug_info, |
1373 | const NodeDef_ExperimentalDebugInfo& experimental_debug_info, |
1374 | StringPiece node_op, AttrSlice node_attrs, const KernelRegistration** reg, |
1375 | bool* was_attr_mismatch) { |
1376 | *reg = nullptr; |
1377 | *was_attr_mismatch = false; |
1378 | |
1379 | const string& label = GetKernelLabelAttr(node_attrs); |
1380 | |
1381 | const string key = Key(node_op, device_type, label); |
1382 | auto typed_registry = GlobalKernelRegistryTyped(); |
1383 | tf_shared_lock lock(typed_registry->mu); |
1384 | auto regs = typed_registry->registry.equal_range(key); |
1385 | for (auto iter = regs.first; iter != regs.second; ++iter) { |
1386 | // If there is a kernel registered for the op and device_type, |
1387 | // check that the attrs match. |
1388 | bool match; |
1389 | TF_RETURN_IF_ERROR(KernelAttrsMatch(iter->second.def, node_attrs, &match)); |
1390 | if (match) { |
1391 | if (*reg != nullptr) { |
1392 | if ((*reg)->def.priority() == iter->second.def.priority()) { |
1393 | return errors::InvalidArgument( |
1394 | "Multiple OpKernel registrations match NodeDef at the same " |
1395 | "priority '" , |
1396 | FormatNodeDefForError(node_name, has_experimental_debug_info, |
1397 | experimental_debug_info), |
1398 | "': '" , (*reg)->def.ShortDebugString(), "' and '" , |
1399 | iter->second.def.ShortDebugString(), "'" ); |
1400 | } else if ((*reg)->def.priority() > iter->second.def.priority()) { |
1401 | continue; |
1402 | } |
1403 | // iter->second's priority is higher than *reg. |
1404 | } |
1405 | *reg = &iter->second; |
1406 | } else { |
1407 | *was_attr_mismatch = true; |
1408 | } |
1409 | } |
1410 | // Check if no device specific registrations found. If not, try finding a |
1411 | // default kernel. |
1412 | if (*reg == nullptr && |
1413 | !IsSymbolicExecutionDevice(device_type.type_string())) { |
1414 | const string default_key = Key(node_op, DEVICE_DEFAULT, label); |
1415 | auto regs = typed_registry->registry.equal_range(default_key); |
1416 | for (auto iter = regs.first; iter != regs.second; ++iter) { |
1417 | // If there is a kernel registered for the op and device_type, |
1418 | // check that the attrs match. |
1419 | bool match; |
1420 | TF_RETURN_IF_ERROR( |
1421 | KernelAttrsMatch(iter->second.def, node_attrs, &match)); |
1422 | if (match) { |
1423 | if (*reg != nullptr) { |
1424 | return errors::InvalidArgument( |
1425 | "Multiple Default OpKernel registrations match NodeDef '" , |
1426 | FormatNodeDefForError(node_name, has_experimental_debug_info, |
1427 | experimental_debug_info), |
1428 | "': '" , (*reg)->def.ShortDebugString(), "' and '" , |
1429 | iter->second.def.ShortDebugString(), "'" ); |
1430 | } |
1431 | *reg = &iter->second; |
1432 | } else { |
1433 | *was_attr_mismatch = true; |
1434 | } |
1435 | } |
1436 | |
1437 | if (*reg != nullptr) { |
1438 | VLOG(1) << "No device-specific kernels found for NodeDef '" |
1439 | << FormatNodeDefForError(node_name, has_experimental_debug_info, |
1440 | experimental_debug_info) |
1441 | << "'" |
1442 | << "Will fall back to a default kernel." << std::endl; |
1443 | } |
1444 | } |
1445 | |
1446 | return OkStatus(); |
1447 | } |
1448 | |
1449 | Status FindKernelRegistration(const DeviceType& device_type, |
1450 | const NodeDef& node_def, |
1451 | const KernelRegistration** reg, |
1452 | bool* was_attr_mismatch) { |
1453 | return FindKernelRegistration( |
1454 | device_type, node_def.name(), node_def.has_experimental_debug_info(), |
1455 | node_def.experimental_debug_info(), node_def.op(), |
1456 | AttrSlice(&node_def.attr()), reg, was_attr_mismatch); |
1457 | } |
1458 | |
1459 | } // namespace |
1460 | |
1461 | bool KernelDefAvailable(const DeviceType& device_type, |
1462 | const NodeDef& node_def) { |
1463 | const KernelRegistration* reg = nullptr; |
1464 | bool was_attr_mismatch; |
1465 | Status result = |
1466 | FindKernelRegistration(device_type, node_def, ®, &was_attr_mismatch); |
1467 | return result.ok() && reg != nullptr; |
1468 | } |
1469 | |
1470 | // TODO(irving): Change const NodeDef& to const Node& |
1471 | Status FindKernelDef( |
1472 | const DeviceType& device_type, StringPiece node_name, |
1473 | bool has_experimental_debug_info, |
1474 | const NodeDef_ExperimentalDebugInfo& experimental_debug_info, |
1475 | StringPiece node_op, StringPiece node_device, AttrSlice node_attrs, |
1476 | const KernelDef** def, string* kernel_class_name) { |
1477 | const KernelRegistration* reg = nullptr; |
1478 | bool was_attr_mismatch; |
1479 | TF_RETURN_IF_ERROR(FindKernelRegistration( |
1480 | device_type, node_name, has_experimental_debug_info, |
1481 | experimental_debug_info, node_op, node_attrs, ®, &was_attr_mismatch)); |
1482 | if (reg == nullptr) { |
1483 | const std::string device_str = DeviceTypeString(device_type); |
1484 | Status s = errors::NotFound( |
1485 | "No registered '" , node_op, "' OpKernel for " , device_str, |
1486 | " devices compatible with node " , |
1487 | FormatNodeDefForError(node_name, has_experimental_debug_info, |
1488 | experimental_debug_info)); |
1489 | if (was_attr_mismatch) { |
1490 | errors::AppendToMessage( |
1491 | &s, " (OpKernel was found, but attributes didn't match) " , |
1492 | "Requested Attributes: " , |
1493 | SummarizeAttrsHelper(node_attrs, node_device)); |
1494 | } |
1495 | |
1496 | // Do not print kernel registrations for other devices when using _JIT |
1497 | // devices for compilation or for MKL ops. |
1498 | // TODO (intel-tf) : Remove the check for MKL ops when support for |
1499 | // block format is removed. |
1500 | if (!absl::StrContains(device_str, "JIT" ) && |
1501 | !absl::StartsWith(node_name, "_Mkl" )) { |
1502 | errors::AppendToMessage( |
1503 | &s, ". Registered:" , KernelsRegisteredForOp(node_op)); |
1504 | } |
1505 | |
1506 | return s; |
1507 | } |
1508 | if (def != nullptr) *def = ®->def; |
1509 | if (kernel_class_name != nullptr) *kernel_class_name = reg->kernel_class_name; |
1510 | return OkStatus(); |
1511 | } |
1512 | |
1513 | Status FindKernelDef(const DeviceType& device_type, const NodeDef& node_def, |
1514 | const KernelDef** def, string* kernel_class_name) { |
1515 | return FindKernelDef( |
1516 | device_type, node_def.name(), node_def.has_experimental_debug_info(), |
1517 | node_def.experimental_debug_info(), node_def.op(), node_def.device(), |
1518 | AttrSlice(&node_def.attr()), def, kernel_class_name); |
1519 | } |
1520 | |
1521 | Status SupportedDeviceTypesForNode( |
1522 | const std::vector<DeviceType>& prioritized_types, const NodeDef& def, |
1523 | PrioritizedDeviceTypeVector* prioritized_device_types, |
1524 | const DeviceNameUtils::ParsedName* local_address_spec) { |
1525 | // TODO(zhifengc): Changes the callers (SimplePlacer and |
1526 | // DynamicPlacer) to consider the possibility that 'def' is call to |
1527 | // a user-defined function and only calls this |
1528 | // SupportedDeviceTypesForNode for primitive ops. |
1529 | const OpRegistrationData* op_reg_data; |
1530 | const Status s = OpRegistry::Global()->LookUp(def.op(), &op_reg_data); |
1531 | if (s.ok()) { |
1532 | bool exists_attr_mismatch = false; |
1533 | for (const DeviceType& device_type : prioritized_types) { |
1534 | const KernelRegistration* reg = nullptr; |
1535 | bool was_attr_mismatch = false; |
1536 | TF_RETURN_IF_ERROR( |
1537 | FindKernelRegistration(device_type, def, ®, &was_attr_mismatch)); |
1538 | exists_attr_mismatch = exists_attr_mismatch || was_attr_mismatch; |
1539 | if (reg != nullptr) { |
1540 | int32_t priority = reg->def.priority(); |
1541 | prioritized_device_types->emplace_back(device_type, priority); |
1542 | } |
1543 | } |
1544 | // Add extra supported device types if the following conditions are |
1545 | // satisfied: |
1546 | // 1) No kernel is defined for the given op (e.g. PyFunc on worker process) |
1547 | // 2) A device is requested for this node which specifies job/replica/task |
1548 | // 3) A local device is provided which specifies job/replica/task |
1549 | // 4) The local device does not have the same (job, replica, task) as the |
1550 | // requested device |
1551 | // |
1552 | // The goal is to address the issue where a graph includes op (e.g. PyFunc) |
1553 | // whose kernel is known to a remote process but not to the current process. |
1554 | if (prioritized_device_types->empty() && !exists_attr_mismatch && |
1555 | local_address_spec != nullptr) { |
1556 | DeviceNameUtils::ParsedName requested_device_name; |
1557 | DeviceNameUtils::ParseFullName(def.device(), &requested_device_name); |
1558 | if (DeviceNameUtils::IsDifferentAddressSpace(*local_address_spec, |
1559 | requested_device_name)) { |
1560 | if (requested_device_name.has_type) { |
1561 | prioritized_device_types->push_back( |
1562 | std::make_pair(DeviceType(requested_device_name.type), 0)); |
1563 | } else { |
1564 | for (const DeviceType& device_type : prioritized_types) { |
1565 | prioritized_device_types->push_back(std::make_pair(device_type, 0)); |
1566 | } |
1567 | } |
1568 | } |
1569 | } |
1570 | |
1571 | // If we were unable to find any valid devices let's validate if the node is |
1572 | // even valid. |
1573 | if (prioritized_device_types->empty()) { |
1574 | TF_RETURN_IF_ERROR(ValidateNodeDef(def, op_reg_data->op_def)); |
1575 | } |
1576 | |
1577 | std::stable_sort(prioritized_device_types->begin(), |
1578 | prioritized_device_types->end(), |
1579 | [](const std::pair<DeviceType, int32>& a, |
1580 | const std::pair<DeviceType, int32>& b) { |
1581 | return a.second > b.second; |
1582 | }); |
1583 | } else { |
1584 | // Assumes that all device types support this node. |
1585 | for (const DeviceType& device_type : prioritized_types) { |
1586 | prioritized_device_types->push_back(std::make_pair(device_type, 0)); |
1587 | } |
1588 | } |
1589 | return OkStatus(); |
1590 | } |
1591 | |
1592 | void LogAllRegisteredKernels() { |
1593 | KernelList kernel_list = GetAllRegisteredKernels(); |
1594 | for (const auto& kernel_def : kernel_list.kernel()) { |
1595 | LOG(INFO) << "OpKernel ('" << kernel_def.ShortDebugString() << "')" ; |
1596 | } |
1597 | } |
1598 | |
1599 | KernelList GetAllRegisteredKernels() { |
1600 | return GetFilteredRegisteredKernels([](const KernelDef& k) { return true; }); |
1601 | } |
1602 | |
1603 | KernelList GetFilteredRegisteredKernels( |
1604 | const std::function<bool(const KernelDef&)>& predicate) { |
1605 | KernelRegistry* const typed_registry = GlobalKernelRegistryTyped(); |
1606 | KernelList kernel_list; |
1607 | tf_shared_lock lock(typed_registry->mu); |
1608 | kernel_list.mutable_kernel()->Reserve(typed_registry->registry.size()); |
1609 | for (const auto& p : typed_registry->registry) { |
1610 | const KernelDef& kernel_def = p.second.def; |
1611 | if (predicate(kernel_def)) { |
1612 | *kernel_list.add_kernel() = kernel_def; |
1613 | } |
1614 | } |
1615 | return kernel_list; |
1616 | } |
1617 | |
1618 | KernelList GetRegisteredKernelsForOp(StringPiece op_name) { |
1619 | auto op_pred = [op_name](const KernelDef& k) { return k.op() == op_name; }; |
1620 | return GetFilteredRegisteredKernels(op_pred); |
1621 | } |
1622 | |
1623 | string KernelsRegisteredForOp(StringPiece op_name) { |
1624 | KernelList kernel_list = GetRegisteredKernelsForOp(op_name); |
1625 | if (kernel_list.kernel_size() == 0) return " <no registered kernels>\n" ; |
1626 | string ret; |
1627 | for (const auto& kernel_def : kernel_list.kernel()) { |
1628 | strings::StrAppend(&ret, " device='" , kernel_def.device_type(), "'" ); |
1629 | if (!kernel_def.label().empty()) { |
1630 | strings::StrAppend(&ret, "; label='" , kernel_def.label(), "'" ); |
1631 | } |
1632 | for (int i = 0; i < kernel_def.constraint_size(); ++i) { |
1633 | strings::StrAppend( |
1634 | &ret, "; " , kernel_def.constraint(i).name(), " in " , |
1635 | SummarizeAttrValue(kernel_def.constraint(i).allowed_values())); |
1636 | } |
1637 | strings::StrAppend(&ret, "\n" ); |
1638 | } |
1639 | return ret; |
1640 | } |
1641 | |
1642 | /* TODO(rmlarsen): This API is deprecated. Remove it if possible to avoid |
1643 | * copying the NodeDef. */ |
1644 | std::unique_ptr<OpKernel> CreateOpKernel( |
1645 | DeviceType device_type, DeviceBase* device, Allocator* allocator, |
1646 | const NodeDef& node_def, int graph_def_version, Status* status) { |
1647 | // Look up the Op registered for this op name. |
1648 | std::shared_ptr<const NodeProperties> props; |
1649 | status->Update(NodeProperties::CreateFromNodeDef( |
1650 | node_def, OpRegistry::Global(), &props)); |
1651 | if (!status->ok()) { |
1652 | errors::AppendToMessage(status, |
1653 | " for node: " , FormatNodeDefForError(node_def)); |
1654 | return nullptr; |
1655 | } |
1656 | return CreateOpKernel(device_type, device, allocator, props, |
1657 | graph_def_version, status); |
1658 | } |
1659 | |
1660 | std::unique_ptr<OpKernel> CreateOpKernel( |
1661 | DeviceType device_type, DeviceBase* device, Allocator* allocator, |
1662 | const std::shared_ptr<const NodeProperties>& props, int graph_def_version, |
1663 | Status* status) { |
1664 | OpKernel* kernel = nullptr; |
1665 | *status = CreateOpKernel(std::move(device_type), device, allocator, |
1666 | /*flib=*/nullptr, props, graph_def_version, &kernel); |
1667 | return std::unique_ptr<OpKernel>(kernel); |
1668 | } |
1669 | |
1670 | Status CreateOpKernel(DeviceType device_type, DeviceBase* device, |
1671 | Allocator* allocator, FunctionLibraryRuntime* flib, |
1672 | const std::shared_ptr<const NodeProperties>& props, |
1673 | int graph_def_version, OpKernel** kernel) { |
1674 | return CreateOpKernel(std::move(device_type), device, allocator, flib, |
1675 | /* resource_mgr= */ nullptr, props, graph_def_version, |
1676 | kernel); |
1677 | } |
1678 | |
1679 | Status CreateOpKernel(DeviceType device_type, DeviceBase* device, |
1680 | Allocator* allocator, FunctionLibraryRuntime* flib, |
1681 | ResourceMgr* resource_mgr, |
1682 | const std::shared_ptr<const NodeProperties>& props, |
1683 | int graph_def_version, OpKernel** kernel) { |
1684 | const NodeDef& node_def = props->node_def; |
1685 | bool was_attr_mismatch; |
1686 | const KernelRegistration* registration = nullptr; |
1687 | Status s; |
1688 | if (props != nullptr) { |
1689 | VLOG(1) << "Instantiating kernel for node: " << SummarizeNodeDef(node_def); |
1690 | |
1691 | // Validate node_def against OpDef. |
1692 | TF_RETURN_IF_ERROR(ValidateNodeDef(node_def, *props->op_def)); |
1693 | |
1694 | // Look up kernel registration. |
1695 | s = FindKernelRegistration(device_type, node_def, ®istration, |
1696 | &was_attr_mismatch); |
1697 | if (!s.ok()) { |
1698 | errors::AppendToMessage(&s, " when instantiating " , node_def.op()); |
1699 | return s; |
1700 | } |
1701 | } |
1702 | if (registration == nullptr) { |
1703 | s.Update(errors::NotFound("No registered '" , node_def.op(), |
1704 | "' OpKernel for '" , DeviceTypeString(device_type), |
1705 | "' devices compatible with node " , |
1706 | FormatNodeDefForError(node_def))); |
1707 | if (was_attr_mismatch) { |
1708 | errors::AppendToMessage( |
1709 | &s, " (OpKernel was found, but attributes didn't match) " , |
1710 | "Requested Attributes: " , SummarizeAttrs(node_def)); |
1711 | } |
1712 | errors::AppendToMessage( |
1713 | &s, ". Registered:" , KernelsRegisteredForOp(node_def.op())); |
1714 | return s; |
1715 | } |
1716 | |
1717 | // We are creating a kernel for an op registered in |
1718 | // OpRegistry::Global(), we consult the kernel registry to decide |
1719 | // the kernel's input and output memory types. |
1720 | MemoryTypeVector input_memory_types; |
1721 | MemoryTypeVector output_memory_types; |
1722 | TF_RETURN_IF_ERROR(MemoryTypesForNode(OpRegistry::Global(), device_type, |
1723 | node_def, &input_memory_types, |
1724 | &output_memory_types)); |
1725 | |
1726 | // Everything needed for OpKernel construction. |
1727 | OpKernelConstruction context(std::move(device_type), device, allocator, flib, |
1728 | resource_mgr, props, input_memory_types, |
1729 | output_memory_types, graph_def_version, &s); |
1730 | *kernel = registration->factory->Create(&context); |
1731 | if (!s.ok()) { |
1732 | delete *kernel; |
1733 | *kernel = nullptr; |
1734 | } |
1735 | return s; |
1736 | } |
1737 | |
1738 | namespace { |
1739 | |
1740 | bool FindArgInOp(StringPiece arg_name, |
1741 | const protobuf::RepeatedPtrField<OpDef::ArgDef>& args) { |
1742 | for (const auto& arg : args) { |
1743 | if (arg_name == arg.name()) { |
1744 | return true; |
1745 | } |
1746 | } |
1747 | return false; |
1748 | } |
1749 | |
1750 | } // namespace |
1751 | |
1752 | Status ValidateKernelRegistrations(const OpRegistryInterface& op_registry) { |
1753 | auto typed_registry = GlobalKernelRegistryTyped(); |
1754 | tf_shared_lock lock(typed_registry->mu); |
1755 | for (const auto& key_registration : typed_registry->registry) { |
1756 | const KernelDef& kernel_def(key_registration.second.def); |
1757 | const OpRegistrationData* op_reg_data; |
1758 | const Status status = op_registry.LookUp(kernel_def.op(), &op_reg_data); |
1759 | if (!status.ok()) { |
1760 | // TODO(josh11b): Make this a hard error. |
1761 | LOG(ERROR) << "OpKernel ('" << kernel_def.ShortDebugString() |
1762 | << "') for unknown op: " << kernel_def.op(); |
1763 | continue; |
1764 | } |
1765 | const OpDef& op_def = op_reg_data->op_def; |
1766 | for (const auto& host_memory_arg : kernel_def.host_memory_arg()) { |
1767 | if (!FindArgInOp(host_memory_arg, op_def.input_arg()) && |
1768 | !FindArgInOp(host_memory_arg, op_def.output_arg())) { |
1769 | return errors::InvalidArgument( |
1770 | "HostMemory arg '" , host_memory_arg, |
1771 | "' not found in OpDef: " , SummarizeOpDef(op_def)); |
1772 | } |
1773 | } |
1774 | } |
1775 | return OkStatus(); |
1776 | } |
1777 | |
1778 | template <> |
1779 | const Eigen::ThreadPoolDevice& OpKernelContext::eigen_device() const { |
1780 | return eigen_cpu_device(); |
1781 | } |
1782 | |
1783 | template <> |
1784 | const Eigen::GpuDevice& OpKernelContext::eigen_device() const { |
1785 | return eigen_gpu_device(); |
1786 | } |
1787 | |
1788 | void OpKernelConstruction::CtxFailure(const Status& s) { |
1789 | VLOG(1) << s; |
1790 | SetStatus(s); |
1791 | } |
1792 | |
1793 | void OpKernelConstruction::CtxFailureWithWarning(const Status& s) { |
1794 | LOG(WARNING) << s; |
1795 | SetStatus(s); |
1796 | } |
1797 | |
1798 | void OpKernelConstruction::CtxFailure(const char* file, int line, |
1799 | const Status& s) { |
1800 | VLOG(1) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line |
1801 | << " : " << s; |
1802 | SetStatus(s); |
1803 | } |
1804 | |
1805 | void OpKernelConstruction::CtxFailureWithWarning(const char* file, int line, |
1806 | const Status& s) { |
1807 | LOG(WARNING) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line |
1808 | << " : " << s; |
1809 | SetStatus(s); |
1810 | } |
1811 | |
1812 | void OpKernelContext::CtxFailure(const Status& s) { |
1813 | VLOG(1) << s; |
1814 | SetStatus(s); |
1815 | } |
1816 | |
1817 | void OpKernelContext::CtxFailureWithWarning(const Status& s) { |
1818 | LOG(WARNING) << s; |
1819 | SetStatus(s); |
1820 | } |
1821 | |
1822 | void OpKernelContext::CtxFailure(const char* file, int line, const Status& s) { |
1823 | VLOG(1) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line |
1824 | << " : " << s; |
1825 | SetStatus(s); |
1826 | } |
1827 | |
1828 | void OpKernelContext::CtxFailureWithWarning(const char* file, int line, |
1829 | const Status& s) { |
1830 | LOG(WARNING) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line |
1831 | << " : " << s; |
1832 | SetStatus(s); |
1833 | } |
1834 | |
1835 | void CheckNotInComputeAsync(OpKernelContext* ctx, |
1836 | const char* correct_macro_name) { |
1837 | CHECK_EQ(nullptr, ctx->params_->op_kernel->AsAsync()) |
1838 | << "Use " << correct_macro_name << " in AsyncOpKernel implementations." ; |
1839 | } |
1840 | |
1841 | } // namespace tensorflow |
1842 | |