1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/support/scalars.h
22 * \brief Helpers for converting between scalars in native, text, TIR immediate and NDArray forms.
23 */
24
25#ifndef TVM_SUPPORT_SCALARS_H_
26#define TVM_SUPPORT_SCALARS_H_
27
28#include <string>
29#include <utility>
30
31#include "tvm/ir/expr.h"
32#include "tvm/relay/expr.h"
33#include "tvm/runtime/ndarray.h"
34
35namespace tvm {
36namespace support {
37
38/*! \brief Returns true if a tensor of empty shape and given dtype is considered a Relay scalar. */
39bool IsSimpleScalarDtype(DataType dtype);
40
41/*! \brief Returns true if \p constant_node is a float/int/bool scalar. */
42bool IsSimpleScalar(const relay::ConstantNode* constant_node);
43
44/*! \brief Returns NDArray 'scalar' for given TIR immediate. */
45runtime::NDArray IntImmToNDArray(const IntImm& int_imm);
46runtime::NDArray FloatImmToNDArray(const FloatImm& float_imm);
47runtime::NDArray BoolToNDArray(bool value);
48
49/*! \brief Returns Relay literal text for NDArray 'scalar'. */
50std::string NDArrayScalarToString(const runtime::NDArray& data);
51
52/*! \brief Returns Relay literal text for given TIR immediate. */
53std::string IntImmToString(const IntImm& int_imm);
54std::string FloatImmToString(const FloatImm& float_imm);
55
56/*!
57 * \brief Returns TIR immediate for given value and width. Result will be null if value is
58 * out of range in width. Note however for floating point we don't check if the value is
59 * representable without loss of precision.
60 */
61IntImm ValueToIntImm(int64_t value, int width);
62FloatImm ValueToFloatImm(double value, int width);
63
64// 2^15 * (1 + 1023/1024)
65// See https://en.wikipedia.org/wiki/Half-precision_floating-point_format
66constexpr double kMaxFloat16 = 65504.0;
67
68} // namespace support
69} // namespace tvm
70
71#endif // TVM_SUPPORT_SCALARS_H_
72