1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#ifndef GLOW_BASE_SLICERANGE_H
17#define GLOW_BASE_SLICERANGE_H
18
19#include "glow/Base/Type.h"
20#include "glow/Graph/Graph.h"
21
22#include <algorithm>
23#include <unordered_map>
24#include <unordered_set>
25#include <vector>
26
27using namespace glow;
28using llvm::cast;
29using llvm::dyn_cast;
30using llvm::isa;
31
32namespace glow {
33
34/// Dimension range representing a contiguous [start, stop) index interval with
35/// start index included and stop index excluded, along some tensor dimension.
36/// The indices are assumed to be 0 based (C indexing).
37using DimRange = std::pair<dim_t, dim_t>;
38
39/// Dimension paddings representing a virtual padding before/after a given
40/// dimension range.
41using DimPads = std::pair<dim_t, dim_t>;
42
43/// Slice range utility class representing the ranges for all the dimensions
44/// of a slice obtained by extraction from a larger tensor.
45class SliceRange {
46
47 /// Vector of ranges for all the dimensions of a slice.
48 std::vector<DimRange> ranges_;
49
50public:
51 SliceRange() = default;
52
53 /// Ctor.
54 explicit SliceRange(std::vector<DimRange> ranges) { ranges_ = ranges; }
55
56 /// Ctor.
57 explicit SliceRange(TypeRef type) {
58 for (auto size : type->dims()) {
59 ranges_.emplace_back(0, size);
60 }
61 }
62
63 /// Ctor.
64 explicit SliceRange(const SliceNode *slice) {
65 SliceRange sliceRange = SliceRange(slice->getResult().getType());
66 auto start = slice->getStart();
67 for (size_t dim = 0, dimEnd = start.size(); dim < dimEnd; ++dim) {
68 sliceRange[dim].first += start[dim];
69 sliceRange[dim].second += start[dim];
70 }
71 ranges_ = sliceRange.getRanges();
72 }
73
74 /// \returns the dimension ranges.
75 llvm::ArrayRef<DimRange> getRanges() const { return ranges_; }
76
77 /// \returns the start values of the dimension ranges.
78 std::vector<dim_t> getStarts() const {
79 std::vector<dim_t> starts(ranges_.size());
80 for (size_t dim = 0, e = ranges_.size(); dim < e; ++dim) {
81 starts[dim] = ranges_[dim].first;
82 }
83 return starts;
84 }
85
86 /// \returns the sizes of the dimension ranges.
87 std::vector<dim_t> getSizes() const {
88 std::vector<dim_t> sizes(ranges_.size());
89 for (size_t dim = 0, e = ranges_.size(); dim < e; ++dim) {
90 sizes[dim] = ranges_[dim].second - ranges_[dim].first;
91 }
92 return sizes;
93 }
94
95 /// \returns the number of dimensions.
96 size_t getNumDims() const { return ranges_.size(); }
97
98 /// \returns a mutable range for the given dimension \p dim.
99 DimRange &operator[](size_t dim) {
100 DCHECK_LT(dim, ranges_.size()) << "Invalid dimension!";
101 return ranges_[dim];
102 }
103
104 /// \returns an immutable range for the given dimension \p dim.
105 const DimRange &operator[](size_t dim) const {
106 DCHECK_LT(dim, ranges_.size()) << "Invalid dimension!";
107 return ranges_[dim];
108 }
109
110 /// \returns whether this slice range is equal to \p other.
111 bool operator==(const SliceRange &other) const {
112 auto rangesOther = other.getRanges();
113 if (ranges_.size() != rangesOther.size()) {
114 return false;
115 }
116 for (size_t dim = 0, e = ranges_.size(); dim < e; ++dim) {
117 if (ranges_[dim] != rangesOther[dim]) {
118 return false;
119 }
120 }
121 return true;
122 }
123
124 /// \returns the range size along dimension \p dim.
125 dim_t getDimSize(size_t dim) const {
126 DCHECK_LT(dim, ranges_.size()) << "Invalid dimension!";
127 return ranges_[dim].second - ranges_[dim].first;
128 }
129
130 /// \returns a slice range by extracting the dimension ranges between
131 /// \p dimStart and \p dimStop (both included).
132 SliceRange extractRanges(size_t dimStart, size_t dimStop) const {
133 DCHECK_LT(dimStart, ranges_.size()) << "Invalid start dimension!";
134 DCHECK_LT(dimStop, ranges_.size()) << "Invalid stop dimension!";
135 DCHECK_LE(dimStart, dimStop) << "Invalid start/stop dimension!";
136 std::vector<DimRange> dimRanges(ranges_.cbegin() + dimStart,
137 ranges_.cbegin() + dimStop + 1);
138 return SliceRange(dimRanges);
139 }
140
141 /// \returns a slice range by shuffling the dimension ranges using the
142 /// indices \p shuffle. The flag \p invert allows optionally to invert
143 /// the shuffle permutation before using it.
144 SliceRange shuffleRanges(llvm::ArrayRef<size_t> shuffle,
145 bool invert = false) const {
146 DCHECK_EQ(ranges_.size(), shuffle.size())
147 << "Mismatch between ranges and shuffle sizes!";
148 std::vector<DimRange> dimRanges(ranges_.size());
149 for (size_t idx = 0, e = ranges_.size(); idx < e; ++idx) {
150 size_t dimInp = invert ? idx : shuffle[idx];
151 size_t dimOut = invert ? shuffle[idx] : idx;
152 DCHECK_LT(dimInp, ranges_.size()) << "Invalid input shuffle index!";
153 DCHECK_LT(dimOut, ranges_.size()) << "Invalid output shuffle index!";
154 dimRanges[dimOut] = ranges_[dimInp];
155 }
156 return SliceRange(dimRanges);
157 }
158
159 /// \returns whether this slice range is empty.
160 bool isEmpty() const {
161 if (!ranges_.size()) {
162 return true;
163 }
164 for (const auto &range : ranges_) {
165 if (!(range.first < range.second)) {
166 return true;
167 }
168 }
169 return false;
170 }
171
172 /// \returns whether both ends of the range for a given dimension \p dim are
173 /// aligned to \p align. For example the range [4, 8) is aligned to 4.
174 bool isDimRangeAligned(size_t dim, dim_t align) const {
175 DCHECK_LT(dim, ranges_.size()) << "Invalid dimension!";
176 return (ranges_[dim].first % align == 0) &&
177 (ranges_[dim].second % align == 0);
178 }
179
180 /// \returns whether this slice range is included by \p other.
181 bool isIncludedBy(const SliceRange &other) const {
182 auto rangesOther = other.getRanges();
183 if (ranges_.size() != rangesOther.size()) {
184 return false;
185 }
186 for (size_t dim = 0, e = ranges_.size(); dim < e; ++dim) {
187 if (!((rangesOther[dim].first <= ranges_[dim].first) &&
188 (ranges_[dim].second <= rangesOther[dim].second))) {
189 return false;
190 }
191 }
192 return true;
193 }
194
195 /// \returns a textual representation of this slice range.
196 std::string toString() const {
197 std::string storage;
198 llvm::raw_string_ostream os(storage);
199 for (size_t dim = 0, e = ranges_.size(); dim < e; ++dim) {
200 os << getDimSize(dim) << "[" << ranges_[dim].first << ":"
201 << ranges_[dim].second << ") ";
202 }
203 return os.str();
204 }
205};
206
207} // namespace glow
208
209#endif // GLOW_BASE_SLICERANGE_H
210