1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/ir/affine_type.cc |
22 | * \brief The Type information for quantized nodes. |
23 | */ |
24 | #include <tvm/ir/affine_type.h> |
25 | #include <tvm/runtime/registry.h> |
26 | #include <tvm/tir/op.h> |
27 | |
28 | namespace tvm { |
29 | |
30 | using tvm::ReprPrinter; |
31 | using namespace tvm::runtime; |
32 | |
33 | TensorAffineType::TensorAffineType(RelayExpr scale, RelayExpr zero_point, DataType dtype, |
34 | int axis) { |
35 | ObjectPtr<TensorAffineTypeNode> n = make_object<TensorAffineTypeNode>(); |
36 | n->scale = std::move(scale); |
37 | n->zero_point = std::move(zero_point); |
38 | n->dtype = std::move(dtype); |
39 | n->axis = std::move(axis); |
40 | data_ = std::move(n); |
41 | } |
42 | |
43 | TVM_REGISTER_NODE_TYPE(TensorAffineTypeNode); |
44 | |
45 | TVM_REGISTER_GLOBAL("ir.TensorAffineType" ) |
46 | .set_body_typed([](RelayExpr scale, RelayExpr zero_point, DataType dtype, int axis) { |
47 | return TensorAffineType(scale, zero_point, dtype, axis); |
48 | }); |
49 | |
50 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
51 | .set_dispatch<TensorAffineTypeNode>([](const ObjectRef& ref, ReprPrinter* p) { |
52 | auto* node = static_cast<const TensorAffineTypeNode*>(ref.get()); |
53 | p->stream << "TensorAffineType(" << node->scale << ", " << node->zero_point << ", " |
54 | << node->dtype << ", " << node->axis << ")" ; |
55 | }); |
56 | |
57 | TupleAffineType::TupleAffineType(Array<TensorAffineType> types) { |
58 | ObjectPtr<TupleAffineTypeNode> n = make_object<TupleAffineTypeNode>(); |
59 | n->types = std::move(types); |
60 | data_ = std::move(n); |
61 | } |
62 | |
63 | TVM_REGISTER_NODE_TYPE(TupleAffineTypeNode); |
64 | |
65 | TVM_REGISTER_GLOBAL("ir.TupleAffineType" ).set_body_typed([](Array<TensorAffineType> types) { |
66 | return TupleAffineType(types); |
67 | }); |
68 | |
69 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
70 | .set_dispatch<TupleAffineTypeNode>([](const ObjectRef& ref, ReprPrinter* p) { |
71 | auto* node = static_cast<const TupleAffineTypeNode*>(ref.get()); |
72 | p->stream << "TupleAffineType([" ; |
73 | for (size_t i = 0; i < node->types.size(); ++i) { |
74 | p->stream << node->types[i]; |
75 | if (i < node->types.size() - 1) { |
76 | p->stream << ", " ; |
77 | } |
78 | } |
79 | p->stream << "])" ; |
80 | }); |
81 | |
82 | } // namespace tvm |
83 | |