1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "propagator.h"
20
21#include <tvm/relay/expr.h>
22#include <tvm/runtime/container/array.h>
23#include <tvm/runtime/object.h>
24
25#include <utility>
26#include <vector>
27
28#include "common.h"
29#include "stripe_config.h"
30
31namespace tvm {
32namespace contrib {
33namespace ethosu {
34namespace cascader {
35
36void PropagatorNode::VisitAttrs(AttrVisitor* v) {
37 Array<Array<FloatImm>> tmp_transform;
38 for (const auto& vec : transform_) {
39 tmp_transform.push_back(make_array(vec));
40 }
41 v->Visit("_transform", &tmp_transform);
42 Array<Integer> tmp_arr = make_array(offset_);
43 v->Visit("_offset", &tmp_arr);
44}
45
46Propagator::Propagator(const std::vector<std::vector<float>>& transform,
47 const std::vector<int>& offset) {
48 auto n = make_object<PropagatorNode>();
49 size_t rows = transform.size();
50 ICHECK_GT(rows, 0) << "The transform matrix must have at least 1 row.";
51 size_t columns = transform[0].size();
52 for (const auto& row : transform) {
53 ICHECK_EQ(row.size(), columns)
54 << "All rows of the transform matrix must be of the same length.";
55 }
56 ICHECK_EQ(offset.size(), rows - 1)
57 << "The offset vector length must be equal to the transform matrix rows - 1.";
58 n->transform_ = std::move(transform);
59 n->offset_ = std::move(offset);
60 data_ = std::move(n);
61}
62
63StripeConfig PropagatorNode::propagate(const StripeConfig& stripe_config) const {
64 size_t input_dimensions = transform_[0].size() - 1;
65 size_t output_dimensions = transform_.size() - 1;
66 auto n = make_object<StripeConfigNode>();
67 n->shape_.resize(output_dimensions);
68 n->extent_.resize(output_dimensions);
69 n->strides_.resize(output_dimensions);
70 n->order_.resize(output_dimensions);
71 n->stripes_.resize(output_dimensions);
72 n->offset_.resize(output_dimensions);
73 for (size_t i = 0; i < output_dimensions; i++) {
74 float new_shape_acc{};
75 float new_extent_acc{};
76 const float* row = &transform_[i][0];
77 for (size_t j = 0; j < input_dimensions; j++) {
78 new_shape_acc += row[j] * stripe_config->shape_[j];
79 new_extent_acc += row[j] * stripe_config->extent_[j];
80 n->strides_[i] += row[j] * stripe_config->strides_[j];
81 // Order, stripes and offset should only get re-ordered, so we only
82 // care about whether or not transform elements are non-zero.
83 int non_zero = row[j] != 0;
84 n->order_[i] += non_zero * stripe_config->order_[j];
85 n->stripes_[i] += non_zero * stripe_config->stripes_[j];
86 n->offset_[i] += non_zero * stripe_config->offset_[j];
87 }
88 // Shape and extent gain an additional constant term
89 new_shape_acc += row[input_dimensions];
90 new_extent_acc += row[input_dimensions];
91 // Shape and extent are ceil-rounded back to integers
92 n->shape_[i] = std::ceil(new_shape_acc);
93 n->extent_[i] += std::ceil(new_extent_acc);
94 // Apply the offset
95 n->offset_[i] += offset_[i];
96 // No axis can have '0 stripes', so change all 0 elements to 1
97 n->stripes_[i] = n->stripes_[i] == 0 ? 1 : n->stripes_[i];
98 }
99 // Remember to compute the hash
100 n->ComputeHash_();
101 return StripeConfig(n);
102}
103
104TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.Propagator")
105 .set_body_typed([](Array<Array<FloatImm>> transform, Array<Integer> offset) {
106 std::vector<std::vector<float>> vtransform;
107 for (const auto& vec : transform) {
108 vtransform.push_back(make_vector<float, FloatImm>(vec));
109 }
110 std::vector<int> voffset = make_vector<int, Integer>(offset);
111 return Propagator(vtransform, voffset);
112 });
113
114TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.PropagatorPropagate")
115 .set_body_typed([](Propagator propagator, StripeConfig stripe_config) {
116 return propagator->propagate(stripe_config);
117 });
118
119TVM_REGISTER_NODE_TYPE(PropagatorNode);
120
121} // namespace cascader
122} // namespace ethosu
123} // namespace contrib
124} // namespace tvm
125