1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "stripe_config.h"
20
21#include <tvm/runtime/container/array.h>
22#include <tvm/runtime/object.h>
23#include <tvm/runtime/registry.h>
24
25#include <algorithm>
26#include <limits>
27#include <map>
28#include <utility>
29#include <vector>
30
31#include "common.h"
32
33namespace tvm {
34namespace contrib {
35namespace ethosu {
36namespace cascader {
37
38template <class T>
39std::map<std::vector<T>, int> MultiplyCombinations(std::vector<std::map<T, int>> values) {
40 if (values.size() == 1) {
41 std::map<std::vector<T>, int> combs;
42 for (const auto& it : values[0]) {
43 combs[std::vector<T>(1, it.first)] = it.second;
44 }
45 return combs;
46 }
47 auto combs =
48 MultiplyCombinations(std::vector<std::map<T, int>>(values.begin(), values.end() - 1));
49 std::map<std::vector<T>, int> new_combs;
50 for (const auto& val_it : values.back()) {
51 for (const auto& comb_it : combs) {
52 auto new_comb = std::vector<T>(comb_it.first);
53 new_comb.push_back(val_it.first);
54 new_combs[new_comb] = val_it.second * comb_it.second;
55 }
56 }
57 return new_combs;
58}
59
60std::map<std::vector<int>, int> CountStripes(const StripeConfig& stripe_config,
61 bool enable_sliding_window = false) {
62 std::vector<std::map<int, int>> per_axis_sizes(stripe_config->GetOrder().size());
63 for (size_t axis = 0; axis < stripe_config->GetOrder().size(); axis++) {
64 int start = stripe_config->GetOffset()[axis];
65 size_t stripe_count = static_cast<size_t>(stripe_config->GetStripes()[axis]);
66 int stride = stripe_config->GetStrides()[axis];
67 int shape = stripe_config->GetShape()[axis];
68 int extent = stripe_config->GetExtent()[axis];
69 int low;
70 int high = std::numeric_limits<int>::min();
71 for (size_t i = 0; i < stripe_count; i++) {
72 // Calculate the 'non-edge case' sizes in one go to save effort
73 if (!enable_sliding_window || i > 0) {
74 if (start >= 0 && extent - shape - start >= 0 && stride > 0) {
75 int whole_stripes =
76 std::min(static_cast<int>(stripe_count - i), (extent - shape - start) / stride + 1);
77 if (enable_sliding_window) {
78 per_axis_sizes[axis][stride] += whole_stripes;
79 } else {
80 per_axis_sizes[axis][shape] += whole_stripes;
81 }
82 i += whole_stripes - 1;
83 start += whole_stripes * stride;
84 high = std::min(start - stride + shape, extent);
85 continue;
86 }
87 }
88 low = std::max(start, 0);
89 if (enable_sliding_window) {
90 low = std::max(low, high);
91 }
92 high = std::min(start + shape, extent);
93 int size = high - low;
94 if (size > 0) {
95 per_axis_sizes[axis][size]++;
96 }
97 start += stride;
98 }
99 }
100 return MultiplyCombinations(per_axis_sizes);
101}
102
103TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.CountStripes")
104 .set_body_typed([](StripeConfig stripe_config, bool enable_sliding_window) {
105 Map<Array<Integer>, Integer> ret;
106 auto stripe_counts = CountStripes(stripe_config, enable_sliding_window);
107 for (const auto& it : stripe_counts) {
108 ret.Set(make_array(it.first), it.second);
109 }
110 return ret;
111 });
112
113void StripeConfigNode::VisitAttrs(AttrVisitor* v) {
114 Array<Integer> tmp_arr = make_array(shape_);
115 v->Visit("_shape", &tmp_arr);
116 tmp_arr = make_array(extent_);
117 v->Visit("_extent", &tmp_arr);
118 tmp_arr = make_array(order_);
119 v->Visit("_order", &tmp_arr);
120 tmp_arr = make_array(stripes_);
121 v->Visit("_stripes", &tmp_arr);
122 tmp_arr = make_array(offset_);
123 v->Visit("_offset", &tmp_arr);
124 Array<FloatImm> tmp_float_arr = make_array(strides_);
125 v->Visit("_strides", &tmp_float_arr);
126 int64_t tmp_hash = static_cast<int64_t>(hash_);
127 v->Visit("_hash", &tmp_hash);
128}
129
130void StripeConfigNode::ComputeHash_() {
131 hash_ = hash_vector(shape_);
132 hash_combine(&hash_, hash_vector(extent_));
133 hash_combine(&hash_, hash_vector(strides_));
134 hash_combine(&hash_, hash_vector(order_));
135 hash_combine(&hash_, hash_vector(stripes_));
136 hash_combine(&hash_, hash_vector(offset_));
137}
138
139StripeConfig::StripeConfig(const std::vector<int>& shape, const std::vector<int>& extent,
140 const std::vector<float>& strides, const std::vector<int>& order,
141 const std::vector<int>& stripes, const std::vector<int>& offset) {
142 auto n = make_object<StripeConfigNode>();
143 n->shape_ = std::move(shape);
144 n->extent_ = std::move(extent);
145 n->strides_ = std::move(strides);
146 n->order_ = std::move(order);
147 n->stripes_ = std::move(stripes);
148 n->offset_ = std::move(offset);
149 n->ComputeHash_();
150 data_ = std::move(n);
151}
152
153inline bool StripeConfig::operator==(const StripeConfig& other) const {
154 if (get() == other.get()) return true;
155 if (get() == nullptr || other.get() == nullptr) return false;
156 return ((*this)->shape_ == other->shape_ && (*this)->extent_ == other->extent_ &&
157 (*this)->strides_ == other->strides_ && (*this)->order_ == other->order_ &&
158 (*this)->stripes_ == other->stripes_ && (*this)->offset_ == other->offset_);
159}
160
161TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.StripeConfig")
162 .set_body_typed([](Array<Integer> shape, Array<Integer> extent, Array<FloatImm> strides,
163 Array<Integer> order, Array<Integer> stripes, Array<Integer> offset) {
164 std::vector<int> vshape = make_vector<int, Integer>(shape);
165 std::vector<int> vextent = make_vector<int, Integer>(extent);
166 std::vector<float> vstrides = make_vector<float, FloatImm>(strides);
167 std::vector<int> vorder = make_vector<int, Integer>(order);
168 std::vector<int> vstripes = make_vector<int, Integer>(stripes);
169 std::vector<int> voffset = make_vector<int, Integer>(offset);
170 return StripeConfig(vshape, vextent, vstrides, vorder, vstripes, voffset);
171 });
172
173TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.StripeConfigEqual")
174 .set_body_method(&StripeConfig::operator==);
175
176TVM_REGISTER_NODE_TYPE(StripeConfigNode);
177
178} // namespace cascader
179} // namespace ethosu
180} // namespace contrib
181} // namespace tvm
182