1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file math.cc
22 * \brief Math operators.
23 */
24#include <tvm/relay/expr.h>
25#include <tvm/relay/op.h>
26#include <tvm/topi/einsum.h>
27
28#include "../make_op.h"
29#include "../op_common.h"
30#include "../type_relations.h"
31
32namespace tvm {
33namespace relay {
34
35// relay.einsum
36TVM_REGISTER_NODE_TYPE(EinsumAttrs);
37
38bool EinsumRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
39 const TypeReporter& reporter) {
40 // Check attrs
41 const EinsumAttrs* param = attrs.as<EinsumAttrs>();
42 if (param == nullptr) {
43 reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
44 << "the call attributes are not defined");
45 return false;
46 }
47
48 // types: [data, result]
49 ICHECK_EQ(types.size(), 2) << "the arity of einsum is 2, not " << types.size();
50
51 // Check input type is a tuple.
52 const auto* tensor_tuple = types[0].as<TupleTypeNode>();
53 if (tensor_tuple == nullptr) {
54 reporter->GetDiagCtx().EmitFatal(
55 Diagnostic::Error(reporter->GetSpan())
56 << "einsum requires a tuple of tensors as the first argument, found "
57 << PrettyPrint(types[0]));
58 return false;
59 }
60
61 // Check the input tuple consists of tensors with consistent dtype.
62 const auto& first = Downcast<TensorType>(tensor_tuple->fields[0]);
63 const DataType dtype = first->dtype;
64 std::vector<Array<PrimExpr>> input_shapes;
65 for (const Type& ele : tensor_tuple->fields) {
66 if (ele.as<IncompleteTypeNode>()) {
67 return false;
68 }
69
70 const auto& e = Downcast<TensorType>(ele);
71
72 const DataType& e_dtype = e->dtype;
73 if (e_dtype != dtype) {
74 throw Error("relay.einsum requires all tensors have the same dtype");
75 }
76 input_shapes.push_back(e->shape);
77 }
78
79 // Calculate output shape
80 Array<IndexExpr> oshape = topi::InferEinsumShape(param->equation, input_shapes);
81
82 auto rtype = TensorType(oshape, dtype);
83 reporter->Assign(types[1], rtype);
84 return true;
85}
86
87Array<te::Tensor> EinsumCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
88 const Type& out_type) {
89 const EinsumAttrs* param = attrs.as<EinsumAttrs>();
90 ICHECK(param != nullptr);
91 return Array<te::Tensor>{topi::einsum(param->equation, inputs)};
92}
93
94Expr MakeEinsum(Expr data, String equation) {
95 auto attrs = make_object<EinsumAttrs>();
96 attrs->equation = std::move(equation);
97 static const Op& op = Op::Get("einsum");
98 return Call(op, {data}, Attrs(attrs), {});
99}
100
101TVM_REGISTER_GLOBAL("relay.op._make.einsum").set_body_typed(MakeEinsum);
102
103RELAY_REGISTER_OP("einsum")
104 .describe(R"doc(Evaluates the Einstein summation convention
105on the operands)doc" TVM_ADD_FILELINE)
106 .set_attrs_type<EinsumAttrs>()
107 .set_num_inputs(1)
108 .add_argument("data", "Tuple of Tensors", "The input list of tensors.")
109 .set_support_level(11)
110 .add_type_rel("Einsum", EinsumRel)
111 .set_attr<FTVMCompute>("FTVMCompute", EinsumCompute)
112 .set_attr<TOpPattern>("TOpPattern", kInjective);
113
114} // namespace relay
115} // namespace tvm
116