1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | #ifndef GLOW_BASE_SLICERANGE_H |
17 | #define GLOW_BASE_SLICERANGE_H |
18 | |
19 | #include "glow/Base/Type.h" |
20 | #include "glow/Graph/Graph.h" |
21 | |
22 | #include <algorithm> |
23 | #include <unordered_map> |
24 | #include <unordered_set> |
25 | #include <vector> |
26 | |
27 | using namespace glow; |
28 | using llvm::cast; |
29 | using llvm::dyn_cast; |
30 | using llvm::isa; |
31 | |
32 | namespace glow { |
33 | |
34 | /// Dimension range representing a contiguous [start, stop) index interval with |
35 | /// start index included and stop index excluded, along some tensor dimension. |
36 | /// The indices are assumed to be 0 based (C indexing). |
37 | using DimRange = std::pair<dim_t, dim_t>; |
38 | |
39 | /// Dimension paddings representing a virtual padding before/after a given |
40 | /// dimension range. |
41 | using DimPads = std::pair<dim_t, dim_t>; |
42 | |
43 | /// Slice range utility class representing the ranges for all the dimensions |
44 | /// of a slice obtained by extraction from a larger tensor. |
45 | class SliceRange { |
46 | |
47 | /// Vector of ranges for all the dimensions of a slice. |
48 | std::vector<DimRange> ranges_; |
49 | |
50 | public: |
51 | SliceRange() = default; |
52 | |
53 | /// Ctor. |
54 | explicit SliceRange(std::vector<DimRange> ranges) { ranges_ = ranges; } |
55 | |
56 | /// Ctor. |
57 | explicit SliceRange(TypeRef type) { |
58 | for (auto size : type->dims()) { |
59 | ranges_.emplace_back(0, size); |
60 | } |
61 | } |
62 | |
63 | /// Ctor. |
64 | explicit SliceRange(const SliceNode *slice) { |
65 | SliceRange sliceRange = SliceRange(slice->getResult().getType()); |
66 | auto start = slice->getStart(); |
67 | for (size_t dim = 0, dimEnd = start.size(); dim < dimEnd; ++dim) { |
68 | sliceRange[dim].first += start[dim]; |
69 | sliceRange[dim].second += start[dim]; |
70 | } |
71 | ranges_ = sliceRange.getRanges(); |
72 | } |
73 | |
74 | /// \returns the dimension ranges. |
75 | llvm::ArrayRef<DimRange> getRanges() const { return ranges_; } |
76 | |
77 | /// \returns the start values of the dimension ranges. |
78 | std::vector<dim_t> getStarts() const { |
79 | std::vector<dim_t> starts(ranges_.size()); |
80 | for (size_t dim = 0, e = ranges_.size(); dim < e; ++dim) { |
81 | starts[dim] = ranges_[dim].first; |
82 | } |
83 | return starts; |
84 | } |
85 | |
86 | /// \returns the sizes of the dimension ranges. |
87 | std::vector<dim_t> getSizes() const { |
88 | std::vector<dim_t> sizes(ranges_.size()); |
89 | for (size_t dim = 0, e = ranges_.size(); dim < e; ++dim) { |
90 | sizes[dim] = ranges_[dim].second - ranges_[dim].first; |
91 | } |
92 | return sizes; |
93 | } |
94 | |
95 | /// \returns the number of dimensions. |
96 | size_t getNumDims() const { return ranges_.size(); } |
97 | |
98 | /// \returns a mutable range for the given dimension \p dim. |
99 | DimRange &operator[](size_t dim) { |
100 | DCHECK_LT(dim, ranges_.size()) << "Invalid dimension!" ; |
101 | return ranges_[dim]; |
102 | } |
103 | |
104 | /// \returns an immutable range for the given dimension \p dim. |
105 | const DimRange &operator[](size_t dim) const { |
106 | DCHECK_LT(dim, ranges_.size()) << "Invalid dimension!" ; |
107 | return ranges_[dim]; |
108 | } |
109 | |
110 | /// \returns whether this slice range is equal to \p other. |
111 | bool operator==(const SliceRange &other) const { |
112 | auto rangesOther = other.getRanges(); |
113 | if (ranges_.size() != rangesOther.size()) { |
114 | return false; |
115 | } |
116 | for (size_t dim = 0, e = ranges_.size(); dim < e; ++dim) { |
117 | if (ranges_[dim] != rangesOther[dim]) { |
118 | return false; |
119 | } |
120 | } |
121 | return true; |
122 | } |
123 | |
124 | /// \returns the range size along dimension \p dim. |
125 | dim_t getDimSize(size_t dim) const { |
126 | DCHECK_LT(dim, ranges_.size()) << "Invalid dimension!" ; |
127 | return ranges_[dim].second - ranges_[dim].first; |
128 | } |
129 | |
130 | /// \returns a slice range by extracting the dimension ranges between |
131 | /// \p dimStart and \p dimStop (both included). |
132 | SliceRange (size_t dimStart, size_t dimStop) const { |
133 | DCHECK_LT(dimStart, ranges_.size()) << "Invalid start dimension!" ; |
134 | DCHECK_LT(dimStop, ranges_.size()) << "Invalid stop dimension!" ; |
135 | DCHECK_LE(dimStart, dimStop) << "Invalid start/stop dimension!" ; |
136 | std::vector<DimRange> dimRanges(ranges_.cbegin() + dimStart, |
137 | ranges_.cbegin() + dimStop + 1); |
138 | return SliceRange(dimRanges); |
139 | } |
140 | |
141 | /// \returns a slice range by shuffling the dimension ranges using the |
142 | /// indices \p shuffle. The flag \p invert allows optionally to invert |
143 | /// the shuffle permutation before using it. |
144 | SliceRange shuffleRanges(llvm::ArrayRef<size_t> shuffle, |
145 | bool invert = false) const { |
146 | DCHECK_EQ(ranges_.size(), shuffle.size()) |
147 | << "Mismatch between ranges and shuffle sizes!" ; |
148 | std::vector<DimRange> dimRanges(ranges_.size()); |
149 | for (size_t idx = 0, e = ranges_.size(); idx < e; ++idx) { |
150 | size_t dimInp = invert ? idx : shuffle[idx]; |
151 | size_t dimOut = invert ? shuffle[idx] : idx; |
152 | DCHECK_LT(dimInp, ranges_.size()) << "Invalid input shuffle index!" ; |
153 | DCHECK_LT(dimOut, ranges_.size()) << "Invalid output shuffle index!" ; |
154 | dimRanges[dimOut] = ranges_[dimInp]; |
155 | } |
156 | return SliceRange(dimRanges); |
157 | } |
158 | |
159 | /// \returns whether this slice range is empty. |
160 | bool isEmpty() const { |
161 | if (!ranges_.size()) { |
162 | return true; |
163 | } |
164 | for (const auto &range : ranges_) { |
165 | if (!(range.first < range.second)) { |
166 | return true; |
167 | } |
168 | } |
169 | return false; |
170 | } |
171 | |
172 | /// \returns whether both ends of the range for a given dimension \p dim are |
173 | /// aligned to \p align. For example the range [4, 8) is aligned to 4. |
174 | bool isDimRangeAligned(size_t dim, dim_t align) const { |
175 | DCHECK_LT(dim, ranges_.size()) << "Invalid dimension!" ; |
176 | return (ranges_[dim].first % align == 0) && |
177 | (ranges_[dim].second % align == 0); |
178 | } |
179 | |
180 | /// \returns whether this slice range is included by \p other. |
181 | bool isIncludedBy(const SliceRange &other) const { |
182 | auto rangesOther = other.getRanges(); |
183 | if (ranges_.size() != rangesOther.size()) { |
184 | return false; |
185 | } |
186 | for (size_t dim = 0, e = ranges_.size(); dim < e; ++dim) { |
187 | if (!((rangesOther[dim].first <= ranges_[dim].first) && |
188 | (ranges_[dim].second <= rangesOther[dim].second))) { |
189 | return false; |
190 | } |
191 | } |
192 | return true; |
193 | } |
194 | |
195 | /// \returns a textual representation of this slice range. |
196 | std::string toString() const { |
197 | std::string storage; |
198 | llvm::raw_string_ostream os(storage); |
199 | for (size_t dim = 0, e = ranges_.size(); dim < e; ++dim) { |
200 | os << getDimSize(dim) << "[" << ranges_[dim].first << ":" |
201 | << ranges_[dim].second << ") " ; |
202 | } |
203 | return os.str(); |
204 | } |
205 | }; |
206 | |
207 | } // namespace glow |
208 | |
209 | #endif // GLOW_BASE_SLICERANGE_H |
210 | |