1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file random/mt_random_engine.cc
22 * \brief mt19937 random engine
23 */
24#include <tvm/runtime/c_backend_api.h>
25#include <tvm/runtime/device_api.h>
26#include <tvm/runtime/logging.h>
27#include <tvm/runtime/ndarray.h>
28#include <tvm/runtime/threading_backend.h>
29
30#include <algorithm>
31#include <ctime>
32#include <random>
33#include <thread>
34
35#include "../3rdparty/compiler-rt/builtin_fp16.h"
36
37namespace tvm {
38namespace contrib {
39
40/*!
41 * \brief An interface for generating [tensors of] random numbers.
42 */
43class RandomEngine {
44 public:
45 /*!
46 * \brief Creates a RandomEngine using a default seed.
47 */
48 RandomEngine() { this->Seed(time(nullptr)); }
49
50 /*!
51 * \brief Creates a RandomEngine, suggesting the use of a provided seed.
52 */
53 explicit RandomEngine(unsigned seed) { this->Seed(seed); }
54
55 /*!
56 * \brief Seeds the underlying RNG, if possible.
57 */
58 inline void Seed(unsigned seed) {
59 rnd_engine_.seed(seed);
60 this->rseed_ = static_cast<unsigned>(seed);
61 }
62
63 /*!
64 * \return the seed associated with the underlying RNG.
65 */
66 inline unsigned GetSeed() const { return rseed_; }
67
68 /*!
69 * \return a random integer sampled from the RNG.
70 */
71 inline unsigned GetRandInt() { return rnd_engine_(); }
72
73 /*!
74 * \brief Fills a tensor with values drawn from Unif(low, high)
75 */
76 void SampleUniform(DLTensor* data, float low, float high) {
77 ICHECK_GT(high, low) << "high must be bigger than low";
78 ICHECK(data->strides == nullptr);
79
80 DLDataType dtype = data->dtype;
81 int64_t size = 1;
82 for (int i = 0; i < data->ndim; ++i) {
83 size *= data->shape[i];
84 }
85
86 ICHECK(dtype.code == kDLFloat && dtype.bits == 32 && dtype.lanes == 1);
87
88 if (data->device.device_type == kDLCPU) {
89 std::uniform_real_distribution<float> uniform_dist(low, high);
90 std::generate_n(static_cast<float*>(data->data), size,
91 [&]() { return uniform_dist(rnd_engine_); });
92 } else {
93 LOG(FATAL) << "Do not support random.uniform on this device yet";
94 }
95 }
96
97 /*!
98 * \brief Fills a tensor with values drawn from Normal(loc, scale**2)
99 */
100 void SampleNormal(DLTensor* data, float loc, float scale) {
101 ICHECK_GT(scale, 0) << "standard deviation must be positive";
102 ICHECK(data->strides == nullptr);
103
104 DLDataType dtype = data->dtype;
105 int64_t size = 1;
106 for (int i = 0; i < data->ndim; ++i) {
107 size *= data->shape[i];
108 }
109
110 ICHECK(dtype.code == kDLFloat && dtype.bits == 32 && dtype.lanes == 1);
111
112 if (data->device.device_type == kDLCPU) {
113 std::normal_distribution<float> normal_dist(loc, scale);
114 std::generate_n(static_cast<float*>(data->data), size,
115 [&]() { return normal_dist(rnd_engine_); });
116 } else {
117 LOG(FATAL) << "Do not support random.normal on this device yet";
118 }
119 }
120
121 void RandomFill(DLTensor* data) {
122 if (data->device.device_type == kDLCPU) {
123 FillData(data);
124 } else {
125 runtime::NDArray local = runtime::NDArray::Empty(
126 std::vector<int64_t>{data->shape, data->shape + data->ndim}, data->dtype, {kDLCPU, 0});
127 DLTensor* tensor = const_cast<DLTensor*>(local.operator->());
128 FillData(tensor);
129 runtime::NDArray::CopyFromTo(tensor, data);
130 }
131 }
132
133 void RandomFillForMeasure(DLTensor* data) {
134 if (data->device.device_type == kDLCPU) {
135 FillDataForMeasure(data);
136 } else {
137 runtime::NDArray local = runtime::NDArray::Empty(
138 std::vector<int64_t>{data->shape, data->shape + data->ndim}, data->dtype, {kDLCPU, 0});
139 DLTensor* tensor = const_cast<DLTensor*>(local.operator->());
140 FillDataForMeasure(tensor);
141 runtime::NDArray::CopyFromTo(tensor, data);
142 }
143 }
144
145 private:
146 void FillDataImpl(void* data, int64_t st, int64_t ed, DLDataType dtype) {
147 // Make the value be 1.0 - 10.0, not (0.0 - 1.0) so that we could satisfy
148 // quantized dtype (uint8 / int8) data non-empty requirement
149 std::uniform_real_distribution<> dist(1.0, 10.0);
150 // Use float representation could make us work well on float / int type too.
151 if (dtype.bits == 1) {
152 std::generate_n(static_cast<bool*>(data) + st, ed - st, [&]() { return dist(rnd_engine_); });
153 } else if (dtype.bits == 4) {
154 // For uint4/int4 we pack two values into a single byte.
155 // Thus, to ensure both values are non-zero, we use a distribution of 17 - 30.
156 std::uniform_real_distribution<> packed_dist(17.0, 30.0);
157 std::generate_n(reinterpret_cast<uint8_t*>(data) + st, ed - st,
158 [&]() { return packed_dist(rnd_engine_); });
159 } else if (dtype.bits == 8) {
160 std::generate_n(static_cast<uint8_t*>(data) + st, ed - st,
161 [&]() { return dist(rnd_engine_); });
162 } else if (dtype.bits == 16) {
163 std::generate_n(static_cast<uint16_t*>(data) + st, ed - st, [&]() {
164 return __truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 10>(
165 static_cast<float>(dist(rnd_engine_)));
166 });
167 } else if (dtype.bits == 32) {
168 std::generate_n(static_cast<float*>(data) + st, ed - st, [&]() { return dist(rnd_engine_); });
169 } else if (dtype.bits == 64) {
170 std::generate_n(static_cast<double*>(data) + st, ed - st,
171 [&]() { return dist(rnd_engine_); });
172 } else {
173 LOG(FATAL) << "Doesn't support dtype code " << dtype.code << " dtype bits " << dtype.bits;
174 }
175 }
176
177 void FillData(DLTensor* tensor) {
178 int64_t size = 1;
179 for (int i = 0; i < tensor->ndim; ++i) {
180 size *= tensor->shape[i];
181 }
182 DLDataType dtype = tensor->dtype;
183 if (dtype.bits == 1 || dtype.bits == 4 || dtype.bits == 8 || dtype.bits == 16 ||
184 dtype.bits == 32 || dtype.bits == 64) {
185 FillDataImpl(tensor->data, 0, size, dtype);
186 } else {
187 LOG(FATAL) << "Doesn't support dtype code " << dtype.code << " dtype bits " << dtype.bits;
188 }
189 }
190
191 void FillDataForMeasure(DLTensor* tensor) {
192 struct ParallelTask {
193 static int RunTask(int task_id, TVMParallelGroupEnv* penv, void* cdata) {
194 ParallelTask* task = static_cast<ParallelTask*>(cdata);
195 task->Run(task_id, penv->num_task);
196 return 0;
197 }
198
199 void Run(int i, int num_tasks) {
200 int64_t chunk_size = size / num_tasks;
201 int64_t st = i * chunk_size;
202 int64_t ed = std::min(st + chunk_size, size);
203 self->FillDataImpl(data, st, ed, dtype);
204 }
205
206 RandomEngine* self;
207 void* data;
208 int64_t size;
209 DLDataType dtype;
210 };
211
212 ParallelTask task;
213 task.self = this;
214 task.data = tensor->data;
215 DLDataType dtype = task.dtype = tensor->dtype;
216 int64_t& size = task.size = 1;
217 for (int i = 0; i < tensor->ndim; ++i) {
218 size *= tensor->shape[i];
219 }
220 if (dtype.bits == 1 || dtype.bits == 4 || dtype.bits == 8 || dtype.bits == 16 ||
221 dtype.bits == 32 || dtype.bits == 64) {
222 int res = TVMBackendParallelLaunch(ParallelTask::RunTask, &task, 0);
223 ICHECK_EQ(res, 0) << "RandomFillForMeasure: TVMBackendParallelLaunch failed";
224 } else {
225 LOG(FATAL) << "Doesn't support dtype code " << dtype.code << " dtype bits " << dtype.bits;
226 }
227 }
228
229 private:
230 std::mt19937 rnd_engine_;
231 unsigned rseed_;
232};
233
234} // namespace contrib
235} // namespace tvm
236