1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_FRAMEWORK_SESSION_STATE_H_ |
17 | #define TENSORFLOW_CORE_FRAMEWORK_SESSION_STATE_H_ |
18 | |
19 | #include <string> |
20 | #include <unordered_map> |
21 | #include <vector> |
22 | |
23 | #include "tensorflow/core/framework/tensor.h" |
24 | #include "tensorflow/core/lib/core/errors.h" |
25 | #include "tensorflow/core/platform/mutex.h" |
26 | |
27 | namespace tensorflow { |
28 | |
29 | // The session state remembers the tensors we choose to keep across |
30 | // multiple run calls. |
31 | class SessionState { |
32 | public: |
33 | // Get a tensor from the session state. |
34 | Status GetTensor(const std::string& handle, Tensor* tensor); |
35 | |
36 | // Store a tensor in the session state. |
37 | Status AddTensor(const std::string& handle, const Tensor& tensor); |
38 | |
39 | // Delete a tensor from the session state. |
40 | Status DeleteTensor(const std::string& handle); |
41 | |
42 | int64_t GetNewId(); |
43 | |
44 | static const char* kTensorHandleResourceTypeName; |
45 | |
46 | private: |
47 | mutex state_lock_; |
48 | |
49 | // For generating unique ids for tensors stored in the session. |
50 | int64_t tensor_id_ = 0; |
51 | |
52 | // The live tensors in the session. A map from tensor handle to tensor. |
53 | std::unordered_map<string, Tensor> tensors_; |
54 | }; |
55 | |
56 | // The tensor store remembers the tensors we choose to keep for the |
57 | // current run call. It is available to every op kernel. |
58 | class TensorStore { |
59 | public: |
60 | struct TensorAndKey { |
61 | Tensor tensor; |
62 | int64_t id; |
63 | std::string device_name; |
64 | |
65 | std::string GetHandle(const std::string& tensor_name) { |
66 | return strings::StrCat(tensor_name, ";" , id, ";" , device_name); |
67 | } |
68 | }; |
69 | |
70 | // Add the named tensor to the tensor store for this run. |
71 | Status AddTensor(const std::string& name, const TensorAndKey& tk); |
72 | |
73 | // Save the tensors in the tensor store of this run to the session. |
74 | Status SaveTensors(const std::vector<string>& output_names, |
75 | SessionState* session_state); |
76 | |
77 | // Returns true if no tensors have been added to this store. |
78 | bool empty() TF_NO_THREAD_SAFETY_ANALYSIS { return !dirty_; } |
79 | |
80 | private: |
81 | mutex lock_; |
82 | std::atomic<bool> dirty_ TF_GUARDED_BY(lock_){false}; |
83 | |
84 | // The tensors that will be saved to session state when this run completes. |
85 | // A map from tensor string name to tensor. |
86 | std::unordered_map<string, TensorAndKey> tensors_ TF_GUARDED_BY(lock_); |
87 | }; |
88 | |
89 | } // namespace tensorflow |
90 | |
91 | #endif // TENSORFLOW_CORE_FRAMEWORK_SESSION_STATE_H_ |
92 | |