1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include "stripe_config.h" |
20 | |
21 | #include <tvm/runtime/container/array.h> |
22 | #include <tvm/runtime/object.h> |
23 | #include <tvm/runtime/registry.h> |
24 | |
25 | #include <algorithm> |
26 | #include <limits> |
27 | #include <map> |
28 | #include <utility> |
29 | #include <vector> |
30 | |
31 | #include "common.h" |
32 | |
33 | namespace tvm { |
34 | namespace contrib { |
35 | namespace ethosu { |
36 | namespace cascader { |
37 | |
38 | template <class T> |
39 | std::map<std::vector<T>, int> MultiplyCombinations(std::vector<std::map<T, int>> values) { |
40 | if (values.size() == 1) { |
41 | std::map<std::vector<T>, int> combs; |
42 | for (const auto& it : values[0]) { |
43 | combs[std::vector<T>(1, it.first)] = it.second; |
44 | } |
45 | return combs; |
46 | } |
47 | auto combs = |
48 | MultiplyCombinations(std::vector<std::map<T, int>>(values.begin(), values.end() - 1)); |
49 | std::map<std::vector<T>, int> new_combs; |
50 | for (const auto& val_it : values.back()) { |
51 | for (const auto& comb_it : combs) { |
52 | auto new_comb = std::vector<T>(comb_it.first); |
53 | new_comb.push_back(val_it.first); |
54 | new_combs[new_comb] = val_it.second * comb_it.second; |
55 | } |
56 | } |
57 | return new_combs; |
58 | } |
59 | |
60 | std::map<std::vector<int>, int> CountStripes(const StripeConfig& stripe_config, |
61 | bool enable_sliding_window = false) { |
62 | std::vector<std::map<int, int>> per_axis_sizes(stripe_config->GetOrder().size()); |
63 | for (size_t axis = 0; axis < stripe_config->GetOrder().size(); axis++) { |
64 | int start = stripe_config->GetOffset()[axis]; |
65 | size_t stripe_count = static_cast<size_t>(stripe_config->GetStripes()[axis]); |
66 | int stride = stripe_config->GetStrides()[axis]; |
67 | int shape = stripe_config->GetShape()[axis]; |
68 | int extent = stripe_config->GetExtent()[axis]; |
69 | int low; |
70 | int high = std::numeric_limits<int>::min(); |
71 | for (size_t i = 0; i < stripe_count; i++) { |
72 | // Calculate the 'non-edge case' sizes in one go to save effort |
73 | if (!enable_sliding_window || i > 0) { |
74 | if (start >= 0 && extent - shape - start >= 0 && stride > 0) { |
75 | int whole_stripes = |
76 | std::min(static_cast<int>(stripe_count - i), (extent - shape - start) / stride + 1); |
77 | if (enable_sliding_window) { |
78 | per_axis_sizes[axis][stride] += whole_stripes; |
79 | } else { |
80 | per_axis_sizes[axis][shape] += whole_stripes; |
81 | } |
82 | i += whole_stripes - 1; |
83 | start += whole_stripes * stride; |
84 | high = std::min(start - stride + shape, extent); |
85 | continue; |
86 | } |
87 | } |
88 | low = std::max(start, 0); |
89 | if (enable_sliding_window) { |
90 | low = std::max(low, high); |
91 | } |
92 | high = std::min(start + shape, extent); |
93 | int size = high - low; |
94 | if (size > 0) { |
95 | per_axis_sizes[axis][size]++; |
96 | } |
97 | start += stride; |
98 | } |
99 | } |
100 | return MultiplyCombinations(per_axis_sizes); |
101 | } |
102 | |
103 | TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.CountStripes" ) |
104 | .set_body_typed([](StripeConfig stripe_config, bool enable_sliding_window) { |
105 | Map<Array<Integer>, Integer> ret; |
106 | auto stripe_counts = CountStripes(stripe_config, enable_sliding_window); |
107 | for (const auto& it : stripe_counts) { |
108 | ret.Set(make_array(it.first), it.second); |
109 | } |
110 | return ret; |
111 | }); |
112 | |
113 | void StripeConfigNode::VisitAttrs(AttrVisitor* v) { |
114 | Array<Integer> tmp_arr = make_array(shape_); |
115 | v->Visit("_shape" , &tmp_arr); |
116 | tmp_arr = make_array(extent_); |
117 | v->Visit("_extent" , &tmp_arr); |
118 | tmp_arr = make_array(order_); |
119 | v->Visit("_order" , &tmp_arr); |
120 | tmp_arr = make_array(stripes_); |
121 | v->Visit("_stripes" , &tmp_arr); |
122 | tmp_arr = make_array(offset_); |
123 | v->Visit("_offset" , &tmp_arr); |
124 | Array<FloatImm> tmp_float_arr = make_array(strides_); |
125 | v->Visit("_strides" , &tmp_float_arr); |
126 | int64_t tmp_hash = static_cast<int64_t>(hash_); |
127 | v->Visit("_hash" , &tmp_hash); |
128 | } |
129 | |
130 | void StripeConfigNode::ComputeHash_() { |
131 | hash_ = hash_vector(shape_); |
132 | hash_combine(&hash_, hash_vector(extent_)); |
133 | hash_combine(&hash_, hash_vector(strides_)); |
134 | hash_combine(&hash_, hash_vector(order_)); |
135 | hash_combine(&hash_, hash_vector(stripes_)); |
136 | hash_combine(&hash_, hash_vector(offset_)); |
137 | } |
138 | |
139 | StripeConfig::StripeConfig(const std::vector<int>& shape, const std::vector<int>& extent, |
140 | const std::vector<float>& strides, const std::vector<int>& order, |
141 | const std::vector<int>& stripes, const std::vector<int>& offset) { |
142 | auto n = make_object<StripeConfigNode>(); |
143 | n->shape_ = std::move(shape); |
144 | n->extent_ = std::move(extent); |
145 | n->strides_ = std::move(strides); |
146 | n->order_ = std::move(order); |
147 | n->stripes_ = std::move(stripes); |
148 | n->offset_ = std::move(offset); |
149 | n->ComputeHash_(); |
150 | data_ = std::move(n); |
151 | } |
152 | |
153 | inline bool StripeConfig::operator==(const StripeConfig& other) const { |
154 | if (get() == other.get()) return true; |
155 | if (get() == nullptr || other.get() == nullptr) return false; |
156 | return ((*this)->shape_ == other->shape_ && (*this)->extent_ == other->extent_ && |
157 | (*this)->strides_ == other->strides_ && (*this)->order_ == other->order_ && |
158 | (*this)->stripes_ == other->stripes_ && (*this)->offset_ == other->offset_); |
159 | } |
160 | |
161 | TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.StripeConfig" ) |
162 | .set_body_typed([](Array<Integer> shape, Array<Integer> extent, Array<FloatImm> strides, |
163 | Array<Integer> order, Array<Integer> stripes, Array<Integer> offset) { |
164 | std::vector<int> vshape = make_vector<int, Integer>(shape); |
165 | std::vector<int> vextent = make_vector<int, Integer>(extent); |
166 | std::vector<float> vstrides = make_vector<float, FloatImm>(strides); |
167 | std::vector<int> vorder = make_vector<int, Integer>(order); |
168 | std::vector<int> vstripes = make_vector<int, Integer>(stripes); |
169 | std::vector<int> voffset = make_vector<int, Integer>(offset); |
170 | return StripeConfig(vshape, vextent, vstrides, vorder, vstripes, voffset); |
171 | }); |
172 | |
173 | TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.StripeConfigEqual" ) |
174 | .set_body_method(&StripeConfig::operator==); |
175 | |
176 | TVM_REGISTER_NODE_TYPE(StripeConfigNode); |
177 | |
178 | } // namespace cascader |
179 | } // namespace ethosu |
180 | } // namespace contrib |
181 | } // namespace tvm |
182 | |