1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/contrib/ethosu/cascader/propagator.h |
22 | * \brief Propagator class for the NPU cascader |
23 | */ |
24 | #ifndef TVM_CONTRIB_ETHOSU_CASCADER_PROPAGATOR_H_ |
25 | #define TVM_CONTRIB_ETHOSU_CASCADER_PROPAGATOR_H_ |
26 | |
27 | #include <tvm/node/reflection.h> |
28 | #include <tvm/runtime/object.h> |
29 | |
30 | #include <vector> |
31 | |
32 | namespace tvm { |
33 | namespace contrib { |
34 | namespace ethosu { |
35 | namespace cascader { |
36 | |
37 | class Propagator; |
38 | class StripeConfig; |
39 | |
40 | /*! \brief Node to represent a Propagator */ |
41 | class PropagatorNode : public Object { |
42 | public: |
43 | void VisitAttrs(AttrVisitor* v); |
44 | |
45 | /*! \return The transform matrix to apply to the StripeConfigs */ |
46 | const std::vector<std::vector<float>> GetTransform() const { return transform_; } |
47 | /*! \return The offset vector to apply to the StripeConfigs */ |
48 | const std::vector<int> GetOffset() const { return offset_; } |
49 | /*! \return The number of input dimensions */ |
50 | size_t GetInputDims() const { return offset_.size(); } |
51 | /*! \return The number of output dimensions */ |
52 | size_t GetOutputDims() const { return transform_[0].size() - 1; } |
53 | /*! |
54 | * \brief Propagate a StripeConfig through the transform and offset matrices. |
55 | * \param stripe_config The StripeConfig to propagate. |
56 | * \return The transformed StripeConfig. |
57 | * \note The propagation proceeds as follows: |
58 | * |
59 | * Both the stripe shape and extent have 1 appended to them (so they pick up |
60 | * constant factors from the affine transform) and are then multiplied by the |
61 | * transform matrix. The result is then ceil-rounded and has the trailing 1 |
62 | * stripped to give the new shape and extent. |
63 | * |
64 | * The strides has 0 appended to it (so it doesn't pick up constant factors) |
65 | * and is then multiplied by the transform matrix. The trailing 0 is stripped. |
66 | * |
67 | * For the remaining three values we introduce the concept of the 'binarized' |
68 | * transform matrix. This is the transform matrix but with every non-zero element |
69 | * set to 1. It represents how axes get re-ordered as part of the propagation. |
70 | * |
71 | * [2, 0, 0, 1] [1, 0, 0, 1] |
72 | * [0, 0, 0.4, 2] binarize [0, 0, 1, 1] |
73 | * [0, 1.5, 0, 0] ----> [0, 1, 0, 0] |
74 | * [0, 0, 0, 1] [0, 0, 0, 1] |
75 | * |
76 | * The order has 0 appended to it and is multiplied by the 'binarized' transform |
77 | * matrix. The trailing 0 is then stripped. |
78 | * |
79 | * The stripes has 0 appended to it and multiplied by the 'binarized' transform |
80 | * matrix. The trailing 0 is then stripped and any remaining 0 elements that |
81 | * were introduced by the transform are set instead to 1. |
82 | * |
83 | * The stripe offset is multiplied by the 'binarized' transform matrix and is |
84 | * then summed with the propagator offset. |
85 | */ |
86 | StripeConfig propagate(const StripeConfig& stripe_config) const; |
87 | |
88 | static constexpr const char* _type_key = "contrib.ethosu.cascader.Propagator" ; |
89 | TVM_DECLARE_FINAL_OBJECT_INFO(PropagatorNode, Object); |
90 | |
91 | protected: |
92 | friend class Propagator; |
93 | |
94 | /*! \brief The transform matrix to apply to the StripeConfigs */ |
95 | std::vector<std::vector<float>> transform_; |
96 | /*! \brief The offset vector to apply to the StripeConfigs */ |
97 | std::vector<int> offset_; |
98 | }; |
99 | |
100 | /*! |
101 | * \brief A class to transform StripeConfigs according to the data dependencies |
102 | between Part outputs and inputs. The dependency is represented as an affine |
103 | transformation matrix + an offset vector. Using this, an output StripeConfig |
104 | can be propagated through a Part to arrive at the input StripeConfigs. |
105 | * \note The transform matrix should be a 2D affine transform matrix. |
106 | * As an example, consider a (1, 1, 2, 32) output stripe for an NHWC pooling |
107 | * operation with a 3x3 pool size: |
108 | * |
109 | * [1, 0, 0, 0, 0] [ 1] [ 1] |
110 | * [0, 1, 0, 0, 2] [ 1] [ 3] |
111 | * [0, 0, 1, 0, 2] x [ 2] = [ 4] |
112 | * [0, 0, 0, 1, 0] [32] [32] |
113 | * [0, 0, 0, 0, 1] [ 1] [ 1] |
114 | * |
115 | * Using the appropriate affine matrix we see that the required input data to |
116 | * produce that output stripe is a (1, 3, 4, 32) stripe. These matrices should |
117 | * be derived for the Parts to relate input and output data dependencies. |
118 | * |
119 | * The offset is a 1D vector representing the first tensor element to read. |
120 | * Often this is just the 0 element, but for an operator such as pad it may be |
121 | * negative. For instance, a symmetric padding by 1 of a 2D tensor would require |
122 | * the offset vector [-1, -1]. Additionally, positive offsets may be required |
123 | * for operators like strided_slice where only part of a tensor is read from. |
124 | */ |
125 | class Propagator : public ObjectRef { |
126 | public: |
127 | Propagator(const std::vector<std::vector<float>>& transform, const std::vector<int>& offset); |
128 | |
129 | TVM_DEFINE_OBJECT_REF_METHODS(Propagator, ObjectRef, PropagatorNode); |
130 | }; |
131 | |
132 | } // namespace cascader |
133 | } // namespace ethosu |
134 | } // namespace contrib |
135 | } // namespace tvm |
136 | |
137 | #endif // TVM_CONTRIB_ETHOSU_CASCADER_PROPAGATOR_H_ |
138 | |