1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20#include <tvm/relay/attrs/random.h>
21#include <tvm/relay/op.h>
22
23namespace tvm {
24namespace relay {
25
26TVM_REGISTER_NODE_TYPE(ThreefryGenerateAttrs);
27
28static TensorType ThreefryKeyType() { return TensorType({10}, tvm::DataType::UInt(64)); }
29
30bool ThreefryGenerateRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
31 const TypeReporter& reporter) {
32 const ThreefryGenerateAttrs* param = attrs.as<ThreefryGenerateAttrs>();
33 ICHECK_EQ(types.size(), 2) << "ThreefryGenerate should have one input and one output";
34
35 reporter->Assign(types[0], ThreefryKeyType());
36
37 std::vector<IndexExpr> oshape;
38 for (auto& x : param->out_shape) {
39 oshape.push_back(x);
40 }
41 // generate returns the next key and an array of random values
42 // TODO(@tkonolige, @altanh): support other output dtypes?
43 reporter->Assign(types[1],
44 TupleType({ThreefryKeyType(), TensorType(oshape, tvm::DataType::UInt(64))}));
45 return true;
46}
47
48Expr MakeThreefryGenerate(Expr key, Array<Integer> out_shape) {
49 auto attrs = make_object<ThreefryGenerateAttrs>();
50 attrs->out_shape = out_shape;
51 static const Op& op = Op::Get("random.threefry_generate");
52 return Call(op, {key}, Attrs(attrs), {});
53}
54
55TVM_REGISTER_GLOBAL("relay.op.random._make.threefry_generate").set_body_typed(MakeThreefryGenerate);
56
57RELAY_REGISTER_OP("random.threefry_generate")
58 .describe(
59 R"doc(Generate an array of random numbers using the Threefry algorithm.)doc" TVM_ADD_FILELINE)
60 .set_num_inputs(1)
61 .set_attrs_type<ThreefryGenerateAttrs>()
62 .add_argument("key", "Tensor", "Input Threefry key")
63 .add_type_rel("ThreefryGenerate", ThreefryGenerateRel);
64
65bool ThreefrySplitRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
66 const TypeReporter& reporter) {
67 ICHECK_EQ(types.size(), 2) << "ThreefrySplit should have one input and one output";
68
69 reporter->Assign(types[0], ThreefryKeyType());
70 reporter->Assign(types[1], TupleType({ThreefryKeyType(), ThreefryKeyType()}));
71
72 return true;
73}
74
75Expr MakeThreefrySplit(Expr key) {
76 static const Op& op = Op::Get("random.threefry_split");
77 return Call(op, {key}, Attrs(), {});
78}
79
80TVM_REGISTER_GLOBAL("relay.op.random._make.threefry_split").set_body_typed(MakeThreefrySplit);
81
82RELAY_REGISTER_OP("random.threefry_split")
83 .describe(R"doc(Split the input Threefry key into two new ones.)doc" TVM_ADD_FILELINE)
84 .set_num_inputs(1)
85 .add_argument("key", "Tensor", "Input Threefry key")
86 .add_type_rel("ThreefrySplit", ThreefrySplitRel);
87
88TVM_REGISTER_NODE_TYPE(UniformAttrs);
89
90bool UniformRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
91 const TypeReporter& reporter) {
92 const UniformAttrs* param = attrs.as<UniformAttrs>();
93 ICHECK_EQ(types.size(), 4) << "Uniform should have three inputs and one output";
94
95 std::vector<IndexExpr> oshape;
96 for (auto& x : param->out_shape) {
97 oshape.push_back(x);
98 }
99 DataType out_dtype = param->out_dtype;
100 // we are supporting float32 and float64 at the moment.
101 if (!(out_dtype.is_float() && (out_dtype.bits() == 32 || out_dtype.bits() == 64))) {
102 reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
103 << "We only support generating uniform random value of "
104 << "type float32 or float64, got " << out_dtype << ".");
105 return false;
106 }
107 reporter->Assign(types[0], ThreefryKeyType());
108 reporter->Assign(types[1], TensorType({}, out_dtype));
109 reporter->Assign(types[2], TensorType({}, out_dtype));
110 // generate returns the next key and an array of random values
111 reporter->Assign(types[3], TupleType({ThreefryKeyType(), TensorType(oshape, out_dtype)}));
112 return true;
113}
114
115Expr MakeUniform(Expr key, Expr low, Expr high, Array<Integer> out_shape, DataType out_dtype) {
116 auto attrs = make_object<UniformAttrs>();
117 attrs->out_shape = out_shape;
118 attrs->out_dtype = out_dtype;
119 static const Op& op = Op::Get("random.uniform");
120 return Call(op, {key, low, high}, Attrs(attrs), {});
121}
122
123TVM_REGISTER_GLOBAL("relay.op.random._make.uniform").set_body_typed(MakeUniform);
124
125RELAY_REGISTER_OP("random.uniform")
126 .describe(
127 R"doc(Generate an array of random numbers under uniform distribution.)doc" TVM_ADD_FILELINE)
128 .set_num_inputs(3)
129 .set_attrs_type<UniformAttrs>()
130 .add_argument("key", "Tensor", "Input Threefry key")
131 .add_argument("low", "Tensor", "Lower bound of the distribution")
132 .add_argument("high", "Tensor", "Higher bound of the distribution")
133 .add_type_rel("Uniform", UniformRel);
134
135TVM_REGISTER_NODE_TYPE(NormalAttrs);
136
137bool NormalRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
138 const TypeReporter& reporter) {
139 const NormalAttrs* param = attrs.as<NormalAttrs>();
140 ICHECK_EQ(types.size(), 4) << "Normal should have three inputs and one output";
141
142 std::vector<IndexExpr> oshape;
143 for (auto& x : param->out_shape) {
144 oshape.push_back(x);
145 }
146 DataType out_dtype = param->out_dtype;
147 // we are supporting float32 and float64 at the moment.
148 if (!(out_dtype.is_float() && (out_dtype.bits() == 32 || out_dtype.bits() == 64))) {
149 reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
150 << "We only support generating Normal random value of "
151 << "type float32 or float64, got " << out_dtype << ".");
152 return false;
153 }
154 reporter->Assign(types[0], ThreefryKeyType());
155 reporter->Assign(types[1], TensorType({}, out_dtype));
156 reporter->Assign(types[2], TensorType({}, out_dtype));
157 // generate returns the next key and an array of random values
158 reporter->Assign(types[3], TupleType({ThreefryKeyType(), TensorType(oshape, out_dtype)}));
159 return true;
160}
161
162Expr MakeNormal(Expr key, Expr mean, Expr scale, Array<Integer> out_shape, DataType out_dtype) {
163 auto attrs = make_object<NormalAttrs>();
164 attrs->out_shape = out_shape;
165 attrs->out_dtype = out_dtype;
166 static const Op& op = Op::Get("random.normal");
167 return Call(op, {key, mean, scale}, Attrs(attrs), {});
168}
169
170TVM_REGISTER_GLOBAL("relay.op.random._make.normal").set_body_typed(MakeNormal);
171
172RELAY_REGISTER_OP("random.normal")
173 .describe(
174 R"doc(Generate an array of random numbers under normal distribution.)doc" TVM_ADD_FILELINE)
175 .set_num_inputs(3)
176 .set_attrs_type<NormalAttrs>()
177 .add_argument("key", "Tensor", "Input Threefry key")
178 .add_argument("mean", "Tensor", "Mean of the distribution")
179 .add_argument("scale", "Tensor", "Standard deviation of the distribution")
180 .add_type_rel("Normal", NormalRel);
181
182TVM_REGISTER_NODE_TYPE(MultinomialAttrs);
183
184bool MultinomialRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
185 const TypeReporter& reporter) {
186 const MultinomialAttrs* param = attrs.as<MultinomialAttrs>();
187 ICHECK_EQ(types.size(), 3) << "Normal should have two inputs and one output";
188
189 const auto* data = types[1].as<TensorTypeNode>();
190 if (data == nullptr) {
191 ICHECK(types[1].as<IncompleteTypeNode>())
192 << "multinomial: expect input type to be TensorType but get " << types[0];
193 return false;
194 }
195
196 std::vector<IndexExpr> oshape;
197 for (size_t i = 0; i < data->shape.size() - 1; i++) {
198 oshape.push_back(data->shape[i]);
199 }
200 oshape.push_back(param->num_samples);
201
202 DataType out_dtype = tvm::DataType::Int(32);
203
204 reporter->Assign(types[0], ThreefryKeyType());
205 // generate returns the next key and an array of random values
206 reporter->Assign(types[2], TupleType({ThreefryKeyType(), TensorType(oshape, out_dtype)}));
207 return true;
208}
209
210Expr MakeMultinomial(Expr key, Expr probs, Integer num_samples) {
211 auto attrs = make_object<MultinomialAttrs>();
212 attrs->num_samples = num_samples;
213 static const Op& op = Op::Get("random.multinomial");
214 return Call(op, {key, probs}, Attrs(attrs), {});
215}
216
217TVM_REGISTER_GLOBAL("relay.op.random._make.multinomial").set_body_typed(MakeMultinomial);
218
219RELAY_REGISTER_OP("random.multinomial")
220 .describe(
221 R"doc(Generate an array of random numbers under normal distribution.)doc" TVM_ADD_FILELINE)
222 .set_num_inputs(2)
223 .set_attrs_type<MultinomialAttrs>()
224 .add_argument("key", "Tensor", "Input Threefry key")
225 .add_argument("probs", "Tensor", "Array of probabilities for each corresponding index.")
226 .add_type_rel("Multinomial", MultinomialRel);
227
228} // namespace relay
229} // namespace tvm
230