1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file math.cc |
22 | * \brief Math operators. |
23 | */ |
24 | #include <tvm/relay/expr.h> |
25 | #include <tvm/relay/op.h> |
26 | #include <tvm/topi/einsum.h> |
27 | |
28 | #include "../make_op.h" |
29 | #include "../op_common.h" |
30 | #include "../type_relations.h" |
31 | |
32 | namespace tvm { |
33 | namespace relay { |
34 | |
35 | // relay.einsum |
36 | TVM_REGISTER_NODE_TYPE(EinsumAttrs); |
37 | |
38 | bool EinsumRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
39 | const TypeReporter& reporter) { |
40 | // Check attrs |
41 | const EinsumAttrs* param = attrs.as<EinsumAttrs>(); |
42 | if (param == nullptr) { |
43 | reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) |
44 | << "the call attributes are not defined" ); |
45 | return false; |
46 | } |
47 | |
48 | // types: [data, result] |
49 | ICHECK_EQ(types.size(), 2) << "the arity of einsum is 2, not " << types.size(); |
50 | |
51 | // Check input type is a tuple. |
52 | const auto* tensor_tuple = types[0].as<TupleTypeNode>(); |
53 | if (tensor_tuple == nullptr) { |
54 | reporter->GetDiagCtx().EmitFatal( |
55 | Diagnostic::Error(reporter->GetSpan()) |
56 | << "einsum requires a tuple of tensors as the first argument, found " |
57 | << PrettyPrint(types[0])); |
58 | return false; |
59 | } |
60 | |
61 | // Check the input tuple consists of tensors with consistent dtype. |
62 | const auto& first = Downcast<TensorType>(tensor_tuple->fields[0]); |
63 | const DataType dtype = first->dtype; |
64 | std::vector<Array<PrimExpr>> input_shapes; |
65 | for (const Type& ele : tensor_tuple->fields) { |
66 | if (ele.as<IncompleteTypeNode>()) { |
67 | return false; |
68 | } |
69 | |
70 | const auto& e = Downcast<TensorType>(ele); |
71 | |
72 | const DataType& e_dtype = e->dtype; |
73 | if (e_dtype != dtype) { |
74 | throw Error("relay.einsum requires all tensors have the same dtype" ); |
75 | } |
76 | input_shapes.push_back(e->shape); |
77 | } |
78 | |
79 | // Calculate output shape |
80 | Array<IndexExpr> oshape = topi::InferEinsumShape(param->equation, input_shapes); |
81 | |
82 | auto rtype = TensorType(oshape, dtype); |
83 | reporter->Assign(types[1], rtype); |
84 | return true; |
85 | } |
86 | |
87 | Array<te::Tensor> EinsumCompute(const Attrs& attrs, const Array<te::Tensor>& inputs, |
88 | const Type& out_type) { |
89 | const EinsumAttrs* param = attrs.as<EinsumAttrs>(); |
90 | ICHECK(param != nullptr); |
91 | return Array<te::Tensor>{topi::einsum(param->equation, inputs)}; |
92 | } |
93 | |
94 | Expr MakeEinsum(Expr data, String equation) { |
95 | auto attrs = make_object<EinsumAttrs>(); |
96 | attrs->equation = std::move(equation); |
97 | static const Op& op = Op::Get("einsum" ); |
98 | return Call(op, {data}, Attrs(attrs), {}); |
99 | } |
100 | |
101 | TVM_REGISTER_GLOBAL("relay.op._make.einsum" ).set_body_typed(MakeEinsum); |
102 | |
103 | RELAY_REGISTER_OP("einsum" ) |
104 | .describe(R"doc(Evaluates the Einstein summation convention |
105 | on the operands)doc" TVM_ADD_FILELINE) |
106 | .set_attrs_type<EinsumAttrs>() |
107 | .set_num_inputs(1) |
108 | .add_argument("data" , "Tuple of Tensors" , "The input list of tensors." ) |
109 | .set_support_level(11) |
110 | .add_type_rel("Einsum" , EinsumRel) |
111 | .set_attr<FTVMCompute>("FTVMCompute" , EinsumCompute) |
112 | .set_attr<TOpPattern>("TOpPattern" , kInjective); |
113 | |
114 | } // namespace relay |
115 | } // namespace tvm |
116 | |