1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/ir/adt.h
22 * \brief Algebraic data type definitions.
23 *
24 * We adopt relay's ADT definition as a unified class
25 * for decripting structured data.
26 */
27#ifndef TVM_IR_ADT_H_
28#define TVM_IR_ADT_H_
29
30#include <tvm/ir/expr.h>
31#include <tvm/ir/type.h>
32#include <tvm/node/node.h>
33#include <tvm/runtime/container/adt.h>
34#include <tvm/runtime/container/array.h>
35#include <tvm/runtime/container/string.h>
36#include <tvm/runtime/object.h>
37
38#include <string>
39
40namespace tvm {
41
42/*!
43 * \brief ADT constructor.
44 * Constructors compare by pointer equality.
45 * \sa Constructor
46 */
47class ConstructorNode : public RelayExprNode {
48 public:
49 /*! \brief The name (only a hint) */
50 String name_hint;
51 /*! \brief Input to the constructor. */
52 Array<Type> inputs;
53 /*! \brief The datatype the constructor will construct. */
54 GlobalTypeVar belong_to;
55 /*! \brief Index in the table of constructors (set when the type is registered). */
56 mutable int32_t tag = -1;
57
58 ConstructorNode() {}
59
60 void VisitAttrs(AttrVisitor* v) {
61 v->Visit("name_hint", &name_hint);
62 v->Visit("inputs", &inputs);
63 v->Visit("belong_to", &belong_to);
64 v->Visit("tag", &tag);
65 v->Visit("span", &span);
66 v->Visit("_checked_type_", &checked_type_);
67 }
68
69 bool SEqualReduce(const ConstructorNode* other, SEqualReducer equal) const {
70 // Use namehint for now to be consistent with the legacy relay impl
71 // TODO(tvm-team) revisit, need to check the type var.
72 return equal(name_hint, other->name_hint) && equal(inputs, other->inputs);
73 }
74
75 void SHashReduce(SHashReducer hash_reduce) const {
76 hash_reduce(name_hint);
77 hash_reduce(inputs);
78 }
79
80 static constexpr const char* _type_key = "relay.Constructor";
81 TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorNode, RelayExprNode);
82};
83
84/*!
85 * \brief Managed reference to ConstructorNode
86 * \sa ConstructorNode
87 */
88class Constructor : public RelayExpr {
89 public:
90 /*!
91 * \brief Constructor
92 * \param name_hint the name of the constructor.
93 * \param inputs The input types.
94 * \param belong_to The data type var the constructor will construct.
95 */
96 TVM_DLL Constructor(String name_hint, Array<Type> inputs, GlobalTypeVar belong_to);
97
98 TVM_DEFINE_OBJECT_REF_METHODS(Constructor, RelayExpr, ConstructorNode);
99};
100
101/*! \brief TypeData container node */
102class TypeDataNode : public TypeNode {
103 public:
104 /*!
105 * \brief The header is simply the name of the ADT.
106 * We adopt nominal typing for ADT definitions;
107 * that is, differently-named ADT definitions with same constructors
108 * have different types.
109 */
110 GlobalTypeVar header;
111 /*! \brief The type variables (to allow for polymorphism). */
112 Array<TypeVar> type_vars;
113 /*! \brief The constructors. */
114 Array<Constructor> constructors;
115
116 void VisitAttrs(AttrVisitor* v) {
117 v->Visit("header", &header);
118 v->Visit("type_vars", &type_vars);
119 v->Visit("constructors", &constructors);
120 v->Visit("span", &span);
121 }
122
123 bool SEqualReduce(const TypeDataNode* other, SEqualReducer equal) const {
124 return equal.DefEqual(header, other->header) && equal.DefEqual(type_vars, other->type_vars) &&
125 equal(constructors, other->constructors);
126 }
127
128 void SHashReduce(SHashReducer hash_reduce) const {
129 hash_reduce.DefHash(header);
130 hash_reduce.DefHash(type_vars);
131 hash_reduce(constructors);
132 }
133
134 static constexpr const char* _type_key = "relay.TypeData";
135 TVM_DECLARE_FINAL_OBJECT_INFO(TypeDataNode, TypeNode);
136};
137
138/*!
139 * \brief Stores all data for an Algebraic Data Type (ADT).
140 *
141 * In particular, it stores the handle (global type var) for an ADT
142 * and the constructors used to build it and is kept in the module. Note
143 * that type parameters are also indicated in the type data: this means that
144 * for any instance of an ADT, the type parameters must be indicated. That is,
145 * an ADT definition is treated as a type-level function, so an ADT handle
146 * must be wrapped in a TypeCall node that instantiates the type-level arguments.
147 * The kind checker enforces this.
148 */
149class TypeData : public Type {
150 public:
151 /*!
152 * \brief Constructor
153 * \param header the name of ADT.
154 * \param type_vars type variables.
155 * \param constructors constructors field.
156 */
157 TVM_DLL TypeData(GlobalTypeVar header, Array<TypeVar> type_vars, Array<Constructor> constructors);
158
159 TVM_DEFINE_OBJECT_REF_METHODS(TypeData, Type, TypeDataNode);
160};
161
162} // namespace tvm
163#endif // TVM_IR_ADT_H_
164