1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_FRAMEWORK_SESSION_STATE_H_
17#define TENSORFLOW_CORE_FRAMEWORK_SESSION_STATE_H_
18
19#include <string>
20#include <unordered_map>
21#include <vector>
22
23#include "tensorflow/core/framework/tensor.h"
24#include "tensorflow/core/lib/core/errors.h"
25#include "tensorflow/core/platform/mutex.h"
26
27namespace tensorflow {
28
29// The session state remembers the tensors we choose to keep across
30// multiple run calls.
31class SessionState {
32 public:
33 // Get a tensor from the session state.
34 Status GetTensor(const std::string& handle, Tensor* tensor);
35
36 // Store a tensor in the session state.
37 Status AddTensor(const std::string& handle, const Tensor& tensor);
38
39 // Delete a tensor from the session state.
40 Status DeleteTensor(const std::string& handle);
41
42 int64_t GetNewId();
43
44 static const char* kTensorHandleResourceTypeName;
45
46 private:
47 mutex state_lock_;
48
49 // For generating unique ids for tensors stored in the session.
50 int64_t tensor_id_ = 0;
51
52 // The live tensors in the session. A map from tensor handle to tensor.
53 std::unordered_map<string, Tensor> tensors_;
54};
55
56// The tensor store remembers the tensors we choose to keep for the
57// current run call. It is available to every op kernel.
58class TensorStore {
59 public:
60 struct TensorAndKey {
61 Tensor tensor;
62 int64_t id;
63 std::string device_name;
64
65 std::string GetHandle(const std::string& tensor_name) {
66 return strings::StrCat(tensor_name, ";", id, ";", device_name);
67 }
68 };
69
70 // Add the named tensor to the tensor store for this run.
71 Status AddTensor(const std::string& name, const TensorAndKey& tk);
72
73 // Save the tensors in the tensor store of this run to the session.
74 Status SaveTensors(const std::vector<string>& output_names,
75 SessionState* session_state);
76
77 // Returns true if no tensors have been added to this store.
78 bool empty() TF_NO_THREAD_SAFETY_ANALYSIS { return !dirty_; }
79
80 private:
81 mutex lock_;
82 std::atomic<bool> dirty_ TF_GUARDED_BY(lock_){false};
83
84 // The tensors that will be saved to session state when this run completes.
85 // A map from tensor string name to tensor.
86 std::unordered_map<string, TensorAndKey> tensors_ TF_GUARDED_BY(lock_);
87};
88
89} // namespace tensorflow
90
91#endif // TENSORFLOW_CORE_FRAMEWORK_SESSION_STATE_H_
92