1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file fold_scale_axis.cc |
22 | * |
23 | * \brief Fold axis scaling into weights of |
24 | * conv/dense operators. |
25 | */ |
26 | #include <tvm/relay/analysis.h> |
27 | #include <tvm/relay/attrs/nn.h> |
28 | #include <tvm/relay/expr_functor.h> |
29 | #include <tvm/relay/transform.h> |
30 | #include <tvm/tir/data_layout.h> |
31 | |
32 | #include "../backend/utils.h" |
33 | #include "../op/tensor/transform.h" |
34 | #include "pass_utils.h" |
35 | #include "pattern_utils.h" |
36 | |
37 | namespace tvm { |
38 | namespace relay { |
39 | /*! |
40 | * \brief namespace of fold scale axis |
41 | * |
42 | * Use namespace to reduce potential naming conflict. |
43 | */ |
44 | |
45 | namespace fold_scale_axis { |
46 | |
47 | using runtime::TypedPackedFunc; |
48 | |
49 | // FoldScaleAxis algorithm: |
50 | // |
51 | // The general idea is to transform Expr to tuple of |
52 | // (value, axes, scale), where the final result satisfies: |
53 | // |
54 | // result = value |
55 | // for i, k in enumerate(axes): |
56 | // k-th dimension of result *= i-th dimension of scale |
57 | // |
58 | // Then we can propagate this signal along and fold the scale if necessary. |
59 | // However, it is possible that certain scale may never be consumed |
60 | // if there is no dense/conv2d that follows multiplication. |
61 | // |
62 | // In order to make sure all the scale we sent out can be consumed eventually, |
63 | // we run a backward "preparation phase", which propagates the demand |
64 | // of the potential axes scaling back to its input. |
65 | // |
66 | // Forward folding process is done in two steps: |
67 | // - Prepare phase: backward propagation of demand. |
68 | // - Transform phase: forward transformation, |
69 | // |
70 | // Similarly, backward folding process is done in two steps: |
71 | // - Prepare phase: forward propagation of demand. |
72 | // - Transform phase: transformation by push down the axes scale signal to inputs. |
73 | // |
74 | |
75 | /*! |
76 | * \brief sorted array axis, can also be nullptr. |
77 | * |
78 | * nullptr means no scaling request can be done. |
79 | */ |
80 | using AxesSet = Array<Integer>; |
81 | |
82 | class Message; |
83 | |
84 | /*! |
85 | * \brief Message propogated during the prepare phase. |
86 | */ |
87 | class MessageNode : public RelayNode { |
88 | public: |
89 | /*! \brief Axes for scaling */ |
90 | AxesSet axes; |
91 | /*! |
92 | * \brief Whether folding requires the scale to be positive constant. This is necessary if some |
93 | * operators (e.g. Relu) is present. |
94 | */ |
95 | bool require_positive; |
96 | |
97 | static constexpr const char* _type_key = "relay.pass.fold_scale_axis.Message" ; |
98 | TVM_DECLARE_FINAL_OBJECT_INFO(MessageNode, RelayNode); |
99 | }; |
100 | |
101 | class Message : public ObjectRef { |
102 | public: |
103 | /*! |
104 | * \brief The constructor |
105 | * \param axes Axes for scaling |
106 | * \param require_positive If folding requires the scales to be positive |
107 | * values. |
108 | */ |
109 | Message(const AxesSet& axes, bool require_positive); |
110 | |
111 | TVM_DEFINE_OBJECT_REF_METHODS(Message, ObjectRef, MessageNode); |
112 | }; |
113 | |
114 | Message::Message(const AxesSet& axes, bool require_positive) { |
115 | auto n = make_object<MessageNode>(); |
116 | n->axes = axes; |
117 | n->require_positive = require_positive; |
118 | data_ = std::move(n); |
119 | } |
120 | |
121 | /*! |
122 | * \brief Merge two axis set together by taking |
123 | * intersection. |
124 | * |
125 | * \note The axes in a AxesSet should be sorted. |
126 | * |
127 | * \param lhs The left axis. |
128 | * \param rhs The right axis. |
129 | * \return The result of the inersection. |
130 | */ |
131 | AxesSet Intersect(const AxesSet& lhs, const AxesSet& rhs) { |
132 | if (!lhs.defined()) return lhs; |
133 | if (!rhs.defined()) return rhs; |
134 | // This code relies on axes in a AxesSet to be sorted. |
135 | AxesSet ret; |
136 | size_t i = 0, j = 0; |
137 | while (i < lhs.size() && j < rhs.size()) { |
138 | if (lhs[i]->value < rhs[j]->value) { |
139 | ++i; |
140 | } else if (lhs[i]->value > rhs[j]->value) { |
141 | ++j; |
142 | } else { |
143 | ret.push_back(lhs[i]); |
144 | ++i; |
145 | ++j; |
146 | } |
147 | } |
148 | return ret; |
149 | } |
150 | |
151 | /*! |
152 | * \brief Merge two messages together by taking intersection. |
153 | * |
154 | * \param lhs The lhs message. |
155 | * \param rhs The rhs message. |
156 | * \return The result of intersection. |
157 | */ |
158 | Message Intersect(const Message& lhs, const Message& rhs) { |
159 | if (!lhs.defined()) return lhs; |
160 | if (!rhs.defined()) return rhs; |
161 | auto axes = Intersect(lhs->axes, rhs->axes); |
162 | return Message(axes, lhs->require_positive || rhs->require_positive); |
163 | } |
164 | |
165 | /*! |
166 | * \brief Preparation function for pass scale forward. |
167 | * \param call The call node. |
168 | * \param out_message Message from the output containing possible scaling on axes and whether |
169 | * positive scale is required. |
170 | * \return The message containing the result scaling on axes of the input. |
171 | */ |
172 | using FForwardPrep = |
173 | runtime::TypedPackedFunc<Array<Message>(const Call& call, const Message& out_message)>; |
174 | |
175 | /*! \brief Axis scale tuple. */ |
176 | class ScaledExprNode : public TempExprNode { |
177 | public: |
178 | /*! \brief The value */ |
179 | Expr value; |
180 | /*! \brief The axes to scale, can be nullptr(means no-scaling) */ |
181 | AxesSet axes = NullValue<AxesSet>(); |
182 | /*! \brief The scaling factor */ |
183 | Expr scale = NullValue<Expr>(); |
184 | |
185 | Expr Realize() const final { |
186 | ICHECK(!axes.defined()) << "outstanding scale" ; |
187 | return value; |
188 | } |
189 | |
190 | void VisitAttrs(AttrVisitor* v) { |
191 | v->Visit("value" , &value); |
192 | v->Visit("axes" , &axes); |
193 | v->Visit("scale" , &scale); |
194 | } |
195 | |
196 | static constexpr const char* _type_key = "relay.fold_scale_axis.ScaledExpr" ; |
197 | TVM_DECLARE_FINAL_OBJECT_INFO(ScaledExprNode, TempExprNode); |
198 | }; |
199 | |
200 | using FForwardRewrite = TypedPackedFunc<Expr(const Call& ref_call, const Array<Expr>& new_args, |
201 | const Message& message)>; |
202 | |
203 | //---------------------------------------------- |
204 | // Generic Visitors for FScaleAxisForward |
205 | //---------------------------------------------- |
206 | class ForwardPrep : private MixedModeVisitor { |
207 | public: |
208 | std::unordered_map<const Object*, Message> Prepare(const Expr& body) { |
209 | this->Update(body, NullValue<Message>()); |
210 | this->VisitExpr(body); |
211 | // flist is added in the Post-DFS order |
212 | // which is a special case of topological order. |
213 | // We reversely traverse the list to invoke the lazy functions. |
214 | // This act like a backprop of valid scale axis messages |
215 | for (auto it = flist_.rbegin(); it != flist_.rend(); ++it) { |
216 | (*it)(); |
217 | } |
218 | // return the created message; |
219 | return std::move(message_); |
220 | } |
221 | |
222 | private: |
223 | // The invoke list |
224 | std::vector<std::function<void()>> flist_; |
225 | // The message on each node. |
226 | std::unordered_map<const Object*, Message> message_; |
227 | // Update the message stored at node. |
228 | void Update(const Expr& node, const Message& message) { |
229 | // We run intersection of messages: |
230 | // |
231 | // %y = multiply(%x, %scale) |
232 | // %z1 = conv2d(%y, %w) |
233 | // %z2 = exp(%y) |
234 | // |
235 | // Consider the above code example, |
236 | // because %z2 will propagate null to %y, |
237 | // the AxesSet on %y is also null, |
238 | // and the forward folding won't be triggered. |
239 | const Object* key = node.get(); |
240 | if (message_.count(key)) { |
241 | message_[key] = Intersect(message_[key], message); |
242 | } else { |
243 | message_[key] = message; |
244 | } |
245 | } |
246 | |
247 | // We intended the following overrides on implementations from ExprVisitor. |
248 | using MixedModeVisitor::VisitExpr_; |
249 | |
250 | // Visitor pattern override. |
251 | void VisitExpr_(const TupleGetItemNode* op) final { MixedModeVisitor::VisitExpr_(op); } |
252 | |
253 | void VisitExpr_(const LetNode* op) final { |
254 | ExprVisitor::VisitExpr_(op); |
255 | // do pass through condition |
256 | // by assigning NullValue<Message> |
257 | // it means fuse signal cannot pass |
258 | // through into these subexpressions. |
259 | auto flazy = [this, op]() { |
260 | this->Update(op->value, NullValue<Message>()); |
261 | this->Update(op->body, NullValue<Message>()); |
262 | }; |
263 | flist_.push_back(flazy); |
264 | } |
265 | |
266 | void VisitExpr_(const FunctionNode* op) final { |
267 | ExprVisitor::VisitExpr_(op); |
268 | auto flazy = [this, op] { this->Update(op->body, NullValue<Message>()); }; |
269 | flist_.push_back(flazy); |
270 | } |
271 | |
272 | void VisitExpr_(const CallNode* call) final { |
273 | ExprVisitor::VisitExpr_(call); |
274 | // function to be lazily invoked |
275 | auto flazy = [this, call]() { |
276 | static const auto& fprep = Op::GetAttrMap<FForwardPrep>("FScaleAxisForwardPrep" ); |
277 | // find the message send to this node. |
278 | auto it = message_.find(call); |
279 | Message out_message; |
280 | if (it != message_.end()) { |
281 | out_message = it->second; |
282 | } else { |
283 | out_message = NullValue<Message>(); |
284 | } |
285 | // pass the message back to all the children it references. |
286 | auto f = fprep.get(call->op, nullptr); |
287 | if (f != nullptr) { |
288 | Array<Message> in_messages = f(GetRef<Call>(call), out_message); |
289 | ICHECK_EQ(in_messages.size(), call->args.size()); |
290 | for (size_t i = 0; i < call->args.size(); ++i) { |
291 | this->Update(call->args[i], in_messages[i]); |
292 | } |
293 | } else { |
294 | for (size_t i = 0; i < call->args.size(); ++i) { |
295 | this->Update(call->args[i], NullValue<Message>()); |
296 | } |
297 | } |
298 | }; |
299 | flist_.push_back(flazy); |
300 | } |
301 | |
302 | void VisitExpr_(const TupleNode* op) final { |
303 | ExprVisitor::VisitExpr_(op); |
304 | // do not support pass scale through tuple for now. |
305 | auto flazy = [this, op]() { |
306 | for (const Expr& field : op->fields) { |
307 | this->Update(field, NullValue<Message>()); |
308 | } |
309 | }; |
310 | flist_.push_back(flazy); |
311 | } |
312 | |
313 | void VisitExpr_(const IfNode* op) final { |
314 | ExprVisitor::VisitExpr_(op); |
315 | // do pass through condition |
316 | // by assigning NullValue<Message> |
317 | // it means fuse signal cannot pass |
318 | // through into these subexpressions. |
319 | auto flazy = [this, op]() { |
320 | this->Update(op->cond, NullValue<Message>()); |
321 | this->Update(op->true_branch, NullValue<Message>()); |
322 | this->Update(op->false_branch, NullValue<Message>()); |
323 | }; |
324 | flist_.push_back(flazy); |
325 | } |
326 | }; |
327 | |
328 | static bool IsIntInArray(const Array<Integer>& axis, int v) { |
329 | for (size_t i = 0; i < axis.size(); i++) { |
330 | if (axis[i] == v) return true; |
331 | } |
332 | return false; |
333 | } |
334 | |
335 | static Expr ReshapeToMatchAxis(Expr scale, const Array<PrimExpr>& shape, |
336 | const Array<Integer>& axis) { |
337 | Array<Integer> arr; |
338 | for (size_t i = 0; i < shape.size(); i++) { |
339 | if (IsIntInArray(axis, i)) { |
340 | auto node = shape[i].as<IntImmNode>(); |
341 | if (!node) { |
342 | // if the shape is not a constant, use normal transform |
343 | return Expr(); |
344 | } |
345 | arr.push_back(node->value); |
346 | } else { |
347 | arr.push_back(1); |
348 | } |
349 | } |
350 | return MakeReshape(scale, std::move(arr)); |
351 | } |
352 | |
353 | // if only one axis, use expand dim. Else, use reshape |
354 | static Expr ReshapeOrExpandToMatchAxis(Expr scale, const Array<PrimExpr>& shape, |
355 | const Array<Integer>& axis) { |
356 | if (axis.size() > 1) { |
357 | return ReshapeToMatchAxis(scale, shape, axis); |
358 | } else { |
359 | return ExpandBiasToMatchAxis(scale, shape.size(), axis); |
360 | } |
361 | } |
362 | |
363 | //---------------------------------------------- |
364 | // Per operator defs for FScaleAxisForward |
365 | //---------------------------------------------- |
366 | |
367 | // Intermediate operators |
368 | Array<Message> ReluForwardPrep(const Call& call, const Message& out_message) { |
369 | if (out_message.defined()) { |
370 | return {Message(out_message->axes, true)}; |
371 | } |
372 | return {out_message}; |
373 | } |
374 | |
375 | Expr ReluForwardRewrite(const Call& ref_call, const Array<Expr>& new_args, const Message& message) { |
376 | const auto* input = new_args[0].as<ScaledExprNode>(); |
377 | if (input == nullptr) return Expr(nullptr); |
378 | // return transformed conv2d |
379 | auto rnode = make_object<ScaledExprNode>(); |
380 | rnode->value = Call(ref_call->op, {input->value}, ref_call->attrs, ref_call->type_args); |
381 | rnode->scale = input->scale; |
382 | rnode->axes = input->axes; |
383 | return Expr(rnode); |
384 | } |
385 | |
386 | RELAY_REGISTER_OP("nn.relu" ).set_attr<FForwardPrep>("FScaleAxisForwardPrep" , ReluForwardPrep); |
387 | |
388 | RELAY_REGISTER_OP("nn.relu" ).set_attr<FForwardRewrite>("FScaleAxisForwardRewrite" , |
389 | ReluForwardRewrite); |
390 | |
391 | RELAY_REGISTER_OP("nn.leaky_relu" ).set_attr<FForwardPrep>("FScaleAxisForwardPrep" , ReluForwardPrep); |
392 | |
393 | RELAY_REGISTER_OP("nn.leaky_relu" ) |
394 | .set_attr<FForwardRewrite>("FScaleAxisForwardRewrite" , ReluForwardRewrite); |
395 | |
396 | // AddSub |
397 | Array<Message> AddSubForwardPrep(const Call& call, const Message& out_message) { |
398 | const auto* tlhs = call->args[0]->type_as<TensorTypeNode>(); |
399 | const auto* trhs = call->args[1]->type_as<TensorTypeNode>(); |
400 | auto none = NullValue<Message>(); |
401 | if (out_message.defined()) { |
402 | if (MatchBroadcastToLeftAxes(tlhs, trhs, out_message->axes)) { |
403 | return {out_message, none}; |
404 | } else if (MatchBroadcastToLeftAxes(trhs, tlhs, out_message->axes)) { |
405 | return {none, out_message}; |
406 | } |
407 | } |
408 | return {none, none}; |
409 | } |
410 | |
411 | Expr AddSubForwardRewrite(const Call& ref_call, const Array<Expr>& new_args, |
412 | const Message& message) { |
413 | const auto* slhs = new_args[0].as<ScaledExprNode>(); |
414 | const auto* srhs = new_args[1].as<ScaledExprNode>(); |
415 | if (!slhs && !srhs) return Expr(); |
416 | const auto* tlhs = ref_call->args[0]->type_as<TensorTypeNode>(); |
417 | const auto* trhs = ref_call->args[1]->type_as<TensorTypeNode>(); |
418 | auto rnode = make_object<ScaledExprNode>(); |
419 | |
420 | if (slhs != nullptr) { |
421 | ICHECK(srhs == nullptr); |
422 | ICHECK(MatchBroadcastToLeftAxes(tlhs, trhs, slhs->axes)); |
423 | Expr scale = ReshapeOrExpandToMatchAxis(slhs->scale, tlhs->shape, slhs->axes); |
424 | if (!scale.defined()) { |
425 | return Expr(); |
426 | } |
427 | Expr rhs = Divide(new_args[1], scale); |
428 | rnode->value = Call(ref_call->op, {slhs->value, rhs}, ref_call->attrs, ref_call->type_args); |
429 | rnode->scale = slhs->scale; |
430 | rnode->axes = slhs->axes; |
431 | } else { |
432 | ICHECK(srhs != nullptr); |
433 | ICHECK(MatchBroadcastToLeftAxes(trhs, tlhs, srhs->axes)); |
434 | Expr scale = ReshapeOrExpandToMatchAxis(srhs->scale, trhs->shape, srhs->axes); |
435 | if (!scale.defined()) { |
436 | return Expr(); |
437 | } |
438 | Expr lhs = Divide(new_args[0], scale); |
439 | rnode->value = Call(ref_call->op, {lhs, srhs->value}, ref_call->attrs, ref_call->type_args); |
440 | rnode->scale = srhs->scale; |
441 | rnode->axes = srhs->axes; |
442 | } |
443 | return Expr(rnode); |
444 | } |
445 | |
446 | RELAY_REGISTER_OP("add" ).set_attr<FForwardPrep>("FScaleAxisForwardPrep" , AddSubForwardPrep); |
447 | |
448 | RELAY_REGISTER_OP("add" ).set_attr<FForwardRewrite>("FScaleAxisForwardRewrite" , |
449 | AddSubForwardRewrite); |
450 | |
451 | RELAY_REGISTER_OP("subtract" ).set_attr<FForwardPrep>("FScaleAxisForwardPrep" , AddSubForwardPrep); |
452 | |
453 | RELAY_REGISTER_OP("subtract" ) |
454 | .set_attr<FForwardRewrite>("FScaleAxisForwardRewrite" , AddSubForwardRewrite); |
455 | |
456 | // Producer operators |
457 | // Multiply produces the scale-axis pair. |
458 | Expr MultiplyForwardRewrite(const Call& ref_call, const Array<Expr>& new_args, |
459 | const Message& message) { |
460 | if (!message.defined()) return Expr(); |
461 | const auto& expected_out_axes = message->axes; |
462 | ICHECK(expected_out_axes.defined() && expected_out_axes.size()); |
463 | // TODO(tvm-team) allow same axes accumulation |
464 | // not as important because it is less common in nn. |
465 | const auto* slhs = new_args[0].as<ScaledExprNode>(); |
466 | const auto* srhs = new_args[1].as<ScaledExprNode>(); |
467 | ICHECK(!slhs && !srhs); |
468 | |
469 | const auto* tlhs = ref_call->args[0]->type_as<TensorTypeNode>(); |
470 | const auto* trhs = ref_call->args[1]->type_as<TensorTypeNode>(); |
471 | Expr lhs = new_args[0]; |
472 | Expr rhs = new_args[1]; |
473 | auto rnode = make_object<ScaledExprNode>(); |
474 | |
475 | if (MatchBroadcastToLeftAxes(tlhs, trhs, expected_out_axes, &rhs) && |
476 | (!message->require_positive || IsAllPositiveConstant(rhs))) { |
477 | rnode->value = lhs; |
478 | rnode->scale = rhs; |
479 | rnode->axes = expected_out_axes; |
480 | return Expr(rnode); |
481 | } else if (MatchBroadcastToLeftAxes(trhs, tlhs, expected_out_axes, &lhs) && |
482 | (!message->require_positive || IsAllPositiveConstant(lhs))) { |
483 | rnode->value = rhs; |
484 | rnode->scale = lhs; |
485 | rnode->axes = expected_out_axes; |
486 | return Expr(rnode); |
487 | } else { |
488 | return Expr(); |
489 | } |
490 | } |
491 | |
492 | RELAY_REGISTER_OP("multiply" ) |
493 | .set_attr<FForwardRewrite>("FScaleAxisForwardRewrite" , MultiplyForwardRewrite); |
494 | |
495 | // Consumer operators |
496 | // Conv send out requirement of axis folding. |
497 | template <typename ATTRS> |
498 | Array<Message> ConvForwardPrep(const Call& call, const ATTRS* param, const Message& out_message) { |
499 | // TODO(tvm-team) support general data layout |
500 | // by transforming weight |
501 | ICHECK(param != nullptr); |
502 | Layout data_layout(param->data_layout); |
503 | Layout kernel_layout(param->kernel_layout); |
504 | int c_big_axis = data_layout.IndexOf(LayoutAxis::Get('C')); |
505 | int c_small_axis = data_layout.IndexOf(LayoutAxis::Get('c')); |
506 | |
507 | ICHECK_GE(c_big_axis, 0); |
508 | Message none = NullValue<Message>(); |
509 | // For now, we only support simple pattern (no folded weight/data) |
510 | // More general layout can be supported under the current framework. |
511 | // By using a unified layout transformation. |
512 | // We only need to change the Prep and Mutate function. |
513 | // |
514 | // only handle depthwise or full conv2d. |
515 | // TODO(tvm-team) handle grouped conv by reshape + bcast |
516 | bool is_depthwise_conv = IsDepthwiseConv(call, param, kernel_layout); |
517 | if (param->groups == 1 || is_depthwise_conv) { |
518 | auto ko_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('o')); |
519 | auto ki_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('i')); |
520 | if ((ko_small_axis < 0 && ki_small_axis < 0 && c_small_axis < 0) || // simple layout |
521 | (ko_small_axis >= 0 && ki_small_axis >= 0 && c_small_axis >= 0)) { // blocked layout |
522 | Array<Integer> arr{c_big_axis}; |
523 | if (c_small_axis >= 0) { |
524 | arr.push_back(c_small_axis); |
525 | } |
526 | return {Message(arr, false), none}; |
527 | } |
528 | } |
529 | return {none, none}; |
530 | } |
531 | |
532 | // Conv2D consumes the scale axis during transformation. |
533 | template <typename ATTRS> |
534 | Expr ConvForwardRewrite(const Call& ref_call, const ATTRS* param, const Array<Expr>& new_args, |
535 | const Message& message) { |
536 | // if data do not have scale, normal transform path. |
537 | const auto* sdata = new_args[0].as<ScaledExprNode>(); |
538 | const auto* sweight = new_args[1].as<ScaledExprNode>(); |
539 | if (sdata == nullptr) return Expr(); |
540 | if (sweight != nullptr) return Expr(); |
541 | ICHECK(param != nullptr); |
542 | Layout data_layout(param->data_layout); |
543 | Layout kernel_layout(param->kernel_layout); |
544 | int c_big_axis = data_layout.IndexOf(LayoutAxis::Get('C')); |
545 | ICHECK_GE(c_big_axis, 0); |
546 | int small_ko_axis = kernel_layout.IndexOf(LayoutAxis::Get('o')); |
547 | int small_ki_axis = kernel_layout.IndexOf(LayoutAxis::Get('i')); |
548 | int big_ki_axis = kernel_layout.IndexOf(LayoutAxis::Get('I')); |
549 | int big_ko_axis = kernel_layout.IndexOf(LayoutAxis::Get('O')); |
550 | |
551 | bool is_simple = (small_ko_axis < 0 && small_ki_axis < 0 && big_ki_axis >= 0); |
552 | bool is_blocking = (small_ko_axis >= 0 && small_ki_axis >= 0 && big_ki_axis >= 0); |
553 | ICHECK(is_simple || is_blocking); |
554 | |
555 | // Check it must be depthwise or full conv2d. |
556 | bool is_depthwise_conv = IsDepthwiseConv(ref_call, param, kernel_layout); |
557 | ICHECK(param->groups == 1 || is_depthwise_conv); |
558 | |
559 | Expr weight = new_args[1]; |
560 | |
561 | // match the ic_axis |
562 | if (is_depthwise_conv) { |
563 | if (is_simple) { |
564 | Expr scale = ExpandBiasToMatchAxis(sdata->scale, kernel_layout.ndim(), {big_ko_axis}); |
565 | weight = Multiply(weight, scale); |
566 | } else { |
567 | weight = Multiply(weight, |
568 | ReshapeToMatchAxis(sdata->scale, weight->type_as<TensorTypeNode>()->shape, |
569 | {big_ko_axis, small_ko_axis})); |
570 | if (!weight.defined()) return Expr(); |
571 | } |
572 | |
573 | } else { |
574 | if (is_simple) { |
575 | Expr scale = ExpandBiasToMatchAxis(sdata->scale, kernel_layout.ndim(), {big_ki_axis}); |
576 | weight = Multiply(weight, scale); |
577 | } else { |
578 | weight = Multiply(weight, |
579 | ReshapeToMatchAxis(sdata->scale, weight->type_as<TensorTypeNode>()->shape, |
580 | {big_ki_axis, small_ki_axis})); |
581 | if (!weight.defined()) return Expr(); |
582 | } |
583 | } |
584 | // return transformed conv |
585 | return Call(ref_call->op, {sdata->value, weight}, ref_call->attrs, ref_call->type_args); |
586 | } |
587 | |
588 | Array<Message> PreConvForwardPrep(const Call& call, const Message& out_message) { |
589 | if (backend::IsOp(call.as<CallNode>(), "nn.conv2d" )) { |
590 | const auto* param = call->attrs.as<Conv2DAttrs>(); |
591 | ICHECK(param != nullptr); |
592 | return ConvForwardPrep(call, param, out_message); |
593 | } |
594 | const auto* param = call->attrs.as<Conv3DAttrs>(); |
595 | ICHECK(param != nullptr); |
596 | return ConvForwardPrep(call, param, out_message); |
597 | } |
598 | |
599 | Expr PreConvForwardRewrite(const Call& ref_call, const Array<Expr>& new_args, |
600 | const Message& message) { |
601 | if (backend::IsOp(ref_call.as<CallNode>(), "nn.conv2d" )) { |
602 | const auto* param = ref_call->attrs.as<Conv2DAttrs>(); |
603 | ICHECK(param != nullptr); |
604 | return ConvForwardRewrite(ref_call, param, new_args, message); |
605 | } |
606 | const auto* param = ref_call->attrs.as<Conv3DAttrs>(); |
607 | ICHECK(param != nullptr); |
608 | return ConvForwardRewrite(ref_call, param, new_args, message); |
609 | } |
610 | |
611 | RELAY_REGISTER_OP("nn.conv2d" ).set_attr<FForwardPrep>("FScaleAxisForwardPrep" , PreConvForwardPrep); |
612 | |
613 | RELAY_REGISTER_OP("nn.conv2d" ) |
614 | .set_attr<FForwardRewrite>("FScaleAxisForwardRewrite" , PreConvForwardRewrite); |
615 | |
616 | RELAY_REGISTER_OP("nn.conv3d" ).set_attr<FForwardPrep>("FScaleAxisForwardPrep" , PreConvForwardPrep); |
617 | |
618 | RELAY_REGISTER_OP("nn.conv3d" ) |
619 | .set_attr<FForwardRewrite>("FScaleAxisForwardRewrite" , PreConvForwardRewrite); |
620 | |
621 | // Dense send out requirement of axis folding. |
622 | Array<Message> DenseForwardPrep(const Call& call, const Message& out_message) { |
623 | return {Message({1}, false), NullValue<Message>()}; |
624 | } |
625 | |
626 | // Dense consumes the scale axis during transformation. |
627 | Expr DenseForwardRewrite(const Call& ref_call, const Array<Expr>& new_args, |
628 | const Message& message) { |
629 | const auto* sdata = new_args[0].as<ScaledExprNode>(); |
630 | const auto* sweight = new_args[1].as<ScaledExprNode>(); |
631 | if (sdata == nullptr) return Expr(); |
632 | if (sweight != nullptr) return Expr(); |
633 | |
634 | Expr weight = Multiply(new_args[1], sdata->scale); |
635 | return Call(ref_call->op, {sdata->value, weight}, ref_call->attrs, ref_call->type_args); |
636 | } |
637 | |
638 | RELAY_REGISTER_OP("nn.dense" ).set_attr<FForwardPrep>("FScaleAxisForwardPrep" , DenseForwardPrep); |
639 | |
640 | RELAY_REGISTER_OP("nn.dense" ) |
641 | .set_attr<FForwardRewrite>("FScaleAxisForwardRewrite" , DenseForwardRewrite); |
642 | |
643 | Expr ForwardFoldScaleAxis(const Expr& data) { |
644 | auto message = ForwardPrep().Prepare(data); |
645 | for (const auto& m : message) { |
646 | if (m.second.defined()) { |
647 | // run optimization |
648 | auto fcontext = [&](const Call& call) -> ObjectRef { |
649 | auto it = message.find(call.get()); |
650 | if (it != message.end()) { |
651 | return it->second; |
652 | } else { |
653 | return ObjectRef(nullptr); |
654 | } |
655 | }; |
656 | return ForwardRewrite(data, "FScaleAxisForwardRewrite" , fcontext); |
657 | } |
658 | } |
659 | // no messages - no optimization |
660 | return data; |
661 | } |
662 | |
663 | //---------------------------------------- |
664 | // Implement backward transformations. |
665 | //---------------------------------------- |
666 | class BackwardTransformer; |
667 | |
668 | /*! |
669 | * \brief Preparation function for pass scale backward. |
670 | * \param call The call node. |
671 | * \param in_messages Messages from the input containing allowed input scaling and whether |
672 | * positive scale is required. |
673 | * \return Message containing the result scaling on axes of the input. |
674 | */ |
675 | using FBackwardPrep = TypedPackedFunc<Message(const Call& call, const Array<Message>& in_messages)>; |
676 | |
677 | using FBackwardTransform = |
678 | TypedPackedFunc<Expr(const Call& call, const Message& message, const Expr& scale, |
679 | const BackwardTransformer& transformer)>; |
680 | |
681 | //---------------------------------------------- |
682 | // Generic Visitors for FScaleAxisBackward |
683 | //---------------------------------------------- |
684 | |
685 | class BackwardPrep : private MixedModeVisitor { |
686 | public: |
687 | // The message on each node. |
688 | std::unordered_map<const Object*, Message> Prepare(const Expr& body) { |
689 | ref_counter_ = GetExprRefCount(body); |
690 | this->VisitExpr(body); |
691 | return std::move(message_); |
692 | } |
693 | |
694 | private: |
695 | // The message on each node. |
696 | std::unordered_map<const Object*, Message> message_; |
697 | // reference counter of an internal expr |
698 | std::unordered_map<const Object*, size_t> ref_counter_; |
699 | // Visit the expression. |
700 | void VisitExpr_(const CallNode* call) { |
701 | ExprVisitor::VisitExpr_(call); |
702 | static const auto& fprep = Op::GetAttrMap<FBackwardPrep>("FScaleAxisBackwardPrep" ); |
703 | auto f = fprep.get(call->op, nullptr); |
704 | if (f == nullptr) return; |
705 | auto rit = ref_counter_.find(call); |
706 | ICHECK(rit != ref_counter_.end()); |
707 | // We only allow propagation of scale backward |
708 | // if the expression is only referred by a single parent. |
709 | if (rit->second != 1) return; |
710 | Array<Message> in_messages = GetInMessages(call); |
711 | Message out_message = f(GetRef<Call>(call), in_messages); |
712 | if (out_message.defined()) { |
713 | message_[call] = out_message; |
714 | } |
715 | } |
716 | |
717 | Array<Message> GetInMessages(const CallNode* call) { |
718 | Array<Message> in_messages; |
719 | for (Expr arg : call->args) { |
720 | auto it = message_.find(arg.get()); |
721 | if (it != message_.end()) { |
722 | in_messages.push_back(it->second); |
723 | } else { |
724 | in_messages.push_back(NullValue<Message>()); |
725 | } |
726 | } |
727 | return in_messages; |
728 | } |
729 | }; |
730 | |
731 | /* |
732 | * Hybrid apporach is used with the transformation |
733 | * itself is recursive but the traversal is non-recursive |
734 | */ |
735 | class BackwardTransformerNode : public Object, private MixedModeMutator { |
736 | public: |
737 | using MixedModeMutator::Mutate; |
738 | // Run forward transform. |
739 | Expr Fold(Expr expr) { |
740 | message_ = BackwardPrep().Prepare(expr); |
741 | for (const auto& m : message_) { |
742 | if (m.second.defined()) { |
743 | // run optimization |
744 | return this->Mutate(expr); |
745 | } |
746 | } |
747 | // no messages - no optimization |
748 | return expr; |
749 | } |
750 | |
751 | /*! |
752 | * \brief Transform the expr to consider the scaling. |
753 | */ |
754 | Expr Transform(const Expr& expr, Message message, Expr scale); |
755 | /*! |
756 | * \brief Get the message propogated to the expr. |
757 | * \param expr The expresison. |
758 | * \return The message containing the expected axes and whether positive scale is required. |
759 | */ |
760 | Message GetMessage(const Expr& expr) const { |
761 | auto it = message_.find(expr.get()); |
762 | if (it != message_.end()) return it->second; |
763 | return NullValue<Message>(); |
764 | } |
765 | |
766 | // solver is not serializable. |
767 | void VisitAttrs(tvm::AttrVisitor* v) {} |
768 | |
769 | static constexpr const char* _type_key = "relay.fold_scale_axis.FBackwardTransformer" ; |
770 | TVM_DECLARE_FINAL_OBJECT_INFO(BackwardTransformerNode, Object); |
771 | |
772 | private: |
773 | // Valid axes on each node. |
774 | std::unordered_map<const Object*, Message> message_; |
775 | // Override mutation of call. |
776 | Expr Rewrite_(const CallNode* call_node, const Expr& post) final { |
777 | return Transform(GetRef<Call>(call_node), NullValue<Message>(), NullValue<Expr>()); |
778 | } |
779 | |
780 | public: |
781 | Expr NormalCallTransform(const CallNode* call_node) { return ExprMutator::VisitExpr_(call_node); } |
782 | }; |
783 | |
784 | class BackwardTransformer : public ObjectRef { |
785 | public: |
786 | BackwardTransformer() {} |
787 | explicit BackwardTransformer(::tvm::ObjectPtr<::tvm::Object> n) : ObjectRef(n) {} |
788 | BackwardTransformerNode* operator->() const { |
789 | return static_cast<BackwardTransformerNode*>(get_mutable()); |
790 | } |
791 | using ContainerType = BackwardTransformerNode; |
792 | }; |
793 | |
794 | /*! |
795 | * \brief Transform the expr to consider the scaling. |
796 | * |
797 | * \param expr The input expression. |
798 | * \param message The axes to scale. |
799 | * \param scale The scale applied to the axes. |
800 | * \return The result of transformation. |
801 | */ |
802 | Expr BackwardTransformerNode::Transform(const Expr& expr, Message message, Expr scale) { |
803 | if (const CallNode* call_node = expr.as<CallNode>()) { |
804 | static const auto& ftransform = |
805 | Op::GetAttrMap<FBackwardTransform>("FScaleAxisBackwardTransform" ); |
806 | auto f = ftransform.get(call_node->op, nullptr); |
807 | const Call call = GetRef<Call>(call_node); |
808 | // ignore if there is a message |
809 | if (!message.defined()) { |
810 | const auto it = memo_.find(call); |
811 | if (it != memo_.end()) { |
812 | return it->second; |
813 | } |
814 | } |
815 | Expr new_expr = NullValue<Expr>(); |
816 | if (f != nullptr) { |
817 | new_expr = f(call, message, scale, GetRef<BackwardTransformer>(this)); |
818 | } else { |
819 | ICHECK(!message.defined()) << "outstanding scale" ; |
820 | new_expr = NormalCallTransform(call.operator->()); |
821 | } |
822 | memo_[call] = new_expr; |
823 | return new_expr; |
824 | } else { |
825 | ICHECK(!message.defined()) << "outstanding scale" ; |
826 | return this->Mutate(expr); |
827 | } |
828 | } |
829 | |
830 | //---------------------------------------------- |
831 | // Per operator defs for FScaleAxisForward |
832 | //---------------------------------------------- |
833 | |
834 | // Intermediate operators |
835 | Message ReluBackwardPrep(const Call& call, const Array<Message>& in_messages) { |
836 | if (in_messages[0].defined()) { |
837 | return Message(in_messages[0]->axes, true); |
838 | } |
839 | return in_messages[0]; |
840 | } |
841 | |
842 | Expr ReluBackwardTransform(const Call& call, const Message& message, const Expr& scale, |
843 | const BackwardTransformer& transformer) { |
844 | if (!message.defined()) { |
845 | return transformer->NormalCallTransform(call.operator->()); |
846 | } |
847 | Expr input = transformer->Transform(call->args[0], message, scale); |
848 | return Call(call->op, {input}, call->attrs, call->type_args); |
849 | } |
850 | |
851 | RELAY_REGISTER_OP("nn.relu" ).set_attr<FBackwardPrep>("FScaleAxisBackwardPrep" , ReluBackwardPrep); |
852 | |
853 | RELAY_REGISTER_OP("nn.relu" ).set_attr<FBackwardTransform>("FScaleAxisBackwardTransform" , |
854 | ReluBackwardTransform); |
855 | |
856 | RELAY_REGISTER_OP("nn.leaky_relu" ) |
857 | .set_attr<FBackwardPrep>("FScaleAxisBackwardPrep" , ReluBackwardPrep); |
858 | |
859 | RELAY_REGISTER_OP("nn.leaky_relu" ) |
860 | .set_attr<FBackwardTransform>("FScaleAxisBackwardTransform" , ReluBackwardTransform); |
861 | |
862 | // AddSub |
863 | Message AddSubBackwardPrep(const Call& call, const Array<Message>& in_messages) { |
864 | const auto* tlhs = call->args[0]->type_as<TensorTypeNode>(); |
865 | const auto* trhs = call->args[1]->type_as<TensorTypeNode>(); |
866 | StructuralEqual equal; |
867 | if (in_messages[0].defined() && MatchBroadcastToLeftAxes(tlhs, trhs, in_messages[0]->axes)) { |
868 | return in_messages[0]; |
869 | } else if (in_messages[1].defined() && |
870 | MatchBroadcastToLeftAxes(trhs, tlhs, in_messages[1]->axes)) { |
871 | return in_messages[1]; |
872 | } else if (in_messages[0].defined() && in_messages[1].defined() && |
873 | equal(in_messages[0]->axes, in_messages[1]->axes) && equal(tlhs->shape, trhs->shape)) { |
874 | // add of two elements. |
875 | return in_messages[0]; |
876 | } else { |
877 | auto res = NullValue<Message>(); |
878 | return res; |
879 | } |
880 | } |
881 | |
882 | Expr AddSubBackwardTransform(const Call& call, const Message& message, const Expr& scale, |
883 | const BackwardTransformer& transformer) { |
884 | const auto* tlhs = call->args[0]->type_as<TensorTypeNode>(); |
885 | const auto* trhs = call->args[1]->type_as<TensorTypeNode>(); |
886 | if (!message.defined()) { |
887 | return transformer->NormalCallTransform(call.operator->()); |
888 | } |
889 | |
890 | Message lhs_message = transformer->GetMessage(call->args[0]); |
891 | Message rhs_message = transformer->GetMessage(call->args[1]); |
892 | StructuralEqual equal; |
893 | |
894 | if (lhs_message.defined() && rhs_message.defined()) { |
895 | ICHECK(equal(lhs_message->axes, rhs_message->axes)); |
896 | ICHECK(equal(message->axes, lhs_message->axes)); |
897 | Expr lhs = transformer->Transform(call->args[0], message, scale); |
898 | Expr rhs = transformer->Transform(call->args[1], message, scale); |
899 | return Call(call->op, {lhs, rhs}, call->attrs, call->type_args); |
900 | } else if (lhs_message.defined()) { |
901 | ICHECK(equal(message->axes, lhs_message->axes)); |
902 | Expr lhs = transformer->Transform(call->args[0], message, scale); |
903 | Expr rhs = transformer->Transform(call->args[1], NullValue<Message>(), NullValue<Expr>()); |
904 | Expr rhs_scale = ReshapeOrExpandToMatchAxis(scale, tlhs->shape, message->axes); |
905 | if (!rhs_scale.defined()) { |
906 | return transformer->NormalCallTransform(call.operator->()); |
907 | } |
908 | rhs = Multiply(rhs, rhs_scale); |
909 | return Call(call->op, {lhs, rhs}, call->attrs, call->type_args); |
910 | } else if (rhs_message.defined()) { |
911 | ICHECK(equal(message->axes, rhs_message->axes)); |
912 | Expr lhs = transformer->Transform(call->args[0], NullValue<Message>(), NullValue<Expr>()); |
913 | Expr rhs = transformer->Transform(call->args[1], message, scale); |
914 | Expr lhs_scale = ReshapeOrExpandToMatchAxis(scale, trhs->shape, message->axes); |
915 | if (!lhs_scale.defined()) { |
916 | return transformer->NormalCallTransform(call.operator->()); |
917 | } |
918 | lhs = Multiply(lhs, lhs_scale); |
919 | return Call(call->op, {lhs, rhs}, call->attrs, call->type_args); |
920 | } else { |
921 | LOG(FATAL) << "outstanding scale" ; |
922 | } |
923 | } |
924 | |
925 | RELAY_REGISTER_OP("add" ).set_attr<FBackwardPrep>("FScaleAxisBackwardPrep" , AddSubBackwardPrep); |
926 | |
927 | RELAY_REGISTER_OP("add" ).set_attr<FBackwardTransform>("FScaleAxisBackwardTransform" , |
928 | AddSubBackwardTransform); |
929 | |
930 | RELAY_REGISTER_OP("subtract" ).set_attr<FBackwardPrep>("FScaleAxisBackwardPrep" , AddSubBackwardPrep); |
931 | |
932 | RELAY_REGISTER_OP("subtract" ) |
933 | .set_attr<FBackwardTransform>("FScaleAxisBackwardTransform" , AddSubBackwardTransform); |
934 | |
935 | // Producer operators |
936 | // Multiply produces the scale-axis pair. |
937 | Expr MultiplyBackwardTransform(const Call& call, const Message& message, const Expr& scale, |
938 | const BackwardTransformer& transformer) { |
939 | ICHECK(!message.defined()) << "outstanding scale" ; |
940 | const auto* tlhs = call->args[0]->type_as<TensorTypeNode>(); |
941 | const auto* trhs = call->args[1]->type_as<TensorTypeNode>(); |
942 | Message lhs_message = transformer->GetMessage(call->args[0]); |
943 | Message rhs_message = transformer->GetMessage(call->args[1]); |
944 | if (lhs_message.defined()) { |
945 | ICHECK(lhs_message->axes.defined() && lhs_message->axes.size()); |
946 | // NOTE we won't recursively call mutating on scale part. |
947 | // since there won't be scale chance within scale part. |
948 | Expr rhs = call->args[1]; |
949 | if (MatchBroadcastToLeftAxes(tlhs, trhs, lhs_message->axes, &rhs) && |
950 | (!lhs_message->require_positive || IsAllPositiveConstant(rhs))) { |
951 | return transformer->Transform(call->args[0], lhs_message, rhs); |
952 | } |
953 | } else if (rhs_message.defined()) { |
954 | ICHECK(rhs_message->axes.defined() && rhs_message->axes.size()); |
955 | Expr lhs = call->args[0]; |
956 | if (MatchBroadcastToLeftAxes(trhs, tlhs, rhs_message->axes, &lhs) && |
957 | (!rhs_message->require_positive || IsAllPositiveConstant(lhs))) { |
958 | return transformer->Transform(call->args[1], rhs_message, lhs); |
959 | } |
960 | } |
961 | return transformer->NormalCallTransform(call.operator->()); |
962 | } |
963 | |
964 | RELAY_REGISTER_OP("multiply" ) |
965 | .set_attr<FBackwardTransform>("FScaleAxisBackwardTransform" , MultiplyBackwardTransform); |
966 | |
967 | // Consumer operators |
968 | // Conv send out requirement of axis folding. |
969 | template <typename ATTRS> |
970 | Message ConvBackwardPrep(const Call& call, const ATTRS* param, const Array<Message>& in_messages) { |
971 | ICHECK(param != nullptr); |
972 | Layout kernel_layout(param->kernel_layout); |
973 | Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); |
974 | int c_big_axis = out_layout.IndexOf(LayoutAxis::Get('C')); |
975 | int c_small_axis = out_layout.IndexOf(LayoutAxis::Get('c')); |
976 | |
977 | ICHECK_GE(c_big_axis, 0); |
978 | // For now, we only support simple pattern (no folded weight/data) |
979 | // More general layout can be supported under the current framework. |
980 | // By using a unified layout transformation. |
981 | // We only need to change the Prep and Mutate function. |
982 | // |
983 | // only handle depthwise or full conv. |
984 | // TODO(tvm-team) handle grouped conv by reshape + bcast |
985 | bool is_depthwise_conv = IsDepthwiseConv(call, param, kernel_layout); |
986 | if (param->groups == 1 || is_depthwise_conv) { |
987 | auto ko_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('o')); |
988 | auto ki_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('i')); |
989 | if ((ko_small_axis < 0 && ki_small_axis < 0 && c_small_axis < 0) || // simple layout |
990 | (ko_small_axis >= 0 && ki_small_axis >= 0 && c_small_axis >= 0)) { // blocked layout |
991 | Array<Integer> arr{c_big_axis}; |
992 | if (c_small_axis >= 0) { |
993 | arr.push_back(c_small_axis); |
994 | } |
995 | return Message(arr, false); |
996 | } |
997 | } |
998 | return NullValue<Message>(); |
999 | } |
1000 | |
1001 | // Conv consumes the scale axis during transformation. |
1002 | template <typename ATTRS> |
1003 | Expr ConvBackwardTransform(const Call& call, const ATTRS* param, const Message& message, |
1004 | const Expr& scale, const BackwardTransformer& transformer) { |
1005 | if (!message.defined()) { |
1006 | return transformer->NormalCallTransform(call.operator->()); |
1007 | } |
1008 | ICHECK(param != nullptr); |
1009 | Layout kernel_layout(param->kernel_layout); |
1010 | Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); |
1011 | int c_big_axis = out_layout.IndexOf(LayoutAxis::Get('C')); |
1012 | ICHECK_GE(c_big_axis, 0); |
1013 | // For now, we only support simple pattern (no folded weight/data) |
1014 | // TODO(tvm-team) support general data layout |
1015 | int small_ko_axis = kernel_layout.IndexOf(LayoutAxis::Get('o')); |
1016 | int small_ki_axis = kernel_layout.IndexOf(LayoutAxis::Get('i')); |
1017 | int big_ki_axis = kernel_layout.IndexOf(LayoutAxis::Get('I')); |
1018 | int big_ko_axis = kernel_layout.IndexOf(LayoutAxis::Get('O')); |
1019 | // Check it must be depthwise or full conv. |
1020 | bool is_depthwise_conv = IsDepthwiseConv(call, param, kernel_layout); |
1021 | ICHECK(param->groups == 1 || is_depthwise_conv); |
1022 | bool is_simple = (small_ko_axis < 0 && small_ki_axis < 0 && big_ki_axis >= 0); |
1023 | bool is_blocking = (small_ko_axis >= 0 && small_ki_axis >= 0 && big_ki_axis >= 0); |
1024 | ICHECK(is_simple || is_blocking); |
1025 | |
1026 | Expr data = transformer->Transform(call->args[0], NullValue<Message>(), NullValue<Expr>()); |
1027 | Expr weight = transformer->Transform(call->args[1], NullValue<Message>(), NullValue<Expr>()); |
1028 | // scale on input for deptwise. |
1029 | Expr wscale; |
1030 | if (is_simple) { |
1031 | wscale = ExpandBiasToMatchAxis(scale, kernel_layout.ndim(), {big_ko_axis}); |
1032 | } else { |
1033 | wscale = ReshapeToMatchAxis(scale, weight->type_as<TensorTypeNode>()->shape, |
1034 | {big_ko_axis, small_ko_axis}); |
1035 | if (!wscale.defined()) { |
1036 | return transformer->NormalCallTransform(call.operator->()); |
1037 | } |
1038 | } |
1039 | weight = Multiply(weight, wscale); |
1040 | return Call(call->op, {data, weight}, call->attrs, call->type_args); |
1041 | } |
1042 | |
1043 | Message PreConvBackwardPrep(const Call& call, const Array<Message>& in_messages) { |
1044 | if (backend::IsOp(call.as<CallNode>(), "nn.conv2d" )) { |
1045 | const auto* param = call->attrs.as<Conv2DAttrs>(); |
1046 | ICHECK(param != nullptr); |
1047 | return ConvBackwardPrep(call, param, in_messages); |
1048 | } |
1049 | const auto* param = call->attrs.as<Conv3DAttrs>(); |
1050 | ICHECK(param != nullptr); |
1051 | return ConvBackwardPrep(call, param, in_messages); |
1052 | } |
1053 | |
1054 | Expr PreConvBackwardTransform(const Call& call, const Message& message, const Expr& scale, |
1055 | const BackwardTransformer& transformer) { |
1056 | if (backend::IsOp(call.as<CallNode>(), "nn.conv2d" )) { |
1057 | const auto* param = call->attrs.as<Conv2DAttrs>(); |
1058 | ICHECK(param != nullptr); |
1059 | return ConvBackwardTransform(call, param, message, scale, transformer); |
1060 | } |
1061 | const auto* param = call->attrs.as<Conv3DAttrs>(); |
1062 | ICHECK(param != nullptr); |
1063 | return ConvBackwardTransform(call, param, message, scale, transformer); |
1064 | } |
1065 | |
1066 | RELAY_REGISTER_OP("nn.conv2d" ) |
1067 | .set_attr<FBackwardPrep>("FScaleAxisBackwardPrep" , PreConvBackwardPrep); |
1068 | |
1069 | RELAY_REGISTER_OP("nn.conv2d" ) |
1070 | .set_attr<FBackwardTransform>("FScaleAxisBackwardTransform" , PreConvBackwardTransform); |
1071 | |
1072 | RELAY_REGISTER_OP("nn.conv3d" ) |
1073 | .set_attr<FBackwardPrep>("FScaleAxisBackwardPrep" , PreConvBackwardPrep); |
1074 | |
1075 | RELAY_REGISTER_OP("nn.conv3d" ) |
1076 | .set_attr<FBackwardTransform>("FScaleAxisBackwardTransform" , PreConvBackwardTransform); |
1077 | |
1078 | Message BiasAddBackwardPrep(const Call& call, const Array<Message>& in_messages) { |
1079 | const BiasAddAttrs* attrs = call->attrs.as<BiasAddAttrs>(); |
1080 | ICHECK(attrs); |
1081 | if (in_messages[0].defined() && in_messages[0]->axes.size() == 1 && |
1082 | attrs->axis == static_cast<int>(in_messages[0]->axes[0]->value)) { |
1083 | return in_messages[0]; |
1084 | } else { |
1085 | return NullValue<Message>(); |
1086 | } |
1087 | } |
1088 | |
1089 | Expr BiasAddBackwardTransform(const Call& call, const Message& message, const Expr& scale, |
1090 | const BackwardTransformer& transformer) { |
1091 | if (!message.defined()) { |
1092 | return transformer->NormalCallTransform(call.operator->()); |
1093 | } |
1094 | Message lhs_message = transformer->GetMessage(call->args[0]); |
1095 | Message rhs_message = transformer->GetMessage(call->args[1]); |
1096 | StructuralEqual equal; |
1097 | |
1098 | if (lhs_message.defined()) { |
1099 | ICHECK(equal(message->axes, lhs_message->axes)); |
1100 | Expr lhs = transformer->Transform(call->args[0], message, scale); |
1101 | Expr rhs = transformer->Transform(call->args[1], NullValue<Message>(), NullValue<Expr>()); |
1102 | rhs = Multiply(rhs, scale); |
1103 | return Call(call->op, {lhs, rhs}, call->attrs, call->type_args); |
1104 | } else { |
1105 | LOG(FATAL) << "outstanding scale" ; |
1106 | } |
1107 | } |
1108 | |
1109 | RELAY_REGISTER_OP("nn.bias_add" ) |
1110 | .set_attr<FBackwardPrep>("FScaleAxisBackwardPrep" , BiasAddBackwardPrep); |
1111 | |
1112 | RELAY_REGISTER_OP("nn.bias_add" ) |
1113 | .set_attr<FBackwardTransform>("FScaleAxisBackwardTransform" , BiasAddBackwardTransform); |
1114 | |
1115 | // Dense send out requirement of axis folding. |
1116 | Message DenseBackwardPrep(const Call& call, const Array<Message>& in_messages) { |
1117 | return Message({1}, false); |
1118 | } |
1119 | |
1120 | // Dense consumes the sacle axis during trasformation. |
1121 | Expr DenseBackwardTransform(const Call& call, const Message& message, const Expr& scale, |
1122 | const BackwardTransformer& transformer) { |
1123 | if (!message.defined()) { |
1124 | return transformer->NormalCallTransform(call.operator->()); |
1125 | } |
1126 | Expr data = transformer->Transform(call->args[0], NullValue<Message>(), NullValue<Expr>()); |
1127 | Expr weight = transformer->Transform(call->args[1], NullValue<Message>(), NullValue<Expr>()); |
1128 | Expr wscale = ExpandBiasToMatchAxis(scale, 2, {0}); |
1129 | weight = Multiply(weight, wscale); |
1130 | return Call(call->op, {data, weight}, call->attrs, call->type_args); |
1131 | } |
1132 | |
1133 | RELAY_REGISTER_OP("nn.dense" ).set_attr<FBackwardPrep>("FScaleAxisBackwardPrep" , DenseBackwardPrep); |
1134 | |
1135 | RELAY_REGISTER_OP("nn.dense" ) |
1136 | .set_attr<FBackwardTransform>("FScaleAxisBackwardTransform" , DenseBackwardTransform); |
1137 | |
1138 | Expr BackwardFoldScaleAxis(const Expr& data) { |
1139 | return make_object<BackwardTransformerNode>()->Fold(data); |
1140 | } |
1141 | |
1142 | } // namespace fold_scale_axis |
1143 | |
1144 | namespace transform { |
1145 | |
1146 | Pass ForwardFoldScaleAxis() { |
1147 | runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func = |
1148 | [=](Function f, IRModule m, PassContext pc) { |
1149 | return Downcast<Function>(relay::fold_scale_axis::ForwardFoldScaleAxis(f)); |
1150 | }; |
1151 | return CreateFunctionPass(pass_func, 3, "ForwardFoldScaleAxis" , {"InferType" }); |
1152 | } |
1153 | |
1154 | TVM_REGISTER_GLOBAL("relay._transform.ForwardFoldScaleAxis" ).set_body_typed(ForwardFoldScaleAxis); |
1155 | |
1156 | Pass BackwardFoldScaleAxis() { |
1157 | runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func = |
1158 | [=](Function f, IRModule m, PassContext pc) { |
1159 | return Downcast<Function>(relay::fold_scale_axis::BackwardFoldScaleAxis(f)); |
1160 | }; |
1161 | return CreateFunctionPass(pass_func, 3, "BackwardFoldScaleAxis" , {"InferType" }); |
1162 | } |
1163 | |
1164 | TVM_REGISTER_GLOBAL("relay._transform.BackwardFoldScaleAxis" ).set_body_typed(BackwardFoldScaleAxis); |
1165 | |
1166 | Pass FoldScaleAxis() { |
1167 | // FoldScaleAxis pass contains the following three passes. Therefore, we can |
1168 | // register it as a sequential pass. |
1169 | Pass pass = Sequential({BackwardFoldScaleAxis(), ForwardFoldScaleAxis(), FoldConstant()}, |
1170 | "FoldScaleAxis" ); |
1171 | return pass; |
1172 | } |
1173 | |
1174 | TVM_REGISTER_GLOBAL("relay._transform.FoldScaleAxis" ).set_body_typed(FoldScaleAxis); |
1175 | |
1176 | } // namespace transform |
1177 | |
1178 | } // namespace relay |
1179 | } // namespace tvm |
1180 | |