1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include "propagator.h" |
20 | |
21 | #include <tvm/relay/expr.h> |
22 | #include <tvm/runtime/container/array.h> |
23 | #include <tvm/runtime/object.h> |
24 | |
25 | #include <utility> |
26 | #include <vector> |
27 | |
28 | #include "common.h" |
29 | #include "stripe_config.h" |
30 | |
31 | namespace tvm { |
32 | namespace contrib { |
33 | namespace ethosu { |
34 | namespace cascader { |
35 | |
36 | void PropagatorNode::VisitAttrs(AttrVisitor* v) { |
37 | Array<Array<FloatImm>> tmp_transform; |
38 | for (const auto& vec : transform_) { |
39 | tmp_transform.push_back(make_array(vec)); |
40 | } |
41 | v->Visit("_transform" , &tmp_transform); |
42 | Array<Integer> tmp_arr = make_array(offset_); |
43 | v->Visit("_offset" , &tmp_arr); |
44 | } |
45 | |
46 | Propagator::Propagator(const std::vector<std::vector<float>>& transform, |
47 | const std::vector<int>& offset) { |
48 | auto n = make_object<PropagatorNode>(); |
49 | size_t rows = transform.size(); |
50 | ICHECK_GT(rows, 0) << "The transform matrix must have at least 1 row." ; |
51 | size_t columns = transform[0].size(); |
52 | for (const auto& row : transform) { |
53 | ICHECK_EQ(row.size(), columns) |
54 | << "All rows of the transform matrix must be of the same length." ; |
55 | } |
56 | ICHECK_EQ(offset.size(), rows - 1) |
57 | << "The offset vector length must be equal to the transform matrix rows - 1." ; |
58 | n->transform_ = std::move(transform); |
59 | n->offset_ = std::move(offset); |
60 | data_ = std::move(n); |
61 | } |
62 | |
63 | StripeConfig PropagatorNode::propagate(const StripeConfig& stripe_config) const { |
64 | size_t input_dimensions = transform_[0].size() - 1; |
65 | size_t output_dimensions = transform_.size() - 1; |
66 | auto n = make_object<StripeConfigNode>(); |
67 | n->shape_.resize(output_dimensions); |
68 | n->extent_.resize(output_dimensions); |
69 | n->strides_.resize(output_dimensions); |
70 | n->order_.resize(output_dimensions); |
71 | n->stripes_.resize(output_dimensions); |
72 | n->offset_.resize(output_dimensions); |
73 | for (size_t i = 0; i < output_dimensions; i++) { |
74 | float new_shape_acc{}; |
75 | float new_extent_acc{}; |
76 | const float* row = &transform_[i][0]; |
77 | for (size_t j = 0; j < input_dimensions; j++) { |
78 | new_shape_acc += row[j] * stripe_config->shape_[j]; |
79 | new_extent_acc += row[j] * stripe_config->extent_[j]; |
80 | n->strides_[i] += row[j] * stripe_config->strides_[j]; |
81 | // Order, stripes and offset should only get re-ordered, so we only |
82 | // care about whether or not transform elements are non-zero. |
83 | int non_zero = row[j] != 0; |
84 | n->order_[i] += non_zero * stripe_config->order_[j]; |
85 | n->stripes_[i] += non_zero * stripe_config->stripes_[j]; |
86 | n->offset_[i] += non_zero * stripe_config->offset_[j]; |
87 | } |
88 | // Shape and extent gain an additional constant term |
89 | new_shape_acc += row[input_dimensions]; |
90 | new_extent_acc += row[input_dimensions]; |
91 | // Shape and extent are ceil-rounded back to integers |
92 | n->shape_[i] = std::ceil(new_shape_acc); |
93 | n->extent_[i] += std::ceil(new_extent_acc); |
94 | // Apply the offset |
95 | n->offset_[i] += offset_[i]; |
96 | // No axis can have '0 stripes', so change all 0 elements to 1 |
97 | n->stripes_[i] = n->stripes_[i] == 0 ? 1 : n->stripes_[i]; |
98 | } |
99 | // Remember to compute the hash |
100 | n->ComputeHash_(); |
101 | return StripeConfig(n); |
102 | } |
103 | |
104 | TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.Propagator" ) |
105 | .set_body_typed([](Array<Array<FloatImm>> transform, Array<Integer> offset) { |
106 | std::vector<std::vector<float>> vtransform; |
107 | for (const auto& vec : transform) { |
108 | vtransform.push_back(make_vector<float, FloatImm>(vec)); |
109 | } |
110 | std::vector<int> voffset = make_vector<int, Integer>(offset); |
111 | return Propagator(vtransform, voffset); |
112 | }); |
113 | |
114 | TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.PropagatorPropagate" ) |
115 | .set_body_typed([](Propagator propagator, StripeConfig stripe_config) { |
116 | return propagator->propagate(stripe_config); |
117 | }); |
118 | |
119 | TVM_REGISTER_NODE_TYPE(PropagatorNode); |
120 | |
121 | } // namespace cascader |
122 | } // namespace ethosu |
123 | } // namespace contrib |
124 | } // namespace tvm |
125 | |