1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file x86/bnn.h
22 * \brief x86 schedule for binary operations
23 */
24#ifndef TVM_TOPI_X86_BNN_H_
25#define TVM_TOPI_X86_BNN_H_
26
27#include <tvm/target/generic_func.h>
28#include <tvm/te/operation.h>
29#include <tvm/topi/detail/fuse.h>
30#include <tvm/topi/tags.h>
31
32namespace tvm {
33namespace topi {
34
35using namespace tvm::te;
36
37namespace x86 {
38/*!
39 * \brief Create a generic schedule for binarize_pack
40 *
41 * \param target The target to generate a schedule for.
42 * \param outs The output tensors.
43 *
44 * \return A schedule for the given ops.
45 */
46inline Schedule schedule_binarize_pack(const Target& target, const Array<Tensor>& outs) {
47 Array<Operation> out_ops;
48 for (auto t : outs) {
49 out_ops.push_back(t->op);
50 }
51 auto s = create_schedule(out_ops);
52
53 auto _schedule = [&](const Tensor& out) {
54 s[out].parallel(out->op.as<ComputeOpNode>()->axis[0]);
55 };
56
57 std::function<void(Operation)> traverse;
58 traverse = [&](const Operation& op) {
59 if (op->tag == "binarize_pack") {
60 _schedule(op.output(0));
61 } else {
62 LOG(ERROR) << "Unsupported operator " << op->tag;
63 }
64 };
65
66 traverse(outs[0]->op);
67 return s;
68}
69
70/*!
71 * \brief Create a generic schedule for binary_dense
72 *
73 * \param target The target to generate a schedule for.
74 * \param outs The output tensors.
75 *
76 * \return A schedule for the given ops.
77 */
78inline Schedule schedule_binary_dense(const Target& target, const Array<Tensor>& outs) {
79 Array<Operation> out_ops;
80 for (auto t : outs) {
81 out_ops.push_back(t->op);
82 }
83 auto s = create_schedule(out_ops);
84
85 auto _schedule = [&](const Tensor& A, const Tensor& B, const Tensor& C) {
86 IterVar co, ci;
87 s[C].split(s[C]->op.as<ComputeOpNode>()->reduce_axis[0], 8, &co, &ci);
88 s[C].parallel(s[C]->op.as<ComputeOpNode>()->axis[0]);
89
90 Tensor out;
91 if (detail::contains(s->outputs, C->op)) {
92 out = C;
93 } else {
94 out = outs[0]->op.output(0);
95 }
96
97 IterVar xo, xi;
98 s[out].split(out->op.as<ComputeOpNode>()->axis[1], 8, &xo, &xi);
99 s[out].vectorize(xi);
100 };
101
102 std::function<void(Operation)> traverse;
103 traverse = [&](const Operation& op) {
104 // Inline all one-to-one-mapping operators except the last stage (output)
105 if (is_broadcast(op->tag)) {
106 if (!detail::contains(s->outputs, op)) {
107 s[op].compute_inline();
108 }
109 for (auto tensor : op->InputTensors()) {
110 if (tensor->op->InputTensors().size() > 0) {
111 traverse(tensor->op);
112 }
113 }
114 } else if (op->tag == "binary_dense") {
115 auto output = op.output(0);
116 auto data = op->InputTensors()[0];
117 auto weight = op->InputTensors()[1];
118 _schedule(data, weight, output);
119 } else {
120 LOG(ERROR) << "Unsupported operator " << op->tag;
121 }
122 };
123
124 traverse(outs[0]->op);
125 return s;
126}
127
128} // namespace x86
129} // namespace topi
130} // namespace tvm
131#endif // TVM_TOPI_X86_BNN_H_
132