1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file 3rdparty/byodt/my-custom-datatype.cc
22 * \brief Example Custom Datatype with the Bring Your Own Datatypes (BYODT) framework.
23 * This is a toy example that under the hood simulates floats.
24 *
25 * Users interested in using the BYODT framework can use this file as a template.
26 *
27 * TODO(@gussmith23 @hypercubestart) Link to BYODT docs when they exist?
28 */
29#include <tvm/runtime/c_runtime_api.h>
30
31#include <cmath>
32#include <cstdint>
33#include <limits>
34
35// Custom datatypes are stored as bits in a uint of the appropriate bit length.
36// Thus, when TVM calls these C functions,
37// the arguments of are uints that need to reinterpreted as your custom datatype.
38//
39// When returning, your custom datatype needs to be re-wrapped into a uint,
40// which can be thought of as just a wrapper for the raw bits that represent your custom datatype.
41template <class T>
42TVM_DLL T Uint32ToCustom32(uint32_t in) {
43 // This is a helper function to interpret the uint as your custom dataype.
44 // The following line should be replaced with the appropriate function
45 // that interprets the bits in `in` and returns your custom datatype
46 T* custom = reinterpret_cast<T*>(&in);
47 return *custom;
48}
49
50template <class T>
51TVM_DLL uint32_t Custom32ToUint32(T in) {
52 // This is a helper function to wrap your custom datatype in a uint.
53 // the following line should be replaced with the appropriate function
54 // that converts your custom datatype into a uint
55 uint32_t* bits = reinterpret_cast<uint32_t*>(&in);
56 return *bits;
57}
58
59extern "C" {
60TVM_DLL uint32_t MinCustom32() {
61 // return minimum representable value
62 float min = std::numeric_limits<float>::lowest();
63 return Custom32ToUint32<float>(min);
64}
65
66TVM_DLL float Custom32ToFloat(uint32_t in) {
67 // cast from custom datatype to float
68 float custom_datatype = Uint32ToCustom32<float>(in);
69 // our custom datatype is float, so the following redundant cast to float
70 // is to remind users to cast their own custom datatype to float
71 return static_cast<float>(custom_datatype);
72}
73
74TVM_DLL uint32_t FloatToCustom32(float in) {
75 // cast from float to custom datatype
76 return Custom32ToUint32<float>(in);
77}
78
79TVM_DLL uint32_t Custom32Add(uint32_t a, uint32_t b) {
80 // add operation
81 float acustom = Uint32ToCustom32<float>(a);
82 float bcustom = Uint32ToCustom32<float>(b);
83 return Custom32ToUint32<float>(acustom + bcustom);
84}
85
86TVM_DLL uint32_t Custom32Sub(uint32_t a, uint32_t b) {
87 // subtract
88 float acustom = Uint32ToCustom32<float>(a);
89 float bcustom = Uint32ToCustom32<float>(b);
90 return Custom32ToUint32<float>(acustom - bcustom);
91}
92
93TVM_DLL uint32_t Custom32Mul(uint32_t a, uint32_t b) {
94 // multiply
95 float acustom = Uint32ToCustom32<float>(a);
96 float bcustom = Uint32ToCustom32<float>(b);
97 return Custom32ToUint32<float>(acustom * bcustom);
98}
99
100TVM_DLL uint32_t Custom32Div(uint32_t a, uint32_t b) {
101 // divide
102 float acustom = Uint32ToCustom32<float>(a);
103 float bcustom = Uint32ToCustom32<float>(b);
104 return Custom32ToUint32<float>(acustom / bcustom);
105}
106
107TVM_DLL uint32_t Custom32Max(uint32_t a, uint32_t b) {
108 // max
109 float acustom = Uint32ToCustom32<float>(a);
110 float bcustom = Uint32ToCustom32<float>(b);
111 return Custom32ToUint32<float>(acustom > bcustom ? acustom : bcustom);
112}
113
114TVM_DLL uint32_t Custom32Sqrt(uint32_t a) {
115 // sqrt
116 float acustom = Uint32ToCustom32<float>(a);
117 return Custom32ToUint32<float>(sqrt(acustom));
118}
119
120TVM_DLL uint32_t Custom32Exp(uint32_t a) {
121 // exponential
122 float acustom = Uint32ToCustom32<float>(a);
123 return Custom32ToUint32<float>(exp(acustom));
124}
125
126TVM_DLL uint32_t Custom32Log(uint32_t a) {
127 // log
128 float acustom = Uint32ToCustom32<float>(a);
129 return Custom32ToUint32<float>(log(acustom));
130}
131
132TVM_DLL uint32_t Custom32Sigmoid(uint32_t a) {
133 // sigmoid
134 float acustom = Uint32ToCustom32<float>(a);
135 float one = 1.0f;
136 return Custom32ToUint32<float>(one / (one + exp(-acustom)));
137}
138
139TVM_DLL uint32_t Custom32Tanh(uint32_t a) {
140 // tanh
141 float acustom = Uint32ToCustom32<float>(a);
142 return Custom32ToUint32<float>(tanh(acustom));
143}
144}
145