1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #ifndef TENSORFLOW_CORE_FRAMEWORK_TENSOR_KEY_H_ |
16 | #define TENSORFLOW_CORE_FRAMEWORK_TENSOR_KEY_H_ |
17 | |
18 | #include "tensorflow/core/framework/tensor.h" |
19 | #include "tensorflow/core/framework/types.h" |
20 | |
21 | namespace tensorflow { |
22 | |
23 | class TensorKey : public Tensor { |
24 | public: |
25 | using Tensor::Tensor; |
26 | |
27 | TensorKey(const Tensor& t) : Tensor(t) {} |
28 | |
29 | // Equality operator. Needed for absl hashing. |
30 | friend bool operator==(const TensorKey& t1, const TensorKey& t2) { |
31 | if (t1.dtype() != t2.dtype() || t1.shape() != t2.shape()) { |
32 | return false; |
33 | } |
34 | if (DataTypeCanUseMemcpy(t1.dtype())) { |
35 | return t1.tensor_data() == t2.tensor_data(); |
36 | } else if (t1.dtype() == DT_STRING) { |
37 | const auto s1 = t1.unaligned_flat<tstring>(); |
38 | const auto s2 = t2.unaligned_flat<tstring>(); |
39 | for (int64_t i = 0, n = t1.NumElements(); i < n; ++i) { |
40 | if (TF_PREDICT_FALSE(s1(i) != s2(i))) { |
41 | return false; |
42 | } |
43 | } |
44 | return true; |
45 | } else { |
46 | DCHECK(false) << "Unimplemented dtype " << DataTypeString(t1.dtype()) |
47 | << std::endl; |
48 | } |
49 | return false; |
50 | } |
51 | |
52 | friend bool operator!=(const TensorKey& t1, const TensorKey& t2) { |
53 | return !(t1 == t2); |
54 | } |
55 | |
56 | // Needed for absl hash function. |
57 | template <typename H> |
58 | friend H AbslHashValue(H h, const TensorKey& k) { |
59 | if (DataTypeCanUseMemcpy(k.dtype())) { |
60 | return H::combine(std::move(h), k.tensor_data()); |
61 | } else if (k.dtype() == DT_STRING) { |
62 | const auto strs = k.unaligned_flat<tstring>(); |
63 | for (int64_t i = 0, n = k.NumElements(); i < n; ++i) { |
64 | h = H::combine(std::move(h), strs(i)); |
65 | } |
66 | return h; |
67 | } else { |
68 | DCHECK(false) << "Unimplemented dtype " << DataTypeString(k.dtype()) |
69 | << std::endl; |
70 | } |
71 | return h; |
72 | } |
73 | }; |
74 | |
75 | } // namespace tensorflow |
76 | |
77 | #endif |
78 | |