1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | #include <tvm/relay/attrs/random.h> |
21 | #include <tvm/relay/op.h> |
22 | |
23 | namespace tvm { |
24 | namespace relay { |
25 | |
26 | TVM_REGISTER_NODE_TYPE(ThreefryGenerateAttrs); |
27 | |
28 | static TensorType ThreefryKeyType() { return TensorType({10}, tvm::DataType::UInt(64)); } |
29 | |
30 | bool ThreefryGenerateRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
31 | const TypeReporter& reporter) { |
32 | const ThreefryGenerateAttrs* param = attrs.as<ThreefryGenerateAttrs>(); |
33 | ICHECK_EQ(types.size(), 2) << "ThreefryGenerate should have one input and one output" ; |
34 | |
35 | reporter->Assign(types[0], ThreefryKeyType()); |
36 | |
37 | std::vector<IndexExpr> oshape; |
38 | for (auto& x : param->out_shape) { |
39 | oshape.push_back(x); |
40 | } |
41 | // generate returns the next key and an array of random values |
42 | // TODO(@tkonolige, @altanh): support other output dtypes? |
43 | reporter->Assign(types[1], |
44 | TupleType({ThreefryKeyType(), TensorType(oshape, tvm::DataType::UInt(64))})); |
45 | return true; |
46 | } |
47 | |
48 | Expr MakeThreefryGenerate(Expr key, Array<Integer> out_shape) { |
49 | auto attrs = make_object<ThreefryGenerateAttrs>(); |
50 | attrs->out_shape = out_shape; |
51 | static const Op& op = Op::Get("random.threefry_generate" ); |
52 | return Call(op, {key}, Attrs(attrs), {}); |
53 | } |
54 | |
55 | TVM_REGISTER_GLOBAL("relay.op.random._make.threefry_generate" ).set_body_typed(MakeThreefryGenerate); |
56 | |
57 | RELAY_REGISTER_OP("random.threefry_generate" ) |
58 | .describe( |
59 | R"doc(Generate an array of random numbers using the Threefry algorithm.)doc" TVM_ADD_FILELINE) |
60 | .set_num_inputs(1) |
61 | .set_attrs_type<ThreefryGenerateAttrs>() |
62 | .add_argument("key" , "Tensor" , "Input Threefry key" ) |
63 | .add_type_rel("ThreefryGenerate" , ThreefryGenerateRel); |
64 | |
65 | bool ThreefrySplitRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
66 | const TypeReporter& reporter) { |
67 | ICHECK_EQ(types.size(), 2) << "ThreefrySplit should have one input and one output" ; |
68 | |
69 | reporter->Assign(types[0], ThreefryKeyType()); |
70 | reporter->Assign(types[1], TupleType({ThreefryKeyType(), ThreefryKeyType()})); |
71 | |
72 | return true; |
73 | } |
74 | |
75 | Expr MakeThreefrySplit(Expr key) { |
76 | static const Op& op = Op::Get("random.threefry_split" ); |
77 | return Call(op, {key}, Attrs(), {}); |
78 | } |
79 | |
80 | TVM_REGISTER_GLOBAL("relay.op.random._make.threefry_split" ).set_body_typed(MakeThreefrySplit); |
81 | |
82 | RELAY_REGISTER_OP("random.threefry_split" ) |
83 | .describe(R"doc(Split the input Threefry key into two new ones.)doc" TVM_ADD_FILELINE) |
84 | .set_num_inputs(1) |
85 | .add_argument("key" , "Tensor" , "Input Threefry key" ) |
86 | .add_type_rel("ThreefrySplit" , ThreefrySplitRel); |
87 | |
88 | TVM_REGISTER_NODE_TYPE(UniformAttrs); |
89 | |
90 | bool UniformRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
91 | const TypeReporter& reporter) { |
92 | const UniformAttrs* param = attrs.as<UniformAttrs>(); |
93 | ICHECK_EQ(types.size(), 4) << "Uniform should have three inputs and one output" ; |
94 | |
95 | std::vector<IndexExpr> oshape; |
96 | for (auto& x : param->out_shape) { |
97 | oshape.push_back(x); |
98 | } |
99 | DataType out_dtype = param->out_dtype; |
100 | // we are supporting float32 and float64 at the moment. |
101 | if (!(out_dtype.is_float() && (out_dtype.bits() == 32 || out_dtype.bits() == 64))) { |
102 | reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) |
103 | << "We only support generating uniform random value of " |
104 | << "type float32 or float64, got " << out_dtype << "." ); |
105 | return false; |
106 | } |
107 | reporter->Assign(types[0], ThreefryKeyType()); |
108 | reporter->Assign(types[1], TensorType({}, out_dtype)); |
109 | reporter->Assign(types[2], TensorType({}, out_dtype)); |
110 | // generate returns the next key and an array of random values |
111 | reporter->Assign(types[3], TupleType({ThreefryKeyType(), TensorType(oshape, out_dtype)})); |
112 | return true; |
113 | } |
114 | |
115 | Expr MakeUniform(Expr key, Expr low, Expr high, Array<Integer> out_shape, DataType out_dtype) { |
116 | auto attrs = make_object<UniformAttrs>(); |
117 | attrs->out_shape = out_shape; |
118 | attrs->out_dtype = out_dtype; |
119 | static const Op& op = Op::Get("random.uniform" ); |
120 | return Call(op, {key, low, high}, Attrs(attrs), {}); |
121 | } |
122 | |
123 | TVM_REGISTER_GLOBAL("relay.op.random._make.uniform" ).set_body_typed(MakeUniform); |
124 | |
125 | RELAY_REGISTER_OP("random.uniform" ) |
126 | .describe( |
127 | R"doc(Generate an array of random numbers under uniform distribution.)doc" TVM_ADD_FILELINE) |
128 | .set_num_inputs(3) |
129 | .set_attrs_type<UniformAttrs>() |
130 | .add_argument("key" , "Tensor" , "Input Threefry key" ) |
131 | .add_argument("low" , "Tensor" , "Lower bound of the distribution" ) |
132 | .add_argument("high" , "Tensor" , "Higher bound of the distribution" ) |
133 | .add_type_rel("Uniform" , UniformRel); |
134 | |
135 | TVM_REGISTER_NODE_TYPE(NormalAttrs); |
136 | |
137 | bool NormalRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
138 | const TypeReporter& reporter) { |
139 | const NormalAttrs* param = attrs.as<NormalAttrs>(); |
140 | ICHECK_EQ(types.size(), 4) << "Normal should have three inputs and one output" ; |
141 | |
142 | std::vector<IndexExpr> oshape; |
143 | for (auto& x : param->out_shape) { |
144 | oshape.push_back(x); |
145 | } |
146 | DataType out_dtype = param->out_dtype; |
147 | // we are supporting float32 and float64 at the moment. |
148 | if (!(out_dtype.is_float() && (out_dtype.bits() == 32 || out_dtype.bits() == 64))) { |
149 | reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) |
150 | << "We only support generating Normal random value of " |
151 | << "type float32 or float64, got " << out_dtype << "." ); |
152 | return false; |
153 | } |
154 | reporter->Assign(types[0], ThreefryKeyType()); |
155 | reporter->Assign(types[1], TensorType({}, out_dtype)); |
156 | reporter->Assign(types[2], TensorType({}, out_dtype)); |
157 | // generate returns the next key and an array of random values |
158 | reporter->Assign(types[3], TupleType({ThreefryKeyType(), TensorType(oshape, out_dtype)})); |
159 | return true; |
160 | } |
161 | |
162 | Expr MakeNormal(Expr key, Expr mean, Expr scale, Array<Integer> out_shape, DataType out_dtype) { |
163 | auto attrs = make_object<NormalAttrs>(); |
164 | attrs->out_shape = out_shape; |
165 | attrs->out_dtype = out_dtype; |
166 | static const Op& op = Op::Get("random.normal" ); |
167 | return Call(op, {key, mean, scale}, Attrs(attrs), {}); |
168 | } |
169 | |
170 | TVM_REGISTER_GLOBAL("relay.op.random._make.normal" ).set_body_typed(MakeNormal); |
171 | |
172 | RELAY_REGISTER_OP("random.normal" ) |
173 | .describe( |
174 | R"doc(Generate an array of random numbers under normal distribution.)doc" TVM_ADD_FILELINE) |
175 | .set_num_inputs(3) |
176 | .set_attrs_type<NormalAttrs>() |
177 | .add_argument("key" , "Tensor" , "Input Threefry key" ) |
178 | .add_argument("mean" , "Tensor" , "Mean of the distribution" ) |
179 | .add_argument("scale" , "Tensor" , "Standard deviation of the distribution" ) |
180 | .add_type_rel("Normal" , NormalRel); |
181 | |
182 | TVM_REGISTER_NODE_TYPE(MultinomialAttrs); |
183 | |
184 | bool MultinomialRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
185 | const TypeReporter& reporter) { |
186 | const MultinomialAttrs* param = attrs.as<MultinomialAttrs>(); |
187 | ICHECK_EQ(types.size(), 3) << "Normal should have two inputs and one output" ; |
188 | |
189 | const auto* data = types[1].as<TensorTypeNode>(); |
190 | if (data == nullptr) { |
191 | ICHECK(types[1].as<IncompleteTypeNode>()) |
192 | << "multinomial: expect input type to be TensorType but get " << types[0]; |
193 | return false; |
194 | } |
195 | |
196 | std::vector<IndexExpr> oshape; |
197 | for (size_t i = 0; i < data->shape.size() - 1; i++) { |
198 | oshape.push_back(data->shape[i]); |
199 | } |
200 | oshape.push_back(param->num_samples); |
201 | |
202 | DataType out_dtype = tvm::DataType::Int(32); |
203 | |
204 | reporter->Assign(types[0], ThreefryKeyType()); |
205 | // generate returns the next key and an array of random values |
206 | reporter->Assign(types[2], TupleType({ThreefryKeyType(), TensorType(oshape, out_dtype)})); |
207 | return true; |
208 | } |
209 | |
210 | Expr MakeMultinomial(Expr key, Expr probs, Integer num_samples) { |
211 | auto attrs = make_object<MultinomialAttrs>(); |
212 | attrs->num_samples = num_samples; |
213 | static const Op& op = Op::Get("random.multinomial" ); |
214 | return Call(op, {key, probs}, Attrs(attrs), {}); |
215 | } |
216 | |
217 | TVM_REGISTER_GLOBAL("relay.op.random._make.multinomial" ).set_body_typed(MakeMultinomial); |
218 | |
219 | RELAY_REGISTER_OP("random.multinomial" ) |
220 | .describe( |
221 | R"doc(Generate an array of random numbers under normal distribution.)doc" TVM_ADD_FILELINE) |
222 | .set_num_inputs(2) |
223 | .set_attrs_type<MultinomialAttrs>() |
224 | .add_argument("key" , "Tensor" , "Input Threefry key" ) |
225 | .add_argument("probs" , "Tensor" , "Array of probabilities for each corresponding index." ) |
226 | .add_type_rel("Multinomial" , MultinomialRel); |
227 | |
228 | } // namespace relay |
229 | } // namespace tvm |
230 | |