1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include "tensor_config.h" |
20 | |
21 | #include <tvm/runtime/container/array.h> |
22 | #include <tvm/runtime/object.h> |
23 | #include <tvm/runtime/registry.h> |
24 | |
25 | #include <string> |
26 | #include <utility> |
27 | #include <vector> |
28 | |
29 | #include "common.h" |
30 | |
31 | namespace tvm { |
32 | namespace contrib { |
33 | namespace ethosu { |
34 | namespace cascader { |
35 | |
36 | void MemoryRegionNode::VisitAttrs(AttrVisitor* v) { |
37 | v->Visit("name" , &name); |
38 | v->Visit("size" , &size); |
39 | v->Visit("read_bandwidth" , &read_bandwidth); |
40 | v->Visit("write_bandwidth" , &write_bandwidth); |
41 | v->Visit("read_latency" , &read_latency); |
42 | v->Visit("write_latency" , &write_latency); |
43 | v->Visit("burst_length" , &burst_length); |
44 | } |
45 | |
46 | TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.MemoryRegion" ) |
47 | .set_body_typed([](String name, int size, int read_bandwidth, int write_bandwidth, |
48 | int read_latency, int write_latency, int burst_length) { |
49 | return MemoryRegion(name, size, read_bandwidth, write_bandwidth, read_latency, write_latency, |
50 | burst_length); |
51 | }); |
52 | |
53 | TVM_REGISTER_NODE_TYPE(MemoryRegionNode); |
54 | |
55 | void TensorConfigNode::VisitAttrs(AttrVisitor* v) { |
56 | v->Visit("_tensor" , &tensor_); |
57 | v->Visit("_home_region" , &home_region_); |
58 | int state = static_cast<int>(state_); |
59 | v->Visit("_state" , &state); |
60 | int buffer_mode = static_cast<int>(buffer_mode_); |
61 | v->Visit("_buffer_mode" , &buffer_mode); |
62 | Array<StripeConfig> tmp_arr(stripe_configs_); |
63 | v->Visit("_stripe_configs" , &tmp_arr); |
64 | v->Visit("_copy_tensor" , ©_tensor_); |
65 | v->Visit("_copy_region" , ©_region_); |
66 | int64_t tmp_hash = static_cast<int64_t>(hash_); |
67 | v->Visit("_hash" , &tmp_hash); |
68 | } |
69 | |
70 | int TensorConfigNode::GetBufferSize() const { |
71 | if (buffer_mode_ == BufferMode::RECOMPUTE) { |
72 | return GetRecomputeBufferSize_(); |
73 | } else { |
74 | return GetRollingBufferSize_(); |
75 | } |
76 | } |
77 | |
78 | void TensorConfigNode::ComputeHash_() { |
79 | hash_ = ObjectHash()(tensor_); |
80 | hash_combine(&hash_, std::hash<std::string>()(home_region_->name)); |
81 | hash_combine(&hash_, std::hash<int>()(static_cast<int>(state_))); |
82 | hash_combine(&hash_, std::hash<int>()(static_cast<int>(buffer_mode_))); |
83 | hash_combine(&hash_, hash_vector(stripe_configs_)); |
84 | hash_combine(&hash_, std::hash<bool>()(copy_tensor_)); |
85 | hash_combine(&hash_, std::hash<std::string>()(copy_region_->name)); |
86 | } |
87 | |
88 | int TensorConfigNode::GetRecomputeBufferSize_() const { |
89 | size_t buffer_size = 0; |
90 | for (const auto& stripe_config : stripe_configs_) { |
91 | buffer_size += mul_reduce(stripe_config->GetShape()); |
92 | } |
93 | return buffer_size * tensor_->GetDataType().bytes() * tensor_->GetCompressionRatio(); |
94 | } |
95 | |
96 | int TensorConfigNode::GetRollingBufferSize_() const { |
97 | int buffer_size = 0; |
98 | for (const auto& stripe_config : stripe_configs_) { |
99 | int rolling_axis = -1; |
100 | for (size_t i = 0; i < stripe_config->GetOrder().size(); i++) { |
101 | // The axis must be striped (> 1 stripes) and ordered (order != 0) |
102 | if (stripe_config->GetStripes()[i] > 1 && stripe_config->GetOrder()[i] != 0) { |
103 | // If we've yet to find a possible rolling axis, use this one |
104 | if (rolling_axis == -1) { |
105 | rolling_axis = i; |
106 | continue; |
107 | } |
108 | // Otherwise, replace the rolling axis if the current axis has an earlier order |
109 | if (stripe_config->GetOrder()[i] < stripe_config->GetOrder()[rolling_axis]) { |
110 | rolling_axis = i; |
111 | } |
112 | } |
113 | } |
114 | // If we didn't find a rolling axis, just use axis 0 |
115 | if (rolling_axis == -1) { |
116 | rolling_axis = 0; |
117 | } |
118 | int rolling_size = 1; |
119 | for (size_t i = 0; i < tensor_->GetShape().size(); i++) { |
120 | if (static_cast<int>(i) == rolling_axis) { |
121 | rolling_size *= stripe_config->GetShape()[i]; |
122 | } else { |
123 | rolling_size *= tensor_->GetShape()[i]; |
124 | } |
125 | } |
126 | buffer_size += rolling_size; |
127 | } |
128 | return buffer_size * tensor_->GetDataType().bytes() * tensor_->GetCompressionRatio(); |
129 | } |
130 | |
131 | TensorConfig::TensorConfig(const Tensor& tensor, const MemoryRegion& home_region, |
132 | TensorConfigState state, BufferMode buffer_mode, |
133 | const std::vector<StripeConfig>& stripe_configs, bool copy_tensor, |
134 | const MemoryRegion& copy_region) { |
135 | auto n = make_object<TensorConfigNode>(); |
136 | n->tensor_ = std::move(tensor); |
137 | n->home_region_ = std::move(home_region); |
138 | n->state_ = state; |
139 | n->buffer_mode_ = buffer_mode; |
140 | n->stripe_configs_ = std::move(stripe_configs); |
141 | n->copy_tensor_ = copy_tensor; |
142 | n->copy_region_ = std::move(copy_region); |
143 | n->ComputeHash_(); |
144 | data_ = std::move(n); |
145 | } |
146 | |
147 | inline bool TensorConfig::operator==(const TensorConfig& other) const { |
148 | if (get() == other.get()) return true; |
149 | if (get() == nullptr || other.get() == nullptr) return false; |
150 | if ((*this)->tensor_ == other->tensor_ && (*this)->home_region_ == other->home_region_ && |
151 | (*this)->state_ == other->state_ && (*this)->buffer_mode_ == other->buffer_mode_ && |
152 | (*this)->stripe_configs_ == other->stripe_configs_ && |
153 | (*this)->copy_tensor_ == other->copy_tensor_ && |
154 | (*this)->copy_region_ == other->copy_region_) { |
155 | return true; |
156 | } |
157 | return false; |
158 | } |
159 | |
160 | TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.TensorConfig" ) |
161 | .set_body_typed([](Tensor tensor, MemoryRegion home_region, int state, int buffer_mode, |
162 | Array<StripeConfig> stripe_configs, bool copy_tensor, |
163 | MemoryRegion copy_region) { |
164 | TensorConfigState estate = static_cast<TensorConfigState>(state); |
165 | BufferMode ebuffer_mode = static_cast<BufferMode>(buffer_mode); |
166 | std::vector<StripeConfig> vstripe_configs(stripe_configs.begin(), stripe_configs.end()); |
167 | return TensorConfig(tensor, home_region, estate, ebuffer_mode, vstripe_configs, copy_tensor, |
168 | copy_region); |
169 | }); |
170 | |
171 | TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.TensorConfigEqual" ) |
172 | .set_body_method(&TensorConfig::operator==); |
173 | |
174 | TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.TensorConfigGetBufferSize" ) |
175 | .set_body_method<TensorConfig>(&TensorConfigNode::GetBufferSize); |
176 | |
177 | TVM_REGISTER_NODE_TYPE(TensorConfigNode); |
178 | |
179 | } // namespace cascader |
180 | } // namespace ethosu |
181 | } // namespace contrib |
182 | } // namespace tvm |
183 | |