1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_FRAMEWORK_VARIANT_H_
17#define TENSORFLOW_CORE_FRAMEWORK_VARIANT_H_
18
19#include <functional>
20#include <iostream>
21#include <memory>
22#include <type_traits>
23#include <unordered_map>
24#include <utility>
25
26#include "absl/memory/memory.h"
27#include "tensorflow/core/framework/type_index.h"
28#include "tensorflow/core/framework/variant_tensor_data.h"
29#include "tensorflow/core/platform/logging.h"
30#include "tensorflow/core/platform/strcat.h"
31
32namespace tensorflow {
33
34template <typename T>
35std::string TypeNameVariant(const T& value);
36
37template <typename T>
38std::string DebugStringVariant(const T& value);
39
40// Allows for specializations of Variant Decoding. `data` may be modified in
41// the process of decoding to `value`.
42template <typename T>
43bool DecodeVariant(VariantTensorData* data, T* value);
44
45template <typename T>
46bool DecodeVariant(std::string* buf, T* value);
47
48template <typename T>
49void EncodeVariant(const T& value, VariantTensorData* data);
50
51template <typename T>
52void EncodeVariant(const T& value, std::string* buf);
53
54// This is an implementation of a type-erased container that can store an
55// object of any type. The implementation is very similar to std::any, but has
56// restrictions on the types of objects that can be stored, and eschews some of
57// the fancier constructors available for std::any. An object of
58// tensorflow::Variant is intended to be used as the value that will be stored
59// in a tensorflow::Tensor object when its type is DT_VARIANT.
60//
61// tensorflow::Variant can store an object of a class that satisfies the
62// following constraints:
63//
64// * The class is CopyConstructible.
65// * The class has a default constructor.
66// * It's either a protocol buffer, a tensorflow::Tensor, or defines the
67// following functions:
68//
69// string TypeName() const;
70// void Encode(VariantTensorData* data) const;
71// bool Decode(VariantTensorData data);
72//
73// Simple POD types can elide the Encode/Decode functions, they are provided by
74// helper methods.
75// Here are some typical usage patterns:
76//
77// Variant x = 10;
78// EXPECT_EQ(*x.get<int>(), 10);
79//
80// Tensor t(DT_FLOAT, TensorShape({}));
81// t.flat<float>()(0) = 42.0f;
82// Variant x = t;
83// EXPECT_EQ(x.get<Tensor>()->flat<float>()(0), 42.0f);
84//
85// Accessing the stored object:
86//
87// The get<T> function is the main mechanism to access the object
88// stored in the container. It is type-safe, that is, calling
89// get<T> when the stored object's type is not T, returns a
90// nullptr. A raw pointer to the stored object can be obtained by calling
91// get<void>().
92//
93// Serializing/deserializing Variant object:
94//
95// The Variant class delegates serializing and deserializing operations to the
96// contained object. Helper functions to do these operations are provided for
97// POD data types, tensorflow::Tensor, and protocol buffer objects. However,
98// other classes have to provide Encode/Decode functions to handle
99// serialization.
100//
101// Objects stored in a Variant object often contain references to other
102// tensorflow::Tensors of primitive types (Eg., a list of tensorflow::Tensors).
103// To efficiently support those use cases, a structure is imposed on the
104// serialization format. Namely, classes should serialize their contents into a
105// VariantTensorData object:
106//
107// struct VariantTensorData {
108// string type_name;
109// string metadata;
110// std::vector<Tensor> tensors;
111// };
112//
113// Objects with references to other Tensors can simply store those tensors in
114// the `tensors` field, and serialize other metadata content in to the
115// `metadata` field.
116//
117// Serialization example:
118//
119// Foo f = Foo {...};
120// Variant x = f;
121// string serialized_f;
122// x.Encode(&serialized_f);
123//
124// Variant y = Foo(); // default constructed Foo.
125// y.Decode(std::move(serialized_f));
126// EXPECT_EQ(*x.get<Foo>(), *y.get<Foo>());
127//
128//
129// A Variant storing serialized Variant data (a value of type
130// VariantTensorDataProto) has different behavior from a standard Variant.
131// Namely, its TypeName matches the TypeName of the original Variant;
132// and its non-const get method performs lazy deserialization.
133//
134// Decode and copy example:
135//
136// Foo f = Foo {...};
137// Variant x = f;
138//
139// VariantTensorData serialized_data_f;
140// VariantTensorDataProto serialized_proto_f;
141// x.Encode(&serialized_data_f);
142// serialized_data_f.ToProto(&serialized_proto_f);
143//
144// Variant y_type_unknown = serialized_proto_f; // Store serialized Variant.
145//
146// EXPECT_EQ(x.TypeName(), y_type_unknown.TypeName()); // Looks like Foo.
147// EXPECT_EQ(TypeIndex::Make<VariantTensorDataProto>(),
148// y_type_unknown.TypeId());
149//
150class Variant {
151 public:
152 // Constructs a Variant holding no value (aka `is_empty()`).
153 //
154 // This is done by pointing at nullptr via the heap value.
155 Variant() noexcept : heap_value_(/*pointer=*/nullptr), is_inline_(false) {}
156
157 ~Variant();
158
159 Variant(const Variant& other);
160 Variant(Variant&& other) noexcept;
161
162 // Make sure that the type is CopyConstructible and not a
163 // tensorflow::Variant object itself. We want the copy constructor to be
164 // chosen for the tensorflow::Variant case.
165 template <typename T, typename VT = typename std::decay<T>::type,
166 typename std::enable_if<!std::is_same<Variant, VT>::value &&
167 std::is_move_constructible<VT>::value,
168 void>::type* = nullptr>
169 Variant(T&& value);
170
171 template <typename T, typename VT = typename std::decay<T>::type,
172 typename std::enable_if<!std::is_same<Variant, VT>::value &&
173 std::is_copy_constructible<VT>::value,
174 void>::type* = nullptr>
175 Variant(const T& value);
176
177 template <typename T, typename VT = typename std::decay<T>::type,
178 typename std::enable_if<!std::is_same<Variant, VT>::value &&
179 std::is_copy_constructible<VT>::value,
180 void>::type* = nullptr>
181 Variant& operator=(const T& value);
182
183 template <typename T, typename VT = typename std::decay<T>::type,
184 typename std::enable_if<!std::is_same<Variant, VT>::value &&
185 std::is_move_constructible<VT>::value,
186 void>::type* = nullptr>
187 Variant& operator=(T&& value);
188
189 Variant& operator=(const Variant& rhs) {
190 if (&rhs == this) return *this;
191 Variant(rhs).swap(*this);
192 return *this;
193 }
194
195 Variant& operator=(Variant&& rhs) noexcept {
196 if (&rhs == this) return *this;
197 Variant(std::move(rhs)).swap(*this);
198 return *this;
199 }
200
201 // Constructs a value of type T with the given args in-place in this Variant.
202 // Returns a reference to the newly constructed value.
203 // The signature is based on std::variant<Types...>::emplace() in C++17.
204 template <typename T, class... Args>
205 T& emplace(Args&&... args) {
206 ResetMemory();
207 is_inline_ = CanInlineType<T>();
208 if (is_inline_) {
209 new (&inline_value_)
210 InlineValue(InlineValue::Tag<T>{}, std::forward<Args>(args)...);
211 return static_cast<Variant::Value<T>*>(inline_value_.AsValueInterface())
212 ->value;
213 } else {
214 new (&heap_value_) HeapValue(
215 absl::make_unique<Value<T>>(InPlace(), std::forward<Args>(args)...));
216 return static_cast<Variant::Value<T>*>(heap_value_.get())->value;
217 }
218 }
219
220 bool is_empty() const { return GetValue() == nullptr; }
221
222 void clear() noexcept;
223
224 void swap(Variant& other) noexcept;
225
226 // Note, unlike TypeName(), TypeId() does not return the TypeIndex
227 // of the original type when a TensorValueDataProto is stored as the
228 // value. In this case, it returns the TypeIndex of TensorValueDataProto.
229 TypeIndex TypeId() const {
230 const TypeIndex VoidTypeIndex = TypeIndex::Make<void>();
231 if (is_empty()) {
232 return VoidTypeIndex;
233 }
234 return GetValue()->TypeId();
235 }
236
237 std::string DebugString() const {
238 return strings::StrCat("Variant<type: ", TypeName(),
239 " value: ", SummarizeValue(), ">");
240 }
241
242 std::string SummarizeValue() const {
243 return is_empty() ? "[empty]" : GetValue()->DebugString();
244 }
245
246 // Returns a pointer to the stored value if it is type T, or nullptr
247 // otherwise.
248 template <typename T>
249 T* get() {
250 const TypeIndex TTypeIndex = TypeIndex::Make<T>();
251 if (is_empty() || (TTypeIndex != TypeId())) return nullptr;
252 return std::addressof(static_cast<Variant::Value<T>*>(GetValue())->value);
253 }
254
255 // Returns a pointer to the stored value if it is type T, or nullptr
256 // otherwise.
257 template <typename T>
258 const T* get() const {
259 const TypeIndex TTypeIndex = TypeIndex::Make<T>();
260 if (is_empty() || (TTypeIndex != TypeId())) return nullptr;
261 return std::addressof(
262 static_cast<const Variant::Value<T>*>(GetValue())->value);
263 }
264
265 // Returns TypeNameVariant(value).
266 //
267 // In the special case that a serialized Variant is stored (value
268 // is a VariantTensorDataProto), returns value.TypeName(), the
269 // TypeName field stored in the VariantTensorDataProto buffer.
270 std::string TypeName() const {
271 if (is_empty()) {
272 return "";
273 }
274 return GetValue()->TypeName();
275 }
276
277 // Serialize the contents of the stored object into `data`.
278 void Encode(VariantTensorData* data) const {
279 if (!is_empty()) {
280 GetValue()->Encode(data);
281 }
282 }
283
284 // Deserialize `data` and update the stored object.
285 bool Decode(VariantTensorData data);
286
287 // Helper methods to directly serialize/deserialize from strings.
288 void Encode(std::string* buf) const {
289 if (!is_empty()) {
290 GetValue()->Encode(buf);
291 }
292 }
293 bool Decode(std::string buf) {
294 if (!is_empty()) {
295 return GetValue()->Decode(std::move(buf));
296 }
297 return true;
298 }
299
300 template <typename VT>
301 static constexpr bool CanInlineType() {
302 return ((sizeof(Value<VT>) <= InlineValue::kMaxValueSize) &&
303 (alignof(Value<VT>) <= kMaxInlineValueAlignSize));
304 }
305
306 private:
307 struct in_place_t {};
308 static constexpr in_place_t InPlace() { return in_place_t{}; }
309
310 struct ValueInterface {
311 virtual ~ValueInterface() = default;
312 virtual TypeIndex TypeId() const = 0;
313 virtual void* RawPtr() = 0;
314 virtual const void* RawPtr() const = 0;
315 virtual std::unique_ptr<ValueInterface> Clone() const = 0;
316 virtual void CloneInto(ValueInterface* memory) const = 0;
317 virtual void MoveAssign(ValueInterface* memory) = 0;
318 virtual void MoveInto(ValueInterface* memory) = 0;
319 virtual std::string TypeName() const = 0;
320 virtual std::string DebugString() const = 0;
321 virtual void Encode(VariantTensorData* data) const = 0;
322 virtual bool Decode(VariantTensorData data) = 0;
323 virtual void Encode(std::string* buf) const = 0;
324 virtual bool Decode(std::string data) = 0;
325 };
326
327 template <typename T>
328 struct Value final : ValueInterface {
329 template <class... Args>
330 explicit Value(in_place_t /*tag*/, Args&&... args)
331 : value(std::forward<Args>(args)...) {}
332
333 // NOTE(ebrevdo): Destructor must be explicitly defined for CUDA to happily
334 // build `alignof(Variant<void*>)`.
335 ~Value() final = default;
336
337 TypeIndex TypeId() const final {
338 const TypeIndex value_type_index =
339 TypeIndex::Make<typename std::decay<T>::type>();
340 return value_type_index;
341 }
342
343 void* RawPtr() final { return &value; }
344
345 const void* RawPtr() const final { return &value; }
346
347 std::unique_ptr<ValueInterface> Clone() const final {
348 return absl::make_unique<Value>(InPlace(), value);
349 }
350
351 void MoveAssign(ValueInterface* memory) final {
352 CHECK(TypeId() == memory->TypeId())
353 << TypeId().name() << " vs. " << memory->TypeId().name();
354 static_cast<Value*>(memory)->value = std::move(value);
355 }
356
357 void CloneInto(ValueInterface* memory) const final {
358 new (memory) Value(InPlace(), value);
359 }
360
361 void MoveInto(ValueInterface* memory) final {
362 new (memory) Value(InPlace(), std::move(value));
363 }
364
365 std::string TypeName() const final { return TypeNameVariant(value); }
366
367 std::string DebugString() const final { return DebugStringVariant(value); }
368
369 void Encode(VariantTensorData* data) const final {
370 EncodeVariant(value, data);
371 }
372
373 bool Decode(VariantTensorData data) final {
374 return DecodeVariant(&data, &value);
375 }
376
377 void Encode(std::string* buf) const final { EncodeVariant(value, buf); }
378
379 bool Decode(std::string buf) final { return DecodeVariant(&buf, &value); }
380
381 T value;
382 };
383 static constexpr int kMaxInlineValueAlignSize = alignof(Value<void*>);
384
385 using HeapValue = std::unique_ptr<ValueInterface>;
386
387 struct InlineValue {
388 // We try to size InlineValue so that sizeof(Variant) <= 64 and it can fit
389 // into the aligned space of a TensorBuffer.
390 static constexpr int kMaxValueSize = (64 - /*some extra padding=*/8);
391
392 typedef char ValueDataArray[kMaxValueSize];
393 alignas(kMaxInlineValueAlignSize) ValueDataArray value_data;
394
395 // Tag is used for deducing the right type when constructing a Value in
396 // place.
397 template <typename VT>
398 struct Tag {};
399
400 template <typename VT, class... Args>
401 explicit InlineValue(Tag<VT> /*tag*/, Args&&... args) noexcept {
402 Value<VT>* inline_value_data = reinterpret_cast<Value<VT>*>(value_data);
403 new (inline_value_data) Value<VT>(InPlace(), std::forward<Args>(args)...);
404 }
405
406 InlineValue(const InlineValue& other) noexcept {
407 other.AsValueInterface()->CloneInto(AsValueInterface());
408 }
409
410 InlineValue(InlineValue&& other) noexcept {
411 other.AsValueInterface()->MoveInto(AsValueInterface());
412 }
413
414 void ResetMemory() { AsValueInterface()->~ValueInterface(); }
415
416 InlineValue& operator=(const InlineValue& other) {
417 if (&other == this) return *this;
418 ResetMemory();
419 other.AsValueInterface()->CloneInto(AsValueInterface());
420 return *this;
421 }
422
423 InlineValue& operator=(InlineValue&& other) {
424 if (&other == this) return *this;
425 if (AsValueInterface()->TypeId() == other.AsValueInterface()->TypeId()) {
426 other.AsValueInterface()->MoveAssign(AsValueInterface());
427 } else {
428 ResetMemory();
429 other.AsValueInterface()->MoveInto(AsValueInterface());
430 }
431 return *this;
432 }
433
434 ValueInterface* AsValueInterface() {
435 return reinterpret_cast<ValueInterface*>(value_data);
436 }
437
438 const ValueInterface* AsValueInterface() const {
439 return reinterpret_cast<const ValueInterface*>(value_data);
440 }
441
442 ~InlineValue() { ResetMemory(); }
443 };
444
445 union {
446 HeapValue heap_value_;
447 InlineValue inline_value_;
448 };
449 // is_inline_ provides discrimination between which member of the prior union
450 // is currently within it's lifetime. To switch from one member to the other,
451 // the destructor must be called on the currently alive member before calling
452 // the constructor on the other member. In effect, a member is expected to be
453 // live at any given time and that member is tracked via this boolean.
454 bool is_inline_;
455
456 bool IsInlineValue() const { return is_inline_; }
457
458 // ResetMemory causes the destructor of the currently active member of the
459 // union to be run. This must be follwed with a placement new call on the
460 // member whose lifetime is to start. Additionally, is_inline_ needs to be set
461 // accordingly. ResetAndSetInline and ResetAndSetHeap are simple helper
462 // functions for performing the actions that are required to follow.
463 void ResetMemory() {
464 if (IsInlineValue()) {
465 inline_value_.~InlineValue();
466 } else {
467 heap_value_.~HeapValue();
468 }
469 }
470
471 // ResetAndSetInline clears the current state and then constructs a new value
472 // inline with the provided arguments.
473 template <typename... Args>
474 void ResetAndSetInline(Args&&... args) noexcept {
475 ResetMemory();
476 new (&inline_value_) InlineValue(std::forward<Args>(args)...);
477 is_inline_ = true;
478 }
479
480 // ResetAndSetHeap clears the current state then constructs a new value on the
481 // heap with the provided arguments.
482 template <typename... Args>
483 void ResetAndSetHeap(Args&&... args) noexcept {
484 ResetMemory();
485 new (&heap_value_) HeapValue(std::forward<Args>(args)...);
486 is_inline_ = false;
487 }
488
489 ValueInterface* GetValue() {
490 if (IsInlineValue()) {
491 return inline_value_.AsValueInterface();
492 } else {
493 return heap_value_.get();
494 }
495 }
496
497 const ValueInterface* GetValue() const {
498 if (IsInlineValue()) {
499 return inline_value_.AsValueInterface();
500 } else {
501 return heap_value_.get();
502 }
503 }
504
505 // PRECONDITION: Called on construction or ResetMemory() has been called
506 // before this method.
507 template <typename VT, typename T>
508 void InsertValue(T&& value) {
509 if (IsInlineValue()) {
510 new (&inline_value_)
511 InlineValue(InlineValue::Tag<VT>{}, std::forward<T>(value));
512 } else {
513 new (&heap_value_) HeapValue(
514 absl::make_unique<Value<VT>>(InPlace(), std::forward<T>(value)));
515 }
516 }
517};
518
519// Make sure that a Variant object can reside in a 64-byte aligned Tensor
520// buffer.
521static_assert(sizeof(Variant) <= 64,
522 "Expected internal representation to be 64 bytes.");
523
524inline Variant::Variant(const Variant& other)
525 : is_inline_(other.IsInlineValue()) {
526 if (IsInlineValue()) {
527 new (&inline_value_) InlineValue(other.inline_value_);
528 } else {
529 new (&heap_value_)
530 HeapValue(other.heap_value_ ? other.heap_value_->Clone() : nullptr);
531 }
532}
533
534inline Variant::Variant(Variant&& other) noexcept
535 : is_inline_(other.IsInlineValue()) {
536 if (IsInlineValue()) {
537 new (&inline_value_) InlineValue(std::move(other.inline_value_));
538 } else {
539 new (&heap_value_) HeapValue(std::move(other.heap_value_));
540 }
541}
542
543template <typename T, typename VT,
544 typename std::enable_if<!std::is_same<Variant, VT>::value &&
545 std::is_move_constructible<VT>::value,
546 void>::type*>
547inline Variant::Variant(T&& value) : is_inline_(CanInlineType<VT>()) {
548 InsertValue<VT>(std::forward<T>(value));
549}
550
551template <typename T, typename VT,
552 typename std::enable_if<!std::is_same<Variant, VT>::value &&
553 std::is_copy_constructible<VT>::value,
554 void>::type*>
555inline Variant::Variant(const T& value) : is_inline_(CanInlineType<VT>()) {
556 InsertValue<VT>(value);
557}
558
559template <typename T, typename VT,
560 typename std::enable_if<!std::is_same<Variant, VT>::value &&
561 std::is_move_constructible<VT>::value,
562 void>::type*>
563inline Variant& Variant::operator=(T&& value) {
564 ResetMemory();
565 is_inline_ = CanInlineType<VT>();
566 InsertValue<VT>(std::forward<T>(value));
567 return *this;
568}
569
570template <typename T, typename VT,
571 typename std::enable_if<!std::is_same<Variant, VT>::value &&
572 std::is_copy_constructible<VT>::value,
573 void>::type*>
574inline Variant& Variant::operator=(const T& value) {
575 ResetMemory();
576 is_inline_ = CanInlineType<VT>();
577 InsertValue<VT>(value);
578 return *this;
579}
580
581inline void Variant::clear() noexcept {
582 // We set the internal unique_ptr to nullptr so that we preserve the
583 // invariant that one of the two states must be set at all times. nullptr
584 // indicates that the variant is empty.
585 ResetAndSetHeap(/*pointer=*/nullptr);
586}
587
588inline void Variant::swap(Variant& other) noexcept {
589 if (is_empty()) {
590 if (other.IsInlineValue()) {
591 ResetAndSetInline(std::move(other.inline_value_));
592 } else {
593 ResetAndSetHeap(std::move(other.heap_value_));
594 }
595 other.clear();
596 } else if (other.is_empty()) {
597 if (IsInlineValue()) {
598 other.ResetAndSetInline(std::move(inline_value_));
599 } else {
600 other.ResetAndSetHeap(std::move(heap_value_));
601 }
602 clear();
603 } else { // Both Variants have values.
604 if (other.IsInlineValue() && IsInlineValue()) {
605 std::swap(inline_value_, other.inline_value_);
606 } else if (!other.IsInlineValue() && !IsInlineValue()) {
607 std::swap(heap_value_, other.heap_value_);
608 } else if (other.IsInlineValue() && !IsInlineValue()) {
609 HeapValue v = std::move(heap_value_);
610 ResetAndSetInline(std::move(other.inline_value_));
611 other.ResetAndSetHeap(std::move(v));
612 } else { // !other.IsInlineValue() && IsInlineValue()
613 HeapValue v = std::move(other.heap_value_);
614 other.ResetAndSetInline(std::move(inline_value_));
615 ResetAndSetHeap(std::move(v));
616 }
617 }
618}
619
620template <>
621void* Variant::get();
622
623template <>
624const void* Variant::get() const;
625
626} // end namespace tensorflow
627
628#endif // TENSORFLOW_CORE_FRAMEWORK_VARIANT_H_
629