1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_FRAMEWORK_SHARED_PTR_VARIANT_H_ |
17 | #define TENSORFLOW_CORE_FRAMEWORK_SHARED_PTR_VARIANT_H_ |
18 | |
19 | #include <memory> |
20 | |
21 | #include "tensorflow/core/framework/variant_tensor_data.h" |
22 | #include "tensorflow/core/platform/logging.h" |
23 | |
24 | namespace tensorflow { |
25 | |
26 | template <typename T> |
27 | struct SharedPtrVariant { |
28 | std::shared_ptr<T> shared_ptr; |
29 | |
30 | SharedPtrVariant() : shared_ptr() {} |
31 | |
32 | explicit SharedPtrVariant(std::shared_ptr<T>&& ptr) |
33 | : shared_ptr(std::forward<decltype(ptr)>(ptr)) { |
34 | VLOG(3) << "Creating shared_ptr of " << shared_ptr.get() |
35 | << " count is: " << shared_ptr.use_count(); |
36 | } |
37 | |
38 | SharedPtrVariant(SharedPtrVariant&& rhs) |
39 | : shared_ptr(std::move(rhs.shared_ptr)) { |
40 | VLOG(3) << "Moving SharedPtrVariant of " << shared_ptr.get() |
41 | << " count is: " << shared_ptr.use_count(); |
42 | } |
43 | |
44 | SharedPtrVariant& operator=(const SharedPtrVariant& rhs) = delete; |
45 | |
46 | SharedPtrVariant& operator=(SharedPtrVariant&& rhs) { |
47 | if (&rhs == this) return *this; |
48 | std::swap(shared_ptr, rhs.shared_ptr); |
49 | VLOG(3) << "Move-assign of SharedPtrVariant of " << shared_ptr.get() |
50 | << " count is: " << shared_ptr.use_count(); |
51 | return *this; |
52 | } |
53 | |
54 | SharedPtrVariant(const SharedPtrVariant& rhs) : shared_ptr(rhs.shared_ptr) { |
55 | VLOG(3) << "Copying SharedPtrVariant of " << shared_ptr.get() |
56 | << " count is: " << shared_ptr.use_count(); |
57 | } |
58 | |
59 | ~SharedPtrVariant() { |
60 | VLOG(3) << "Destroying SharedPtrVariant of " << shared_ptr.get() |
61 | << " count is: " << shared_ptr.use_count(); |
62 | } |
63 | |
64 | void Encode(VariantTensorData*) const { |
65 | // Not supported. |
66 | } |
67 | |
68 | bool Decode(const VariantTensorData&) { |
69 | return false; // Not supported. |
70 | } |
71 | }; |
72 | |
73 | } // namespace tensorflow |
74 | |
75 | #endif // TENSORFLOW_CORE_FRAMEWORK_SHARED_PTR_VARIANT_H_ |
76 | |