1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file External random functions for tensor. |
22 | */ |
23 | #include <dmlc/thread_local.h> |
24 | #include <tvm/runtime/data_type.h> |
25 | #include <tvm/runtime/logging.h> |
26 | #include <tvm/runtime/registry.h> |
27 | #include <tvm/runtime/threading_backend.h> |
28 | |
29 | #include <algorithm> |
30 | |
31 | #include "mt_random_engine.cc" |
32 | |
33 | #define DLPACK_INTEGER_TYPE_SWITCH(type, DType, ...) \ |
34 | if (type.code == kDLInt && type.bits == 32) { \ |
35 | typedef int32_t DType; \ |
36 | { __VA_ARGS__ } \ |
37 | } else if (type.code == kDLInt && type.bits == 16) { \ |
38 | typedef int16_t DType; \ |
39 | { __VA_ARGS__ } \ |
40 | } else if (type.code == kDLInt && type.bits == 8) { \ |
41 | typedef int8_t DType; \ |
42 | { __VA_ARGS__ } \ |
43 | } else if (type.code == kDLUInt && type.bits == 32) { \ |
44 | typedef uint32_t DType; \ |
45 | { __VA_ARGS__ } \ |
46 | } else if (type.code == kDLUInt && type.bits == 16) { \ |
47 | typedef uint16_t DType; \ |
48 | { __VA_ARGS__ } \ |
49 | } else if (type.code == kDLUInt && type.bits == 8) { \ |
50 | typedef uint8_t DType; \ |
51 | { __VA_ARGS__ } \ |
52 | } else { \ |
53 | LOG(FATAL) << "unknown data type"; \ |
54 | } |
55 | |
56 | namespace tvm { |
57 | namespace contrib { |
58 | |
59 | using namespace runtime; |
60 | |
61 | struct RandomThreadLocalEntry { |
62 | RandomEngine random_engine; |
63 | static RandomThreadLocalEntry* ThreadLocal(); |
64 | }; |
65 | |
66 | typedef dmlc::ThreadLocalStore<RandomThreadLocalEntry> RandomThreadLocalStore; |
67 | |
68 | RandomThreadLocalEntry* RandomThreadLocalEntry::ThreadLocal() { |
69 | return RandomThreadLocalStore::Get(); |
70 | } |
71 | |
72 | TVM_REGISTER_GLOBAL("tvm.contrib.random.randint" ).set_body([](TVMArgs args, TVMRetValue* ret) { |
73 | RandomThreadLocalEntry* entry = RandomThreadLocalEntry::ThreadLocal(); |
74 | int64_t low = args[0]; |
75 | int64_t high = args[1]; |
76 | DLTensor* out = args[2]; |
77 | ICHECK_GT(high, low) << "high must be bigger than low" ; |
78 | ICHECK(out->strides == nullptr); |
79 | |
80 | DLDataType dtype = out->dtype; |
81 | int64_t size = 1; |
82 | for (int i = 0; i < out->ndim; ++i) { |
83 | size *= out->shape[i]; |
84 | } |
85 | |
86 | DLPACK_INTEGER_TYPE_SWITCH(dtype, DType, { |
87 | int64_t numeric_low = std::numeric_limits<DType>::min(); |
88 | int64_t numeric_high = std::numeric_limits<DType>::max(); |
89 | numeric_high += 1; // exclusive upper bound |
90 | low = std::max(low, numeric_low); |
91 | high = std::min(high, numeric_high); |
92 | |
93 | if (out->device.device_type == kDLCPU) { |
94 | // file the data with random byte |
95 | std::generate_n(static_cast<DType*>(out->data), size, [&]() { |
96 | unsigned rint = entry->random_engine.GetRandInt(); |
97 | return low + rint % (high - low); |
98 | }); |
99 | } else { |
100 | LOG(FATAL) << "Do not support random.randint on this device yet" ; |
101 | } |
102 | }) |
103 | }); |
104 | |
105 | TVM_REGISTER_GLOBAL("tvm.contrib.random.uniform" ).set_body([](TVMArgs args, TVMRetValue* ret) { |
106 | RandomThreadLocalEntry* entry = RandomThreadLocalEntry::ThreadLocal(); |
107 | double low = args[0]; |
108 | double high = args[1]; |
109 | DLTensor* out = args[2]; |
110 | entry->random_engine.SampleUniform(out, low, high); |
111 | }); |
112 | |
113 | TVM_REGISTER_GLOBAL("tvm.contrib.random.normal" ).set_body([](TVMArgs args, TVMRetValue* ret) { |
114 | RandomThreadLocalEntry* entry = RandomThreadLocalEntry::ThreadLocal(); |
115 | double loc = args[0]; |
116 | double scale = args[1]; |
117 | DLTensor* out = args[2]; |
118 | entry->random_engine.SampleNormal(out, loc, scale); |
119 | }); |
120 | |
121 | TVM_REGISTER_GLOBAL("tvm.contrib.random.random_fill" ).set_body([](TVMArgs args, TVMRetValue* ret) { |
122 | RandomThreadLocalEntry* entry = RandomThreadLocalEntry::ThreadLocal(); |
123 | DLTensor* out = args[0]; |
124 | entry->random_engine.RandomFill(out); |
125 | }); |
126 | |
127 | TVM_REGISTER_GLOBAL("tvm.contrib.random.random_fill_for_measure" ) |
128 | .set_body([](TVMArgs args, TVMRetValue* ret) -> void { |
129 | static const PackedFunc* curand = Registry::Get("runtime.contrib.curand.RandomFill" ); |
130 | DLTensor* out = args[0]; |
131 | if (curand && out->device.device_type == DLDeviceType::kDLCUDA) { |
132 | if (out->dtype.code == DLDataTypeCode::kDLFloat) { |
133 | (*curand)(out); |
134 | return; |
135 | } |
136 | } |
137 | RandomThreadLocalEntry* entry = RandomThreadLocalEntry::ThreadLocal(); |
138 | entry->random_engine.RandomFillForMeasure(out); |
139 | }); |
140 | |
141 | } // namespace contrib |
142 | } // namespace tvm |
143 | |