1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_FRAMEWORK_TENSOR_SLICE_H_ |
17 | #define TENSORFLOW_CORE_FRAMEWORK_TENSOR_SLICE_H_ |
18 | |
19 | #include <string> |
20 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
21 | #include "tensorflow/core/framework/tensor_shape.h" |
22 | #include "tensorflow/core/framework/tensor_slice.pb.h" |
23 | #include "tensorflow/core/lib/core/status.h" |
24 | #include "tensorflow/core/lib/core/stringpiece.h" |
25 | #include "tensorflow/core/lib/gtl/inlined_vector.h" |
26 | #include "tensorflow/core/platform/logging.h" |
27 | |
28 | namespace tensorflow { |
29 | |
30 | // A tensor slice represents a slice of a given tensor. It is represented by a |
31 | // list of (start, length) pairs, where the size of the list is the rank of the |
32 | // tensor. |
33 | |
34 | class TensorSlice { |
35 | public: |
36 | // Construct a tensor slice: you have a number of ways: |
37 | // -- creating an empty slice |
38 | // -- from just a dimension (in this case it will create a full slice) |
39 | // -- from an array of pairs of integers. |
40 | // -- from a TensorSliceProto protocol buffer |
41 | // -- from a string format of "start,length:start,length..." where each |
42 | // "start,length" pair represents the slice on one dimension. We allow a |
43 | // special "-" that means "everything for this dimension". One such example |
44 | // is: 0,10:-:14,1:-:- |
45 | TensorSlice() {} |
46 | explicit TensorSlice(int dim); |
47 | explicit TensorSlice(const TensorSliceProto& proto); |
48 | explicit TensorSlice( |
49 | std::initializer_list<std::pair<int64_t, int64_t>> extents); |
50 | |
51 | // This factory methods should be used instead of the constructor that takes a |
52 | // `TensorSliceProto` if calling code cannot validate that the sizes specify a |
53 | // valid `TensorSlice`. |
54 | static Status BuildTensorSlice(const TensorSliceProto& proto, |
55 | TensorSlice* output); |
56 | |
57 | static Status Parse(const string& str, TensorSlice* output); |
58 | static TensorSlice ParseOrDie(const string& str) { |
59 | TensorSlice ret; |
60 | Status s = Parse(str, &ret); |
61 | if (!s.ok()) { |
62 | LOG(FATAL) << "Could not parse TensorSlice" ; |
63 | } |
64 | return ret; |
65 | } |
66 | |
67 | void Clear(); |
68 | |
69 | // Accessors |
70 | int dims() const { return starts_.size(); } |
71 | |
72 | int64_t start(int d) const { |
73 | DCHECK_GE(d, 0); |
74 | DCHECK_LT(d, dims()); |
75 | return starts_[d]; |
76 | } |
77 | |
78 | int64_t length(int d) const { |
79 | DCHECK_GE(d, 0); |
80 | DCHECK_LT(d, dims()); |
81 | return lengths_[d]; |
82 | } |
83 | |
84 | int64_t end(int d) const { |
85 | DCHECK_GE(d, 0); |
86 | DCHECK_LT(d, dims()); |
87 | return start(d) + length(d); |
88 | } |
89 | |
90 | void set_start(int d, int64_t x) { |
91 | DCHECK_GE(d, 0); |
92 | DCHECK_LT(d, dims()); |
93 | DCHECK_GE(x, 0); |
94 | starts_[d] = x; |
95 | } |
96 | |
97 | void set_length(int d, int64_t x) { |
98 | DCHECK_GE(d, 0); |
99 | DCHECK_LT(d, dims()); |
100 | lengths_[d] = x; |
101 | } |
102 | |
103 | // If we have a full slice along dimension "d". |
104 | bool IsFullAt(int d) const { |
105 | return lengths_[d] == kFullExtent && starts_[d] == 0; |
106 | } |
107 | |
108 | // If this is a full slice, i.e. IsFullAt(d) for every d. |
109 | bool IsFull() const; |
110 | |
111 | // Set the slice to be a full slice of "dim" dimensions |
112 | void SetFullSlice(int dim); |
113 | |
114 | // Extend a slice to "dim" dimensions: all the added dimensions are full. |
115 | // Requires: dim >= dims(). |
116 | void Extend(int dim); |
117 | |
118 | // Conversion of a TensorSlice to other formats |
119 | void AsProto(TensorSliceProto* proto) const; |
120 | string DebugString() const; |
121 | |
122 | // Fill *indices and *sizes from *this (so that we can use the slice() |
123 | // function in eigen tensor). We need a tensor shape in case some of the |
124 | // slices are full slices. |
125 | // We allow NDIMS to be greater than dims(), in which case we will pad the |
126 | // higher dimensions with trivial dimensions. |
127 | template <int NDIMS> |
128 | void FillIndicesAndSizes( |
129 | const TensorShape& shape, |
130 | Eigen::DSizes<Eigen::DenseIndex, NDIMS>* indices, |
131 | Eigen::DSizes<Eigen::DenseIndex, NDIMS>* sizes) const; |
132 | |
133 | // Interaction with other TensorSlices. |
134 | |
135 | // Compute the intersection with another slice and if "result" is not |
136 | // nullptr, store the results in *result; returns true if there is any real |
137 | // intersection. |
138 | bool Intersect(const TensorSlice& other, TensorSlice* result) const; |
139 | // A short hand. |
140 | bool Overlaps(const TensorSlice& other) const { |
141 | return Intersect(other, nullptr); |
142 | } |
143 | |
144 | // Equals iff "*this" and "other" are logically equivalent. |
145 | bool operator==(const TensorSlice& other) const; |
146 | bool operator!=(const TensorSlice& other) const { return !(*this == other); } |
147 | |
148 | // Interaction with TensorShape. |
149 | |
150 | // Slices a shape and stores the result into *result_shape. |
151 | // Requires that the shape and *this have the same rank. |
152 | // For example, given a tensor shape of {3, 4, 5}, and a slice of |
153 | // 1,2:-:0,2, the result shape is {2, 4, 2}. |
154 | Status SliceTensorShape(const TensorShape& shape, |
155 | TensorShape* result_shape) const; |
156 | |
157 | // Given slice "sub" where "sub" is fully contained in *this, |
158 | // (meaning that the intersection of "sub" and *this equals "sub"), computes |
159 | // the "relative" slice of "sub" with respect to *this. |
160 | // |
161 | // In other words, if we use A>S to denote slicing a shape S with a slice A, |
162 | // then the function is computing a slice X such that: |
163 | // X > (this > S) = sub > S |
164 | // for any shape S. |
165 | // |
166 | // In general, along every dimension, the start of the relative slice is the |
167 | // start of the "sub" slice minus the start of *this; the length of the |
168 | // relative slice is the length of the "sub" slice. |
169 | // |
170 | // For example, say we have a shape of {3, 4, 5}, "this" is 0,2:-:1,2, and |
171 | // "sub" is 1,1:2:2,1,2, then the related slice is 1,1:2,2:0,2. |
172 | // |
173 | // The caller needs to make sure that "sub" is indeed a sub-slice of *this; |
174 | // otherwise the result is undefined. |
175 | void ComputeRelative(const TensorSlice& sub, TensorSlice* relative) const; |
176 | |
177 | // Updates the slice in such a way that it fully covers "other" slice. |
178 | // Note, "other" slice should refer to the same tensor shape. |
179 | // Example: |
180 | // given a slice [2:4, :, 3:] and "other" slice [:, 1:4, 2:4] the |
181 | // updated slice would be [:, :, 2:]. Here is why: |
182 | // dim 0: "2:4" U ":" -> ":" |
183 | // dim 1: ":" U "1-4" -> ":" |
184 | // dim 2: "3:" U "2:4" -> "2:" |
185 | void UpdateToCover(const TensorSlice& other); |
186 | |
187 | // Returns true if the length field was specified in an Extent. |
188 | static bool HasExtentLength(const TensorSliceProto::Extent& extent); |
189 | |
190 | // Returns the value of the length field in an Extent, or -1 if it |
191 | // is not present. |
192 | static int64_t GetExtentLength(const TensorSliceProto::Extent& extent); |
193 | |
194 | private: |
195 | // a length value of kFullExtent (-1) means we have a full slice at this |
196 | // dimension. It's defined in tensor_slice.cc. |
197 | static const int64_t kFullExtent; |
198 | |
199 | // TODO(yangke): switch to Eigen once it supports variable size arrays. |
200 | // A value of |
201 | gtl::InlinedVector<int64_t, 4> starts_; |
202 | gtl::InlinedVector<int64_t, 4> lengths_; |
203 | }; |
204 | |
205 | template <int NDIMS> |
206 | void TensorSlice::FillIndicesAndSizes( |
207 | const TensorShape& shape, Eigen::DSizes<Eigen::DenseIndex, NDIMS>* indices, |
208 | Eigen::DSizes<Eigen::DenseIndex, NDIMS>* sizes) const { |
209 | CHECK_EQ(shape.dims(), dims()) << "Incompatible dimensions between shape " |
210 | << "slices: shape = " << shape.DebugString() |
211 | << ", slice = " << DebugString(); |
212 | CHECK_GE(NDIMS, dims()) << "Asking for a " << NDIMS << "-dim slice from " |
213 | << "a slice of dimension " << dims(); |
214 | for (int d = 0; d < dims(); ++d) { |
215 | if (IsFullAt(d)) { |
216 | (*indices)[d] = 0; |
217 | (*sizes)[d] = shape.dim_size(d); |
218 | } else { |
219 | (*indices)[d] = starts_[d]; |
220 | (*sizes)[d] = lengths_[d]; |
221 | } |
222 | } |
223 | for (int d = dims(); d < NDIMS; ++d) { |
224 | (*indices)[d] = 0; |
225 | (*sizes)[d] = 1; |
226 | } |
227 | } |
228 | |
229 | } // namespace tensorflow |
230 | |
231 | #endif // TENSORFLOW_CORE_FRAMEWORK_TENSOR_SLICE_H_ |
232 | |