1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file External random functions for tensor.
22 */
23#include <dmlc/thread_local.h>
24#include <tvm/runtime/data_type.h>
25#include <tvm/runtime/logging.h>
26#include <tvm/runtime/registry.h>
27#include <tvm/runtime/threading_backend.h>
28
29#include <algorithm>
30
31#include "mt_random_engine.cc"
32
33#define DLPACK_INTEGER_TYPE_SWITCH(type, DType, ...) \
34 if (type.code == kDLInt && type.bits == 32) { \
35 typedef int32_t DType; \
36 { __VA_ARGS__ } \
37 } else if (type.code == kDLInt && type.bits == 16) { \
38 typedef int16_t DType; \
39 { __VA_ARGS__ } \
40 } else if (type.code == kDLInt && type.bits == 8) { \
41 typedef int8_t DType; \
42 { __VA_ARGS__ } \
43 } else if (type.code == kDLUInt && type.bits == 32) { \
44 typedef uint32_t DType; \
45 { __VA_ARGS__ } \
46 } else if (type.code == kDLUInt && type.bits == 16) { \
47 typedef uint16_t DType; \
48 { __VA_ARGS__ } \
49 } else if (type.code == kDLUInt && type.bits == 8) { \
50 typedef uint8_t DType; \
51 { __VA_ARGS__ } \
52 } else { \
53 LOG(FATAL) << "unknown data type"; \
54 }
55
56namespace tvm {
57namespace contrib {
58
59using namespace runtime;
60
61struct RandomThreadLocalEntry {
62 RandomEngine random_engine;
63 static RandomThreadLocalEntry* ThreadLocal();
64};
65
66typedef dmlc::ThreadLocalStore<RandomThreadLocalEntry> RandomThreadLocalStore;
67
68RandomThreadLocalEntry* RandomThreadLocalEntry::ThreadLocal() {
69 return RandomThreadLocalStore::Get();
70}
71
72TVM_REGISTER_GLOBAL("tvm.contrib.random.randint").set_body([](TVMArgs args, TVMRetValue* ret) {
73 RandomThreadLocalEntry* entry = RandomThreadLocalEntry::ThreadLocal();
74 int64_t low = args[0];
75 int64_t high = args[1];
76 DLTensor* out = args[2];
77 ICHECK_GT(high, low) << "high must be bigger than low";
78 ICHECK(out->strides == nullptr);
79
80 DLDataType dtype = out->dtype;
81 int64_t size = 1;
82 for (int i = 0; i < out->ndim; ++i) {
83 size *= out->shape[i];
84 }
85
86 DLPACK_INTEGER_TYPE_SWITCH(dtype, DType, {
87 int64_t numeric_low = std::numeric_limits<DType>::min();
88 int64_t numeric_high = std::numeric_limits<DType>::max();
89 numeric_high += 1; // exclusive upper bound
90 low = std::max(low, numeric_low);
91 high = std::min(high, numeric_high);
92
93 if (out->device.device_type == kDLCPU) {
94 // file the data with random byte
95 std::generate_n(static_cast<DType*>(out->data), size, [&]() {
96 unsigned rint = entry->random_engine.GetRandInt();
97 return low + rint % (high - low);
98 });
99 } else {
100 LOG(FATAL) << "Do not support random.randint on this device yet";
101 }
102 })
103});
104
105TVM_REGISTER_GLOBAL("tvm.contrib.random.uniform").set_body([](TVMArgs args, TVMRetValue* ret) {
106 RandomThreadLocalEntry* entry = RandomThreadLocalEntry::ThreadLocal();
107 double low = args[0];
108 double high = args[1];
109 DLTensor* out = args[2];
110 entry->random_engine.SampleUniform(out, low, high);
111});
112
113TVM_REGISTER_GLOBAL("tvm.contrib.random.normal").set_body([](TVMArgs args, TVMRetValue* ret) {
114 RandomThreadLocalEntry* entry = RandomThreadLocalEntry::ThreadLocal();
115 double loc = args[0];
116 double scale = args[1];
117 DLTensor* out = args[2];
118 entry->random_engine.SampleNormal(out, loc, scale);
119});
120
121TVM_REGISTER_GLOBAL("tvm.contrib.random.random_fill").set_body([](TVMArgs args, TVMRetValue* ret) {
122 RandomThreadLocalEntry* entry = RandomThreadLocalEntry::ThreadLocal();
123 DLTensor* out = args[0];
124 entry->random_engine.RandomFill(out);
125});
126
127TVM_REGISTER_GLOBAL("tvm.contrib.random.random_fill_for_measure")
128 .set_body([](TVMArgs args, TVMRetValue* ret) -> void {
129 static const PackedFunc* curand = Registry::Get("runtime.contrib.curand.RandomFill");
130 DLTensor* out = args[0];
131 if (curand && out->device.device_type == DLDeviceType::kDLCUDA) {
132 if (out->dtype.code == DLDataTypeCode::kDLFloat) {
133 (*curand)(out);
134 return;
135 }
136 }
137 RandomThreadLocalEntry* entry = RandomThreadLocalEntry::ThreadLocal();
138 entry->random_engine.RandomFillForMeasure(out);
139 });
140
141} // namespace contrib
142} // namespace tvm
143