1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/session_state.h" |
17 | #include "tensorflow/core/graph/tensor_id.h" |
18 | |
19 | namespace tensorflow { |
20 | |
21 | // Adjust value in third_party/tensorflow/python/client/tf_session_wrapper.cc |
22 | // in the get_tensor_handle_key function if adjusting the value for |
23 | // kTensorHandleResourceTypeName. |
24 | const char* SessionState::kTensorHandleResourceTypeName = "TensorHandle" ; |
25 | |
26 | Status SessionState::GetTensor(const string& handle, Tensor* tensor) { |
27 | mutex_lock l(state_lock_); |
28 | auto it = tensors_.find(handle); |
29 | if (it == tensors_.end()) { |
30 | return errors::InvalidArgument("The tensor with handle '" , handle, |
31 | "' is not in the session store." ); |
32 | } |
33 | *tensor = it->second; |
34 | return OkStatus(); |
35 | } |
36 | |
37 | Status SessionState::AddTensor(const string& handle, const Tensor& tensor) { |
38 | mutex_lock l(state_lock_); |
39 | if (!tensors_.insert({handle, tensor}).second) { |
40 | return errors::InvalidArgument("Failed to add a tensor with handle '" , |
41 | handle, "' to the session store." ); |
42 | } |
43 | return OkStatus(); |
44 | } |
45 | |
46 | Status SessionState::DeleteTensor(const string& handle) { |
47 | mutex_lock l(state_lock_); |
48 | if (tensors_.erase(handle) == 0) { |
49 | return errors::InvalidArgument("Failed to delete a tensor with handle '" , |
50 | handle, "' in the session store." ); |
51 | } |
52 | return OkStatus(); |
53 | } |
54 | |
55 | int64_t SessionState::GetNewId() { |
56 | mutex_lock l(state_lock_); |
57 | return tensor_id_++; |
58 | } |
59 | |
60 | Status TensorStore::AddTensor(const string& name, const TensorAndKey& tk) { |
61 | mutex_lock l(lock_); |
62 | if (!tensors_.insert({name, tk}).second) { |
63 | return errors::InvalidArgument("Failed to add a tensor with name '" , name, |
64 | "' to the tensor store." ); |
65 | } |
66 | dirty_ = true; |
67 | return OkStatus(); |
68 | } |
69 | |
70 | Status TensorStore::SaveTensors(const std::vector<string>& output_names, |
71 | SessionState* session_state) { |
72 | mutex_lock l(lock_); |
73 | if (!tensors_.empty()) { |
74 | // Save only the tensors in output_names in the session. |
75 | for (const string& name : output_names) { |
76 | TensorId id(ParseTensorName(name)); |
77 | const string op_name(id.first); |
78 | auto it = tensors_.find(op_name); |
79 | if (it != tensors_.end()) { |
80 | // Save the tensor to the session state. |
81 | string key = it->second.GetHandle(op_name); |
82 | TF_RETURN_IF_ERROR(session_state->AddTensor(key, it->second.tensor)); |
83 | } |
84 | } |
85 | } |
86 | return OkStatus(); |
87 | } |
88 | |
89 | } // namespace tensorflow |
90 | |