1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file fold_scale_axis.cc
22 *
23 * \brief Fold axis scaling into weights of
24 * conv/dense operators.
25 */
26#include <tvm/relay/analysis.h>
27#include <tvm/relay/attrs/nn.h>
28#include <tvm/relay/expr_functor.h>
29#include <tvm/relay/transform.h>
30#include <tvm/tir/data_layout.h>
31
32#include "../backend/utils.h"
33#include "../op/tensor/transform.h"
34#include "pass_utils.h"
35#include "pattern_utils.h"
36
37namespace tvm {
38namespace relay {
39/*!
40 * \brief namespace of fold scale axis
41 *
42 * Use namespace to reduce potential naming conflict.
43 */
44
45namespace fold_scale_axis {
46
47using runtime::TypedPackedFunc;
48
49// FoldScaleAxis algorithm:
50//
51// The general idea is to transform Expr to tuple of
52// (value, axes, scale), where the final result satisfies:
53//
54// result = value
55// for i, k in enumerate(axes):
56// k-th dimension of result *= i-th dimension of scale
57//
58// Then we can propagate this signal along and fold the scale if necessary.
59// However, it is possible that certain scale may never be consumed
60// if there is no dense/conv2d that follows multiplication.
61//
62// In order to make sure all the scale we sent out can be consumed eventually,
63// we run a backward "preparation phase", which propagates the demand
64// of the potential axes scaling back to its input.
65//
66// Forward folding process is done in two steps:
67// - Prepare phase: backward propagation of demand.
68// - Transform phase: forward transformation,
69//
70// Similarly, backward folding process is done in two steps:
71// - Prepare phase: forward propagation of demand.
72// - Transform phase: transformation by push down the axes scale signal to inputs.
73//
74
75/*!
76 * \brief sorted array axis, can also be nullptr.
77 *
78 * nullptr means no scaling request can be done.
79 */
80using AxesSet = Array<Integer>;
81
82class Message;
83
84/*!
85 * \brief Message propogated during the prepare phase.
86 */
87class MessageNode : public RelayNode {
88 public:
89 /*! \brief Axes for scaling */
90 AxesSet axes;
91 /*!
92 * \brief Whether folding requires the scale to be positive constant. This is necessary if some
93 * operators (e.g. Relu) is present.
94 */
95 bool require_positive;
96
97 static constexpr const char* _type_key = "relay.pass.fold_scale_axis.Message";
98 TVM_DECLARE_FINAL_OBJECT_INFO(MessageNode, RelayNode);
99};
100
101class Message : public ObjectRef {
102 public:
103 /*!
104 * \brief The constructor
105 * \param axes Axes for scaling
106 * \param require_positive If folding requires the scales to be positive
107 * values.
108 */
109 Message(const AxesSet& axes, bool require_positive);
110
111 TVM_DEFINE_OBJECT_REF_METHODS(Message, ObjectRef, MessageNode);
112};
113
114Message::Message(const AxesSet& axes, bool require_positive) {
115 auto n = make_object<MessageNode>();
116 n->axes = axes;
117 n->require_positive = require_positive;
118 data_ = std::move(n);
119}
120
121/*!
122 * \brief Merge two axis set together by taking
123 * intersection.
124 *
125 * \note The axes in a AxesSet should be sorted.
126 *
127 * \param lhs The left axis.
128 * \param rhs The right axis.
129 * \return The result of the inersection.
130 */
131AxesSet Intersect(const AxesSet& lhs, const AxesSet& rhs) {
132 if (!lhs.defined()) return lhs;
133 if (!rhs.defined()) return rhs;
134 // This code relies on axes in a AxesSet to be sorted.
135 AxesSet ret;
136 size_t i = 0, j = 0;
137 while (i < lhs.size() && j < rhs.size()) {
138 if (lhs[i]->value < rhs[j]->value) {
139 ++i;
140 } else if (lhs[i]->value > rhs[j]->value) {
141 ++j;
142 } else {
143 ret.push_back(lhs[i]);
144 ++i;
145 ++j;
146 }
147 }
148 return ret;
149}
150
151/*!
152 * \brief Merge two messages together by taking intersection.
153 *
154 * \param lhs The lhs message.
155 * \param rhs The rhs message.
156 * \return The result of intersection.
157 */
158Message Intersect(const Message& lhs, const Message& rhs) {
159 if (!lhs.defined()) return lhs;
160 if (!rhs.defined()) return rhs;
161 auto axes = Intersect(lhs->axes, rhs->axes);
162 return Message(axes, lhs->require_positive || rhs->require_positive);
163}
164
165/*!
166 * \brief Preparation function for pass scale forward.
167 * \param call The call node.
168 * \param out_message Message from the output containing possible scaling on axes and whether
169 * positive scale is required.
170 * \return The message containing the result scaling on axes of the input.
171 */
172using FForwardPrep =
173 runtime::TypedPackedFunc<Array<Message>(const Call& call, const Message& out_message)>;
174
175/*! \brief Axis scale tuple. */
176class ScaledExprNode : public TempExprNode {
177 public:
178 /*! \brief The value */
179 Expr value;
180 /*! \brief The axes to scale, can be nullptr(means no-scaling) */
181 AxesSet axes = NullValue<AxesSet>();
182 /*! \brief The scaling factor */
183 Expr scale = NullValue<Expr>();
184
185 Expr Realize() const final {
186 ICHECK(!axes.defined()) << "outstanding scale";
187 return value;
188 }
189
190 void VisitAttrs(AttrVisitor* v) {
191 v->Visit("value", &value);
192 v->Visit("axes", &axes);
193 v->Visit("scale", &scale);
194 }
195
196 static constexpr const char* _type_key = "relay.fold_scale_axis.ScaledExpr";
197 TVM_DECLARE_FINAL_OBJECT_INFO(ScaledExprNode, TempExprNode);
198};
199
200using FForwardRewrite = TypedPackedFunc<Expr(const Call& ref_call, const Array<Expr>& new_args,
201 const Message& message)>;
202
203//----------------------------------------------
204// Generic Visitors for FScaleAxisForward
205//----------------------------------------------
206class ForwardPrep : private MixedModeVisitor {
207 public:
208 std::unordered_map<const Object*, Message> Prepare(const Expr& body) {
209 this->Update(body, NullValue<Message>());
210 this->VisitExpr(body);
211 // flist is added in the Post-DFS order
212 // which is a special case of topological order.
213 // We reversely traverse the list to invoke the lazy functions.
214 // This act like a backprop of valid scale axis messages
215 for (auto it = flist_.rbegin(); it != flist_.rend(); ++it) {
216 (*it)();
217 }
218 // return the created message;
219 return std::move(message_);
220 }
221
222 private:
223 // The invoke list
224 std::vector<std::function<void()>> flist_;
225 // The message on each node.
226 std::unordered_map<const Object*, Message> message_;
227 // Update the message stored at node.
228 void Update(const Expr& node, const Message& message) {
229 // We run intersection of messages:
230 //
231 // %y = multiply(%x, %scale)
232 // %z1 = conv2d(%y, %w)
233 // %z2 = exp(%y)
234 //
235 // Consider the above code example,
236 // because %z2 will propagate null to %y,
237 // the AxesSet on %y is also null,
238 // and the forward folding won't be triggered.
239 const Object* key = node.get();
240 if (message_.count(key)) {
241 message_[key] = Intersect(message_[key], message);
242 } else {
243 message_[key] = message;
244 }
245 }
246
247 // We intended the following overrides on implementations from ExprVisitor.
248 using MixedModeVisitor::VisitExpr_;
249
250 // Visitor pattern override.
251 void VisitExpr_(const TupleGetItemNode* op) final { MixedModeVisitor::VisitExpr_(op); }
252
253 void VisitExpr_(const LetNode* op) final {
254 ExprVisitor::VisitExpr_(op);
255 // do pass through condition
256 // by assigning NullValue<Message>
257 // it means fuse signal cannot pass
258 // through into these subexpressions.
259 auto flazy = [this, op]() {
260 this->Update(op->value, NullValue<Message>());
261 this->Update(op->body, NullValue<Message>());
262 };
263 flist_.push_back(flazy);
264 }
265
266 void VisitExpr_(const FunctionNode* op) final {
267 ExprVisitor::VisitExpr_(op);
268 auto flazy = [this, op] { this->Update(op->body, NullValue<Message>()); };
269 flist_.push_back(flazy);
270 }
271
272 void VisitExpr_(const CallNode* call) final {
273 ExprVisitor::VisitExpr_(call);
274 // function to be lazily invoked
275 auto flazy = [this, call]() {
276 static const auto& fprep = Op::GetAttrMap<FForwardPrep>("FScaleAxisForwardPrep");
277 // find the message send to this node.
278 auto it = message_.find(call);
279 Message out_message;
280 if (it != message_.end()) {
281 out_message = it->second;
282 } else {
283 out_message = NullValue<Message>();
284 }
285 // pass the message back to all the children it references.
286 auto f = fprep.get(call->op, nullptr);
287 if (f != nullptr) {
288 Array<Message> in_messages = f(GetRef<Call>(call), out_message);
289 ICHECK_EQ(in_messages.size(), call->args.size());
290 for (size_t i = 0; i < call->args.size(); ++i) {
291 this->Update(call->args[i], in_messages[i]);
292 }
293 } else {
294 for (size_t i = 0; i < call->args.size(); ++i) {
295 this->Update(call->args[i], NullValue<Message>());
296 }
297 }
298 };
299 flist_.push_back(flazy);
300 }
301
302 void VisitExpr_(const TupleNode* op) final {
303 ExprVisitor::VisitExpr_(op);
304 // do not support pass scale through tuple for now.
305 auto flazy = [this, op]() {
306 for (const Expr& field : op->fields) {
307 this->Update(field, NullValue<Message>());
308 }
309 };
310 flist_.push_back(flazy);
311 }
312
313 void VisitExpr_(const IfNode* op) final {
314 ExprVisitor::VisitExpr_(op);
315 // do pass through condition
316 // by assigning NullValue<Message>
317 // it means fuse signal cannot pass
318 // through into these subexpressions.
319 auto flazy = [this, op]() {
320 this->Update(op->cond, NullValue<Message>());
321 this->Update(op->true_branch, NullValue<Message>());
322 this->Update(op->false_branch, NullValue<Message>());
323 };
324 flist_.push_back(flazy);
325 }
326};
327
328static bool IsIntInArray(const Array<Integer>& axis, int v) {
329 for (size_t i = 0; i < axis.size(); i++) {
330 if (axis[i] == v) return true;
331 }
332 return false;
333}
334
335static Expr ReshapeToMatchAxis(Expr scale, const Array<PrimExpr>& shape,
336 const Array<Integer>& axis) {
337 Array<Integer> arr;
338 for (size_t i = 0; i < shape.size(); i++) {
339 if (IsIntInArray(axis, i)) {
340 auto node = shape[i].as<IntImmNode>();
341 if (!node) {
342 // if the shape is not a constant, use normal transform
343 return Expr();
344 }
345 arr.push_back(node->value);
346 } else {
347 arr.push_back(1);
348 }
349 }
350 return MakeReshape(scale, std::move(arr));
351}
352
353// if only one axis, use expand dim. Else, use reshape
354static Expr ReshapeOrExpandToMatchAxis(Expr scale, const Array<PrimExpr>& shape,
355 const Array<Integer>& axis) {
356 if (axis.size() > 1) {
357 return ReshapeToMatchAxis(scale, shape, axis);
358 } else {
359 return ExpandBiasToMatchAxis(scale, shape.size(), axis);
360 }
361}
362
363//----------------------------------------------
364// Per operator defs for FScaleAxisForward
365//----------------------------------------------
366
367// Intermediate operators
368Array<Message> ReluForwardPrep(const Call& call, const Message& out_message) {
369 if (out_message.defined()) {
370 return {Message(out_message->axes, true)};
371 }
372 return {out_message};
373}
374
375Expr ReluForwardRewrite(const Call& ref_call, const Array<Expr>& new_args, const Message& message) {
376 const auto* input = new_args[0].as<ScaledExprNode>();
377 if (input == nullptr) return Expr(nullptr);
378 // return transformed conv2d
379 auto rnode = make_object<ScaledExprNode>();
380 rnode->value = Call(ref_call->op, {input->value}, ref_call->attrs, ref_call->type_args);
381 rnode->scale = input->scale;
382 rnode->axes = input->axes;
383 return Expr(rnode);
384}
385
386RELAY_REGISTER_OP("nn.relu").set_attr<FForwardPrep>("FScaleAxisForwardPrep", ReluForwardPrep);
387
388RELAY_REGISTER_OP("nn.relu").set_attr<FForwardRewrite>("FScaleAxisForwardRewrite",
389 ReluForwardRewrite);
390
391RELAY_REGISTER_OP("nn.leaky_relu").set_attr<FForwardPrep>("FScaleAxisForwardPrep", ReluForwardPrep);
392
393RELAY_REGISTER_OP("nn.leaky_relu")
394 .set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", ReluForwardRewrite);
395
396// AddSub
397Array<Message> AddSubForwardPrep(const Call& call, const Message& out_message) {
398 const auto* tlhs = call->args[0]->type_as<TensorTypeNode>();
399 const auto* trhs = call->args[1]->type_as<TensorTypeNode>();
400 auto none = NullValue<Message>();
401 if (out_message.defined()) {
402 if (MatchBroadcastToLeftAxes(tlhs, trhs, out_message->axes)) {
403 return {out_message, none};
404 } else if (MatchBroadcastToLeftAxes(trhs, tlhs, out_message->axes)) {
405 return {none, out_message};
406 }
407 }
408 return {none, none};
409}
410
411Expr AddSubForwardRewrite(const Call& ref_call, const Array<Expr>& new_args,
412 const Message& message) {
413 const auto* slhs = new_args[0].as<ScaledExprNode>();
414 const auto* srhs = new_args[1].as<ScaledExprNode>();
415 if (!slhs && !srhs) return Expr();
416 const auto* tlhs = ref_call->args[0]->type_as<TensorTypeNode>();
417 const auto* trhs = ref_call->args[1]->type_as<TensorTypeNode>();
418 auto rnode = make_object<ScaledExprNode>();
419
420 if (slhs != nullptr) {
421 ICHECK(srhs == nullptr);
422 ICHECK(MatchBroadcastToLeftAxes(tlhs, trhs, slhs->axes));
423 Expr scale = ReshapeOrExpandToMatchAxis(slhs->scale, tlhs->shape, slhs->axes);
424 if (!scale.defined()) {
425 return Expr();
426 }
427 Expr rhs = Divide(new_args[1], scale);
428 rnode->value = Call(ref_call->op, {slhs->value, rhs}, ref_call->attrs, ref_call->type_args);
429 rnode->scale = slhs->scale;
430 rnode->axes = slhs->axes;
431 } else {
432 ICHECK(srhs != nullptr);
433 ICHECK(MatchBroadcastToLeftAxes(trhs, tlhs, srhs->axes));
434 Expr scale = ReshapeOrExpandToMatchAxis(srhs->scale, trhs->shape, srhs->axes);
435 if (!scale.defined()) {
436 return Expr();
437 }
438 Expr lhs = Divide(new_args[0], scale);
439 rnode->value = Call(ref_call->op, {lhs, srhs->value}, ref_call->attrs, ref_call->type_args);
440 rnode->scale = srhs->scale;
441 rnode->axes = srhs->axes;
442 }
443 return Expr(rnode);
444}
445
446RELAY_REGISTER_OP("add").set_attr<FForwardPrep>("FScaleAxisForwardPrep", AddSubForwardPrep);
447
448RELAY_REGISTER_OP("add").set_attr<FForwardRewrite>("FScaleAxisForwardRewrite",
449 AddSubForwardRewrite);
450
451RELAY_REGISTER_OP("subtract").set_attr<FForwardPrep>("FScaleAxisForwardPrep", AddSubForwardPrep);
452
453RELAY_REGISTER_OP("subtract")
454 .set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", AddSubForwardRewrite);
455
456// Producer operators
457// Multiply produces the scale-axis pair.
458Expr MultiplyForwardRewrite(const Call& ref_call, const Array<Expr>& new_args,
459 const Message& message) {
460 if (!message.defined()) return Expr();
461 const auto& expected_out_axes = message->axes;
462 ICHECK(expected_out_axes.defined() && expected_out_axes.size());
463 // TODO(tvm-team) allow same axes accumulation
464 // not as important because it is less common in nn.
465 const auto* slhs = new_args[0].as<ScaledExprNode>();
466 const auto* srhs = new_args[1].as<ScaledExprNode>();
467 ICHECK(!slhs && !srhs);
468
469 const auto* tlhs = ref_call->args[0]->type_as<TensorTypeNode>();
470 const auto* trhs = ref_call->args[1]->type_as<TensorTypeNode>();
471 Expr lhs = new_args[0];
472 Expr rhs = new_args[1];
473 auto rnode = make_object<ScaledExprNode>();
474
475 if (MatchBroadcastToLeftAxes(tlhs, trhs, expected_out_axes, &rhs) &&
476 (!message->require_positive || IsAllPositiveConstant(rhs))) {
477 rnode->value = lhs;
478 rnode->scale = rhs;
479 rnode->axes = expected_out_axes;
480 return Expr(rnode);
481 } else if (MatchBroadcastToLeftAxes(trhs, tlhs, expected_out_axes, &lhs) &&
482 (!message->require_positive || IsAllPositiveConstant(lhs))) {
483 rnode->value = rhs;
484 rnode->scale = lhs;
485 rnode->axes = expected_out_axes;
486 return Expr(rnode);
487 } else {
488 return Expr();
489 }
490}
491
492RELAY_REGISTER_OP("multiply")
493 .set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", MultiplyForwardRewrite);
494
495// Consumer operators
496// Conv send out requirement of axis folding.
497template <typename ATTRS>
498Array<Message> ConvForwardPrep(const Call& call, const ATTRS* param, const Message& out_message) {
499 // TODO(tvm-team) support general data layout
500 // by transforming weight
501 ICHECK(param != nullptr);
502 Layout data_layout(param->data_layout);
503 Layout kernel_layout(param->kernel_layout);
504 int c_big_axis = data_layout.IndexOf(LayoutAxis::Get('C'));
505 int c_small_axis = data_layout.IndexOf(LayoutAxis::Get('c'));
506
507 ICHECK_GE(c_big_axis, 0);
508 Message none = NullValue<Message>();
509 // For now, we only support simple pattern (no folded weight/data)
510 // More general layout can be supported under the current framework.
511 // By using a unified layout transformation.
512 // We only need to change the Prep and Mutate function.
513 //
514 // only handle depthwise or full conv2d.
515 // TODO(tvm-team) handle grouped conv by reshape + bcast
516 bool is_depthwise_conv = IsDepthwiseConv(call, param, kernel_layout);
517 if (param->groups == 1 || is_depthwise_conv) {
518 auto ko_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('o'));
519 auto ki_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('i'));
520 if ((ko_small_axis < 0 && ki_small_axis < 0 && c_small_axis < 0) || // simple layout
521 (ko_small_axis >= 0 && ki_small_axis >= 0 && c_small_axis >= 0)) { // blocked layout
522 Array<Integer> arr{c_big_axis};
523 if (c_small_axis >= 0) {
524 arr.push_back(c_small_axis);
525 }
526 return {Message(arr, false), none};
527 }
528 }
529 return {none, none};
530}
531
532// Conv2D consumes the scale axis during transformation.
533template <typename ATTRS>
534Expr ConvForwardRewrite(const Call& ref_call, const ATTRS* param, const Array<Expr>& new_args,
535 const Message& message) {
536 // if data do not have scale, normal transform path.
537 const auto* sdata = new_args[0].as<ScaledExprNode>();
538 const auto* sweight = new_args[1].as<ScaledExprNode>();
539 if (sdata == nullptr) return Expr();
540 if (sweight != nullptr) return Expr();
541 ICHECK(param != nullptr);
542 Layout data_layout(param->data_layout);
543 Layout kernel_layout(param->kernel_layout);
544 int c_big_axis = data_layout.IndexOf(LayoutAxis::Get('C'));
545 ICHECK_GE(c_big_axis, 0);
546 int small_ko_axis = kernel_layout.IndexOf(LayoutAxis::Get('o'));
547 int small_ki_axis = kernel_layout.IndexOf(LayoutAxis::Get('i'));
548 int big_ki_axis = kernel_layout.IndexOf(LayoutAxis::Get('I'));
549 int big_ko_axis = kernel_layout.IndexOf(LayoutAxis::Get('O'));
550
551 bool is_simple = (small_ko_axis < 0 && small_ki_axis < 0 && big_ki_axis >= 0);
552 bool is_blocking = (small_ko_axis >= 0 && small_ki_axis >= 0 && big_ki_axis >= 0);
553 ICHECK(is_simple || is_blocking);
554
555 // Check it must be depthwise or full conv2d.
556 bool is_depthwise_conv = IsDepthwiseConv(ref_call, param, kernel_layout);
557 ICHECK(param->groups == 1 || is_depthwise_conv);
558
559 Expr weight = new_args[1];
560
561 // match the ic_axis
562 if (is_depthwise_conv) {
563 if (is_simple) {
564 Expr scale = ExpandBiasToMatchAxis(sdata->scale, kernel_layout.ndim(), {big_ko_axis});
565 weight = Multiply(weight, scale);
566 } else {
567 weight = Multiply(weight,
568 ReshapeToMatchAxis(sdata->scale, weight->type_as<TensorTypeNode>()->shape,
569 {big_ko_axis, small_ko_axis}));
570 if (!weight.defined()) return Expr();
571 }
572
573 } else {
574 if (is_simple) {
575 Expr scale = ExpandBiasToMatchAxis(sdata->scale, kernel_layout.ndim(), {big_ki_axis});
576 weight = Multiply(weight, scale);
577 } else {
578 weight = Multiply(weight,
579 ReshapeToMatchAxis(sdata->scale, weight->type_as<TensorTypeNode>()->shape,
580 {big_ki_axis, small_ki_axis}));
581 if (!weight.defined()) return Expr();
582 }
583 }
584 // return transformed conv
585 return Call(ref_call->op, {sdata->value, weight}, ref_call->attrs, ref_call->type_args);
586}
587
588Array<Message> PreConvForwardPrep(const Call& call, const Message& out_message) {
589 if (backend::IsOp(call.as<CallNode>(), "nn.conv2d")) {
590 const auto* param = call->attrs.as<Conv2DAttrs>();
591 ICHECK(param != nullptr);
592 return ConvForwardPrep(call, param, out_message);
593 }
594 const auto* param = call->attrs.as<Conv3DAttrs>();
595 ICHECK(param != nullptr);
596 return ConvForwardPrep(call, param, out_message);
597}
598
599Expr PreConvForwardRewrite(const Call& ref_call, const Array<Expr>& new_args,
600 const Message& message) {
601 if (backend::IsOp(ref_call.as<CallNode>(), "nn.conv2d")) {
602 const auto* param = ref_call->attrs.as<Conv2DAttrs>();
603 ICHECK(param != nullptr);
604 return ConvForwardRewrite(ref_call, param, new_args, message);
605 }
606 const auto* param = ref_call->attrs.as<Conv3DAttrs>();
607 ICHECK(param != nullptr);
608 return ConvForwardRewrite(ref_call, param, new_args, message);
609}
610
611RELAY_REGISTER_OP("nn.conv2d").set_attr<FForwardPrep>("FScaleAxisForwardPrep", PreConvForwardPrep);
612
613RELAY_REGISTER_OP("nn.conv2d")
614 .set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", PreConvForwardRewrite);
615
616RELAY_REGISTER_OP("nn.conv3d").set_attr<FForwardPrep>("FScaleAxisForwardPrep", PreConvForwardPrep);
617
618RELAY_REGISTER_OP("nn.conv3d")
619 .set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", PreConvForwardRewrite);
620
621// Dense send out requirement of axis folding.
622Array<Message> DenseForwardPrep(const Call& call, const Message& out_message) {
623 return {Message({1}, false), NullValue<Message>()};
624}
625
626// Dense consumes the scale axis during transformation.
627Expr DenseForwardRewrite(const Call& ref_call, const Array<Expr>& new_args,
628 const Message& message) {
629 const auto* sdata = new_args[0].as<ScaledExprNode>();
630 const auto* sweight = new_args[1].as<ScaledExprNode>();
631 if (sdata == nullptr) return Expr();
632 if (sweight != nullptr) return Expr();
633
634 Expr weight = Multiply(new_args[1], sdata->scale);
635 return Call(ref_call->op, {sdata->value, weight}, ref_call->attrs, ref_call->type_args);
636}
637
638RELAY_REGISTER_OP("nn.dense").set_attr<FForwardPrep>("FScaleAxisForwardPrep", DenseForwardPrep);
639
640RELAY_REGISTER_OP("nn.dense")
641 .set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", DenseForwardRewrite);
642
643Expr ForwardFoldScaleAxis(const Expr& data) {
644 auto message = ForwardPrep().Prepare(data);
645 for (const auto& m : message) {
646 if (m.second.defined()) {
647 // run optimization
648 auto fcontext = [&](const Call& call) -> ObjectRef {
649 auto it = message.find(call.get());
650 if (it != message.end()) {
651 return it->second;
652 } else {
653 return ObjectRef(nullptr);
654 }
655 };
656 return ForwardRewrite(data, "FScaleAxisForwardRewrite", fcontext);
657 }
658 }
659 // no messages - no optimization
660 return data;
661}
662
663//----------------------------------------
664// Implement backward transformations.
665//----------------------------------------
666class BackwardTransformer;
667
668/*!
669 * \brief Preparation function for pass scale backward.
670 * \param call The call node.
671 * \param in_messages Messages from the input containing allowed input scaling and whether
672 * positive scale is required.
673 * \return Message containing the result scaling on axes of the input.
674 */
675using FBackwardPrep = TypedPackedFunc<Message(const Call& call, const Array<Message>& in_messages)>;
676
677using FBackwardTransform =
678 TypedPackedFunc<Expr(const Call& call, const Message& message, const Expr& scale,
679 const BackwardTransformer& transformer)>;
680
681//----------------------------------------------
682// Generic Visitors for FScaleAxisBackward
683//----------------------------------------------
684
685class BackwardPrep : private MixedModeVisitor {
686 public:
687 // The message on each node.
688 std::unordered_map<const Object*, Message> Prepare(const Expr& body) {
689 ref_counter_ = GetExprRefCount(body);
690 this->VisitExpr(body);
691 return std::move(message_);
692 }
693
694 private:
695 // The message on each node.
696 std::unordered_map<const Object*, Message> message_;
697 // reference counter of an internal expr
698 std::unordered_map<const Object*, size_t> ref_counter_;
699 // Visit the expression.
700 void VisitExpr_(const CallNode* call) {
701 ExprVisitor::VisitExpr_(call);
702 static const auto& fprep = Op::GetAttrMap<FBackwardPrep>("FScaleAxisBackwardPrep");
703 auto f = fprep.get(call->op, nullptr);
704 if (f == nullptr) return;
705 auto rit = ref_counter_.find(call);
706 ICHECK(rit != ref_counter_.end());
707 // We only allow propagation of scale backward
708 // if the expression is only referred by a single parent.
709 if (rit->second != 1) return;
710 Array<Message> in_messages = GetInMessages(call);
711 Message out_message = f(GetRef<Call>(call), in_messages);
712 if (out_message.defined()) {
713 message_[call] = out_message;
714 }
715 }
716
717 Array<Message> GetInMessages(const CallNode* call) {
718 Array<Message> in_messages;
719 for (Expr arg : call->args) {
720 auto it = message_.find(arg.get());
721 if (it != message_.end()) {
722 in_messages.push_back(it->second);
723 } else {
724 in_messages.push_back(NullValue<Message>());
725 }
726 }
727 return in_messages;
728 }
729};
730
731/*
732 * Hybrid apporach is used with the transformation
733 * itself is recursive but the traversal is non-recursive
734 */
735class BackwardTransformerNode : public Object, private MixedModeMutator {
736 public:
737 using MixedModeMutator::Mutate;
738 // Run forward transform.
739 Expr Fold(Expr expr) {
740 message_ = BackwardPrep().Prepare(expr);
741 for (const auto& m : message_) {
742 if (m.second.defined()) {
743 // run optimization
744 return this->Mutate(expr);
745 }
746 }
747 // no messages - no optimization
748 return expr;
749 }
750
751 /*!
752 * \brief Transform the expr to consider the scaling.
753 */
754 Expr Transform(const Expr& expr, Message message, Expr scale);
755 /*!
756 * \brief Get the message propogated to the expr.
757 * \param expr The expresison.
758 * \return The message containing the expected axes and whether positive scale is required.
759 */
760 Message GetMessage(const Expr& expr) const {
761 auto it = message_.find(expr.get());
762 if (it != message_.end()) return it->second;
763 return NullValue<Message>();
764 }
765
766 // solver is not serializable.
767 void VisitAttrs(tvm::AttrVisitor* v) {}
768
769 static constexpr const char* _type_key = "relay.fold_scale_axis.FBackwardTransformer";
770 TVM_DECLARE_FINAL_OBJECT_INFO(BackwardTransformerNode, Object);
771
772 private:
773 // Valid axes on each node.
774 std::unordered_map<const Object*, Message> message_;
775 // Override mutation of call.
776 Expr Rewrite_(const CallNode* call_node, const Expr& post) final {
777 return Transform(GetRef<Call>(call_node), NullValue<Message>(), NullValue<Expr>());
778 }
779
780 public:
781 Expr NormalCallTransform(const CallNode* call_node) { return ExprMutator::VisitExpr_(call_node); }
782};
783
784class BackwardTransformer : public ObjectRef {
785 public:
786 BackwardTransformer() {}
787 explicit BackwardTransformer(::tvm::ObjectPtr<::tvm::Object> n) : ObjectRef(n) {}
788 BackwardTransformerNode* operator->() const {
789 return static_cast<BackwardTransformerNode*>(get_mutable());
790 }
791 using ContainerType = BackwardTransformerNode;
792};
793
794/*!
795 * \brief Transform the expr to consider the scaling.
796 *
797 * \param expr The input expression.
798 * \param message The axes to scale.
799 * \param scale The scale applied to the axes.
800 * \return The result of transformation.
801 */
802Expr BackwardTransformerNode::Transform(const Expr& expr, Message message, Expr scale) {
803 if (const CallNode* call_node = expr.as<CallNode>()) {
804 static const auto& ftransform =
805 Op::GetAttrMap<FBackwardTransform>("FScaleAxisBackwardTransform");
806 auto f = ftransform.get(call_node->op, nullptr);
807 const Call call = GetRef<Call>(call_node);
808 // ignore if there is a message
809 if (!message.defined()) {
810 const auto it = memo_.find(call);
811 if (it != memo_.end()) {
812 return it->second;
813 }
814 }
815 Expr new_expr = NullValue<Expr>();
816 if (f != nullptr) {
817 new_expr = f(call, message, scale, GetRef<BackwardTransformer>(this));
818 } else {
819 ICHECK(!message.defined()) << "outstanding scale";
820 new_expr = NormalCallTransform(call.operator->());
821 }
822 memo_[call] = new_expr;
823 return new_expr;
824 } else {
825 ICHECK(!message.defined()) << "outstanding scale";
826 return this->Mutate(expr);
827 }
828}
829
830//----------------------------------------------
831// Per operator defs for FScaleAxisForward
832//----------------------------------------------
833
834// Intermediate operators
835Message ReluBackwardPrep(const Call& call, const Array<Message>& in_messages) {
836 if (in_messages[0].defined()) {
837 return Message(in_messages[0]->axes, true);
838 }
839 return in_messages[0];
840}
841
842Expr ReluBackwardTransform(const Call& call, const Message& message, const Expr& scale,
843 const BackwardTransformer& transformer) {
844 if (!message.defined()) {
845 return transformer->NormalCallTransform(call.operator->());
846 }
847 Expr input = transformer->Transform(call->args[0], message, scale);
848 return Call(call->op, {input}, call->attrs, call->type_args);
849}
850
851RELAY_REGISTER_OP("nn.relu").set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", ReluBackwardPrep);
852
853RELAY_REGISTER_OP("nn.relu").set_attr<FBackwardTransform>("FScaleAxisBackwardTransform",
854 ReluBackwardTransform);
855
856RELAY_REGISTER_OP("nn.leaky_relu")
857 .set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", ReluBackwardPrep);
858
859RELAY_REGISTER_OP("nn.leaky_relu")
860 .set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", ReluBackwardTransform);
861
862// AddSub
863Message AddSubBackwardPrep(const Call& call, const Array<Message>& in_messages) {
864 const auto* tlhs = call->args[0]->type_as<TensorTypeNode>();
865 const auto* trhs = call->args[1]->type_as<TensorTypeNode>();
866 StructuralEqual equal;
867 if (in_messages[0].defined() && MatchBroadcastToLeftAxes(tlhs, trhs, in_messages[0]->axes)) {
868 return in_messages[0];
869 } else if (in_messages[1].defined() &&
870 MatchBroadcastToLeftAxes(trhs, tlhs, in_messages[1]->axes)) {
871 return in_messages[1];
872 } else if (in_messages[0].defined() && in_messages[1].defined() &&
873 equal(in_messages[0]->axes, in_messages[1]->axes) && equal(tlhs->shape, trhs->shape)) {
874 // add of two elements.
875 return in_messages[0];
876 } else {
877 auto res = NullValue<Message>();
878 return res;
879 }
880}
881
882Expr AddSubBackwardTransform(const Call& call, const Message& message, const Expr& scale,
883 const BackwardTransformer& transformer) {
884 const auto* tlhs = call->args[0]->type_as<TensorTypeNode>();
885 const auto* trhs = call->args[1]->type_as<TensorTypeNode>();
886 if (!message.defined()) {
887 return transformer->NormalCallTransform(call.operator->());
888 }
889
890 Message lhs_message = transformer->GetMessage(call->args[0]);
891 Message rhs_message = transformer->GetMessage(call->args[1]);
892 StructuralEqual equal;
893
894 if (lhs_message.defined() && rhs_message.defined()) {
895 ICHECK(equal(lhs_message->axes, rhs_message->axes));
896 ICHECK(equal(message->axes, lhs_message->axes));
897 Expr lhs = transformer->Transform(call->args[0], message, scale);
898 Expr rhs = transformer->Transform(call->args[1], message, scale);
899 return Call(call->op, {lhs, rhs}, call->attrs, call->type_args);
900 } else if (lhs_message.defined()) {
901 ICHECK(equal(message->axes, lhs_message->axes));
902 Expr lhs = transformer->Transform(call->args[0], message, scale);
903 Expr rhs = transformer->Transform(call->args[1], NullValue<Message>(), NullValue<Expr>());
904 Expr rhs_scale = ReshapeOrExpandToMatchAxis(scale, tlhs->shape, message->axes);
905 if (!rhs_scale.defined()) {
906 return transformer->NormalCallTransform(call.operator->());
907 }
908 rhs = Multiply(rhs, rhs_scale);
909 return Call(call->op, {lhs, rhs}, call->attrs, call->type_args);
910 } else if (rhs_message.defined()) {
911 ICHECK(equal(message->axes, rhs_message->axes));
912 Expr lhs = transformer->Transform(call->args[0], NullValue<Message>(), NullValue<Expr>());
913 Expr rhs = transformer->Transform(call->args[1], message, scale);
914 Expr lhs_scale = ReshapeOrExpandToMatchAxis(scale, trhs->shape, message->axes);
915 if (!lhs_scale.defined()) {
916 return transformer->NormalCallTransform(call.operator->());
917 }
918 lhs = Multiply(lhs, lhs_scale);
919 return Call(call->op, {lhs, rhs}, call->attrs, call->type_args);
920 } else {
921 LOG(FATAL) << "outstanding scale";
922 }
923}
924
925RELAY_REGISTER_OP("add").set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", AddSubBackwardPrep);
926
927RELAY_REGISTER_OP("add").set_attr<FBackwardTransform>("FScaleAxisBackwardTransform",
928 AddSubBackwardTransform);
929
930RELAY_REGISTER_OP("subtract").set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", AddSubBackwardPrep);
931
932RELAY_REGISTER_OP("subtract")
933 .set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", AddSubBackwardTransform);
934
935// Producer operators
936// Multiply produces the scale-axis pair.
937Expr MultiplyBackwardTransform(const Call& call, const Message& message, const Expr& scale,
938 const BackwardTransformer& transformer) {
939 ICHECK(!message.defined()) << "outstanding scale";
940 const auto* tlhs = call->args[0]->type_as<TensorTypeNode>();
941 const auto* trhs = call->args[1]->type_as<TensorTypeNode>();
942 Message lhs_message = transformer->GetMessage(call->args[0]);
943 Message rhs_message = transformer->GetMessage(call->args[1]);
944 if (lhs_message.defined()) {
945 ICHECK(lhs_message->axes.defined() && lhs_message->axes.size());
946 // NOTE we won't recursively call mutating on scale part.
947 // since there won't be scale chance within scale part.
948 Expr rhs = call->args[1];
949 if (MatchBroadcastToLeftAxes(tlhs, trhs, lhs_message->axes, &rhs) &&
950 (!lhs_message->require_positive || IsAllPositiveConstant(rhs))) {
951 return transformer->Transform(call->args[0], lhs_message, rhs);
952 }
953 } else if (rhs_message.defined()) {
954 ICHECK(rhs_message->axes.defined() && rhs_message->axes.size());
955 Expr lhs = call->args[0];
956 if (MatchBroadcastToLeftAxes(trhs, tlhs, rhs_message->axes, &lhs) &&
957 (!rhs_message->require_positive || IsAllPositiveConstant(lhs))) {
958 return transformer->Transform(call->args[1], rhs_message, lhs);
959 }
960 }
961 return transformer->NormalCallTransform(call.operator->());
962}
963
964RELAY_REGISTER_OP("multiply")
965 .set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", MultiplyBackwardTransform);
966
967// Consumer operators
968// Conv send out requirement of axis folding.
969template <typename ATTRS>
970Message ConvBackwardPrep(const Call& call, const ATTRS* param, const Array<Message>& in_messages) {
971 ICHECK(param != nullptr);
972 Layout kernel_layout(param->kernel_layout);
973 Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
974 int c_big_axis = out_layout.IndexOf(LayoutAxis::Get('C'));
975 int c_small_axis = out_layout.IndexOf(LayoutAxis::Get('c'));
976
977 ICHECK_GE(c_big_axis, 0);
978 // For now, we only support simple pattern (no folded weight/data)
979 // More general layout can be supported under the current framework.
980 // By using a unified layout transformation.
981 // We only need to change the Prep and Mutate function.
982 //
983 // only handle depthwise or full conv.
984 // TODO(tvm-team) handle grouped conv by reshape + bcast
985 bool is_depthwise_conv = IsDepthwiseConv(call, param, kernel_layout);
986 if (param->groups == 1 || is_depthwise_conv) {
987 auto ko_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('o'));
988 auto ki_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('i'));
989 if ((ko_small_axis < 0 && ki_small_axis < 0 && c_small_axis < 0) || // simple layout
990 (ko_small_axis >= 0 && ki_small_axis >= 0 && c_small_axis >= 0)) { // blocked layout
991 Array<Integer> arr{c_big_axis};
992 if (c_small_axis >= 0) {
993 arr.push_back(c_small_axis);
994 }
995 return Message(arr, false);
996 }
997 }
998 return NullValue<Message>();
999}
1000
1001// Conv consumes the scale axis during transformation.
1002template <typename ATTRS>
1003Expr ConvBackwardTransform(const Call& call, const ATTRS* param, const Message& message,
1004 const Expr& scale, const BackwardTransformer& transformer) {
1005 if (!message.defined()) {
1006 return transformer->NormalCallTransform(call.operator->());
1007 }
1008 ICHECK(param != nullptr);
1009 Layout kernel_layout(param->kernel_layout);
1010 Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
1011 int c_big_axis = out_layout.IndexOf(LayoutAxis::Get('C'));
1012 ICHECK_GE(c_big_axis, 0);
1013 // For now, we only support simple pattern (no folded weight/data)
1014 // TODO(tvm-team) support general data layout
1015 int small_ko_axis = kernel_layout.IndexOf(LayoutAxis::Get('o'));
1016 int small_ki_axis = kernel_layout.IndexOf(LayoutAxis::Get('i'));
1017 int big_ki_axis = kernel_layout.IndexOf(LayoutAxis::Get('I'));
1018 int big_ko_axis = kernel_layout.IndexOf(LayoutAxis::Get('O'));
1019 // Check it must be depthwise or full conv.
1020 bool is_depthwise_conv = IsDepthwiseConv(call, param, kernel_layout);
1021 ICHECK(param->groups == 1 || is_depthwise_conv);
1022 bool is_simple = (small_ko_axis < 0 && small_ki_axis < 0 && big_ki_axis >= 0);
1023 bool is_blocking = (small_ko_axis >= 0 && small_ki_axis >= 0 && big_ki_axis >= 0);
1024 ICHECK(is_simple || is_blocking);
1025
1026 Expr data = transformer->Transform(call->args[0], NullValue<Message>(), NullValue<Expr>());
1027 Expr weight = transformer->Transform(call->args[1], NullValue<Message>(), NullValue<Expr>());
1028 // scale on input for deptwise.
1029 Expr wscale;
1030 if (is_simple) {
1031 wscale = ExpandBiasToMatchAxis(scale, kernel_layout.ndim(), {big_ko_axis});
1032 } else {
1033 wscale = ReshapeToMatchAxis(scale, weight->type_as<TensorTypeNode>()->shape,
1034 {big_ko_axis, small_ko_axis});
1035 if (!wscale.defined()) {
1036 return transformer->NormalCallTransform(call.operator->());
1037 }
1038 }
1039 weight = Multiply(weight, wscale);
1040 return Call(call->op, {data, weight}, call->attrs, call->type_args);
1041}
1042
1043Message PreConvBackwardPrep(const Call& call, const Array<Message>& in_messages) {
1044 if (backend::IsOp(call.as<CallNode>(), "nn.conv2d")) {
1045 const auto* param = call->attrs.as<Conv2DAttrs>();
1046 ICHECK(param != nullptr);
1047 return ConvBackwardPrep(call, param, in_messages);
1048 }
1049 const auto* param = call->attrs.as<Conv3DAttrs>();
1050 ICHECK(param != nullptr);
1051 return ConvBackwardPrep(call, param, in_messages);
1052}
1053
1054Expr PreConvBackwardTransform(const Call& call, const Message& message, const Expr& scale,
1055 const BackwardTransformer& transformer) {
1056 if (backend::IsOp(call.as<CallNode>(), "nn.conv2d")) {
1057 const auto* param = call->attrs.as<Conv2DAttrs>();
1058 ICHECK(param != nullptr);
1059 return ConvBackwardTransform(call, param, message, scale, transformer);
1060 }
1061 const auto* param = call->attrs.as<Conv3DAttrs>();
1062 ICHECK(param != nullptr);
1063 return ConvBackwardTransform(call, param, message, scale, transformer);
1064}
1065
1066RELAY_REGISTER_OP("nn.conv2d")
1067 .set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", PreConvBackwardPrep);
1068
1069RELAY_REGISTER_OP("nn.conv2d")
1070 .set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", PreConvBackwardTransform);
1071
1072RELAY_REGISTER_OP("nn.conv3d")
1073 .set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", PreConvBackwardPrep);
1074
1075RELAY_REGISTER_OP("nn.conv3d")
1076 .set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", PreConvBackwardTransform);
1077
1078Message BiasAddBackwardPrep(const Call& call, const Array<Message>& in_messages) {
1079 const BiasAddAttrs* attrs = call->attrs.as<BiasAddAttrs>();
1080 ICHECK(attrs);
1081 if (in_messages[0].defined() && in_messages[0]->axes.size() == 1 &&
1082 attrs->axis == static_cast<int>(in_messages[0]->axes[0]->value)) {
1083 return in_messages[0];
1084 } else {
1085 return NullValue<Message>();
1086 }
1087}
1088
1089Expr BiasAddBackwardTransform(const Call& call, const Message& message, const Expr& scale,
1090 const BackwardTransformer& transformer) {
1091 if (!message.defined()) {
1092 return transformer->NormalCallTransform(call.operator->());
1093 }
1094 Message lhs_message = transformer->GetMessage(call->args[0]);
1095 Message rhs_message = transformer->GetMessage(call->args[1]);
1096 StructuralEqual equal;
1097
1098 if (lhs_message.defined()) {
1099 ICHECK(equal(message->axes, lhs_message->axes));
1100 Expr lhs = transformer->Transform(call->args[0], message, scale);
1101 Expr rhs = transformer->Transform(call->args[1], NullValue<Message>(), NullValue<Expr>());
1102 rhs = Multiply(rhs, scale);
1103 return Call(call->op, {lhs, rhs}, call->attrs, call->type_args);
1104 } else {
1105 LOG(FATAL) << "outstanding scale";
1106 }
1107}
1108
1109RELAY_REGISTER_OP("nn.bias_add")
1110 .set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", BiasAddBackwardPrep);
1111
1112RELAY_REGISTER_OP("nn.bias_add")
1113 .set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", BiasAddBackwardTransform);
1114
1115// Dense send out requirement of axis folding.
1116Message DenseBackwardPrep(const Call& call, const Array<Message>& in_messages) {
1117 return Message({1}, false);
1118}
1119
1120// Dense consumes the sacle axis during trasformation.
1121Expr DenseBackwardTransform(const Call& call, const Message& message, const Expr& scale,
1122 const BackwardTransformer& transformer) {
1123 if (!message.defined()) {
1124 return transformer->NormalCallTransform(call.operator->());
1125 }
1126 Expr data = transformer->Transform(call->args[0], NullValue<Message>(), NullValue<Expr>());
1127 Expr weight = transformer->Transform(call->args[1], NullValue<Message>(), NullValue<Expr>());
1128 Expr wscale = ExpandBiasToMatchAxis(scale, 2, {0});
1129 weight = Multiply(weight, wscale);
1130 return Call(call->op, {data, weight}, call->attrs, call->type_args);
1131}
1132
1133RELAY_REGISTER_OP("nn.dense").set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", DenseBackwardPrep);
1134
1135RELAY_REGISTER_OP("nn.dense")
1136 .set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", DenseBackwardTransform);
1137
1138Expr BackwardFoldScaleAxis(const Expr& data) {
1139 return make_object<BackwardTransformerNode>()->Fold(data);
1140}
1141
1142} // namespace fold_scale_axis
1143
1144namespace transform {
1145
1146Pass ForwardFoldScaleAxis() {
1147 runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
1148 [=](Function f, IRModule m, PassContext pc) {
1149 return Downcast<Function>(relay::fold_scale_axis::ForwardFoldScaleAxis(f));
1150 };
1151 return CreateFunctionPass(pass_func, 3, "ForwardFoldScaleAxis", {"InferType"});
1152}
1153
1154TVM_REGISTER_GLOBAL("relay._transform.ForwardFoldScaleAxis").set_body_typed(ForwardFoldScaleAxis);
1155
1156Pass BackwardFoldScaleAxis() {
1157 runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
1158 [=](Function f, IRModule m, PassContext pc) {
1159 return Downcast<Function>(relay::fold_scale_axis::BackwardFoldScaleAxis(f));
1160 };
1161 return CreateFunctionPass(pass_func, 3, "BackwardFoldScaleAxis", {"InferType"});
1162}
1163
1164TVM_REGISTER_GLOBAL("relay._transform.BackwardFoldScaleAxis").set_body_typed(BackwardFoldScaleAxis);
1165
1166Pass FoldScaleAxis() {
1167 // FoldScaleAxis pass contains the following three passes. Therefore, we can
1168 // register it as a sequential pass.
1169 Pass pass = Sequential({BackwardFoldScaleAxis(), ForwardFoldScaleAxis(), FoldConstant()},
1170 "FoldScaleAxis");
1171 return pass;
1172}
1173
1174TVM_REGISTER_GLOBAL("relay._transform.FoldScaleAxis").set_body_typed(FoldScaleAxis);
1175
1176} // namespace transform
1177
1178} // namespace relay
1179} // namespace tvm
1180