1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/ir/type.cc
22 * \brief Common type system AST nodes throughout the IR.
23 */
24#include <tvm/ir/type.h>
25#include <tvm/runtime/registry.h>
26namespace tvm {
27
28PrimType::PrimType(runtime::DataType dtype) {
29 ObjectPtr<PrimTypeNode> n = make_object<PrimTypeNode>();
30 n->dtype = dtype;
31 data_ = std::move(n);
32}
33
34TVM_REGISTER_NODE_TYPE(PrimTypeNode);
35
36TVM_REGISTER_GLOBAL("ir.PrimType").set_body_typed([](runtime::DataType dtype) {
37 return PrimType(dtype);
38});
39
40PointerType::PointerType(Type element_type, String storage_scope) {
41 ObjectPtr<PointerTypeNode> n = make_object<PointerTypeNode>();
42 n->element_type = std::move(element_type);
43 n->storage_scope = std::move(storage_scope);
44 data_ = std::move(n);
45}
46
47TVM_REGISTER_NODE_TYPE(PointerTypeNode);
48
49TVM_REGISTER_GLOBAL("ir.PointerType")
50 .set_body_typed([](Type element_type, String storage_scope = "") {
51 return PointerType(element_type, storage_scope);
52 });
53
54TypeVar::TypeVar(String name, TypeKind kind, Span span) {
55 ObjectPtr<TypeVarNode> n = make_object<TypeVarNode>();
56 n->name_hint = std::move(name);
57 n->kind = std::move(kind);
58 n->span = std::move(span);
59 data_ = std::move(n);
60}
61
62TVM_REGISTER_NODE_TYPE(TypeVarNode);
63
64TVM_REGISTER_GLOBAL("ir.TypeVar").set_body_typed([](String name, int kind) {
65 return TypeVar(name, static_cast<TypeKind>(kind));
66});
67
68GlobalTypeVar::GlobalTypeVar(String name, TypeKind kind, Span span) {
69 ObjectPtr<GlobalTypeVarNode> n = make_object<GlobalTypeVarNode>();
70 n->name_hint = std::move(name);
71 n->kind = std::move(kind);
72 n->span = std::move(span);
73 data_ = std::move(n);
74}
75
76TVM_REGISTER_NODE_TYPE(GlobalTypeVarNode);
77
78TVM_REGISTER_GLOBAL("ir.GlobalTypeVar").set_body_typed([](String name, int kind) {
79 return GlobalTypeVar(name, static_cast<TypeKind>(kind));
80});
81
82FuncType::FuncType(tvm::Array<Type> arg_types, Type ret_type, tvm::Array<TypeVar> type_params,
83 tvm::Array<TypeConstraint> type_constraints, Span span) {
84 ObjectPtr<FuncTypeNode> n = make_object<FuncTypeNode>();
85 n->arg_types = std::move(arg_types);
86 n->ret_type = std::move(ret_type);
87 n->type_params = std::move(type_params);
88 n->type_constraints = std::move(type_constraints);
89 n->span = std::move(span);
90 data_ = std::move(n);
91}
92
93TVM_REGISTER_NODE_TYPE(FuncTypeNode);
94
95TVM_REGISTER_GLOBAL("ir.FuncType")
96 .set_body_typed([](tvm::Array<Type> arg_types, Type ret_type, tvm::Array<TypeVar> type_params,
97 tvm::Array<TypeConstraint> type_constraints) {
98 return FuncType(arg_types, ret_type, type_params, type_constraints);
99 });
100
101TupleType::TupleType(Array<Type> fields, Span span) {
102 ObjectPtr<TupleTypeNode> n = make_object<TupleTypeNode>();
103 n->fields = std::move(fields);
104 n->span = std::move(span);
105 data_ = std::move(n);
106}
107
108TupleType TupleType::Empty() { return TupleType(Array<Type>()); }
109
110TVM_REGISTER_NODE_TYPE(TupleTypeNode);
111
112TVM_REGISTER_GLOBAL("ir.TupleType").set_body_typed([](Array<Type> fields) {
113 return TupleType(fields);
114});
115
116IncompleteType::IncompleteType(TypeKind kind, Span span) {
117 auto n = make_object<IncompleteTypeNode>();
118 n->kind = std::move(kind);
119 n->span = std::move(span);
120 data_ = std::move(n);
121}
122
123TVM_REGISTER_NODE_TYPE(IncompleteTypeNode);
124
125TVM_REGISTER_GLOBAL("ir.IncompleteType").set_body_typed([](int kind) {
126 return IncompleteType(static_cast<TypeKind>(kind));
127});
128
129RelayRefType::RelayRefType(Type value, Span span) {
130 ObjectPtr<RelayRefTypeNode> n = make_object<RelayRefTypeNode>();
131 n->value = std::move(value);
132 n->span = std::move(span);
133 data_ = std::move(n);
134}
135
136TVM_REGISTER_GLOBAL("ir.RelayRefType").set_body_typed([](Type value) {
137 return RelayRefType(value);
138});
139
140TVM_REGISTER_NODE_TYPE(RelayRefTypeNode);
141
142} // namespace tvm
143