1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file attr_functor.h |
22 | * \brief A way to define arbitrary function signature |
23 | * with dispatch on common attributes. |
24 | * |
25 | * Common attributes include: |
26 | * - int, float, str constants |
27 | * - array of attributes |
28 | * - map of attributes |
29 | */ |
30 | #ifndef TVM_IR_ATTR_FUNCTOR_H_ |
31 | #define TVM_IR_ATTR_FUNCTOR_H_ |
32 | |
33 | #include <tvm/node/functor.h> |
34 | #include <tvm/tir/expr.h> |
35 | |
36 | #include <utility> |
37 | |
38 | namespace tvm { |
39 | |
40 | template <typename FType> |
41 | class AttrFunctor; |
42 | |
43 | #define ATTR_FUNCTOR_DEFAULT \ |
44 | { return VisitAttrDefault_(op, std::forward<Args>(args)...); } |
45 | |
46 | #define ATTR_FUNCTOR_DISPATCH(OP) \ |
47 | vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, Args... args) { \ |
48 | return self->VisitAttr_(static_cast<const OP*>(n.get()), std::forward<Args>(args)...); \ |
49 | }); |
50 | |
51 | // A functor for common attribute information. |
52 | template <typename R, typename... Args> |
53 | class AttrFunctor<R(const ObjectRef& n, Args...)> { |
54 | private: |
55 | using TSelf = AttrFunctor<R(const ObjectRef& n, Args...)>; |
56 | using FType = tvm::NodeFunctor<R(const ObjectRef& n, TSelf* self, Args...)>; |
57 | |
58 | public: |
59 | /*! \brief the result type of this functor */ |
60 | using result_type = R; |
61 | /*! \brief virtual destructor */ |
62 | virtual ~AttrFunctor() {} |
63 | /*! |
64 | * \brief The functor call. |
65 | * \param n The expression node. |
66 | * \param args Additional arguments. |
67 | * \return The result of the call |
68 | */ |
69 | virtual R VisitAttr(const ObjectRef& n, Args... args) { |
70 | static FType vtable = InitVTable(); |
71 | if (vtable.can_dispatch(n)) { |
72 | return vtable(n, this, std::forward<Args>(args)...); |
73 | } else { |
74 | return VisitAttrDefault_(n.get(), std::forward<Args>(args)...); |
75 | } |
76 | } |
77 | virtual R VisitAttrDefault_(const Object* node, Args... args) = 0; |
78 | virtual R VisitAttr_(const ArrayNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; |
79 | virtual R VisitAttr_(const tir::IntImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; |
80 | virtual R VisitAttr_(const tir::FloatImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; |
81 | virtual R VisitAttr_(const tir::StringImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; |
82 | // deep comparison of symbolic integer expressions. |
83 | virtual R VisitAttr_(const tir::VarNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; |
84 | virtual R VisitAttr_(const tir::SizeVarNode* op, Args... args) { |
85 | return VisitAttr_(static_cast<const tir::VarNode*>(op), std::forward<Args>(args)...); |
86 | } |
87 | virtual R VisitAttr_(const tir::AddNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; |
88 | virtual R VisitAttr_(const tir::SubNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; |
89 | virtual R VisitAttr_(const tir::MulNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; |
90 | virtual R VisitAttr_(const tir::DivNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; |
91 | virtual R VisitAttr_(const tir::ModNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; |
92 | virtual R VisitAttr_(const tir::FloorDivNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; |
93 | virtual R VisitAttr_(const tir::FloorModNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; |
94 | virtual R VisitAttr_(const tir::MinNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; |
95 | virtual R VisitAttr_(const tir::MaxNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; |
96 | virtual R VisitAttr_(const tir::GENode* op, Args... args) ATTR_FUNCTOR_DEFAULT; |
97 | virtual R VisitAttr_(const tir::GTNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; |
98 | virtual R VisitAttr_(const tir::LTNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; |
99 | virtual R VisitAttr_(const tir::LENode* op, Args... args) ATTR_FUNCTOR_DEFAULT; |
100 | virtual R VisitAttr_(const tir::EQNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; |
101 | virtual R VisitAttr_(const tir::NENode* op, Args... args) ATTR_FUNCTOR_DEFAULT; |
102 | virtual R VisitAttr_(const tir::AndNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; |
103 | virtual R VisitAttr_(const tir::OrNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; |
104 | virtual R VisitAttr_(const tir::NotNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; |
105 | virtual R VisitAttr_(const tir::CastNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; |
106 | virtual R VisitAttr_(const tir::CallNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; |
107 | virtual R VisitAttr_(const tir::SelectNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; |
108 | |
109 | private: |
110 | // initialize the vtable. |
111 | static FType InitVTable() { |
112 | using namespace tir; |
113 | FType vtable; |
114 | // Set dispatch |
115 | ATTR_FUNCTOR_DISPATCH(ArrayNode); |
116 | ATTR_FUNCTOR_DISPATCH(IntImmNode); |
117 | ATTR_FUNCTOR_DISPATCH(FloatImmNode); |
118 | ATTR_FUNCTOR_DISPATCH(StringImmNode); |
119 | ATTR_FUNCTOR_DISPATCH(VarNode); |
120 | ATTR_FUNCTOR_DISPATCH(SizeVarNode); |
121 | ATTR_FUNCTOR_DISPATCH(AddNode); |
122 | ATTR_FUNCTOR_DISPATCH(SubNode); |
123 | ATTR_FUNCTOR_DISPATCH(MulNode); |
124 | ATTR_FUNCTOR_DISPATCH(DivNode); |
125 | ATTR_FUNCTOR_DISPATCH(ModNode); |
126 | ATTR_FUNCTOR_DISPATCH(FloorDivNode); |
127 | ATTR_FUNCTOR_DISPATCH(FloorModNode); |
128 | ATTR_FUNCTOR_DISPATCH(MinNode); |
129 | ATTR_FUNCTOR_DISPATCH(MaxNode); |
130 | ATTR_FUNCTOR_DISPATCH(GENode); |
131 | ATTR_FUNCTOR_DISPATCH(GTNode); |
132 | ATTR_FUNCTOR_DISPATCH(LENode); |
133 | ATTR_FUNCTOR_DISPATCH(LTNode); |
134 | ATTR_FUNCTOR_DISPATCH(EQNode); |
135 | ATTR_FUNCTOR_DISPATCH(NENode); |
136 | ATTR_FUNCTOR_DISPATCH(AndNode); |
137 | ATTR_FUNCTOR_DISPATCH(OrNode); |
138 | ATTR_FUNCTOR_DISPATCH(NotNode); |
139 | ATTR_FUNCTOR_DISPATCH(CastNode); |
140 | ATTR_FUNCTOR_DISPATCH(CallNode); |
141 | ATTR_FUNCTOR_DISPATCH(SelectNode); |
142 | return vtable; |
143 | } |
144 | }; |
145 | |
146 | } // namespace tvm |
147 | #endif // TVM_IR_ATTR_FUNCTOR_H_ |
148 | |