1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file attr_functor.h
22 * \brief A way to define arbitrary function signature
23 * with dispatch on common attributes.
24 *
25 * Common attributes include:
26 * - int, float, str constants
27 * - array of attributes
28 * - map of attributes
29 */
30#ifndef TVM_IR_ATTR_FUNCTOR_H_
31#define TVM_IR_ATTR_FUNCTOR_H_
32
33#include <tvm/node/functor.h>
34#include <tvm/tir/expr.h>
35
36#include <utility>
37
38namespace tvm {
39
40template <typename FType>
41class AttrFunctor;
42
43#define ATTR_FUNCTOR_DEFAULT \
44 { return VisitAttrDefault_(op, std::forward<Args>(args)...); }
45
46#define ATTR_FUNCTOR_DISPATCH(OP) \
47 vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, Args... args) { \
48 return self->VisitAttr_(static_cast<const OP*>(n.get()), std::forward<Args>(args)...); \
49 });
50
51// A functor for common attribute information.
52template <typename R, typename... Args>
53class AttrFunctor<R(const ObjectRef& n, Args...)> {
54 private:
55 using TSelf = AttrFunctor<R(const ObjectRef& n, Args...)>;
56 using FType = tvm::NodeFunctor<R(const ObjectRef& n, TSelf* self, Args...)>;
57
58 public:
59 /*! \brief the result type of this functor */
60 using result_type = R;
61 /*! \brief virtual destructor */
62 virtual ~AttrFunctor() {}
63 /*!
64 * \brief The functor call.
65 * \param n The expression node.
66 * \param args Additional arguments.
67 * \return The result of the call
68 */
69 virtual R VisitAttr(const ObjectRef& n, Args... args) {
70 static FType vtable = InitVTable();
71 if (vtable.can_dispatch(n)) {
72 return vtable(n, this, std::forward<Args>(args)...);
73 } else {
74 return VisitAttrDefault_(n.get(), std::forward<Args>(args)...);
75 }
76 }
77 virtual R VisitAttrDefault_(const Object* node, Args... args) = 0;
78 virtual R VisitAttr_(const ArrayNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
79 virtual R VisitAttr_(const tir::IntImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
80 virtual R VisitAttr_(const tir::FloatImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
81 virtual R VisitAttr_(const tir::StringImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
82 // deep comparison of symbolic integer expressions.
83 virtual R VisitAttr_(const tir::VarNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
84 virtual R VisitAttr_(const tir::SizeVarNode* op, Args... args) {
85 return VisitAttr_(static_cast<const tir::VarNode*>(op), std::forward<Args>(args)...);
86 }
87 virtual R VisitAttr_(const tir::AddNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
88 virtual R VisitAttr_(const tir::SubNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
89 virtual R VisitAttr_(const tir::MulNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
90 virtual R VisitAttr_(const tir::DivNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
91 virtual R VisitAttr_(const tir::ModNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
92 virtual R VisitAttr_(const tir::FloorDivNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
93 virtual R VisitAttr_(const tir::FloorModNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
94 virtual R VisitAttr_(const tir::MinNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
95 virtual R VisitAttr_(const tir::MaxNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
96 virtual R VisitAttr_(const tir::GENode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
97 virtual R VisitAttr_(const tir::GTNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
98 virtual R VisitAttr_(const tir::LTNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
99 virtual R VisitAttr_(const tir::LENode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
100 virtual R VisitAttr_(const tir::EQNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
101 virtual R VisitAttr_(const tir::NENode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
102 virtual R VisitAttr_(const tir::AndNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
103 virtual R VisitAttr_(const tir::OrNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
104 virtual R VisitAttr_(const tir::NotNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
105 virtual R VisitAttr_(const tir::CastNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
106 virtual R VisitAttr_(const tir::CallNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
107 virtual R VisitAttr_(const tir::SelectNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
108
109 private:
110 // initialize the vtable.
111 static FType InitVTable() {
112 using namespace tir;
113 FType vtable;
114 // Set dispatch
115 ATTR_FUNCTOR_DISPATCH(ArrayNode);
116 ATTR_FUNCTOR_DISPATCH(IntImmNode);
117 ATTR_FUNCTOR_DISPATCH(FloatImmNode);
118 ATTR_FUNCTOR_DISPATCH(StringImmNode);
119 ATTR_FUNCTOR_DISPATCH(VarNode);
120 ATTR_FUNCTOR_DISPATCH(SizeVarNode);
121 ATTR_FUNCTOR_DISPATCH(AddNode);
122 ATTR_FUNCTOR_DISPATCH(SubNode);
123 ATTR_FUNCTOR_DISPATCH(MulNode);
124 ATTR_FUNCTOR_DISPATCH(DivNode);
125 ATTR_FUNCTOR_DISPATCH(ModNode);
126 ATTR_FUNCTOR_DISPATCH(FloorDivNode);
127 ATTR_FUNCTOR_DISPATCH(FloorModNode);
128 ATTR_FUNCTOR_DISPATCH(MinNode);
129 ATTR_FUNCTOR_DISPATCH(MaxNode);
130 ATTR_FUNCTOR_DISPATCH(GENode);
131 ATTR_FUNCTOR_DISPATCH(GTNode);
132 ATTR_FUNCTOR_DISPATCH(LENode);
133 ATTR_FUNCTOR_DISPATCH(LTNode);
134 ATTR_FUNCTOR_DISPATCH(EQNode);
135 ATTR_FUNCTOR_DISPATCH(NENode);
136 ATTR_FUNCTOR_DISPATCH(AndNode);
137 ATTR_FUNCTOR_DISPATCH(OrNode);
138 ATTR_FUNCTOR_DISPATCH(NotNode);
139 ATTR_FUNCTOR_DISPATCH(CastNode);
140 ATTR_FUNCTOR_DISPATCH(CallNode);
141 ATTR_FUNCTOR_DISPATCH(SelectNode);
142 return vtable;
143 }
144};
145
146} // namespace tvm
147#endif // TVM_IR_ATTR_FUNCTOR_H_
148