1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/ir/type.cc |
22 | * \brief Common type system AST nodes throughout the IR. |
23 | */ |
24 | #include <tvm/ir/type.h> |
25 | #include <tvm/runtime/registry.h> |
26 | namespace tvm { |
27 | |
28 | PrimType::PrimType(runtime::DataType dtype) { |
29 | ObjectPtr<PrimTypeNode> n = make_object<PrimTypeNode>(); |
30 | n->dtype = dtype; |
31 | data_ = std::move(n); |
32 | } |
33 | |
34 | TVM_REGISTER_NODE_TYPE(PrimTypeNode); |
35 | |
36 | TVM_REGISTER_GLOBAL("ir.PrimType" ).set_body_typed([](runtime::DataType dtype) { |
37 | return PrimType(dtype); |
38 | }); |
39 | |
40 | PointerType::PointerType(Type element_type, String storage_scope) { |
41 | ObjectPtr<PointerTypeNode> n = make_object<PointerTypeNode>(); |
42 | n->element_type = std::move(element_type); |
43 | n->storage_scope = std::move(storage_scope); |
44 | data_ = std::move(n); |
45 | } |
46 | |
47 | TVM_REGISTER_NODE_TYPE(PointerTypeNode); |
48 | |
49 | TVM_REGISTER_GLOBAL("ir.PointerType" ) |
50 | .set_body_typed([](Type element_type, String storage_scope = "" ) { |
51 | return PointerType(element_type, storage_scope); |
52 | }); |
53 | |
54 | TypeVar::TypeVar(String name, TypeKind kind, Span span) { |
55 | ObjectPtr<TypeVarNode> n = make_object<TypeVarNode>(); |
56 | n->name_hint = std::move(name); |
57 | n->kind = std::move(kind); |
58 | n->span = std::move(span); |
59 | data_ = std::move(n); |
60 | } |
61 | |
62 | TVM_REGISTER_NODE_TYPE(TypeVarNode); |
63 | |
64 | TVM_REGISTER_GLOBAL("ir.TypeVar" ).set_body_typed([](String name, int kind) { |
65 | return TypeVar(name, static_cast<TypeKind>(kind)); |
66 | }); |
67 | |
68 | GlobalTypeVar::GlobalTypeVar(String name, TypeKind kind, Span span) { |
69 | ObjectPtr<GlobalTypeVarNode> n = make_object<GlobalTypeVarNode>(); |
70 | n->name_hint = std::move(name); |
71 | n->kind = std::move(kind); |
72 | n->span = std::move(span); |
73 | data_ = std::move(n); |
74 | } |
75 | |
76 | TVM_REGISTER_NODE_TYPE(GlobalTypeVarNode); |
77 | |
78 | TVM_REGISTER_GLOBAL("ir.GlobalTypeVar" ).set_body_typed([](String name, int kind) { |
79 | return GlobalTypeVar(name, static_cast<TypeKind>(kind)); |
80 | }); |
81 | |
82 | FuncType::FuncType(tvm::Array<Type> arg_types, Type ret_type, tvm::Array<TypeVar> type_params, |
83 | tvm::Array<TypeConstraint> type_constraints, Span span) { |
84 | ObjectPtr<FuncTypeNode> n = make_object<FuncTypeNode>(); |
85 | n->arg_types = std::move(arg_types); |
86 | n->ret_type = std::move(ret_type); |
87 | n->type_params = std::move(type_params); |
88 | n->type_constraints = std::move(type_constraints); |
89 | n->span = std::move(span); |
90 | data_ = std::move(n); |
91 | } |
92 | |
93 | TVM_REGISTER_NODE_TYPE(FuncTypeNode); |
94 | |
95 | TVM_REGISTER_GLOBAL("ir.FuncType" ) |
96 | .set_body_typed([](tvm::Array<Type> arg_types, Type ret_type, tvm::Array<TypeVar> type_params, |
97 | tvm::Array<TypeConstraint> type_constraints) { |
98 | return FuncType(arg_types, ret_type, type_params, type_constraints); |
99 | }); |
100 | |
101 | TupleType::TupleType(Array<Type> fields, Span span) { |
102 | ObjectPtr<TupleTypeNode> n = make_object<TupleTypeNode>(); |
103 | n->fields = std::move(fields); |
104 | n->span = std::move(span); |
105 | data_ = std::move(n); |
106 | } |
107 | |
108 | TupleType TupleType::Empty() { return TupleType(Array<Type>()); } |
109 | |
110 | TVM_REGISTER_NODE_TYPE(TupleTypeNode); |
111 | |
112 | TVM_REGISTER_GLOBAL("ir.TupleType" ).set_body_typed([](Array<Type> fields) { |
113 | return TupleType(fields); |
114 | }); |
115 | |
116 | IncompleteType::IncompleteType(TypeKind kind, Span span) { |
117 | auto n = make_object<IncompleteTypeNode>(); |
118 | n->kind = std::move(kind); |
119 | n->span = std::move(span); |
120 | data_ = std::move(n); |
121 | } |
122 | |
123 | TVM_REGISTER_NODE_TYPE(IncompleteTypeNode); |
124 | |
125 | TVM_REGISTER_GLOBAL("ir.IncompleteType" ).set_body_typed([](int kind) { |
126 | return IncompleteType(static_cast<TypeKind>(kind)); |
127 | }); |
128 | |
129 | RelayRefType::RelayRefType(Type value, Span span) { |
130 | ObjectPtr<RelayRefTypeNode> n = make_object<RelayRefTypeNode>(); |
131 | n->value = std::move(value); |
132 | n->span = std::move(span); |
133 | data_ = std::move(n); |
134 | } |
135 | |
136 | TVM_REGISTER_GLOBAL("ir.RelayRefType" ).set_body_typed([](Type value) { |
137 | return RelayRefType(value); |
138 | }); |
139 | |
140 | TVM_REGISTER_NODE_TYPE(RelayRefTypeNode); |
141 | |
142 | } // namespace tvm |
143 | |