1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/op_segment.h" |
17 | |
18 | #include "tensorflow/core/framework/function.h" |
19 | #include "tensorflow/core/framework/op_kernel.h" |
20 | #include "tensorflow/core/lib/core/errors.h" |
21 | #include "tensorflow/core/lib/gtl/map_util.h" |
22 | #include "tensorflow/core/platform/logging.h" |
23 | #include "tensorflow/core/platform/mutex.h" |
24 | #include "tensorflow/core/platform/types.h" |
25 | |
26 | namespace tensorflow { |
27 | |
28 | OpSegment::Item::~Item() { |
29 | for (const auto& kv : name_kernel) delete kv.second; |
30 | } |
31 | |
32 | OpSegment::OpSegment() {} |
33 | |
34 | OpSegment::~OpSegment() { |
35 | for (const auto& kv : sessions_) delete kv.second; |
36 | } |
37 | |
38 | Status OpSegment::FindOrCreate(const string& session_handle, |
39 | const string& node_name, OpKernel** kernel, |
40 | CreateKernelFn create_fn) { |
41 | { |
42 | mutex_lock l(mu_); |
43 | auto item = gtl::FindPtrOrNull(sessions_, session_handle); |
44 | if (item == nullptr) { |
45 | return errors::NotFound("Session " , session_handle, " is not found." ); |
46 | } |
47 | *kernel = gtl::FindPtrOrNull(item->name_kernel, node_name); |
48 | if (*kernel != nullptr) { |
49 | return OkStatus(); |
50 | } |
51 | } |
52 | Status s = create_fn(kernel); |
53 | if (!s.ok()) { |
54 | LOG(ERROR) << "Create kernel failed: " << s; |
55 | return s; |
56 | } |
57 | { |
58 | mutex_lock l(mu_); |
59 | auto item = gtl::FindPtrOrNull(sessions_, session_handle); |
60 | if (item == nullptr) { |
61 | return errors::NotFound("Session " , session_handle, " is not found." ); |
62 | } |
63 | OpKernel** p_kernel = &(item->name_kernel[node_name]); |
64 | if (*p_kernel == nullptr) { |
65 | *p_kernel = *kernel; // Inserts 'kernel' in the map. |
66 | } else { |
67 | delete *kernel; |
68 | *kernel = *p_kernel; |
69 | } |
70 | } |
71 | return OkStatus(); |
72 | } |
73 | |
74 | void OpSegment::AddHold(const string& session_handle) { |
75 | mutex_lock l(mu_); |
76 | Item** item = &sessions_[session_handle]; |
77 | if (*item == nullptr) { |
78 | *item = new Item; // num_holds == 1 |
79 | } else { |
80 | ++((*item)->num_holds); |
81 | } |
82 | } |
83 | |
84 | void OpSegment::RemoveHold(const string& session_handle) { |
85 | Item* item = nullptr; |
86 | { |
87 | mutex_lock l(mu_); |
88 | auto siter = sessions_.find(session_handle); |
89 | if (siter == sessions_.end()) { |
90 | VLOG(1) << "Session " << session_handle << " is not found." ; |
91 | return; |
92 | } |
93 | item = siter->second; |
94 | if (--(item->num_holds) > 0) { |
95 | return; |
96 | } else { |
97 | sessions_.erase(siter); |
98 | } |
99 | } |
100 | delete item; |
101 | } |
102 | |
103 | bool OpSegment::ShouldOwnKernel(FunctionLibraryRuntime* lib, |
104 | const string& node_op) { |
105 | // OpSegment should not own kernel if the node is stateless, or a function. |
106 | return lib->IsStateful(node_op) && |
107 | lib->GetFunctionLibraryDefinition()->Find(node_op) == nullptr && |
108 | node_op != "PartitionedCall" && node_op != "StatefulPartitionedCall" ; |
109 | } |
110 | |
111 | } // end namespace tensorflow |
112 | |