1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "tensor_config.h"
20
21#include <tvm/runtime/container/array.h>
22#include <tvm/runtime/object.h>
23#include <tvm/runtime/registry.h>
24
25#include <string>
26#include <utility>
27#include <vector>
28
29#include "common.h"
30
31namespace tvm {
32namespace contrib {
33namespace ethosu {
34namespace cascader {
35
36void MemoryRegionNode::VisitAttrs(AttrVisitor* v) {
37 v->Visit("name", &name);
38 v->Visit("size", &size);
39 v->Visit("read_bandwidth", &read_bandwidth);
40 v->Visit("write_bandwidth", &write_bandwidth);
41 v->Visit("read_latency", &read_latency);
42 v->Visit("write_latency", &write_latency);
43 v->Visit("burst_length", &burst_length);
44}
45
46TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.MemoryRegion")
47 .set_body_typed([](String name, int size, int read_bandwidth, int write_bandwidth,
48 int read_latency, int write_latency, int burst_length) {
49 return MemoryRegion(name, size, read_bandwidth, write_bandwidth, read_latency, write_latency,
50 burst_length);
51 });
52
53TVM_REGISTER_NODE_TYPE(MemoryRegionNode);
54
55void TensorConfigNode::VisitAttrs(AttrVisitor* v) {
56 v->Visit("_tensor", &tensor_);
57 v->Visit("_home_region", &home_region_);
58 int state = static_cast<int>(state_);
59 v->Visit("_state", &state);
60 int buffer_mode = static_cast<int>(buffer_mode_);
61 v->Visit("_buffer_mode", &buffer_mode);
62 Array<StripeConfig> tmp_arr(stripe_configs_);
63 v->Visit("_stripe_configs", &tmp_arr);
64 v->Visit("_copy_tensor", &copy_tensor_);
65 v->Visit("_copy_region", &copy_region_);
66 int64_t tmp_hash = static_cast<int64_t>(hash_);
67 v->Visit("_hash", &tmp_hash);
68}
69
70int TensorConfigNode::GetBufferSize() const {
71 if (buffer_mode_ == BufferMode::RECOMPUTE) {
72 return GetRecomputeBufferSize_();
73 } else {
74 return GetRollingBufferSize_();
75 }
76}
77
78void TensorConfigNode::ComputeHash_() {
79 hash_ = ObjectHash()(tensor_);
80 hash_combine(&hash_, std::hash<std::string>()(home_region_->name));
81 hash_combine(&hash_, std::hash<int>()(static_cast<int>(state_)));
82 hash_combine(&hash_, std::hash<int>()(static_cast<int>(buffer_mode_)));
83 hash_combine(&hash_, hash_vector(stripe_configs_));
84 hash_combine(&hash_, std::hash<bool>()(copy_tensor_));
85 hash_combine(&hash_, std::hash<std::string>()(copy_region_->name));
86}
87
88int TensorConfigNode::GetRecomputeBufferSize_() const {
89 size_t buffer_size = 0;
90 for (const auto& stripe_config : stripe_configs_) {
91 buffer_size += mul_reduce(stripe_config->GetShape());
92 }
93 return buffer_size * tensor_->GetDataType().bytes() * tensor_->GetCompressionRatio();
94}
95
96int TensorConfigNode::GetRollingBufferSize_() const {
97 int buffer_size = 0;
98 for (const auto& stripe_config : stripe_configs_) {
99 int rolling_axis = -1;
100 for (size_t i = 0; i < stripe_config->GetOrder().size(); i++) {
101 // The axis must be striped (> 1 stripes) and ordered (order != 0)
102 if (stripe_config->GetStripes()[i] > 1 && stripe_config->GetOrder()[i] != 0) {
103 // If we've yet to find a possible rolling axis, use this one
104 if (rolling_axis == -1) {
105 rolling_axis = i;
106 continue;
107 }
108 // Otherwise, replace the rolling axis if the current axis has an earlier order
109 if (stripe_config->GetOrder()[i] < stripe_config->GetOrder()[rolling_axis]) {
110 rolling_axis = i;
111 }
112 }
113 }
114 // If we didn't find a rolling axis, just use axis 0
115 if (rolling_axis == -1) {
116 rolling_axis = 0;
117 }
118 int rolling_size = 1;
119 for (size_t i = 0; i < tensor_->GetShape().size(); i++) {
120 if (static_cast<int>(i) == rolling_axis) {
121 rolling_size *= stripe_config->GetShape()[i];
122 } else {
123 rolling_size *= tensor_->GetShape()[i];
124 }
125 }
126 buffer_size += rolling_size;
127 }
128 return buffer_size * tensor_->GetDataType().bytes() * tensor_->GetCompressionRatio();
129}
130
131TensorConfig::TensorConfig(const Tensor& tensor, const MemoryRegion& home_region,
132 TensorConfigState state, BufferMode buffer_mode,
133 const std::vector<StripeConfig>& stripe_configs, bool copy_tensor,
134 const MemoryRegion& copy_region) {
135 auto n = make_object<TensorConfigNode>();
136 n->tensor_ = std::move(tensor);
137 n->home_region_ = std::move(home_region);
138 n->state_ = state;
139 n->buffer_mode_ = buffer_mode;
140 n->stripe_configs_ = std::move(stripe_configs);
141 n->copy_tensor_ = copy_tensor;
142 n->copy_region_ = std::move(copy_region);
143 n->ComputeHash_();
144 data_ = std::move(n);
145}
146
147inline bool TensorConfig::operator==(const TensorConfig& other) const {
148 if (get() == other.get()) return true;
149 if (get() == nullptr || other.get() == nullptr) return false;
150 if ((*this)->tensor_ == other->tensor_ && (*this)->home_region_ == other->home_region_ &&
151 (*this)->state_ == other->state_ && (*this)->buffer_mode_ == other->buffer_mode_ &&
152 (*this)->stripe_configs_ == other->stripe_configs_ &&
153 (*this)->copy_tensor_ == other->copy_tensor_ &&
154 (*this)->copy_region_ == other->copy_region_) {
155 return true;
156 }
157 return false;
158}
159
160TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.TensorConfig")
161 .set_body_typed([](Tensor tensor, MemoryRegion home_region, int state, int buffer_mode,
162 Array<StripeConfig> stripe_configs, bool copy_tensor,
163 MemoryRegion copy_region) {
164 TensorConfigState estate = static_cast<TensorConfigState>(state);
165 BufferMode ebuffer_mode = static_cast<BufferMode>(buffer_mode);
166 std::vector<StripeConfig> vstripe_configs(stripe_configs.begin(), stripe_configs.end());
167 return TensorConfig(tensor, home_region, estate, ebuffer_mode, vstripe_configs, copy_tensor,
168 copy_region);
169 });
170
171TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.TensorConfigEqual")
172 .set_body_method(&TensorConfig::operator==);
173
174TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.TensorConfigGetBufferSize")
175 .set_body_method<TensorConfig>(&TensorConfigNode::GetBufferSize);
176
177TVM_REGISTER_NODE_TYPE(TensorConfigNode);
178
179} // namespace cascader
180} // namespace ethosu
181} // namespace contrib
182} // namespace tvm
183