1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file 3rdparty/byodt/my-custom-datatype.cc |
22 | * \brief Example Custom Datatype with the Bring Your Own Datatypes (BYODT) framework. |
23 | * This is a toy example that under the hood simulates floats. |
24 | * |
25 | * Users interested in using the BYODT framework can use this file as a template. |
26 | * |
27 | * TODO(@gussmith23 @hypercubestart) Link to BYODT docs when they exist? |
28 | */ |
29 | #include <tvm/runtime/c_runtime_api.h> |
30 | |
31 | #include <cmath> |
32 | #include <cstdint> |
33 | #include <limits> |
34 | |
35 | // Custom datatypes are stored as bits in a uint of the appropriate bit length. |
36 | // Thus, when TVM calls these C functions, |
37 | // the arguments of are uints that need to reinterpreted as your custom datatype. |
38 | // |
39 | // When returning, your custom datatype needs to be re-wrapped into a uint, |
40 | // which can be thought of as just a wrapper for the raw bits that represent your custom datatype. |
41 | template <class T> |
42 | TVM_DLL T Uint32ToCustom32(uint32_t in) { |
43 | // This is a helper function to interpret the uint as your custom dataype. |
44 | // The following line should be replaced with the appropriate function |
45 | // that interprets the bits in `in` and returns your custom datatype |
46 | T* custom = reinterpret_cast<T*>(&in); |
47 | return *custom; |
48 | } |
49 | |
50 | template <class T> |
51 | TVM_DLL uint32_t Custom32ToUint32(T in) { |
52 | // This is a helper function to wrap your custom datatype in a uint. |
53 | // the following line should be replaced with the appropriate function |
54 | // that converts your custom datatype into a uint |
55 | uint32_t* bits = reinterpret_cast<uint32_t*>(&in); |
56 | return *bits; |
57 | } |
58 | |
59 | extern "C" { |
60 | TVM_DLL uint32_t MinCustom32() { |
61 | // return minimum representable value |
62 | float min = std::numeric_limits<float>::lowest(); |
63 | return Custom32ToUint32<float>(min); |
64 | } |
65 | |
66 | TVM_DLL float Custom32ToFloat(uint32_t in) { |
67 | // cast from custom datatype to float |
68 | float custom_datatype = Uint32ToCustom32<float>(in); |
69 | // our custom datatype is float, so the following redundant cast to float |
70 | // is to remind users to cast their own custom datatype to float |
71 | return static_cast<float>(custom_datatype); |
72 | } |
73 | |
74 | TVM_DLL uint32_t FloatToCustom32(float in) { |
75 | // cast from float to custom datatype |
76 | return Custom32ToUint32<float>(in); |
77 | } |
78 | |
79 | TVM_DLL uint32_t Custom32Add(uint32_t a, uint32_t b) { |
80 | // add operation |
81 | float acustom = Uint32ToCustom32<float>(a); |
82 | float bcustom = Uint32ToCustom32<float>(b); |
83 | return Custom32ToUint32<float>(acustom + bcustom); |
84 | } |
85 | |
86 | TVM_DLL uint32_t Custom32Sub(uint32_t a, uint32_t b) { |
87 | // subtract |
88 | float acustom = Uint32ToCustom32<float>(a); |
89 | float bcustom = Uint32ToCustom32<float>(b); |
90 | return Custom32ToUint32<float>(acustom - bcustom); |
91 | } |
92 | |
93 | TVM_DLL uint32_t Custom32Mul(uint32_t a, uint32_t b) { |
94 | // multiply |
95 | float acustom = Uint32ToCustom32<float>(a); |
96 | float bcustom = Uint32ToCustom32<float>(b); |
97 | return Custom32ToUint32<float>(acustom * bcustom); |
98 | } |
99 | |
100 | TVM_DLL uint32_t Custom32Div(uint32_t a, uint32_t b) { |
101 | // divide |
102 | float acustom = Uint32ToCustom32<float>(a); |
103 | float bcustom = Uint32ToCustom32<float>(b); |
104 | return Custom32ToUint32<float>(acustom / bcustom); |
105 | } |
106 | |
107 | TVM_DLL uint32_t Custom32Max(uint32_t a, uint32_t b) { |
108 | // max |
109 | float acustom = Uint32ToCustom32<float>(a); |
110 | float bcustom = Uint32ToCustom32<float>(b); |
111 | return Custom32ToUint32<float>(acustom > bcustom ? acustom : bcustom); |
112 | } |
113 | |
114 | TVM_DLL uint32_t Custom32Sqrt(uint32_t a) { |
115 | // sqrt |
116 | float acustom = Uint32ToCustom32<float>(a); |
117 | return Custom32ToUint32<float>(sqrt(acustom)); |
118 | } |
119 | |
120 | TVM_DLL uint32_t Custom32Exp(uint32_t a) { |
121 | // exponential |
122 | float acustom = Uint32ToCustom32<float>(a); |
123 | return Custom32ToUint32<float>(exp(acustom)); |
124 | } |
125 | |
126 | TVM_DLL uint32_t Custom32Log(uint32_t a) { |
127 | // log |
128 | float acustom = Uint32ToCustom32<float>(a); |
129 | return Custom32ToUint32<float>(log(acustom)); |
130 | } |
131 | |
132 | TVM_DLL uint32_t Custom32Sigmoid(uint32_t a) { |
133 | // sigmoid |
134 | float acustom = Uint32ToCustom32<float>(a); |
135 | float one = 1.0f; |
136 | return Custom32ToUint32<float>(one / (one + exp(-acustom))); |
137 | } |
138 | |
139 | TVM_DLL uint32_t Custom32Tanh(uint32_t a) { |
140 | // tanh |
141 | float acustom = Uint32ToCustom32<float>(a); |
142 | return Custom32ToUint32<float>(tanh(acustom)); |
143 | } |
144 | } |
145 | |