1#include <parser.h>
2
3#include <arith.h>
4#include <instrumentation.h>
5#include <ir_all_nodes.h>
6#include <ir_builder.h>
7#include <ir_iostream.h>
8#include <ops/all_ops.h>
9#include <type_inference.h>
10#include <type_promotion.h>
11#include <utils.h>
12
13#include <torch/csrc/jit/frontend/function_schema_parser.h>
14#include <torch/csrc/jit/ir/constants.h>
15
16#include <ATen/native/Activation.h>
17
18#include <c10/util/CallOnce.h>
19
20#include <unordered_map>
21#include <utility>
22
23namespace torch {
24namespace jit {
25
26typedef Value JitValue;
27typedef Node JitOp;
28
29namespace fuser {
30namespace cuda {
31
32constexpr auto kNumUnaryOps = 10;
33constexpr auto kNumUnaryFloatOps = 23;
34constexpr auto kNumUnaryIsOps = 6;
35
36constexpr auto kNumBinaryFloatOps = 3;
37constexpr auto kNumBinaryComparisonOps = 12;
38constexpr auto kNumBinaryCastOps = 19;
39
40constexpr auto kNumBinaryOpsWithAlpha = 6;
41constexpr auto kNumLerpOps = 2;
42constexpr auto kNumLayernormFwd = 2;
43constexpr auto kNumBatchnormFwd = 3;
44constexpr auto kNumBatchnormBwd = 2;
45constexpr auto kNumInstancenormFwd = 1;
46constexpr auto kNumSumToSize = 2;
47constexpr auto kNumAutocastOps = 2;
48constexpr auto kNumAliasDimOps = 2;
49constexpr auto kNumViewOps = 2;
50constexpr auto kNumVarOps = 2;
51constexpr auto kNumSoftmaxFwd = 2;
52constexpr auto kNumSoftmaxBwd = 2;
53constexpr auto kNumAminAmaxOps = 2;
54
55namespace {
56
57#define REGISTER_PARSE_RULE(op, func_body, ...) \
58 registerParseRule( \
59 op, \
60 [](const Node* node, std::unordered_map<size_t, ValueHolder>& value_map) \
61 -> void func_body, \
62 __VA_ARGS__)
63
64const auto& reductionSizeAttr = Symbol::attr("profiled_reduction_size");
65const auto& viewSizeAttr = Symbol::attr("profiled_view_size");
66const auto& intListAttr = Symbol::attr("profiled_int_list");
67const auto& intAttr = Symbol::attr("profiled_int");
68const auto& boolListAttr = Symbol::attr("profiled_bool_list");
69const auto& boolAttr = Symbol::attr("profiled_bool");
70const auto& strAttr = Symbol::attr("profiled_str");
71const auto& ivalAttr = Symbol::attr("profiled_ival");
72const auto& profileFailedAttr = Symbol::attr("profile_failed");
73
74typedef Val* CgValue;
75typedef Expr* CgOp;
76
77Val* castTensoToDtype(CgValue self, JitValue* cast_val) {
78 auto cast_ival = toIValue(cast_val);
79 // we need static type for cast
80 TORCH_INTERNAL_ASSERT(cast_ival.has_value());
81 if (cast_ival->isInt()) {
82 auto dtype = cast_ival->toScalarType();
83
84 // We want to keep our internal fusion math in FP32
85 // Shape Inference will continue to propagate the right
86 // type to outputs unchanged.
87 if (dtype == at::ScalarType::Half || dtype == at::ScalarType::BFloat16) {
88 dtype = at::ScalarType::Float;
89 }
90
91 return castOp(aten_to_data_type(dtype), self);
92 } else {
93 TORCH_INTERNAL_ASSERT(
94 cast_ival->isNone(),
95 "unrecognized dtype option, expect 'int' but got: ",
96 cast_ival->tagKind());
97
98 // return a copy if dtype is `None`
99 return set(self);
100 }
101}
102
103bool isReductionNonCompatibleTensor(
104 const std::shared_ptr<c10::TensorType>& tensor_type) {
105 return is_zero_dim_tensor(tensor_type) || is_zero_sized_tensor(tensor_type);
106}
107
108bool isInputNonSizeZeroTensor(const Node* node) {
109 for (const auto& val : node->inputs()) {
110 auto tensor_type = val->type()->cast<TensorType>();
111 if (tensor_type && is_zero_sized_tensor(tensor_type)) {
112 return false;
113 }
114 }
115 return true;
116}
117
118bool isScalarTypeCompatible(const Node* node, size_t offset) {
119 auto val = node->input(offset);
120 // return true if it's not specified
121 if (val->type()->isSubtypeOf(static_cast<c10::TypePtr>(NoneType::get()))) {
122 return true;
123 }
124 // return false if it's runtime value
125 if (val->node()->kind() != prim::Constant) {
126 return false;
127 }
128 auto dtype = toIValue(val)->toScalarType();
129
130 // we do NOT support half math type yet
131 if (dtype == at::ScalarType::Half || dtype == at::ScalarType::BFloat16) {
132 return false;
133 }
134 return true;
135}
136
137// Note [ Permutation Bookkeeping and Propagation in Parser ]
138//
139// The goal in supporting permutation propagation in parser is to:
140// 1. resolves conflicts and propagate permutation;
141// 2. bookkeeping of permutation on existing tensors;
142//
143// The requirement right now is that all parsing rules should support
144// non-permuted inputs, some binary operations support inputs with arbitrary
145// permutation, a few operations support special inputs.
146// In case where "wrong" inputs are fed to an operation, we should transpose
147// it to proper supported permutation. This allows us to progressively expand
148// permutation support.
149// Currently we bind all permuted codegen Val in `ValueHolder`. This saves
150// unnecessary transpose (not sure if it actually helps) since we can reuse
151// permuted tensors.
152//
153// Parsing rule pattern:
154// a. ops that only support non-permuted inputs (e.g. sum)
155//
156// // Specifying `MemoryFormat::Contiguous` here to force all inputs to be in
157// // `Contiguous`
158// auto [format, self] = getConsistentValues(
159// MemoryFormat::Contiguous,
160// value_map[node->inputs()[0]->unique()]);
161// // ... use self
162//
163// b. format agnostic ops (e.g. PW unary/binary op like aten::add)
164//
165// // getConsistentValues -> return target format and copies of operands in
166// // the same format
167// auto [format, lhs, rhs] = getConsistentValues(
168// c10::nullopt,
169// value_map[node->inputs()[0]->unique()],
170// value_map[node->inputs()[1]->unique()]);
171//
172// // compute out
173// auto out = binaryOp(op_mapping[node->kind()], lhs, rhs);
174// // specify `format` for out when adding it to `value_map_`
175// value_map.emplace(node->output()->unique(), ValueHolder(out, format));
176//
177// c. ops that supports special permutation. e.g. aten::batch_norm with
178// channels-last inputs.
179
180struct MemoryFormat {
181 // indices of dimensions with increasing stride.
182 std::vector<int> permuted_order_;
183
184 // permutation_ encodes `permuted_order_` by concatenating all elements, with
185 // the exception for unpermuted tensor, where we special case permutation_ to
186 // be 0.
187 //
188 // e.g. for an channels-last tensor, permutation_ would be (n-1)123...(n-2);
189 // Note: we are omitting the leading '0' when applicable, and apparently this
190 // encoding only works with rank < 10
191 // see [ Note: MemoryFormat and Stride Order ]
192 size_t permutation_ = 0;
193
194 // default to non-permuted tensor
195 MemoryFormat() = default;
196
197 // [ Note: MemoryFormat and Stride Order ]
198 // stride_order is extracted from
199 // `TensorType::stride_properties()::stride_index_`, it describes the
200 // index of axes from fastest to slowest.
201 // or a 4d tensor, if we have stride_order = {x0, x1, x2, x3}, The i-th
202 // fastest dimension would be stride_order[i].
203 //
204 // Look at comment for c10::Stride in aten/src/ATen/core/jit_type.h
205 //
206 // eg0. for rank 4 non-permuted tensor, stride_order would be {3, 2, 1, 0}, it
207 // means the fastest dimension is axis-3. the next one would be 2, e.t.c.. So
208 // it's a non-permuted tensor.
209 // it should be encoded as permutation_ = 3210 (we special case it to 0)
210 //
211 // eg1. for rank 4 channels-last tensor, stride_order would be {1, 3, 2, 0},
212 // it means the fastest dimension is axis-1. the next one would be 3, and then
213 // 2, and then 0. So this is a channels last tensor (NCHW).
214 // it will be encoded as permutation_ = 1320
215 //
216 // eg2. for a rank 4 permuted tensor, stride_order can be {0, 3, 2, 1}
217 // it will be encoded as permutation_ = 321 (omitting leading '0')
218 void setPermutation(const std::vector<int>& stride_order) {
219 int rank = stride_order.size();
220 TORCH_INTERNAL_ASSERT(
221 rank <= 10, "MemoryFormat for permutation only supports rank <= 10");
222
223 // storing stride_order in `permuted_order` for a simpler life, so we don't
224 // have to decode `permutation_` when we want to apply/restore permutation_.
225 permuted_order_ = stride_order;
226 bool has_permutation = false;
227 permutation_ = 0;
228 for (const auto i : c10::irange(rank)) {
229 permutation_ = permutation_ * 10 + stride_order[i];
230 if (!has_permutation && stride_order[i] != rank - 1 - i) {
231 has_permutation = true;
232 }
233 }
234
235 // special case permutation_ to reflect non-permuted tensor
236 if (!has_permutation) {
237 permutation_ = 0;
238 }
239 }
240
241 // returns the stride order for given MemoryFormat encoding permutation_
242 //
243 // see details for encoding in [ Note: MemoryFormat and Stride Order ]
244 std::vector<int> toStrideOrder() const {
245 std::vector<int> stride_order;
246 // return empty vector for no permutation
247 if (hasPermutation()) {
248 // be generous with reserved space
249 stride_order.reserve(10);
250 bool encountered_zero = false;
251 size_t permutation = permutation_;
252 while (permutation != 0) {
253 int order = static_cast<int>(permutation % 10);
254 permutation /= 10;
255 if (order == 0) {
256 encountered_zero = true;
257 }
258 stride_order.push_back(order);
259 }
260 if (!encountered_zero) {
261 // in case leading '0' is omitted, push it back
262 stride_order.push_back(0);
263 }
264 // since we use push_back, our stride_order is reversed.
265 std::reverse(stride_order.begin(), stride_order.end());
266 }
267 return stride_order;
268 }
269
270 // returns c10::nullopt when it's not safe to broadcast current permutation to
271 // rank
272 c10::optional<MemoryFormat> broadcastToRank(size_t rank) const {
273 auto ret = Contiguous();
274 if (hasPermutation()) {
275 auto stride_order = toStrideOrder();
276 auto cur_rank = stride_order.size();
277 // no op for (cur_rank == 0) || (cur_rank == rank)
278 if (cur_rank < rank) {
279 // broadcasting to hight rank can be done by:
280 // 1. incrementing all existing stride order by rank_diff;
281 // 2. push back decrementing elements starting with rank_diff;
282 // where rank_diff = rank - cur_rank
283 //
284 // see [ Note: MemoryFormat and Stride Order]
285 // e.g.
286 // taking broadcasted bias for channels last as an example
287 // stride_order = {0, 2, 1} broadcasted to rank == 4 would give us
288 // rank_diff = 4 - 3 = 1
289 // take step 1 -> {1, 3, 2}
290 // take step 2 -> {1, 3, 2, 0}
291 int rank_diff = static_cast<int>(rank - cur_rank);
292 for (auto& val : stride_order) {
293 val += rank_diff;
294 }
295 for (int i = rank_diff - 1; i >= 0; i--) {
296 stride_order.push_back(i);
297 }
298 } else if (cur_rank > rank) {
299 // shrink permutation to lower rank. We can simply discard higher rank
300 // stride order when they are not permuted to lower rank bit, because in
301 // those instance we can't obey broadcasting semantics while preserving
302 // permutation. We check for stride order and ensure that the lower
303 // `rank` bits are all permuted within the lower rank. Afterwards, we
304 // update stride_order by decrement each entry by rank_diff to reflect
305 // correct stride order.
306 //
307 // see [ Note: MemoryFormat and Stride Order]
308 // e.g. for rank 4 channels last {1, 3, 2, 0}:
309 // 1. format can safely shrink to rank 3, since any@{1, 3, 2} >=
310 // (4-3); We ditch last (4-3) rank and decrement each element by (4-1)
311 // that gives us {0, 2, 1};
312 // 2. but when we shrink it to rank 2, we have {1, 3} where 1 < (4-2)
313 // and it can't be handled, we return c10::nullopt.
314 int collapsed_ranks = static_cast<int>(cur_rank - rank);
315 for (size_t i = 0; i < rank; i++) {
316 if (stride_order[i] < collapsed_ranks) {
317 // illegal collapsing, return c10::nullopt
318 return c10::nullopt;
319 }
320 // update collapsed stride_order
321 stride_order[i] -= collapsed_ranks;
322 }
323 // discard higher rank stride order.
324 stride_order.resize(rank);
325 }
326 ret.setPermutation(stride_order);
327 }
328 return ret;
329 }
330
331 // returns non-permuted format
332 static MemoryFormat Contiguous() {
333 return MemoryFormat();
334 }
335
336 bool hasPermutation() const {
337 return permutation_ != 0;
338 }
339
340 bool isChannelsLast() const {
341 int rank = permuted_order_.size();
342
343 if (rank > 2 && permuted_order_[0] == 1 && permuted_order_[rank - 1] == 0) {
344 for (const auto i : c10::irange(rank - 2)) {
345 if (permuted_order_[i + 1] != rank - 1 - i) {
346 return false;
347 }
348 }
349 return true;
350 }
351 return false;
352 }
353
354 // returns transpose map to achieve permutation on non-permuted tensor
355 // note: used for aten::permute API and codegen tranpose API
356 std::vector<int64_t> apply() const {
357 std::vector<int64_t> ret;
358 if (hasPermutation()) {
359 ret.resize(permuted_order_.size());
360 std::copy(permuted_order_.rbegin(), permuted_order_.rend(), ret.begin());
361 }
362 return ret;
363 }
364
365 // returns transpose map to restore back to non-permuted tensor
366 // note: used for aten::permute API and codegen transpose API
367 std::vector<int64_t> restore() const {
368 std::vector<int64_t> ret;
369 if (hasPermutation()) {
370 int rank = permuted_order_.size();
371 ret.resize(rank);
372 for (const auto i : c10::irange(rank)) {
373 ret[permuted_order_[i]] = rank - 1 - i;
374 }
375 }
376 return ret;
377 }
378};
379
380struct MemoryCompare {
381 bool operator()(const MemoryFormat& format0, const MemoryFormat& format1)
382 const {
383 return format0.permutation_ < format1.permutation_;
384 }
385};
386
387typedef std::map<MemoryFormat, CgValue, MemoryCompare> MemoryFormatMap;
388
389MemoryFormat operator+(const MemoryFormat& a, const MemoryFormat& b) {
390 // Note: TensorIterator logic uses first input to dominate output MemoryFormat
391 // so instead of `a.permutation_ >= b.permutation_ ? a : b;`, we use:
392 return a;
393};
394
395//! ValueHolder is holds multiple copies in different permutation `MemoryFormat`
396//! of a tensor view. This mainly serves two purposes:
397//!
398//! 1. reuse permuted tensor views among consumers
399//! 2. bookkeeping for permuted tensor views in input/output tensors
400//!
401//! refer to Note [ Permutation Bookkeeping and Propagation in Parser ]
402class ValueHolder {
403 public:
404 // checks if given Val in target format exists.
405 bool hasValue(const MemoryFormat& format) const {
406 return vals_.count(format) != 0;
407 }
408
409 // returns Val in target format.
410 CgValue value(const MemoryFormat& format) const {
411 auto iter_val = vals_.find(format);
412 TORCH_INTERNAL_ASSERT(
413 iter_val != vals_.end(), "accessing non existing c_last_value()");
414 return iter_val->second;
415 }
416
417 // returns Val in target format if it exists, otherwise, transpose an existing
418 // copy and add that to bookkeeping.
419 CgValue maybeConvertValue(const MemoryFormat& format) {
420 auto cur_rank = rank();
421 // scalar (tensor) where cur_rank == 0, memory format doesn't carry meaning
422 // and should just return the value as-is. same for non-tensor where
423 // cur_rank == -1
424 if (cur_rank <= 0) {
425 return std::get<1>(getEntry());
426 }
427 MemoryFormat format_s;
428 CgValue value_s = nullptr;
429 std::tie(format_s, value_s) = getEntry();
430
431 auto opt_format_d = format.broadcastToRank(static_cast<size_t>(cur_rank));
432 TORCH_INTERNAL_ASSERT(
433 opt_format_d.has_value(),
434 "maybeConvertValue requested for illegal permutation");
435 MemoryFormat format_d = opt_format_d.value();
436
437 auto iter_val = vals_.find(format_d);
438 if (iter_val != vals_.end()) {
439 return iter_val->second;
440 }
441 auto val = convertValue(format_d, format_s, value_s);
442 vals_[format_d] = val;
443 return val;
444 }
445
446 int rank() const {
447 if (!is_tensor_view_) {
448 return -1;
449 } else {
450 auto v = std::get<1>(getEntry());
451 TORCH_INTERNAL_ASSERT(
452 v->isA<TensorView>(), "can only access rank of TensorView");
453 return static_cast<int>(v->as<TensorView>()->nDims());
454 }
455 }
456
457 // TODO: delete this and update accessor for value_map(_)
458 ValueHolder() {
459 TORCH_INTERNAL_ASSERT(false, "can't default constructor ValueHolder");
460 }
461
462 ValueHolder(CgValue val, MemoryFormat format = MemoryFormat()) {
463 vals_[format] = val;
464 if (val->isA<TensorView>()) {
465 is_tensor_view_ = true;
466 }
467 }
468
469 // returns the MemoryFormat and codegen Val with the highest precedence among
470 // existing copies.
471 std::tuple<MemoryFormat, CgValue> getEntry() const {
472 TORCH_CHECK(!vals_.empty(), "ValueHolder::getEntry() on empty vals_");
473 // return the last entry, this allows us to prioritize permuted (e.g.
474 // channels-last) tensor over non-permuted tensors
475 return *vals_.rbegin();
476 }
477
478 // TODO: code cleaning in parser so we don't need these.
479 // returns Val*, keeping them here just so we have less code change.
480 CgValue operator*() const {
481 return std::get<1>(getEntry());
482 }
483 CgValue operator->() const {
484 return std::get<1>(getEntry());
485 }
486 operator CgValue() const {
487 return std::get<1>(getEntry());
488 }
489
490 private:
491 // helper function to convert value_s @ format_s to format_d
492 CgValue convertValue(
493 MemoryFormat format_d,
494 MemoryFormat format_s,
495 CgValue value_s) {
496 TORCH_INTERNAL_ASSERT(
497 value_s->isA<TensorView>(), "cannot convert non-TensorView");
498 auto tv = value_s->as<TensorView>();
499 // TODO: we could probably merge the two if it has perf impact on generated
500 // kernel
501
502 // restore source permutation
503 if (format_s.hasPermutation()) {
504 tv = permute(tv, format_s.restore());
505 }
506 // apply destination permutation
507 if (format_d.hasPermutation()) {
508 tv = permute(tv, format_d.apply());
509 }
510 return tv;
511 }
512
513 private:
514 // container to hold all copies of value in different MemoryFormat
515 // std::unordered_map<MemoryFormat, CgValue> vals_;
516 MemoryFormatMap vals_;
517
518 // identify scalar Val
519 bool is_tensor_view_ = false;
520};
521
522template <class Func, class... Values>
523auto iterate(Func f, ValueHolder& val) {
524 return f(val);
525}
526
527template <class Func, class... Values>
528auto iterate(Func f, ValueHolder& val, Values&... vals) {
529 return f(val, iterate(f, vals...));
530}
531
532// iterate through all vals and return the output MemoryFormat and copies of
533// vals.
534// 1. When `forced_format == c10::nullopt`, target MemoryFormat returns the
535// format of the first val in `vals`, this is to achieve a coherent
536// behavior as with eager TensorIterator;
537// 2. The target can be overwritten vias specifying `forced_format`.
538//
539// Note: take `Values&` by reference, since `maybeConvertValue` needs to modify
540// the entry and we want that to be updated in `value_map_`
541template <class... Values>
542std::pair<MemoryFormat, std::list<CgValue>> getConsistentValues(
543 c10::optional<MemoryFormat> forced_format,
544 Values&... vals) {
545 MemoryFormat format;
546 if (forced_format.has_value()) {
547 format = forced_format.value();
548 } else {
549 // check for identical nDim on vals
550 auto rank_func = [](const ValueHolder& val, int rank = 0) {
551 int v_rank = val.rank();
552 v_rank = std::max(0, v_rank);
553 if (rank == 0) {
554 return v_rank;
555 } else if (v_rank == 0) {
556 return rank;
557 } else if (rank == -1 || v_rank != rank) {
558 return -1;
559 }
560 return rank;
561 };
562 int rank = iterate(rank_func, vals...);
563
564 // TODO: this is not needed as we are only using the first val
565 // only apply permutation when all inputs are of identical rank, since
566 // permutation could have changed semantics among broadcasted tensors.
567 // Consider pointwise operation between two tensor [N, C, H, W] + [H, W]
568 if (rank > 0) {
569 auto format_func = [](const ValueHolder& val,
570 MemoryFormat f = MemoryFormat::Contiguous()) {
571 return std::get<0>(val.getEntry()) + f;
572 };
573 format = iterate(format_func, vals...);
574 } else {
575 format = MemoryFormat::Contiguous();
576 }
577 }
578
579 auto convert_func = [format](
580 ValueHolder& val, std::list<CgValue> list_val = {}) {
581 list_val.push_front(val.maybeConvertValue(format));
582 return list_val;
583 };
584 auto list_val = iterate(convert_func, vals...);
585
586 return std::make_pair(format, list_val);
587}
588
589// iterate through all vals and return the output MemoryFormat and copies of
590// vals.
591// 1. When `forced_format == c10::nullopt`, target MemoryFormat returns the
592// format of the first val in `vals`, this is to achieve a coherent
593// behavior as with eager TensorIterator;
594// 2. The target can be overwritten vias specifying `forced_format`.
595//
596// Note: take `Values&` by reference, since `maybeConvertValue` needs to modify
597// the entry and we want that to be updated in `value_map_`
598template <class... Values>
599std::pair<MemoryFormat, std::list<CgValue>> getPWFormatValues(
600 c10::optional<MemoryFormat> forced_format,
601 Values&... vals) {
602 MemoryFormat format;
603 if (forced_format.has_value()) {
604 format = forced_format.value();
605 } else {
606 // get maximum rank on vals
607 std::vector<MemoryFormat> formats;
608 std::vector<int> ranks;
609 auto max_rank_func = [&ranks](const ValueHolder& val, int rank = 0) {
610 int v_rank = val.rank();
611 ranks.push_back(v_rank);
612 return std::max(rank, v_rank);
613 };
614 int max_rank = iterate(max_rank_func, vals...);
615
616 // going through all permutation, keeping consistency with TensorIterator
617 // behavior and the first tensor with highest rank dictates output
618 // permutation
619 auto format_func = [&formats, &max_rank](
620 const ValueHolder& val,
621 MemoryFormat f = MemoryFormat::Contiguous()) {
622 auto cur_format = std::get<0>(val.getEntry());
623 formats.push_back(cur_format);
624 return val.rank() == max_rank ? cur_format : f;
625 };
626 format = iterate(format_func, vals...);
627
628 // we need to do pair-wise comparison to ensure that all permutation are
629 // compatible since permutation could have changed semantics among
630 // broadcasted tensors. Consider pointwise operation between three tensor
631 // [N, C, H, W] + [C, H, W] + [H, W]
632 for (size_t i = 0; i < formats.size() && format.hasPermutation(); i++) {
633 for (size_t j = 0; j < formats.size(); j++) {
634 // don't compare scalar tensor or scalar
635 if (ranks[i] <= 0 || ranks[j] <= 0 || i == j) {
636 continue;
637 }
638 size_t lower_rank = std::min(ranks[i], ranks[j]);
639 auto i_format = formats[i].broadcastToRank(lower_rank);
640 auto j_format = formats[j].broadcastToRank(lower_rank);
641
642 // breaks permutation if any:
643 // 1. i_format can't be broadcasted to lower_rank;
644 // 2. j_format can't be broadcasted to lower_rank;
645 if (!i_format.has_value() || !j_format.has_value()) {
646 format = MemoryFormat::Contiguous();
647 }
648 }
649 }
650 }
651
652 auto convert_func = [format](
653 ValueHolder& val, std::list<CgValue> list_val = {}) {
654 list_val.push_front(val.maybeConvertValue(format));
655 return list_val;
656 };
657 auto list_val = iterate(convert_func, vals...);
658
659 return std::make_pair(format, list_val);
660}
661
662typedef void (
663 *ParseFuncPtr)(const Node*, std::unordered_map<size_t, ValueHolder>&);
664typedef bool (*MergeQueryFuncPtr)(const Node*);
665
666// TODO: add a mutex to make it thread safe.
667class IrParser {
668 enum class OperatorType {
669 ElementWise,
670 Reduction,
671 ReductionToSize,
672 Normalization
673 };
674 typedef OperatorType (*OperatorTypeFuncPtr)(const Node*);
675
676 class RegistrationEntry {
677 public:
678 RegistrationEntry(
679 ParseFuncPtr parse_f,
680 MergeQueryFuncPtr merge_f = nullptr,
681 OperatorTypeFuncPtr type_f = nullptr)
682 : parse_f_(parse_f), merge_f_(merge_f), type_f_(type_f) {}
683
684 void parse(
685 const Node* node,
686 std::unordered_map<size_t, ValueHolder>& values) const {
687 parse_f_(node, values);
688 }
689
690 bool isCompatible(const Node* node) const {
691 if (merge_f_ == nullptr) {
692 return true;
693 }
694 return merge_f_(node);
695 }
696
697 bool isType(const Node* node, OperatorType type) const {
698 auto n_type =
699 type_f_ == nullptr ? OperatorType::ElementWise : type_f_(node);
700 return n_type == type;
701 }
702
703 private:
704 ParseFuncPtr parse_f_;
705 MergeQueryFuncPtr merge_f_;
706 OperatorTypeFuncPtr type_f_;
707 };
708
709 public:
710 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
711 IrParser(std::shared_ptr<Graph> graph) : graph_(std::move(graph)) {
712 initRegistry();
713 }
714
715 std::unique_ptr<Fusion> parse() {
716 auto fusion = std::make_unique<Fusion>();
717 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
718 FusionGuard fg(fusion.get());
719 auto block = graph_->block();
720
721 std::unordered_map<Val*, MemoryFormat> permuted_tensors;
722 // register all inputs;
723 for (auto val : block->inputs()) {
724 TORCH_INTERNAL_ASSERT(
725 registerValue(val),
726 "Failure when register value: ",
727 *(val->node()),
728 " with type: ",
729 val->type()->repr_str());
730 MemoryFormat format;
731 Val* operand = nullptr;
732 std::tie(format, operand) = value_map_[val->unique()].getEntry();
733 fusion->addInput(operand);
734
735 // mark input tensor as permuted;
736 if (format.hasPermutation()) {
737 permuted_tensors.insert({operand, format});
738 }
739
740 auto opt_dtype = operand->getDataType();
741 // computation promotion, we cast fp16 or bf16 inputs to fp32 and use
742 // promoted type in the computation.
743 if (opt_dtype.has_value() &&
744 (opt_dtype.value() == DataType::Half ||
745 opt_dtype.value() == DataType::BFloat16)) {
746 Val* promoted_val = castOp(DataType::Float, operand);
747 value_map_[val->unique()] = ValueHolder(promoted_val, format);
748 }
749 }
750
751 // compose nodes in topo order;
752 for (const JitOp* node : block->nodes()) {
753 processJitNode(node);
754 }
755
756 // mark output;
757 for (auto jit_output : block->outputs()) {
758 MemoryFormat format;
759 Val* operand = nullptr;
760 std::tie(format, operand) = value_map_[jit_output->unique()].getEntry();
761 TensorView* out = operand->as<TensorView>();
762 // demote output dtype to be match PyTorch JIT graph.
763 auto tensor_type = jit_output->type()->cast<TensorType>();
764 TORCH_INTERNAL_ASSERT(
765 tensor_type, "output of fusion group is not TensorType.");
766 if (tensor_type->scalarType().has_value()) {
767 out = optionalCastStrict(
768 aten_to_data_type(*tensor_type->scalarType()), out)
769 ->as<TensorView>();
770 }
771
772 if (out->isFusionOutput()) {
773 // TODO: This is wasted memory bandwidth, we need to copy since we can't
774 // output a tensor twice.
775 out = set(out);
776 }
777
778 fusion->addOutput(out);
779
780 // mark output tensor as permuted;
781 if (format.hasPermutation()) {
782 permuted_tensors.insert({out, format});
783 }
784 }
785
786 for (const auto& i : c10::irange(fusion->inputs().size())) {
787 const auto& entry = permuted_tensors.find(fusion->inputs()[i]);
788 if (entry != permuted_tensors.end()) {
789 fusion->setPermutationOnInput(i, entry->second.apply());
790 }
791 }
792 for (const auto& i : c10::irange(fusion->outputs().size())) {
793 const auto& entry = permuted_tensors.find(fusion->outputs()[i]);
794 if (entry != permuted_tensors.end()) {
795 fusion->setPermutationOnOutput(i, entry->second.restore());
796 }
797 }
798 return fusion;
799 }
800
801 static bool lookupInSymbolSet(const Node* node) {
802 initRegistry();
803
804 std::lock_guard<std::mutex> lock(parser_mutex_);
805 return parser_symbol_set_.count(node->kind()) != 0;
806 }
807
808 // return nullptr if entry does not exist
809 static const RegistrationEntry* lookupInRegistry(const Node* node) {
810 std::lock_guard<std::mutex> lock(parser_mutex_);
811
812 if (parser_skip_set_.count(node->kind()) != 0) {
813 return nullptr;
814 }
815 // we need to use maybeSchema for nodes like prim::Constant, which doesn't
816 // have a schema
817 auto schema_ptr = node->maybeSchema();
818 if (schema_ptr != nullptr) {
819 // search cached entry first
820 auto cache_it = cached_registry_lookup_.find(schema_ptr);
821 if (cache_it != cached_registry_lookup_.end()) {
822 return cache_it->second;
823 } else {
824 // match signature
825 auto schema_str = canonicalSchemaString(*schema_ptr);
826
827 auto iter = jit_operator_registry_.find(schema_str);
828 if (iter != jit_operator_registry_.end()) {
829 // update cache entry
830 cached_registry_lookup_.insert(cache_it, {schema_ptr, &iter->second});
831 return &iter->second;
832 }
833 }
834 }
835 return nullptr;
836 }
837
838 static bool querySkipSymbolSet(c10::Symbol symbol, bool flip) {
839 initRegistry();
840
841 std::lock_guard<std::mutex> lock(parser_mutex_);
842 // no need to init registry here (unlike `lookupInSymbolSet`, as
843 // `parser_skip_set_` is not initialized via initialization
844 bool ret = parser_skip_set_.count(symbol) != 0;
845 if (flip) {
846 if (ret) {
847 parser_skip_set_.erase(symbol);
848 } else {
849 parser_skip_set_.insert(symbol);
850 }
851 }
852 return ret;
853 }
854
855 static void initRegistry() {
856 c10::call_once(once_flag_, []() {
857 std::lock_guard<std::mutex> lock(parser_mutex_);
858 registerJitOperator();
859 });
860 }
861
862 static bool canParseNode(const Node* node) {
863 initRegistry();
864
865 // match signature.
866 auto schema_ptr = node->maybeSchema();
867 if (schema_ptr == nullptr) {
868 return false;
869 }
870 auto reg_entry = lookupInRegistry(node);
871 return reg_entry != nullptr && reg_entry->isCompatible(node);
872 }
873
874 static bool isReductionToSizeNode(const Node* node) {
875 initRegistry();
876
877 auto reg_entry = lookupInRegistry(node);
878 return reg_entry != nullptr &&
879 reg_entry->isType(node, OperatorType::ReductionToSize);
880 }
881
882 static bool isReductionNode(const Node* node) {
883 initRegistry();
884
885 auto reg_entry = lookupInRegistry(node);
886 return reg_entry != nullptr &&
887 (reg_entry->isType(node, OperatorType::Reduction) ||
888 reg_entry->isType(node, OperatorType::ReductionToSize));
889 }
890
891 static bool isNormalizationNode(const Node* node) {
892 initRegistry();
893
894 auto reg_entry = lookupInRegistry(node);
895 return reg_entry != nullptr &&
896 reg_entry->isType(node, OperatorType::Normalization);
897 }
898
899 static bool isElementWiseNode(const Node* node) {
900 initRegistry();
901
902 auto reg_entry = lookupInRegistry(node);
903 return reg_entry != nullptr &&
904 reg_entry->isType(node, OperatorType::ElementWise);
905 }
906
907 // TODO: is_reduction is too hacky here. we should categorize operation types
908 // based on their memory accessing pattern, which would affect fusion
909 // strategy and partition logic.
910 static void registerParseRule(
911 std::shared_ptr<Operator>& op,
912 ParseFuncPtr parse_fn,
913 MergeQueryFuncPtr merge_query_fn = nullptr,
914 OperatorTypeFuncPtr type_fn = nullptr) {
915 auto op_name = op->schema().name();
916 parser_symbol_set_.insert(c10::Symbol::fromQualString(op_name));
917 // We blindly attempt to profile the inplace version of supported op, this
918 // is to ensure that in-place removal in fusion partition would have the
919 // profile information for them readily available after the pass.
920 parser_symbol_set_.insert(c10::Symbol::fromQualString(op_name + '_'));
921 jit_operator_registry_.emplace(
922 std::piecewise_construct,
923 std::forward_as_tuple(canonicalSchemaString(op->schema())),
924 std::forward_as_tuple(parse_fn, merge_query_fn, type_fn));
925 }
926
927 private:
928 static void registerJitOperator() {
929 // Register parse-function for each JIT operator;
930 // This is a one-time look up, our hash registry indexes on the pointer in
931 // OperatorRegistry.
932
933 std::array<const char*, kNumBinaryOpsWithAlpha> BinaryOpWithAlpha = {
934 "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
935 "aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor",
936 "aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
937 "aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor",
938 "aten::rsub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
939 "aten::rsub(Tensor self, Scalar other, Scalar alpha) -> Tensor"};
940 for (auto signature : BinaryOpWithAlpha) {
941 auto ptr_op = getOperatorForLiteral(signature);
942 REGISTER_PARSE_RULE(
943 ptr_op,
944 {
945 using BinaryOpWithAlphaType = Val* (*)(Val*, Val*, Val*);
946 static std::unordered_map<
947 Symbol,
948 std::pair<BinaryOpType, BinaryOpWithAlphaType>>
949 op_mapping(
950 {{aten::add,
951 std::make_pair(
952 BinaryOpType::Add,
953 static_cast<BinaryOpWithAlphaType>(&add_alpha))},
954 {aten::sub,
955 std::make_pair(
956 BinaryOpType::Sub,
957 static_cast<BinaryOpWithAlphaType>(&sub_alpha))},
958 {aten::rsub,
959 std::make_pair(
960 BinaryOpType::Sub,
961 static_cast<BinaryOpWithAlphaType>(&sub_alpha))}});
962 // TODO: handle scaling factor when it's not constant 1;
963 MemoryFormat format;
964 std::list<Val*> list_val;
965 std::tie(format, list_val) = getPWFormatValues(
966 c10::nullopt,
967 value_map[node->inputs()[0]->unique()],
968 value_map[node->inputs()[1]->unique()]);
969 auto lhs = list_val.front();
970 list_val.pop_front();
971 auto rhs = list_val.front();
972 list_val.pop_front();
973 Val* alpha = value_map[node->inputs()[2]->unique()];
974
975 auto out = alpha->isOneInt()
976 ? binaryOp(
977 op_mapping[node->kind()].first,
978 node->kind() == aten::rsub ? rhs : lhs,
979 node->kind() == aten::rsub ? lhs : rhs,
980 TypePromotion::default_op_config)
981 : (node->kind() == aten::rsub
982 ? op_mapping[node->kind()].second(rhs, lhs, alpha)
983 : op_mapping[node->kind()].second(lhs, rhs, alpha));
984 value_map.emplace(
985 node->output()->unique(), ValueHolder(out, format));
986 },
987 isInputNonSizeZeroTensor,
988 nullptr);
989 }
990
991 std::array<const char*, kNumBinaryFloatOps> BinaryFloatOp = {
992 "aten::div(Tensor self, Tensor other) -> Tensor",
993 "aten::div(Tensor self, Scalar other) -> Tensor",
994 "aten::atan2(Tensor self, Tensor other) -> Tensor"};
995 for (auto signature : BinaryFloatOp) {
996 auto ptr_op = getOperatorForLiteral(signature);
997 REGISTER_PARSE_RULE(
998 ptr_op,
999 {
1000 static std::unordered_map<Symbol, BinaryOpType> op_mapping(
1001 {{aten::div, BinaryOpType::Div},
1002 {aten::atan2, BinaryOpType::Atan2}});
1003
1004 MemoryFormat format;
1005 std::list<Val*> list_val;
1006 std::tie(format, list_val) = getPWFormatValues(
1007 c10::nullopt,
1008 value_map[node->inputs()[0]->unique()],
1009 value_map[node->inputs()[1]->unique()]);
1010 auto lhs = list_val.front();
1011 list_val.pop_front();
1012 auto rhs = list_val.front();
1013 list_val.pop_front();
1014
1015 auto out = binaryOp(
1016 op_mapping[node->kind()],
1017 lhs,
1018 rhs,
1019 TypePromotion::float_op_config);
1020 value_map.emplace(
1021 node->output()->unique(), ValueHolder(out, format));
1022 },
1023 isInputNonSizeZeroTensor,
1024 nullptr);
1025 }
1026
1027 std::array<const char*, kNumBinaryCastOps> BinaryCastOp = {
1028 "aten::mul(Tensor self, Tensor other) -> Tensor",
1029 "aten::mul(Tensor self, Scalar other) -> Tensor",
1030 "aten::max(Tensor self, Tensor other) -> Tensor",
1031 "aten::min(Tensor self, Tensor other) -> Tensor",
1032 "aten::pow(Tensor self, Tensor exponent) -> Tensor",
1033 "aten::pow(Tensor self, Scalar exponent) -> Tensor",
1034 "aten::pow(Scalar self, Tensor exponent) -> Tensor",
1035 "aten::remainder(Tensor self, Tensor other) -> Tensor",
1036 "aten::fmod(Tensor self, Tensor other) -> Tensor",
1037 "aten::bitwise_and(Tensor self, Tensor other) -> Tensor",
1038 "aten::__and__(Tensor self, Tensor other) -> Tensor",
1039 "aten::bitwise_or(Tensor self, Tensor other) -> Tensor",
1040 "aten::__or__(Tensor self, Tensor other) -> Tensor",
1041 "aten::bitwise_xor(Tensor self, Tensor other) -> Tensor",
1042 "aten::__xor__(Tensor self, Tensor other) -> Tensor",
1043 "aten::bitwise_left_shift(Tensor self, Tensor other) -> Tensor",
1044 "aten::__lshift__(Tensor self, Tensor other) -> Tensor",
1045 "aten::bitwise_right_shift(Tensor self, Tensor other) -> Tensor",
1046 "aten::__rshift__(Tensor self, Tensor other) -> Tensor"};
1047 for (auto signature : BinaryCastOp) {
1048 auto ptr_op = getOperatorForLiteral(signature);
1049 REGISTER_PARSE_RULE(
1050 ptr_op,
1051 {
1052 static std::unordered_map<Symbol, BinaryOpType> op_mapping(
1053 {{aten::mul, BinaryOpType::Mul},
1054 {aten::min, BinaryOpType::Min},
1055 {aten::max, BinaryOpType::Max},
1056 {aten::pow, BinaryOpType::Pow},
1057 {aten::remainder, BinaryOpType::Remainder},
1058 {aten::fmod, BinaryOpType::Fmod},
1059 {aten::bitwise_and, BinaryOpType::And},
1060 {aten::__and__, BinaryOpType::And},
1061 {aten::bitwise_or, BinaryOpType::Or},
1062 {aten::__or__, BinaryOpType::Or},
1063 {aten::bitwise_xor, BinaryOpType::Xor},
1064 {aten::__xor__, BinaryOpType::Xor},
1065 {aten::bitwise_left_shift, BinaryOpType::Lshift},
1066 {aten::__lshift__, BinaryOpType::Lshift},
1067 {aten::bitwise_right_shift, BinaryOpType::Rshift},
1068 {aten::__rshift__, BinaryOpType::Rshift}});
1069
1070 MemoryFormat format;
1071 std::list<Val*> list_val;
1072 std::tie(format, list_val) = getPWFormatValues(
1073 c10::nullopt,
1074 value_map[node->inputs()[0]->unique()],
1075 value_map[node->inputs()[1]->unique()]);
1076 auto lhs = list_val.front();
1077 list_val.pop_front();
1078 auto rhs = list_val.front();
1079 list_val.pop_front();
1080
1081 auto out = binaryOp(
1082 op_mapping[node->kind()],
1083 lhs,
1084 rhs,
1085 TypePromotion::default_op_config);
1086 value_map.emplace(
1087 node->output()->unique(), ValueHolder(out, format));
1088 },
1089 isInputNonSizeZeroTensor,
1090 nullptr);
1091 }
1092
1093 std::array<const char*, kNumBinaryComparisonOps> BinaryOp = {
1094 "aten::eq(Tensor self, Tensor other) -> Tensor",
1095 "aten::eq(Tensor self, Scalar other) -> Tensor",
1096 "aten::ne(Tensor self, Tensor other) -> Tensor",
1097 "aten::ne(Tensor self, Scalar other) -> Tensor",
1098 "aten::ge(Tensor self, Tensor other) -> Tensor",
1099 "aten::ge(Tensor self, Scalar other) -> Tensor",
1100 "aten::gt(Tensor self, Tensor other) -> Tensor",
1101 "aten::gt(Tensor self, Scalar other) -> Tensor",
1102 "aten::le(Tensor self, Tensor other) -> Tensor",
1103 "aten::le(Tensor self, Scalar other) -> Tensor",
1104 "aten::lt(Tensor self, Tensor other) -> Tensor",
1105 "aten::lt(Tensor self, Scalar other) -> Tensor"};
1106 for (auto signature : BinaryOp) {
1107 auto ptr_op = getOperatorForLiteral(signature);
1108 REGISTER_PARSE_RULE(
1109 ptr_op,
1110 {
1111 static std::unordered_map<Symbol, BinaryOpType> op_mapping(
1112 {{aten::lt, BinaryOpType::LT},
1113 {aten::le, BinaryOpType::LE},
1114 {aten::gt, BinaryOpType::GT},
1115 {aten::ge, BinaryOpType::GE},
1116 {aten::ne, BinaryOpType::NE},
1117 {aten::eq, BinaryOpType::Eq}});
1118
1119 MemoryFormat format;
1120 std::list<Val*> list_val;
1121 std::tie(format, list_val) = getPWFormatValues(
1122 c10::nullopt,
1123 value_map[node->inputs()[0]->unique()],
1124 value_map[node->inputs()[1]->unique()]);
1125 auto lhs = list_val.front();
1126 list_val.pop_front();
1127 auto rhs = list_val.front();
1128 list_val.pop_front();
1129
1130 auto out = binaryOp(
1131 op_mapping[node->kind()],
1132 lhs,
1133 rhs,
1134 TypePromotion::comparison_op_config);
1135 value_map.emplace(
1136 node->output()->unique(), ValueHolder(out, format));
1137 },
1138 isInputNonSizeZeroTensor,
1139 nullptr);
1140 }
1141
1142 std::array<const char*, kNumUnaryOps> UnaryOp = {
1143 "aten::abs(Tensor self) -> Tensor",
1144 "aten::bitwise_not(Tensor self) -> Tensor",
1145 "aten::ceil(Tensor self) -> Tensor",
1146 "aten::floor(Tensor self) -> Tensor",
1147 "aten::frac(Tensor self) -> Tensor",
1148 "aten::neg(Tensor self) -> Tensor",
1149 "aten::relu(Tensor self) -> Tensor",
1150 "aten::round(Tensor self) -> Tensor",
1151 "aten::silu(Tensor self) -> Tensor",
1152 "aten::trunc(Tensor self) -> Tensor",
1153 };
1154 for (auto signature : UnaryOp) {
1155 auto ptr_op = getOperatorForLiteral(signature);
1156 REGISTER_PARSE_RULE(
1157 ptr_op,
1158 {
1159 static std::unordered_map<Symbol, UnaryOpType> op_mapping({
1160 {aten::abs, UnaryOpType::Abs},
1161 {aten::bitwise_not, UnaryOpType::Not},
1162 {aten::ceil, UnaryOpType::Ceil},
1163 {aten::floor, UnaryOpType::Floor},
1164 {aten::frac, UnaryOpType::Frac},
1165 {aten::neg, UnaryOpType::Neg},
1166 {aten::relu, UnaryOpType::Relu},
1167 {aten::round, UnaryOpType::Round},
1168 {aten::silu, UnaryOpType::Silu},
1169 {aten::trunc, UnaryOpType::Trunc},
1170 });
1171 MemoryFormat format;
1172 std::list<Val*> list_val;
1173 std::tie(format, list_val) = getConsistentValues(
1174 c10::nullopt, value_map[node->inputs()[0]->unique()]);
1175 auto operand = list_val.front();
1176 list_val.pop_front();
1177 auto out = unaryOp(op_mapping[node->kind()], operand);
1178 value_map.emplace(
1179 node->output()->unique(), ValueHolder(out, format));
1180 },
1181 isInputNonSizeZeroTensor,
1182 nullptr);
1183 }
1184
1185 std::array<const char*, kNumUnaryFloatOps> UnaryFloatOp = {
1186 "aten::log(Tensor self) -> Tensor",
1187 "aten::log10(Tensor self) -> Tensor",
1188 "aten::log1p(Tensor self) -> Tensor",
1189 "aten::log2(Tensor self) -> Tensor",
1190 "aten::lgamma(Tensor self) -> Tensor",
1191 "aten::exp(Tensor self) -> Tensor",
1192 "aten::expm1(Tensor self) -> Tensor",
1193 "aten::erf(Tensor self) -> Tensor",
1194 "aten::erfc(Tensor self) -> Tensor",
1195 "aten::cos(Tensor self) -> Tensor",
1196 "aten::acos(Tensor self) -> Tensor",
1197 "aten::cosh(Tensor self) -> Tensor",
1198 "aten::sin(Tensor self) -> Tensor",
1199 "aten::asin(Tensor self) -> Tensor",
1200 "aten::sinh(Tensor self) -> Tensor",
1201 "aten::tan(Tensor self) -> Tensor",
1202 "aten::atan(Tensor self) -> Tensor",
1203 "aten::tanh(Tensor self) -> Tensor",
1204 "aten::atanh(Tensor self) -> Tensor",
1205 "aten::sqrt(Tensor self) -> Tensor",
1206 "aten::rsqrt(Tensor self) -> Tensor",
1207 "aten::reciprocal(Tensor self) -> Tensor",
1208 "aten::sigmoid(Tensor self) -> Tensor"};
1209 for (auto signature : UnaryFloatOp) {
1210 auto ptr_op = getOperatorForLiteral(signature);
1211 REGISTER_PARSE_RULE(
1212 ptr_op,
1213 {
1214 static std::unordered_map<Symbol, UnaryOpType> op_mapping({
1215 {aten::log, UnaryOpType::Log},
1216 {aten::log10, UnaryOpType::Log10},
1217 {aten::log1p, UnaryOpType::Log1p},
1218 {aten::log2, UnaryOpType::Log2},
1219 {aten::lgamma, UnaryOpType::Lgamma},
1220 {aten::exp, UnaryOpType::Exp},
1221 {aten::expm1, UnaryOpType::Expm1},
1222 {aten::erf, UnaryOpType::Erf},
1223 {aten::erfc, UnaryOpType::Erfc},
1224 {aten::cos, UnaryOpType::Cos},
1225 {aten::acos, UnaryOpType::Acos},
1226 {aten::cosh, UnaryOpType::Cosh},
1227 {aten::sin, UnaryOpType::Sin},
1228 {aten::asin, UnaryOpType::Asin},
1229 {aten::sinh, UnaryOpType::Sinh},
1230 {aten::tan, UnaryOpType::Tan},
1231 {aten::tanh, UnaryOpType::Tanh},
1232 {aten::atan, UnaryOpType::Atan},
1233 {aten::atanh, UnaryOpType::Atanh},
1234 {aten::sqrt, UnaryOpType::Sqrt},
1235 {aten::rsqrt, UnaryOpType::Rsqrt},
1236 {aten::reciprocal, UnaryOpType::Reciprocal},
1237 {aten::sigmoid, UnaryOpType::Sigmoid},
1238 });
1239 MemoryFormat format;
1240 std::list<Val*> list_val;
1241 std::tie(format, list_val) = getConsistentValues(
1242 c10::nullopt, value_map[node->inputs()[0]->unique()]);
1243 auto operand = list_val.front();
1244 list_val.pop_front();
1245 auto out = unaryOp(
1246 op_mapping[node->kind()],
1247 operand,
1248 TypePromotion::float_op_config);
1249 value_map.emplace(
1250 node->output()->unique(), ValueHolder(out, format));
1251 },
1252 isInputNonSizeZeroTensor,
1253 nullptr);
1254 }
1255
1256 std::array<const char*, kNumUnaryIsOps> UnaryIsOp = {
1257 "aten::isfinite(Tensor self) -> Tensor",
1258 "aten::isinf(Tensor self) -> Tensor",
1259 "aten::isnan(Tensor self) -> Tensor",
1260 "aten::isneginf(Tensor self) -> Tensor",
1261 "aten::isposinf(Tensor self) -> Tensor",
1262 "aten::isreal(Tensor self) -> Tensor"};
1263 for (auto signature : UnaryIsOp) {
1264 auto ptr_op = getOperatorForLiteral(signature);
1265 REGISTER_PARSE_RULE(
1266 ptr_op,
1267 {
1268 static std::unordered_map<Symbol, UnaryOpType> op_mapping({
1269 {aten::isfinite, UnaryOpType::IsFinite},
1270 {aten::isinf, UnaryOpType::IsInf},
1271 {aten::isnan, UnaryOpType::IsNan},
1272 {aten::isneginf, UnaryOpType::IsNegInf},
1273 {aten::isposinf, UnaryOpType::IsPosInf},
1274 {aten::isreal, UnaryOpType::IsReal},
1275 });
1276 MemoryFormat format;
1277 std::list<Val*> list_val;
1278 std::tie(format, list_val) = getConsistentValues(
1279 c10::nullopt, value_map[node->inputs()[0]->unique()]);
1280 auto operand = list_val.front();
1281 list_val.pop_front();
1282 auto out = unaryIsOp(op_mapping[node->kind()], operand);
1283 value_map.emplace(
1284 node->output()->unique(), ValueHolder(out, format));
1285 },
1286 isInputNonSizeZeroTensor,
1287 nullptr);
1288 }
1289
1290 {
1291 auto ptr_op = getOperatorForLiteral(
1292 "aten::rand_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor");
1293 REGISTER_PARSE_RULE(
1294 ptr_op,
1295 {
1296 MemoryFormat format;
1297 std::list<Val*> list_val;
1298 std::tie(format, list_val) = getConsistentValues(
1299 c10::nullopt, value_map[node->inputs()[0]->unique()]);
1300 auto operand = list_val.front();
1301 list_val.pop_front();
1302
1303 if (!node->input(3)->type()->isSubtypeOf(
1304 static_cast<c10::TypePtr>(NoneType::get()))) {
1305 auto device = constant_as<c10::Device>(node->input(3));
1306 TORCH_INTERNAL_ASSERT(
1307 device.has_value() && device->is_cuda(),
1308 "rand_like in nvfuser is not on cuda device");
1309 auto input_tensor_type =
1310 node->input(0)->type()->cast<TensorType>();
1311 // device->index() == -1 indicating that we don't change device
1312 // index
1313 if (device->index() != -1 && input_tensor_type) {
1314 auto input_device = input_tensor_type->device();
1315 // we expect device index to be consistent with input and it
1316 // should have already been handled by partition
1317 TORCH_INTERNAL_ASSERT(
1318 !input_device.has_value() ||
1319 input_device->index() == device->index(),
1320 "rand_like in nvfuser is not on cuda device");
1321 }
1322 }
1323
1324 auto out = rand_like(operand);
1325 value_map.emplace(
1326 node->output()->unique(), ValueHolder(out, format));
1327 },
1328 [](const Node* node) -> bool {
1329 if (!isInputNonSizeZeroTensor(node)) {
1330 return false;
1331 }
1332 if (!node->input(1)->type()->isSubtypeOf(
1333 static_cast<c10::TypePtr>(NoneType::get())) ||
1334 !node->input(2)->type()->isSubtypeOf(
1335 static_cast<c10::TypePtr>(NoneType::get())) ||
1336 !node->input(5)->type()->isSubtypeOf(
1337 static_cast<c10::TypePtr>(NoneType::get()))) {
1338 return false;
1339 }
1340 return true;
1341 },
1342 nullptr);
1343 }
1344
1345 {
1346 auto ptr_op = getOperatorForLiteral(
1347 "aten::softplus(Tensor self, Scalar beta, Scalar threshold) -> Tensor");
1348 REGISTER_PARSE_RULE(
1349 ptr_op,
1350 {
1351 MemoryFormat format;
1352 std::list<Val*> list_val;
1353 std::tie(format, list_val) = getConsistentValues(
1354 c10::nullopt, value_map[node->inputs()[0]->unique()]);
1355 auto operand = list_val.front()->as<TensorView>();
1356 list_val.pop_front();
1357 auto& beta = value_map[node->inputs()[1]->unique()];
1358 auto& threshold = value_map[node->inputs()[2]->unique()];
1359 auto out = softplus(operand, beta, threshold);
1360 value_map.emplace(
1361 node->output()->unique(), ValueHolder(out, format));
1362 },
1363 isInputNonSizeZeroTensor,
1364 nullptr);
1365 }
1366
1367 {
1368 auto ptr_op = getOperatorForLiteral(
1369 "aten::threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor");
1370 REGISTER_PARSE_RULE(
1371 ptr_op,
1372 {
1373 MemoryFormat format;
1374 std::list<Val*> list_val;
1375 std::tie(format, list_val) = getConsistentValues(
1376 c10::nullopt, value_map[node->inputs()[0]->unique()]);
1377 auto operand = list_val.front();
1378 list_val.pop_front();
1379 auto& th = value_map[node->inputs()[1]->unique()];
1380 auto& value = value_map[node->inputs()[2]->unique()];
1381
1382 auto out = threshold(operand, th, value);
1383 value_map.emplace(
1384 node->output()->unique(), ValueHolder(out, format));
1385 },
1386 isInputNonSizeZeroTensor,
1387 nullptr);
1388 }
1389
1390 { // LTC uses threshold_backward for relu_backward
1391 auto ptr_op = getOperatorForLiteral(
1392 "aten::threshold_backward(Tensor grad_output, Tensor self, Scalar threshold) -> Tensor");
1393 REGISTER_PARSE_RULE(
1394 ptr_op,
1395 {
1396 MemoryFormat format;
1397 std::list<Val*> list_val;
1398 std::tie(format, list_val) = getPWFormatValues(
1399 c10::nullopt,
1400 value_map[node->inputs()[0]->unique()],
1401 value_map[node->inputs()[1]->unique()]);
1402 auto grad_output = list_val.front();
1403 list_val.pop_front();
1404 auto input = list_val.front();
1405 auto& threshold = value_map[node->inputs()[2]->unique()];
1406
1407 auto comparison = binaryOp(
1408 BinaryOpType::GT,
1409 input,
1410 threshold,
1411 TypePromotion::comparison_op_config);
1412 auto mask = castOp(input->getDataType().value(), comparison);
1413 auto out = mul(grad_output, mask);
1414
1415 value_map.emplace(
1416 node->output()->unique(), ValueHolder(out, format));
1417 },
1418 isInputNonSizeZeroTensor,
1419 nullptr);
1420 }
1421
1422 {
1423 auto ptr_op = getOperatorForLiteral(
1424 "aten::clamp(Tensor self, Scalar? min, Scalar? max) -> Tensor");
1425 REGISTER_PARSE_RULE(
1426 ptr_op,
1427 {
1428 MemoryFormat format;
1429 std::list<Val*> list_val;
1430 std::tie(format, list_val) = getConsistentValues(
1431 c10::nullopt, value_map[node->inputs()[0]->unique()]);
1432 auto operand = list_val.front();
1433 list_val.pop_front();
1434 Val* min = value_map.count(node->inputs()[1]->unique()) != 0
1435 ? *value_map[node->inputs()[1]->unique()]
1436 : nullptr;
1437 Val* max = value_map.count(node->inputs()[2]->unique()) != 0
1438 ? *value_map[node->inputs()[2]->unique()]
1439 : nullptr;
1440
1441 Val* out = clamp(operand, min, max);
1442 value_map.emplace(
1443 node->output()->unique(), ValueHolder(out, format));
1444 },
1445 isInputNonSizeZeroTensor,
1446 nullptr);
1447 }
1448
1449 {
1450 auto ptr_op = getOperatorForLiteral(
1451 "aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor");
1452 REGISTER_PARSE_RULE(
1453 ptr_op,
1454 {
1455 MemoryFormat format;
1456 std::list<Val*> list_val;
1457 std::tie(format, list_val) = getPWFormatValues(
1458 c10::nullopt,
1459 value_map[node->inputs()[0]->unique()],
1460 value_map[node->inputs()[1]->unique()],
1461 value_map[node->inputs()[2]->unique()]);
1462 auto condition = list_val.front();
1463 list_val.pop_front();
1464 auto x = list_val.front();
1465 list_val.pop_front();
1466 auto y = list_val.front();
1467 list_val.pop_front();
1468
1469 auto out = where(condition, x, y);
1470 value_map.emplace(
1471 node->output()->unique(), ValueHolder(out, format));
1472 },
1473 isInputNonSizeZeroTensor,
1474 nullptr);
1475 }
1476
1477 {
1478 std::array<const char*, kNumLerpOps> LerpOp = {
1479 "aten::lerp(Tensor self, Tensor end, Scalar weight) -> Tensor",
1480 "aten::lerp(Tensor self, Tensor end, Tensor weight) -> Tensor"};
1481 for (auto signature : LerpOp) {
1482 auto ptr_op = getOperatorForLiteral(signature);
1483 REGISTER_PARSE_RULE(
1484 ptr_op,
1485 {
1486 MemoryFormat format;
1487 std::list<Val*> list_val;
1488 std::tie(format, list_val) = getPWFormatValues(
1489 c10::nullopt,
1490 value_map[node->inputs()[0]->unique()],
1491 value_map[node->inputs()[1]->unique()],
1492 value_map[node->inputs()[2]->unique()]);
1493 auto self = list_val.front();
1494 list_val.pop_front();
1495 auto end = list_val.front();
1496 list_val.pop_front();
1497 auto weight = list_val.front();
1498 list_val.pop_front();
1499
1500 auto out = lerp(self, end, weight);
1501 value_map.emplace(
1502 node->output()->unique(), ValueHolder(out, format));
1503 },
1504 isInputNonSizeZeroTensor,
1505 nullptr);
1506 }
1507 }
1508
1509 {
1510 auto ptr_op = getOperatorForLiteral(
1511 "aten::addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor");
1512 REGISTER_PARSE_RULE(
1513 ptr_op,
1514 {
1515 MemoryFormat format;
1516 std::list<Val*> list_val;
1517 std::tie(format, list_val) = getPWFormatValues(
1518 c10::nullopt,
1519 value_map[node->inputs()[0]->unique()],
1520 value_map[node->inputs()[1]->unique()],
1521 value_map[node->inputs()[2]->unique()],
1522 value_map[node->inputs()[3]->unique()]);
1523 auto self = list_val.front();
1524 list_val.pop_front();
1525 auto tensor1 = list_val.front();
1526 list_val.pop_front();
1527 auto tensor2 = list_val.front();
1528 list_val.pop_front();
1529 auto value = list_val.front();
1530 list_val.pop_front();
1531
1532 auto out = addcmul(self, tensor1, tensor2, value);
1533 value_map.emplace(
1534 node->output()->unique(), ValueHolder(out, format));
1535 },
1536 isInputNonSizeZeroTensor,
1537 nullptr);
1538 }
1539
1540 {
1541 auto ptr_op = getOperatorForLiteral(
1542 "aten::native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor)");
1543 REGISTER_PARSE_RULE(
1544 ptr_op,
1545 {
1546 MemoryFormat format;
1547 std::list<Val*> list_val;
1548 std::tie(format, list_val) = getConsistentValues(
1549 c10::nullopt,
1550 value_map[node->inputs()[0]->unique()],
1551 value_map[node->inputs()[1]->unique()]);
1552 auto input = list_val.front();
1553 list_val.pop_front();
1554 auto prob = list_val.front();
1555 list_val.pop_front();
1556 auto train = constant_as<bool>(node->input(2));
1557
1558 TORCH_INTERNAL_ASSERT(
1559 train.has_value(), "dropout needs constant `train` flag");
1560
1561 if (train.value()) {
1562 auto result = dropout(input->as<TensorView>(), prob);
1563
1564 value_map.emplace(
1565 node->output(0)->unique(),
1566 ValueHolder(result.output, format));
1567 value_map.emplace(
1568 node->output(1)->unique(), ValueHolder(result.mask, format));
1569 } else {
1570 value_map.emplace(node->output(0)->unique(), input);
1571 value_map.emplace(
1572 node->output(1)->unique(),
1573 ValueHolder(TensorViewBuilder().build(), format));
1574 }
1575 },
1576 [](const Node* node) -> bool {
1577 if (!isInputNonSizeZeroTensor(node)) {
1578 return false;
1579 }
1580 if (node->inputs()[2]->node()->kind() != prim::Constant) {
1581 return false;
1582 }
1583 return true;
1584 },
1585 nullptr);
1586 }
1587
1588 {
1589 auto ptr_op = getOperatorForLiteral(
1590 "aten::dropout(Tensor input, float p, bool train) -> Tensor");
1591 REGISTER_PARSE_RULE(
1592 ptr_op,
1593 {
1594 MemoryFormat format;
1595 std::list<Val*> list_val;
1596 std::tie(format, list_val) = getConsistentValues(
1597 c10::nullopt,
1598 value_map[node->inputs()[0]->unique()],
1599 value_map[node->inputs()[1]->unique()]);
1600 auto input = list_val.front();
1601 list_val.pop_front();
1602 auto prob = list_val.front();
1603 list_val.pop_front();
1604
1605 auto train = constant_as<bool>(node->input(2));
1606 TORCH_INTERNAL_ASSERT(
1607 train.has_value(), "dropout needs constant `train` flag");
1608
1609 if (train.value()) {
1610 auto result = dropout(input->as<TensorView>(), prob);
1611
1612 value_map.emplace(
1613 node->output()->unique(), ValueHolder(result.output, format));
1614 } else {
1615 value_map.emplace(
1616 node->output()->unique(), ValueHolder(input, format));
1617 }
1618 },
1619 [](const Node* node) -> bool {
1620 if (!isInputNonSizeZeroTensor(node)) {
1621 return false;
1622 }
1623 if (node->inputs()[2]->node()->kind() != prim::Constant) {
1624 return false;
1625 }
1626 return true;
1627 },
1628 nullptr);
1629 }
1630
1631 {
1632 auto ptr_op = getOperatorForLiteral(
1633 "aten::native_dropout_backward(Tensor grad_output, Tensor mask, float scale) -> Tensor");
1634 REGISTER_PARSE_RULE(
1635 ptr_op,
1636 {
1637 MemoryFormat format;
1638 std::list<Val*> list_val;
1639 std::tie(format, list_val) = getPWFormatValues(
1640 c10::nullopt,
1641 value_map[node->inputs()[0]->unique()],
1642 value_map[node->inputs()[1]->unique()],
1643 value_map[node->inputs()[2]->unique()]);
1644 auto grad = list_val.front();
1645 list_val.pop_front();
1646 auto mask = list_val.front();
1647 list_val.pop_front();
1648 auto scale = list_val.front();
1649 list_val.pop_front();
1650
1651 auto output = dropout_backward(
1652 grad->as<TensorView>(), mask->as<TensorView>(), scale);
1653 value_map.emplace(
1654 node->output()->unique(), ValueHolder(output, format));
1655 },
1656 isInputNonSizeZeroTensor,
1657 nullptr);
1658 }
1659
1660 {
1661 std::array<const char*, kNumInstancenormFwd> InstanceNormFwd = {
1662 "aten::instance_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool use_input_stats, float momentum, float eps, bool cudnn_enabled) -> Tensor"};
1663 for (auto signature : InstanceNormFwd) {
1664 auto ptr_op = getOperatorForLiteral(signature);
1665 REGISTER_PARSE_RULE(
1666 ptr_op,
1667 {
1668 // TODO: handle channels last
1669 MemoryFormat format;
1670 std::list<Val*> list_val;
1671 std::tie(format, list_val) = getConsistentValues(
1672 MemoryFormat::Contiguous(),
1673 value_map[node->inputs()[0]->unique()]);
1674 auto input_t = list_val.front();
1675 list_val.pop_front();
1676 auto input = input_t->as<TensorView>();
1677
1678 TensorView* weight = nullptr;
1679 if (!node->input(1)->type()->isSubtypeOf(
1680 static_cast<c10::TypePtr>(NoneType::get()))) {
1681 weight = value_map[node->input(1)->unique()]->as<TensorView>();
1682 }
1683
1684 TensorView* bias = nullptr;
1685 if (!node->input(2)->type()->isSubtypeOf(
1686 static_cast<c10::TypePtr>(NoneType::get()))) {
1687 bias = value_map[node->input(2)->unique()]->as<TensorView>();
1688 }
1689
1690 TensorView* running_mean = nullptr;
1691 if (!node->input(3)->type()->isSubtypeOf(
1692 static_cast<c10::TypePtr>(NoneType::get()))) {
1693 running_mean =
1694 value_map[node->input(3)->unique()]->as<TensorView>();
1695 }
1696
1697 TensorView* running_var = nullptr;
1698 if (!node->input(4)->type()->isSubtypeOf(
1699 static_cast<c10::TypePtr>(NoneType::get()))) {
1700 running_var =
1701 value_map[node->input(4)->unique()]->as<TensorView>();
1702 }
1703
1704 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
1705 auto use_input_stats = constant_as<bool>(node->input(5));
1706 TORCH_INTERNAL_ASSERT(
1707 use_input_stats.has_value(),
1708 "The use_input_stats (bool) parameter is required.");
1709 const bool kUseInputStats = use_input_stats.value();
1710
1711 Val* momentum_ptr = nullptr;
1712 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
1713 if (auto momentum = constant_as<float>(node->input(6))) {
1714 momentum_ptr = IrBuilder::create<Double>(momentum.value());
1715 } else {
1716 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
1717 momentum_ptr = value_map[node->input(6)->unique()];
1718 }
1719
1720 Val* eps_ptr = nullptr;
1721 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
1722 if (auto eps = constant_as<float>(node->input(7))) {
1723 eps_ptr = IrBuilder::create<Double>(eps.value());
1724 } else {
1725 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
1726 eps_ptr = value_map[node->input(7)->unique()];
1727 }
1728
1729 auto result = instance_norm(
1730 input,
1731 weight,
1732 bias,
1733 running_mean,
1734 running_var,
1735 kUseInputStats,
1736 momentum_ptr,
1737 eps_ptr);
1738
1739 if (node->kind() ==
1740 c10::Symbol::fromQualString("aten::instance_norm")) {
1741 value_map.emplace(node->output()->unique(), result.output);
1742 }
1743 },
1744 [](const Node* node) -> bool {
1745 if (isReductionNonCompatibleTensor(
1746 node->input(0)->type()->cast<TensorType>())) {
1747 return false;
1748 }
1749 return true;
1750 },
1751 [](const Node* node) -> OperatorType {
1752 return OperatorType::Normalization;
1753 });
1754 }
1755 }
1756
1757 {
1758 std::array<const char*, kNumBatchnormFwd> BatchNormFwd = {
1759 "aten::_batch_norm_impl_index(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor, Tensor, Tensor, Tensor, int)",
1760 "aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)",
1761 "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor"};
1762 for (auto signature : BatchNormFwd) {
1763 auto ptr_op = getOperatorForLiteral(signature);
1764 REGISTER_PARSE_RULE(
1765 ptr_op,
1766 {
1767 MemoryFormat format;
1768 Val* operand = nullptr;
1769 std::tie(format, operand) =
1770 value_map[node->input(0)->unique()].getEntry();
1771 if (format.hasPermutation() && !format.isChannelsLast()) {
1772 format = MemoryFormat::Contiguous();
1773 operand = value_map[node->input(0)->unique()].maybeConvertValue(
1774 format);
1775 }
1776 auto input = operand->as<TensorView>();
1777
1778 TensorView* weight = nullptr;
1779 if (!node->input(1)->type()->isSubtypeOf(
1780 static_cast<c10::TypePtr>(NoneType::get()))) {
1781 weight = value_map[node->input(1)->unique()]->as<TensorView>();
1782 }
1783
1784 TensorView* bias = nullptr;
1785 if (!node->input(2)->type()->isSubtypeOf(
1786 static_cast<c10::TypePtr>(NoneType::get()))) {
1787 bias = value_map[node->input(2)->unique()]->as<TensorView>();
1788 }
1789
1790 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
1791 auto training = constant_as<bool>(node->input(5));
1792 TORCH_INTERNAL_ASSERT(
1793 training.has_value(),
1794 "The training (bool) parameter is required.");
1795 const bool kTraining = training.value();
1796
1797 TensorView* running_mean = nullptr;
1798 if (!node->input(3)->type()->isSubtypeOf(
1799 static_cast<c10::TypePtr>(NoneType::get()))) {
1800 running_mean =
1801 value_map[node->input(3)->unique()]->as<TensorView>();
1802 }
1803
1804 TensorView* running_var = nullptr;
1805 if (!node->input(4)->type()->isSubtypeOf(
1806 static_cast<c10::TypePtr>(NoneType::get()))) {
1807 running_var =
1808 value_map[node->input(4)->unique()]->as<TensorView>();
1809 }
1810
1811 Val* momentum_ptr = nullptr;
1812 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
1813 if (auto momentum = constant_as<float>(node->input(6))) {
1814 momentum_ptr = IrBuilder::create<Double>(momentum.value());
1815 } else {
1816 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
1817 momentum_ptr = value_map[node->input(6)->unique()];
1818 }
1819
1820 Val* eps_ptr = nullptr;
1821 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
1822 if (auto eps = constant_as<float>(node->input(7))) {
1823 eps_ptr = IrBuilder::create<Double>(eps.value());
1824 } else {
1825 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
1826 eps_ptr = value_map[node->input(7)->unique()];
1827 }
1828
1829 auto result = batch_norm(
1830 input,
1831 weight,
1832 bias,
1833 running_mean,
1834 running_var,
1835 kTraining,
1836 momentum_ptr,
1837 eps_ptr,
1838 format.isChannelsLast());
1839
1840 if (node->kind() ==
1841 c10::Symbol::fromQualString("aten::native_batch_norm") ||
1842 node->kind() ==
1843 c10::Symbol::fromQualString(
1844 "aten::_batch_norm_impl_index")) {
1845 // TODO: output 3 & 4 are not created
1846 // we are not creating these outputs because codegen
1847 // currently lacks the support.
1848 value_map.emplace(
1849 node->output(0)->unique(),
1850 ValueHolder(result.output, format));
1851 value_map.emplace(node->output(1)->unique(), result.mean);
1852 value_map.emplace(node->output(2)->unique(), result.invstd);
1853 } else if (
1854 node->kind() ==
1855 c10::Symbol::fromQualString("aten::batch_norm")) {
1856 value_map.emplace(
1857 node->output()->unique(),
1858 ValueHolder(result.output, format));
1859 }
1860 },
1861 [](const Node* node) -> bool {
1862 if (isReductionNonCompatibleTensor(
1863 node->input(0)->type()->cast<TensorType>())) {
1864 return false;
1865 }
1866 if (node->input(5)->node()->kind() != prim::Constant) {
1867 return false;
1868 }
1869 return true;
1870 },
1871 [](const Node* node) -> OperatorType {
1872 return OperatorType::Normalization;
1873 });
1874 }
1875 }
1876
1877 {
1878 std::array<const char*, kNumBatchnormBwd> BatchNormBwd = {
1879 "aten::_batch_norm_impl_index_backward(int impl_index, Tensor input, Tensor grad_output, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var_transform, bool train, float eps, bool[3] output_mask, Tensor reservedSpace) -> (Tensor, Tensor, Tensor)",
1880 "aten::native_batch_norm_backward(Tensor grad_out, Tensor input, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_invstd, bool train, float eps, bool[3] output_mask) -> (Tensor, Tensor, Tensor)"};
1881 for (auto signature : BatchNormBwd) {
1882 auto ptr_op = getOperatorForLiteral(signature);
1883 REGISTER_PARSE_RULE(
1884 ptr_op,
1885 {
1886 JitValue* ts_input = nullptr;
1887 JitValue* ts_grad_output;
1888 JitValue* ts_weight = nullptr;
1889 JitValue* ts_r_mean = nullptr;
1890 JitValue* ts_r_var = nullptr;
1891 JitValue* ts_save_mean = nullptr;
1892 JitValue* ts_save_invstd = nullptr;
1893 JitValue* ts_train = nullptr;
1894 JitValue* ts_eps = nullptr;
1895 JitValue* ts_mask = nullptr;
1896 if (node->kind() ==
1897 c10::Symbol::fromQualString(
1898 "aten::_batch_norm_impl_index_backward")) {
1899 ts_input = node->input(1);
1900 ts_grad_output = node->input(2);
1901 ts_weight = node->input(3);
1902 ts_r_mean = node->input(4);
1903 ts_r_var = node->input(5);
1904 ts_save_mean = node->input(6);
1905 ts_save_invstd = node->input(7);
1906 ts_train = node->input(8);
1907 ts_eps = node->input(9);
1908 ts_mask = node->input(10);
1909 } else if (
1910 node->kind() ==
1911 c10::Symbol::fromQualString(
1912 "aten::native_batch_norm_backward")) {
1913 ts_grad_output = node->input(0);
1914 ts_input = node->input(1);
1915 ts_weight = node->input(2);
1916 ts_r_mean = node->input(3);
1917 ts_r_var = node->input(4);
1918 ts_save_mean = node->input(5);
1919 ts_save_invstd = node->input(6);
1920 ts_train = node->input(7);
1921 ts_eps = node->input(8);
1922 ts_mask = node->input(9);
1923 } else {
1924 TORCH_INTERNAL_ASSERT(
1925 false,
1926 "Forgot to register the key for BN variation: ",
1927 node->kind().toDisplayString());
1928 }
1929
1930 // discard impl_index and reservedSpace since we don't use them
1931 MemoryFormat format;
1932 std::list<Val*> list_val;
1933 std::tie(format, list_val) = getConsistentValues(
1934 c10::nullopt,
1935 value_map[ts_input->unique()],
1936 value_map[ts_grad_output->unique()]);
1937 if (format.hasPermutation() && !format.isChannelsLast()) {
1938 std::tie(format, list_val) = getConsistentValues(
1939 MemoryFormat::Contiguous(),
1940 value_map[ts_input->unique()],
1941 value_map[ts_grad_output->unique()]);
1942 }
1943 auto operand0 = list_val.front();
1944 list_val.pop_front();
1945 auto operand1 = list_val.front();
1946 list_val.pop_front();
1947 auto input = operand0->as<TensorView>();
1948 auto grad_out = operand1->as<TensorView>();
1949
1950 TensorView* weight = nullptr;
1951 if (!ts_weight->type()->isSubtypeOf(
1952 static_cast<c10::TypePtr>(NoneType::get()))) {
1953 weight = value_map[ts_weight->unique()]->as<TensorView>();
1954 }
1955
1956 TensorView* running_mean = nullptr;
1957 if (!ts_r_mean->type()->isSubtypeOf(
1958 static_cast<c10::TypePtr>(NoneType::get()))) {
1959 running_mean = value_map[ts_r_mean->unique()]->as<TensorView>();
1960 }
1961
1962 TensorView* running_var = nullptr;
1963 if (!ts_r_var->type()->isSubtypeOf(
1964 static_cast<c10::TypePtr>(NoneType::get()))) {
1965 running_var = value_map[ts_r_var->unique()]->as<TensorView>();
1966 }
1967
1968 TensorView* save_mean = nullptr;
1969 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
1970 if (!ts_save_mean->type()->isSubtypeOf(
1971 static_cast<c10::TypePtr>(NoneType::get()))) {
1972 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
1973 save_mean = value_map[ts_save_mean->unique()]->as<TensorView>();
1974 }
1975
1976 TensorView* save_invstd = nullptr;
1977 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
1978 if (!ts_save_invstd->type()->isSubtypeOf(
1979 static_cast<c10::TypePtr>(NoneType::get()))) {
1980 save_invstd =
1981 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
1982 value_map[ts_save_invstd->unique()]->as<TensorView>();
1983 }
1984
1985 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
1986 auto training = constant_as<bool>(ts_train);
1987 TORCH_INTERNAL_ASSERT(
1988 training.has_value(),
1989 "The training (bool) parameter is required.");
1990 const bool kTraining = training.value();
1991
1992 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
1993 Val* eps_ptr = nullptr;
1994 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
1995 if (auto eps = constant_as<float>(ts_eps)) {
1996 eps_ptr = IrBuilder::create<Double>(eps.value());
1997 } else {
1998 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
1999 eps_ptr = value_map[ts_eps->unique()];
2000 }
2001
2002 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
2003 auto out_mask_list = constant_as<c10::List<bool>>(ts_mask);
2004 TORCH_INTERNAL_ASSERT(
2005 out_mask_list.has_value(),
2006 "output mask for batch_norm_backward");
2007 std::vector<bool> output_mask;
2008 for (const auto value : out_mask_list->vec()) {
2009 output_mask.emplace_back(static_cast<bool>(value));
2010 }
2011
2012 // TODO: merge this loop below.
2013 if (kTraining) {
2014 TORCH_INTERNAL_ASSERT(
2015 save_mean != nullptr && save_invstd != nullptr,
2016 "When training=True, save_mean and save_invstd are required.");
2017 } else {
2018 // TODO: this is not a legit assumption? Can't we run with
2019 // track_running_stats == false && training == false
2020 // which should just run through the case above.
2021 TORCH_INTERNAL_ASSERT(
2022 running_mean != nullptr && running_var != nullptr,
2023 "When training=False, running_mean and running_invstd are required.");
2024 }
2025
2026 auto grads = batch_norm_backward(
2027 input,
2028 grad_out,
2029 weight,
2030 running_mean,
2031 running_var,
2032 save_mean,
2033 save_invstd,
2034 kTraining,
2035 eps_ptr,
2036 output_mask,
2037 format.isChannelsLast());
2038
2039 if (output_mask[0]) {
2040 TORCH_INTERNAL_ASSERT(grads.grad_input != nullptr);
2041 value_map.emplace(
2042 node->output(0)->unique(),
2043 ValueHolder(grads.grad_input, format));
2044 } else {
2045 TORCH_INTERNAL_ASSERT(grads.grad_input == nullptr);
2046 value_map.emplace(
2047 node->output(0)->unique(),
2048 ValueHolder(TensorViewBuilder().build(), format));
2049 }
2050
2051 if (output_mask[1]) {
2052 TORCH_INTERNAL_ASSERT(grads.grad_weight != nullptr);
2053 value_map.emplace(node->output(1)->unique(), grads.grad_weight);
2054 } else {
2055 TORCH_INTERNAL_ASSERT(grads.grad_weight == nullptr);
2056 value_map.emplace(
2057 node->output(1)->unique(), TensorViewBuilder().build());
2058 }
2059
2060 if (output_mask[2]) {
2061 TORCH_INTERNAL_ASSERT(grads.grad_bias != nullptr);
2062 value_map.emplace(node->output(2)->unique(), grads.grad_bias);
2063 } else {
2064 TORCH_INTERNAL_ASSERT(grads.grad_bias == nullptr);
2065 value_map.emplace(
2066 node->output(2)->unique(), TensorViewBuilder().build());
2067 }
2068 },
2069 [](const Node* node) -> bool {
2070 if (isReductionNonCompatibleTensor(
2071 node->input(1)->type()->cast<TensorType>())) {
2072 return false;
2073 }
2074 if (node->kind() ==
2075 c10::Symbol::fromQualString(
2076 "aten::_batch_norm_impl_index_backward")) {
2077 if (node->inputs()[8]->node()->kind() != prim::Constant) {
2078 return false;
2079 }
2080 if (node->inputs()[10]->node()->kind() != prim::Constant) {
2081 return false;
2082 }
2083 } else if (
2084 node->kind() ==
2085 c10::Symbol::fromQualString(
2086 "aten::native_batch_norm_backward")) {
2087 if (node->inputs()[7]->node()->kind() != prim::Constant) {
2088 return false;
2089 }
2090 if (node->inputs()[9]->node()->kind() != prim::Constant) {
2091 return false;
2092 }
2093 } else {
2094 TORCH_INTERNAL_ASSERT(
2095 false,
2096 "Forgot to update profiled constant check for",
2097 node->kind().toDisplayString());
2098 }
2099 return true;
2100 },
2101 [](const Node* node) -> OperatorType {
2102 return OperatorType::Normalization;
2103 });
2104 }
2105 }
2106
2107 {
2108 std::array<const char*, kNumLayernormFwd> LayerNormFwd = {
2109 "aten::native_layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor)",
2110 "aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor"};
2111 for (auto signature : LayerNormFwd) {
2112 auto ptr_op = getOperatorForLiteral(signature);
2113 REGISTER_PARSE_RULE(
2114 ptr_op,
2115 {
2116 MemoryFormat format;
2117 std::list<Val*> list_val;
2118 std::tie(format, list_val) = getConsistentValues(
2119 MemoryFormat::Contiguous(),
2120 value_map[node->inputs()[0]->unique()]);
2121 auto input_t = list_val.front();
2122 list_val.pop_front();
2123 auto input = input_t->as<TensorView>();
2124
2125 auto norm_shape_optional =
2126 constant_as<c10::List<int64_t>>(node->input(1));
2127 TORCH_INTERNAL_ASSERT(
2128 norm_shape_optional.has_value(),
2129 "The Normalized_Shape list is required.");
2130 auto norm_shape = norm_shape_optional->vec();
2131
2132 TensorView* weight = nullptr;
2133 if (!node->input(2)->type()->isSubtypeOf(
2134 static_cast<c10::TypePtr>(NoneType::get()))) {
2135 weight = value_map[node->input(2)->unique()]->as<TensorView>();
2136 }
2137
2138 TensorView* bias = nullptr;
2139 if (!node->input(3)->type()->isSubtypeOf(
2140 static_cast<c10::TypePtr>(NoneType::get()))) {
2141 bias = value_map[node->input(3)->unique()]->as<TensorView>();
2142 }
2143
2144 Val* eps_ptr = nullptr;
2145 if (auto eps = constant_as<float>(node->input(4))) {
2146 eps_ptr = IrBuilder::create<Double>(eps.value());
2147 } else {
2148 eps_ptr = value_map[node->input(4)->unique()];
2149 }
2150
2151 auto result =
2152 layer_norm(input, norm_shape, weight, bias, eps_ptr);
2153
2154 if (node->kind() ==
2155 c10::Symbol::fromQualString("aten::native_layer_norm")) {
2156 value_map.emplace(node->output(0)->unique(), result.output);
2157 value_map.emplace(node->output(1)->unique(), result.mean);
2158 value_map.emplace(node->output(2)->unique(), result.invstd);
2159 } else if (
2160 node->kind() ==
2161 c10::Symbol::fromQualString("aten::layer_norm")) {
2162 value_map.emplace(node->output()->unique(), result.output);
2163 }
2164 },
2165 // TODO: #ProfileIValue List should update this
2166 [](const Node* node) -> bool {
2167 if (isReductionNonCompatibleTensor(
2168 node->input(0)->type()->cast<TensorType>())) {
2169 return false;
2170 }
2171 if (node->inputs()[1]->node()->kind() != prim::Constant) {
2172 return false;
2173 }
2174 return true;
2175 },
2176 [](const Node* node) -> OperatorType {
2177 return OperatorType::Normalization;
2178 });
2179 }
2180 }
2181
2182 {
2183 auto ptr_op = getOperatorForLiteral(
2184 "aten::native_layer_norm_backward(Tensor grad_out, Tensor input, int[] normalized_shape, Tensor mean, Tensor rstd, Tensor? weight, Tensor? bias, bool[3] output_mask) -> (Tensor, Tensor, Tensor)");
2185 REGISTER_PARSE_RULE(
2186 ptr_op,
2187 {
2188 MemoryFormat format;
2189 std::list<Val*> list_val;
2190 std::tie(format, list_val) = getConsistentValues(
2191 MemoryFormat::Contiguous(),
2192 value_map[node->inputs()[0]->unique()],
2193 value_map[node->inputs()[1]->unique()]);
2194 auto grad_out_t = list_val.front();
2195 list_val.pop_front();
2196 auto input_t = list_val.front();
2197 list_val.pop_front();
2198 auto grad_out = grad_out_t->as<TensorView>();
2199 auto input = input_t->as<TensorView>();
2200
2201 auto norm_shape_optional =
2202 constant_as<c10::List<int64_t>>(node->input(2));
2203 TORCH_INTERNAL_ASSERT(
2204 norm_shape_optional.has_value(),
2205 "The Normalized_Shape list is required.");
2206 auto norm_shape = norm_shape_optional->vec();
2207
2208 auto mean = value_map[node->input(3)->unique()]->as<TensorView>();
2209 auto rstd = value_map[node->input(4)->unique()]->as<TensorView>();
2210
2211 TensorView* weight = nullptr;
2212 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
2213 if (!node->input(5)->type()->isSubtypeOf(
2214 static_cast<c10::TypePtr>(NoneType::get()))) {
2215 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
2216 weight = value_map[node->input(5)->unique()]->as<TensorView>();
2217 }
2218
2219 TensorView* bias = nullptr;
2220 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
2221 if (!node->input(6)->type()->isSubtypeOf(
2222 static_cast<c10::TypePtr>(NoneType::get()))) {
2223 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
2224 bias = value_map[node->input(6)->unique()]->as<TensorView>();
2225 }
2226
2227 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
2228 auto output_mask_optional =
2229 constant_as<c10::List<bool>>(node->input(7));
2230 TORCH_INTERNAL_ASSERT(
2231 output_mask_optional.has_value(),
2232 "output mask for layer_norm_backward");
2233 std::vector<bool> output_mask = output_mask_optional->vec();
2234
2235 auto grad = layer_norm_backward(
2236 grad_out,
2237 input,
2238 norm_shape,
2239 mean,
2240 rstd,
2241 weight,
2242 bias,
2243 output_mask);
2244
2245 if (output_mask[0]) {
2246 TORCH_INTERNAL_ASSERT(grad.grad_input != nullptr);
2247 value_map.emplace(node->output(0)->unique(), grad.grad_input);
2248 } else {
2249 TORCH_INTERNAL_ASSERT(grad.grad_input == nullptr);
2250 value_map.emplace(
2251 node->output(0)->unique(), TensorViewBuilder().build());
2252 }
2253
2254 if (output_mask[1] && weight != nullptr) {
2255 TORCH_INTERNAL_ASSERT(grad.grad_weight != nullptr);
2256 value_map.emplace(node->output(1)->unique(), grad.grad_weight);
2257 } else {
2258 TORCH_INTERNAL_ASSERT(grad.grad_weight == nullptr);
2259 value_map.emplace(
2260 node->output(1)->unique(), TensorViewBuilder().build());
2261 }
2262
2263 if (output_mask[2] && bias != nullptr) {
2264 TORCH_INTERNAL_ASSERT(grad.grad_bias != nullptr);
2265 value_map.emplace(node->output(2)->unique(), grad.grad_bias);
2266 } else {
2267 TORCH_INTERNAL_ASSERT(grad.grad_bias == nullptr);
2268 value_map.emplace(
2269 node->output(2)->unique(), TensorViewBuilder().build());
2270 }
2271 },
2272 // TODO: #ProfileIValue List should update this
2273 [](const Node* node) -> bool {
2274 if (isReductionNonCompatibleTensor(
2275 node->input(0)->type()->cast<TensorType>())) {
2276 return false;
2277 }
2278 if (node->inputs()[2]->node()->kind() != prim::Constant) {
2279 return false;
2280 }
2281 if (node->inputs()[7]->node()->kind() != prim::Constant) {
2282 return false;
2283 }
2284 return true;
2285 },
2286 [](const Node* node) -> OperatorType {
2287 return OperatorType::Normalization;
2288 });
2289 }
2290
2291 {
2292 std::array<const char*, kNumSoftmaxFwd> SoftmaxFwd = {
2293 "aten::softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor",
2294 "aten::log_softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor"};
2295 for (auto signature : SoftmaxFwd) {
2296 auto ptr_op = getOperatorForLiteral(signature);
2297 REGISTER_PARSE_RULE(
2298 ptr_op,
2299 {
2300 MemoryFormat format;
2301 std::list<Val*> list_val;
2302 std::tie(format, list_val) = getConsistentValues(
2303 MemoryFormat::Contiguous(),
2304 value_map[node->inputs()[0]->unique()]);
2305 auto input_t = list_val.front();
2306 list_val.pop_front();
2307 auto input = input_t->as<TensorView>();
2308
2309 auto dim_value = constant_as<int>(node->input(1));
2310 TORCH_INTERNAL_ASSERT(
2311 dim_value.has_value(), "dim in softmax is not valid");
2312
2313 auto data_type = DataType::Null;
2314 if (const auto opt_ivalue = toIValue(node->input(2))) {
2315 if (!opt_ivalue->isNone()) {
2316 data_type = aten_to_data_type(opt_ivalue->toScalarType());
2317 }
2318 }
2319
2320 input = (data_type != DataType::Null)
2321 ? optionalCastStrict(data_type, input)->as<TensorView>()
2322 : input;
2323
2324 bool is_log_softmax = node->kind() ==
2325 c10::Symbol::fromQualString("aten::log_softmax");
2326
2327 auto output = (is_log_softmax)
2328 ? log_softmax(input, dim_value.value())
2329 : softmax(input, dim_value.value());
2330
2331 value_map.emplace(node->output()->unique(), output);
2332 },
2333 [](const Node* node) -> bool {
2334 if (isReductionNonCompatibleTensor(
2335 node->input(0)->type()->cast<TensorType>())) {
2336 return false;
2337 }
2338 if (node->inputs()[1]->node()->kind() != prim::Constant) {
2339 return false;
2340 }
2341 if (!isScalarTypeCompatible(node, 2)) {
2342 return false;
2343 }
2344 return true;
2345 },
2346 [](const Node* node) -> OperatorType {
2347 return OperatorType::Normalization;
2348 });
2349 }
2350 }
2351
2352 { // LTC uses this op for softmax
2353 auto ptr_op = getOperatorForLiteral(
2354 "aten::_softmax(Tensor self, int dim, bool half_to_float) -> Tensor");
2355 REGISTER_PARSE_RULE(
2356 ptr_op,
2357 {
2358 MemoryFormat format;
2359 std::list<Val*> list_val;
2360 std::tie(format, list_val) = getConsistentValues(
2361 MemoryFormat::Contiguous(),
2362 value_map[node->inputs()[0]->unique()]);
2363 auto input_t = list_val.front();
2364 list_val.pop_front();
2365 auto input = input_t->as<TensorView>();
2366
2367 auto dim_value = constant_as<int>(node->input(1));
2368 TORCH_INTERNAL_ASSERT(
2369 dim_value.has_value(), "dim in softmax is not valid");
2370
2371 auto output = softmax(input, dim_value.value());
2372 value_map.emplace(node->output()->unique(), output);
2373 },
2374 [](const Node* node) -> bool {
2375 if (isReductionNonCompatibleTensor(
2376 node->input(0)->type()->cast<TensorType>())) {
2377 return false;
2378 }
2379 if (node->inputs()[1]->node()->kind() != prim::Constant) {
2380 return false;
2381 }
2382 if (node->inputs()[2]->node()->kind() != prim::Constant) {
2383 return false;
2384 } else {
2385 const auto half_to_float = constant_as<bool>(node->input(2));
2386 TORCH_INTERNAL_ASSERT(
2387 half_to_float.has_value(), "Bool half_to_float is not valid");
2388 auto input_tensor_type =
2389 node->input(0)->type()->cast<TensorType>();
2390 if (half_to_float.value() &&
2391 input_tensor_type->scalarType() != at::ScalarType::Half) {
2392 return false;
2393 }
2394 }
2395 return true;
2396 },
2397 [](const Node* node) -> OperatorType {
2398 return OperatorType::Normalization;
2399 });
2400 }
2401
2402 {
2403 std::array<const char*, kNumSoftmaxBwd> SoftmaxBwd = {
2404 "aten::_log_softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor",
2405 "aten::_softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor"};
2406 for (auto signature : SoftmaxBwd) {
2407 auto ptr_op = getOperatorForLiteral(signature);
2408 REGISTER_PARSE_RULE(
2409 ptr_op,
2410 {
2411 MemoryFormat format;
2412 std::list<Val*> list_val;
2413 std::tie(format, list_val) = getConsistentValues(
2414 MemoryFormat::Contiguous(),
2415 value_map[node->inputs()[0]->unique()],
2416 value_map[node->inputs()[1]->unique()]);
2417 auto grad_output_t = list_val.front();
2418 list_val.pop_front();
2419 auto grad_output = grad_output_t->as<TensorView>();
2420
2421 auto output_t = list_val.front();
2422 list_val.pop_front();
2423 auto output = output_t->as<TensorView>();
2424
2425 auto dim_value = constant_as<int>(node->input(2));
2426 TORCH_INTERNAL_ASSERT(
2427 dim_value.has_value(), "dim in softmax is not valid");
2428
2429 // input_dtype here is ignored! type_inference handles it
2430 bool is_log_softmax = node->kind() ==
2431 c10::Symbol::fromQualString(
2432 "aten::_log_softmax_backward_data");
2433 auto grad_input = (is_log_softmax)
2434 ? log_softmax_backward(grad_output, output, dim_value.value())
2435 : softmax_backward(grad_output, output, dim_value.value());
2436
2437 value_map.emplace(node->output()->unique(), grad_input);
2438 },
2439 [](const Node* node) -> bool {
2440 if (isReductionNonCompatibleTensor(
2441 node->input(0)->type()->cast<TensorType>())) {
2442 return false;
2443 }
2444 if (node->inputs()[2]->node()->kind() != prim::Constant) {
2445 return false;
2446 }
2447 if (node->inputs()[3]->node()->kind() != prim::Constant) {
2448 return false;
2449 }
2450 return true;
2451 },
2452 [](const Node* node) -> OperatorType {
2453 return OperatorType::Normalization;
2454 });
2455 }
2456 }
2457
2458 {
2459 std::array<const char*, kNumVarOps> Variance = {
2460 "aten::var.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> Tensor",
2461 "aten::std.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> Tensor"};
2462 for (auto signature : Variance) {
2463 auto ptr_op = getOperatorForLiteral(signature);
2464 REGISTER_PARSE_RULE(
2465 ptr_op,
2466 {
2467 MemoryFormat format;
2468 std::list<Val*> list_val;
2469 std::tie(format, list_val) = getConsistentValues(
2470 MemoryFormat::Contiguous(),
2471 value_map[node->inputs()[0]->unique()]);
2472 auto input_t = list_val.front();
2473 list_val.pop_front();
2474 auto input = input_t->as<TensorView>();
2475
2476 bool is_variance =
2477 node->kind() == c10::Symbol::fromQualString("aten::var");
2478
2479 auto dims_list = constant_as<c10::List<int64_t>>(node->input(1));
2480 TORCH_INTERNAL_ASSERT(
2481 dims_list.has_value(), "Cannot fuse with dynamic axes");
2482 std::vector<int> dims;
2483 if (!dims_list->empty()) {
2484 for (const auto dim : dims_list->vec()) {
2485 dims.emplace_back(static_cast<int>(dim));
2486 }
2487 } else {
2488 dims.resize(input->as<TensorView>()->nDims());
2489 std::iota(dims.begin(), dims.end(), 0);
2490 }
2491
2492 auto unbiased = constant_as<bool>(node->input(2));
2493 TORCH_INTERNAL_ASSERT(
2494 unbiased.has_value(), "Cannot fuse with dynamic unbiased");
2495
2496 auto keepdim = constant_as<bool>(node->input(3));
2497 TORCH_INTERNAL_ASSERT(
2498 keepdim.has_value(), "Cannot fuse with dynamic keepdim");
2499
2500 auto output = (is_variance)
2501 ? variance(input, dims, unbiased.value(), keepdim.value())
2502 : standard_deviation(
2503 input, dims, unbiased.value(), keepdim.value());
2504 value_map.emplace(node->output()->unique(), output);
2505 },
2506 [](const Node* node) -> bool {
2507 if (isReductionNonCompatibleTensor(
2508 node->input(0)->type()->cast<TensorType>())) {
2509 return false;
2510 }
2511 return true;
2512 },
2513 [](const Node* node) -> OperatorType {
2514 return OperatorType::Normalization;
2515 });
2516 }
2517 }
2518
2519 {
2520 auto ptr_op = getOperatorForLiteral(
2521 "aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)");
2522 REGISTER_PARSE_RULE(
2523 ptr_op,
2524 {
2525 // TODO: support channels last in sum
2526 MemoryFormat format;
2527 std::list<Val*> list_val;
2528 std::tie(format, list_val) = getConsistentValues(
2529 MemoryFormat::Contiguous(),
2530 value_map[node->inputs()[0]->unique()]);
2531 auto self = list_val.front();
2532 list_val.pop_front();
2533 auto dims_list = constant_as<c10::List<int64_t>>(node->input(1));
2534 TORCH_INTERNAL_ASSERT(
2535 dims_list.has_value(),
2536 "aten::sum cannot be fused with dynamic axes");
2537 std::vector<int> dims;
2538 if (!dims_list->empty()) {
2539 for (const auto dim : dims_list->vec()) {
2540 dims.emplace_back(static_cast<int>(dim));
2541 }
2542 } else {
2543 dims.resize(self->as<TensorView>()->nDims());
2544 std::iota(dims.begin(), dims.end(), 0);
2545 }
2546 auto keepdim = constant_as<bool>(node->input(2));
2547 TORCH_INTERNAL_ASSERT(
2548 keepdim.has_value(),
2549 "aten::sum cannot be fused with dynamic keepdim");
2550 auto out = sum(self->as<TensorView>(), dims, keepdim.value());
2551 value_map.emplace(node->output()->unique(), out);
2552 },
2553 [](const Node* node) -> bool {
2554 if (isReductionNonCompatibleTensor(
2555 node->input(0)->type()->cast<TensorType>())) {
2556 return false;
2557 }
2558 // TODO: support cast of output types
2559 if (!node->inputs()[3]->type()->isSubtypeOf(
2560 static_cast<c10::TypePtr>(NoneType::get()))) {
2561 // We can only handle output as half, float, and double;
2562 if (const auto opt_ivalue = toIValue(node->input(3))) {
2563 const auto scalar_type = opt_ivalue->toScalarType();
2564 if (!at::isFloatingType(scalar_type)) {
2565 return false;
2566 }
2567 }
2568 }
2569 // we don't support dynamic reduction axes;
2570 if (node->inputs()[1]->node()->kind() != prim::Constant) {
2571 return false;
2572 }
2573 // we don't support dynamic keepdim yet;
2574 if (node->inputs()[2]->node()->kind() != prim::Constant) {
2575 return false;
2576 }
2577 return true;
2578 },
2579 [](const Node* node) -> OperatorType {
2580 return OperatorType::Reduction;
2581 });
2582 }
2583
2584 {
2585 auto ptr_op = getOperatorForLiteral(
2586 "aten::mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor");
2587 REGISTER_PARSE_RULE(
2588 ptr_op,
2589 {
2590 MemoryFormat format;
2591 std::list<Val*> list_val;
2592 std::tie(format, list_val) = getConsistentValues(
2593 MemoryFormat::Contiguous(),
2594 value_map[node->inputs()[0]->unique()]);
2595 auto operand = list_val.front();
2596 list_val.pop_front();
2597 auto self = operand->as<TensorView>();
2598 auto dims_list = constant_as<c10::List<int64_t>>(node->input(1));
2599 TORCH_INTERNAL_ASSERT(
2600 dims_list.has_value(),
2601 "aten::mean cannot be fused with dynamic axes");
2602 std::vector<int> dims;
2603 if (!dims_list->empty()) {
2604 for (const auto dim : dims_list->vec()) {
2605 dims.emplace_back(static_cast<int>(dim));
2606 }
2607 } else {
2608 dims.resize(self->as<TensorView>()->nDims());
2609 std::iota(dims.begin(), dims.end(), 0);
2610 }
2611 auto keepdim = constant_as<bool>(node->input(2));
2612 TORCH_INTERNAL_ASSERT(
2613 keepdim.has_value(),
2614 "aten::mean cannot be fused with dynamic keepdim");
2615 auto o_sum = sum(self, dims, keepdim.value());
2616 Val* num_features = IrBuilder::create<Double>(1);
2617 for (auto axis : dims) {
2618 if (axis < 0) {
2619 axis += int(self->nDims());
2620 }
2621 num_features =
2622 mul(num_features, self->domain()->domain()[axis]->extent());
2623 }
2624 auto out = div(o_sum, num_features);
2625 value_map.emplace(node->output()->unique(), out);
2626 },
2627 [](const Node* node) -> bool {
2628 if (isReductionNonCompatibleTensor(
2629 node->input(0)->type()->cast<TensorType>())) {
2630 return false;
2631 }
2632 // TODO: support cast of output types
2633 if (!node->inputs()[3]->type()->isSubtypeOf(
2634 static_cast<c10::TypePtr>(NoneType::get()))) {
2635 // We can only handle output as half, float, and double;
2636 if (const auto opt_ivalue = toIValue(node->input(3))) {
2637 const auto scalar_type = opt_ivalue->toScalarType();
2638 if (!at::isFloatingType(scalar_type)) {
2639 return false;
2640 }
2641 }
2642 }
2643 // we don't support dynamic reduction axes;
2644 if (node->inputs()[1]->node()->kind() != prim::Constant) {
2645 return false;
2646 }
2647 // we don't support dynamic keepdim yet;
2648 if (node->inputs()[2]->node()->kind() != prim::Constant) {
2649 return false;
2650 }
2651 return true;
2652 },
2653 [](const Node* node) -> OperatorType {
2654 return OperatorType::Reduction;
2655 });
2656 }
2657 {
2658 std::array<const char*, kNumSumToSize> SumToSize = {
2659 "aten::_grad_sum_to_size(Tensor(a) self, int[]? size) -> Tensor(a)",
2660 "aten::sum_to_size(Tensor self, int[] size) -> Tensor"};
2661 for (auto signature : SumToSize) {
2662 auto ptr_op = getOperatorForLiteral(signature);
2663 REGISTER_PARSE_RULE(
2664 ptr_op,
2665 {
2666 MemoryFormat format;
2667 std::list<Val*> list_val;
2668 std::tie(format, list_val) = getConsistentValues(
2669 MemoryFormat::Contiguous(),
2670 value_map[node->inputs()[0]->unique()]);
2671 auto self = list_val.front();
2672 list_val.pop_front();
2673 auto size_to = constant_as<c10::List<int64_t>>(node->input(1));
2674 TORCH_INTERNAL_ASSERT(
2675 size_to.has_value(),
2676 "aten::sum cannot be fused with dynamic axes");
2677 if (!size_to->empty()) {
2678 auto input = self->as<TensorView>();
2679 auto out = sum_to(input, size_to->vec());
2680 // this copy is not necessary, but making copy avoids tricky
2681 // computational graph where no-op could be challenging.
2682 if (out == input) {
2683 out = set(input);
2684 }
2685 value_map.emplace(node->output()->unique(), out);
2686 } else {
2687 // We are introducing alias here!
2688 value_map.emplace(node->output()->unique(), self);
2689 }
2690 },
2691 [](const Node* node) -> bool {
2692 if (isReductionNonCompatibleTensor(
2693 node->input(0)->type()->cast<TensorType>())) {
2694 return false;
2695 }
2696 // we don't support dynamic reduction axes;
2697 if (node->inputs()[1]->node()->kind() != prim::Constant) {
2698 return false;
2699 }
2700 return true;
2701 },
2702 [](const Node* node) -> OperatorType {
2703 auto size_to = constant_as<c10::List<int64_t>>(node->input(1));
2704 // technically size_to->empty() should never occur, as specialized
2705 // _grad_sum_to_size should have been removed by optimization pass
2706 if (size_to->empty()) {
2707 return OperatorType::ElementWise;
2708 } else {
2709 return OperatorType::ReductionToSize;
2710 }
2711 });
2712 }
2713 }
2714
2715 {
2716 std::array<const char*, kNumAutocastOps> AutocastOps = {
2717 "aten::_autocast_to_reduced_precision(Tensor(a) self, bool cuda_enabled, bool cpu_enabled, ScalarType cuda_dtype, ScalarType cpu_dtype) -> Tensor(a)",
2718 "aten::_autocast_to_full_precision(Tensor(a) self, bool cuda_enabled, bool cpu_enabled) -> Tensor(a)"};
2719 for (auto signature : AutocastOps) {
2720 auto ptr_op = getOperatorForLiteral(signature);
2721 REGISTER_PARSE_RULE(
2722 ptr_op,
2723 {
2724 MemoryFormat format;
2725 std::list<Val*> list_val;
2726 std::tie(format, list_val) = getConsistentValues(
2727 c10::nullopt, value_map[node->inputs()[0]->unique()]);
2728 auto self = list_val.front();
2729 list_val.pop_front();
2730
2731 auto out = set(self);
2732 value_map.emplace(
2733 node->output()->unique(), ValueHolder(out, format));
2734 },
2735 isInputNonSizeZeroTensor,
2736 nullptr);
2737 }
2738 }
2739
2740 {
2741 auto ptr_op = getOperatorForLiteral(
2742 "aten::_to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor");
2743 REGISTER_PARSE_RULE(
2744 ptr_op,
2745 {
2746 MemoryFormat format;
2747 std::list<Val*> list_val;
2748 std::tie(format, list_val) = getConsistentValues(
2749 c10::nullopt, value_map[node->inputs()[0]->unique()]);
2750 auto self = list_val.front();
2751 list_val.pop_front();
2752
2753 auto out = castTensoToDtype(self, node->input(1));
2754
2755 value_map.emplace(
2756 node->output()->unique(), ValueHolder(out, format));
2757 },
2758 [](const Node* node) -> bool {
2759 if (!isInputNonSizeZeroTensor(node)) {
2760 return false;
2761 }
2762 if (node->inputs()[1]->node()->kind() != prim::Constant) {
2763 return false;
2764 }
2765 // we do not support explicit memory_format on output
2766 if (!node->inputs()[2]->type()->isSubtypeOf(
2767 static_cast<c10::TypePtr>(NoneType::get()))) {
2768 return false;
2769 }
2770 // we do not support explicit memory_format on output
2771 if (!node->inputs()[3]->type()->isSubtypeOf(
2772 static_cast<c10::TypePtr>(NoneType::get()))) {
2773 return false;
2774 }
2775 // we do not support explicit memory_format on output
2776 if (!node->inputs()[4]->type()->isSubtypeOf(
2777 static_cast<c10::TypePtr>(NoneType::get()))) {
2778 return false;
2779 }
2780 // we do not support explicit memory_format on output
2781 if (!node->inputs()[6]->type()->isSubtypeOf(
2782 static_cast<c10::TypePtr>(NoneType::get()))) {
2783 return false;
2784 }
2785 return true;
2786 },
2787 nullptr);
2788 }
2789
2790 // Limiting aten::to implementation to only change the dtype of a tensor
2791 {
2792 auto ptr_op = getOperatorForLiteral(
2793 "aten::to.dtype(Tensor self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor");
2794 REGISTER_PARSE_RULE(
2795 ptr_op,
2796 {
2797 MemoryFormat format;
2798 std::list<Val*> list_val;
2799 std::tie(format, list_val) = getConsistentValues(
2800 c10::nullopt, value_map[node->inputs()[0]->unique()]);
2801 auto self = list_val.front();
2802 list_val.pop_front();
2803
2804 auto out = castTensoToDtype(self, node->input(1));
2805
2806 value_map.emplace(
2807 node->output()->unique(), ValueHolder(out, format));
2808 },
2809 [](const Node* node) -> bool {
2810 if (!isInputNonSizeZeroTensor(node)) {
2811 return false;
2812 }
2813 if (node->inputs()[1]->node()->kind() != prim::Constant) {
2814 return false;
2815 }
2816 // we do not support explicit memory_format on output
2817 if (!node->inputs()[4]->type()->isSubtypeOf(
2818 static_cast<c10::TypePtr>(NoneType::get()))) {
2819 return false;
2820 }
2821 return true;
2822 },
2823 nullptr);
2824 }
2825
2826 {
2827 auto ptr_op = getOperatorForLiteral(
2828 "aten::type_as(Tensor self, Tensor other) -> Tensor");
2829 REGISTER_PARSE_RULE(
2830 ptr_op,
2831 {
2832 MemoryFormat format;
2833 std::list<Val*> list_val;
2834 std::tie(format, list_val) = getConsistentValues(
2835 c10::nullopt, value_map[node->inputs()[0]->unique()]);
2836 auto self = list_val.front();
2837 list_val.pop_front();
2838
2839 // TODO: switch to PyTorch dtype as it's closer to truth.
2840 // For now, reality is that PyTorch IR profiling information could
2841 // be missing even with profiling executor, due to upstream
2842 // transformations between profiling runs to fusion pass.
2843 auto opt_dtype =
2844 value_map[node->inputs()[1]->unique()]->getDataType();
2845 TORCH_INTERNAL_ASSERT(opt_dtype.has_value());
2846
2847 auto out = castOp(opt_dtype.value(), self);
2848 value_map.emplace(
2849 node->output()->unique(), ValueHolder(out, format));
2850 },
2851 isInputNonSizeZeroTensor,
2852 nullptr);
2853 }
2854
2855 {
2856 // We are not fusing `linear` yet, because we can't codegen efficient gemm
2857 // However, we still need this here, so PE would insert profile node for
2858 // this node.
2859 // During fusion pass, We decompose linear into gemm + elementwise.
2860 auto ptr_op = getOperatorForLiteral(
2861 "aten::linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor");
2862 REGISTER_PARSE_RULE(
2863 ptr_op,
2864 {
2865 // this entry is created so we do profile input tensors;
2866 TORCH_INTERNAL_ASSERT(false, "not implemented yet");
2867 },
2868 [](const Node* node) -> bool {
2869 // We only profile `linear` layer but not fusing it.
2870 return false;
2871 });
2872 }
2873
2874 {
2875 auto ptr_op = getOperatorForLiteral(
2876 "prim::add_optional(Tensor(a) input, Tensor? bias) -> Tensor(a)");
2877 REGISTER_PARSE_RULE(
2878 ptr_op,
2879 {
2880 // this entry is created so we do profile input tensors;
2881 if (node->input(1)->type()->isSubtypeOf(
2882 static_cast<c10::TypePtr>(NoneType::get()))) {
2883 // forwarding the value;
2884 value_map.emplace(
2885 node->output()->unique(),
2886 value_map[node->inputs()[0]->unique()]);
2887 } else {
2888 MemoryFormat format;
2889 std::list<Val*> list_val;
2890 std::tie(format, list_val) = getPWFormatValues(
2891 c10::nullopt,
2892 value_map[node->inputs()[0]->unique()],
2893 value_map[node->inputs()[1]->unique()]);
2894 auto lhs = list_val.front();
2895 list_val.pop_front();
2896 auto rhs = list_val.front();
2897 list_val.pop_front();
2898
2899 auto out = binaryOp(
2900 BinaryOpType::Add,
2901 lhs,
2902 rhs,
2903 TypePromotion::default_op_config);
2904 value_map.emplace(
2905 node->output()->unique(), ValueHolder(out, format));
2906 }
2907 },
2908 isInputNonSizeZeroTensor,
2909 nullptr);
2910 }
2911
2912 {
2913 auto ptr_op = getOperatorForLiteral(
2914 "aten::leaky_relu(Tensor self, Scalar negative_slope=0.01) -> Tensor");
2915 REGISTER_PARSE_RULE(
2916 ptr_op,
2917 {
2918 MemoryFormat format;
2919 std::list<Val*> list_val;
2920 std::tie(format, list_val) = getConsistentValues(
2921 c10::nullopt, value_map[node->inputs()[0]->unique()]);
2922 auto self = list_val.front()->as<TensorView>();
2923 list_val.pop_front();
2924
2925 Val* negative_slope = value_map[node->inputs()[1]->unique()];
2926
2927 auto out = leaky_relu(self, negative_slope);
2928 value_map.emplace(
2929 node->output()->unique(), ValueHolder(out, format));
2930 },
2931 [](const Node* node) -> bool {
2932 if (!isInputNonSizeZeroTensor(node)) {
2933 return false;
2934 }
2935 return true;
2936 },
2937 nullptr);
2938 }
2939
2940 {
2941 auto ptr_op = getOperatorForLiteral(
2942 "aten::gelu(Tensor self, *, str approximate='none') -> Tensor");
2943 REGISTER_PARSE_RULE(
2944 ptr_op,
2945 {
2946 MemoryFormat format;
2947 std::list<Val*> list_val;
2948 std::tie(format, list_val) = getConsistentValues(
2949 c10::nullopt, value_map[node->inputs()[0]->unique()]);
2950 auto self = list_val.front()->as<TensorView>();
2951 list_val.pop_front();
2952
2953 auto approximate = constant_as<std::string>(node->input(1));
2954 TORCH_INTERNAL_ASSERT(
2955 approximate.has_value(),
2956 "The approximate parameter is required.");
2957 const auto kTanhGelu =
2958 at::native::get_gelutype_enum(approximate.value()) ==
2959 at::native::GeluType::Tanh;
2960
2961 auto out = (kTanhGelu) ? tanh_gelu(self) : gelu(self);
2962 value_map.emplace(
2963 node->output()->unique(), ValueHolder(out, format));
2964 },
2965 [](const Node* node) -> bool {
2966 if (!isInputNonSizeZeroTensor(node)) {
2967 return false;
2968 }
2969 if (node->input(1)->node()->kind() != prim::Constant) {
2970 return false;
2971 }
2972 return true;
2973 },
2974 nullptr);
2975 }
2976
2977 {
2978 auto ptr_op = getOperatorForLiteral(
2979 "aten::gelu_backward(Tensor grad_output, Tensor self, *, str approximate='none') -> Tensor");
2980 REGISTER_PARSE_RULE(
2981 ptr_op,
2982 {
2983 MemoryFormat format;
2984 std::list<Val*> list_val;
2985 std::tie(format, list_val) = getPWFormatValues(
2986 c10::nullopt,
2987 value_map[node->inputs()[0]->unique()],
2988 value_map[node->inputs()[1]->unique()]);
2989 auto grad_out = list_val.front()->as<TensorView>();
2990 list_val.pop_front();
2991 auto self = list_val.front()->as<TensorView>();
2992 list_val.pop_front();
2993
2994 auto approximate = constant_as<std::string>(node->input(2));
2995 TORCH_INTERNAL_ASSERT(
2996 approximate.has_value(),
2997 "The approximate parameter is required.");
2998 const auto kTanhGelu =
2999 at::native::get_gelutype_enum(approximate.value()) ==
3000 at::native::GeluType::Tanh;
3001
3002 auto grad_in = (kTanhGelu) ? tanh_gelu_backward(grad_out, self)
3003 : gelu_backward(grad_out, self);
3004 value_map.emplace(
3005 node->output()->unique(), ValueHolder(grad_in, format));
3006 },
3007 [](const Node* node) -> bool {
3008 if (!isInputNonSizeZeroTensor(node)) {
3009 return false;
3010 }
3011 if (node->input(2)->node()->kind() != prim::Constant) {
3012 return false;
3013 }
3014 return true;
3015 },
3016 nullptr);
3017 }
3018
3019 {
3020 auto ptr_op = getOperatorForLiteral(
3021 "aten::tanh_backward(Tensor grad_output, Tensor output) -> Tensor");
3022 REGISTER_PARSE_RULE(
3023 ptr_op,
3024 {
3025 MemoryFormat format;
3026 std::list<Val*> list_val;
3027 std::tie(format, list_val) = getPWFormatValues(
3028 c10::nullopt,
3029 value_map[node->inputs()[0]->unique()],
3030 value_map[node->inputs()[1]->unique()]);
3031 auto grad_out = list_val.front()->as<TensorView>();
3032 list_val.pop_front();
3033 auto self = list_val.front()->as<TensorView>();
3034 list_val.pop_front();
3035
3036 auto grad_in = tanh_backward(grad_out, self);
3037 value_map.emplace(
3038 node->output()->unique(), ValueHolder(grad_in, format));
3039 },
3040 isInputNonSizeZeroTensor,
3041 nullptr);
3042 }
3043
3044 {
3045 std::array<const char*, kNumAminAmaxOps> BinaryFloatOp = {
3046 "aten::amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor",
3047 "aten::amin(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor"};
3048 for (auto signature : BinaryFloatOp) {
3049 auto ptr_op = getOperatorForLiteral(signature);
3050 REGISTER_PARSE_RULE(
3051 ptr_op,
3052 {
3053 MemoryFormat format;
3054 std::list<Val*> list_val;
3055 std::tie(format, list_val) = getConsistentValues(
3056 MemoryFormat::Contiguous(),
3057 value_map[node->inputs()[0]->unique()]);
3058 auto self = list_val.front();
3059 list_val.pop_front();
3060 auto dims_list = constant_as<c10::List<int64_t>>(node->input(1));
3061 TORCH_INTERNAL_ASSERT(
3062 dims_list.has_value(),
3063 "aten::amax/amin cannot be fused with dynamic axes");
3064 std::vector<int> dims;
3065 if (!dims_list->empty()) {
3066 for (const auto dim : dims_list->vec()) {
3067 dims.emplace_back(static_cast<int>(dim));
3068 }
3069 } else {
3070 dims.resize(self->as<TensorView>()->nDims());
3071 std::iota(dims.begin(), dims.end(), 0);
3072 }
3073 auto keepdim = constant_as<bool>(node->input(2));
3074 TORCH_INTERNAL_ASSERT(
3075 keepdim.has_value(),
3076 "aten::amax/amin cannot be fused with dynamic keepdim");
3077
3078 TensorView* out = nullptr;
3079 if (node->kind() == c10::Symbol::fromQualString("aten::amax")) {
3080 out = max(self->as<TensorView>(), dims, keepdim.value());
3081 } else if (
3082 node->kind() == c10::Symbol::fromQualString("aten::amin")) {
3083 out = min(self->as<TensorView>(), dims, keepdim.value());
3084 } else {
3085 TORCH_INTERNAL_ASSERT(
3086 false, "unrecognized operation in aten::amax/amin");
3087 }
3088 value_map.emplace(node->output()->unique(), out);
3089 },
3090 [](const Node* node) -> bool {
3091 if (isReductionNonCompatibleTensor(
3092 node->input(0)->type()->cast<TensorType>())) {
3093 return false;
3094 }
3095 // we don't support dynamic reduction axes;
3096 if (node->inputs()[1]->node()->kind() != prim::Constant) {
3097 return false;
3098 }
3099 // we don't support dynamic keepdim yet;
3100 if (node->inputs()[2]->node()->kind() != prim::Constant) {
3101 return false;
3102 }
3103 return true;
3104 },
3105 [](const Node* node) -> OperatorType {
3106 return OperatorType::Reduction;
3107 });
3108 }
3109 }
3110
3111 {
3112 std::array<const char*, kNumViewOps> ViewOps = {
3113 "prim::reshape_copy(Tensor self, int[] shape) -> Tensor",
3114 "prim::view_copy(Tensor self, int[] size) -> Tensor"};
3115 for (auto signature : ViewOps) {
3116 auto ptr_op = getOperatorForLiteral(signature);
3117 REGISTER_PARSE_RULE(
3118 ptr_op,
3119 {
3120 auto self_value = node->inputs()[0];
3121 MemoryFormat format;
3122 std::list<Val*> list_val;
3123 std::tie(format, list_val) = getConsistentValues(
3124 MemoryFormat::Contiguous(), value_map[self_value->unique()]);
3125 auto self = list_val.front()->as<TensorView>();
3126 list_val.pop_front();
3127
3128 auto self_type = self_value->type()->cast<c10::TensorType>();
3129 TORCH_INTERNAL_ASSERT(self_type != nullptr);
3130 auto self_sizes = getTensorSizes(self_type);
3131
3132 auto view_sizes = constant_as<c10::List<int64_t>>(node->input(1));
3133 TORCH_INTERNAL_ASSERT(
3134 view_sizes.has_value(), "The size parameter is required.");
3135
3136 auto output = view(self, self_sizes, view_sizes->vec());
3137 value_map.emplace(node->output()->unique(), output);
3138 },
3139 [](const Node* node) -> bool {
3140 auto self_value = node->inputs()[0];
3141 auto tensor_type = self_value->type()->cast<c10::TensorType>();
3142 if (tensor_type == nullptr) {
3143 return false;
3144 }
3145 if (!tensor_type->sizes().concrete_sizes().has_value()) {
3146 // Shape information for input tensor is required.
3147 return false;
3148 }
3149
3150 if (!isInputNonSizeZeroTensor(node)) {
3151 return false;
3152 }
3153 // Reject fusing node if view_sizes contains an inferred dimension
3154 auto view_sizes = constant_as<c10::List<int64_t>>(node->input(1));
3155 if (!view_sizes.has_value()) {
3156 // The size parameter is required.
3157 return false;
3158 }
3159
3160 for (auto axis_size : view_sizes->vec()) {
3161 if (axis_size == -1) {
3162 return false;
3163 }
3164 }
3165 return true;
3166 },
3167 nullptr);
3168 }
3169 }
3170
3171 {
3172 auto flatten_op = getOperatorForLiteral(
3173 "prim::flatten_copy(Tensor self, int start_dim, int end_dim) -> Tensor");
3174 REGISTER_PARSE_RULE(
3175 flatten_op,
3176 {
3177 auto self_value = node->inputs()[0];
3178 MemoryFormat format;
3179 std::list<Val*> list_val;
3180 std::tie(format, list_val) = getConsistentValues(
3181 MemoryFormat::Contiguous(), value_map[self_value->unique()]);
3182 auto self = list_val.front()->as<TensorView>();
3183 list_val.pop_front();
3184
3185 auto start_dim_value = constant_as<int>(node->input(1));
3186 TORCH_INTERNAL_ASSERT(
3187 start_dim_value.has_value(), "start_dim is not valid");
3188 auto end_dim_value = constant_as<int>(node->input(2));
3189 TORCH_INTERNAL_ASSERT(
3190 end_dim_value.has_value(), "end_dim is not valid");
3191
3192 TensorView* output =
3193 flatten(self, start_dim_value.value(), end_dim_value.value());
3194 value_map.emplace(node->output()->unique(), output);
3195 },
3196 [](const Node* node) -> bool {
3197 // we don't support dynamic start_dim;
3198 if (node->inputs()[1]->node()->kind() != prim::Constant) {
3199 return false;
3200 }
3201 // we don't support dynamic end_dim yet;
3202 if (node->inputs()[2]->node()->kind() != prim::Constant) {
3203 return false;
3204 }
3205 return true;
3206 },
3207 nullptr);
3208 }
3209
3210 {
3211 auto ptr_op =
3212 getOperatorForLiteral("prim::squeeze_copy(Tensor self) -> Tensor");
3213 REGISTER_PARSE_RULE(
3214 ptr_op,
3215 {
3216 auto self_value = node->inputs()[0];
3217 MemoryFormat format;
3218 std::list<Val*> list_val;
3219 std::tie(format, list_val) = getConsistentValues(
3220 MemoryFormat::Contiguous(), value_map[self_value->unique()]);
3221 auto self = list_val.front()->as<TensorView>();
3222 list_val.pop_front();
3223
3224 auto self_type = self_value->type()->cast<c10::TensorType>();
3225 TORCH_INTERNAL_ASSERT(self_type != nullptr);
3226 auto self_sizes = getTensorSizes(self_type);
3227
3228 TensorView* output = nullptr;
3229 if (self_sizes.empty()) {
3230 // squeeze on scalar tensor should just return itself;
3231 output = set(self);
3232 } else {
3233 output = squeeze(self, self_sizes);
3234 }
3235 value_map.emplace(node->output()->unique(), output);
3236 },
3237 [](const Node* node) -> bool {
3238 // Shape information for input tensor is required.
3239 auto self_value = node->inputs()[0];
3240 auto tensor_type = self_value->type()->cast<c10::TensorType>();
3241 if (tensor_type == nullptr) {
3242 return false;
3243 }
3244 if (!isInputNonSizeZeroTensor(node)) {
3245 return false;
3246 }
3247 return tensor_type->sizes().concrete_sizes().has_value();
3248 },
3249 nullptr);
3250 }
3251
3252 {
3253 std::array<const char*, kNumAliasDimOps> AliasOpWithDim = {
3254 "prim::squeeze_copy.dim(Tensor self, int dim) -> Tensor",
3255 "prim::unsqueeze_copy(Tensor self, int dim) -> Tensor"};
3256 for (auto signature : AliasOpWithDim) {
3257 auto ptr_op = getOperatorForLiteral(signature);
3258 REGISTER_PARSE_RULE(
3259 ptr_op,
3260 {
3261 auto self_value = node->inputs()[0];
3262 MemoryFormat format;
3263 std::list<Val*> list_val;
3264 std::tie(format, list_val) = getConsistentValues(
3265 MemoryFormat::Contiguous(),
3266 value_map[node->inputs()[0]->unique()]);
3267 auto self = list_val.front()->as<TensorView>();
3268 list_val.pop_front();
3269
3270 auto dim_value = constant_as<int>(node->input(1));
3271 TORCH_INTERNAL_ASSERT(dim_value.has_value(), "dim is not valid");
3272
3273 TensorView* output = nullptr;
3274 if (node->kind() == prim::unsqueeze_copy) {
3275 output = unsqueeze(self, dim_value.value());
3276 } else {
3277 auto self_type = self_value->type()->cast<c10::TensorType>();
3278 TORCH_INTERNAL_ASSERT(self_type != nullptr);
3279 auto self_sizes = getTensorSizes(self_type);
3280 if (self_sizes.empty()) {
3281 // squeeze on scalar tensor should just return itself;
3282 output = set(self);
3283 } else {
3284 output = squeeze(self, self_sizes, dim_value.value());
3285 }
3286 }
3287 value_map.emplace(node->output()->unique(), output);
3288 },
3289 [](const Node* node) -> bool {
3290 // Shape information for input tensor is required.
3291 auto self_value = node->inputs()[0];
3292 auto tensor_type = self_value->type()->cast<c10::TensorType>();
3293 if (tensor_type == nullptr) {
3294 return false;
3295 }
3296 if (!isInputNonSizeZeroTensor(node)) {
3297 return false;
3298 }
3299 if (node->input(1)->node()->kind() != prim::Constant) {
3300 return false;
3301 }
3302 auto optional_sizes = tensor_type->sizes().concrete_sizes();
3303 return tensor_type->sizes().concrete_sizes().has_value();
3304 },
3305 nullptr);
3306 }
3307 }
3308
3309 {
3310 auto ptr_op = getOperatorForLiteral(
3311 "prim::expand_as_copy(Tensor self, Tensor other) -> Tensor");
3312 REGISTER_PARSE_RULE(
3313 ptr_op,
3314 {
3315 MemoryFormat format;
3316 std::list<Val*> list_val;
3317 std::tie(format, list_val) = getPWFormatValues(
3318 c10::nullopt,
3319 value_map[node->inputs()[0]->unique()],
3320 value_map[node->inputs()[1]->unique()]);
3321 auto self = list_val.front()->as<TensorView>();
3322 list_val.pop_front();
3323 auto other = list_val.front()->as<TensorView>();
3324 list_val.pop_front();
3325
3326 auto output = expand_as(self, other);
3327 value_map.emplace(
3328 node->output()->unique(), ValueHolder(output, format));
3329 },
3330 [](const Node* node) -> bool {
3331 if (!isInputNonSizeZeroTensor(node)) {
3332 return false;
3333 }
3334
3335 return true;
3336 },
3337 nullptr);
3338 }
3339
3340 {
3341 auto ptr_op = getOperatorForLiteral(
3342 "prim::expand_copy(Tensor self, int[] size, *, bool implicit=False) -> Tensor");
3343 REGISTER_PARSE_RULE(
3344 ptr_op,
3345 {
3346 auto self_value = node->inputs()[0];
3347 MemoryFormat format;
3348 std::list<Val*> list_val;
3349 std::tie(format, list_val) = getConsistentValues(
3350 MemoryFormat::Contiguous(), value_map[self_value->unique()]);
3351 auto self = list_val.front()->as<TensorView>();
3352 list_val.pop_front();
3353
3354 auto expand_sizes = constant_as<c10::List<int64_t>>(node->input(1));
3355 TORCH_INTERNAL_ASSERT(
3356 expand_sizes.has_value(), "The size parameter is required.");
3357
3358 std::vector<CgValue> expand_sizes_vec;
3359 for (const int64_t& size : expand_sizes.value()) {
3360 expand_sizes_vec.push_back(IrBuilder::create<Int>(size));
3361 }
3362
3363 // TODO: we should be able to support dynamic expand values
3364 auto output = expand(self, expand_sizes_vec);
3365 value_map.emplace(node->output()->unique(), output);
3366 },
3367 [](const Node* node) -> bool {
3368 if (!isInputNonSizeZeroTensor(node)) {
3369 return false;
3370 }
3371 // expand_sizes needs to be constant
3372 auto expand_sizes = constant_as<c10::List<int64_t>>(node->input(1));
3373 if (!expand_sizes.has_value()) {
3374 return false;
3375 }
3376
3377 return true;
3378 },
3379 nullptr);
3380 }
3381
3382 {
3383 auto ptr_op = getOperatorForLiteral(
3384 "prim::permute_copy.int(Tensor(a) self, int[] dims) -> Tensor");
3385 REGISTER_PARSE_RULE(
3386 ptr_op,
3387 {
3388 MemoryFormat format;
3389 std::list<Val*> list_val;
3390 std::tie(format, list_val) = getConsistentValues(
3391 c10::nullopt, value_map[node->inputs()[0]->unique()]);
3392 auto self_t = list_val.front();
3393 list_val.pop_front();
3394 auto self = self_t->as<TensorView>();
3395
3396 auto dims = constant_as<c10::List<int64_t>>(node->input(1));
3397 TORCH_INTERNAL_ASSERT(
3398 dims.has_value(), "The dims parameter is required.");
3399 TORCH_INTERNAL_ASSERT(
3400 dims.value().size() == self->getMaybeRFactorDomain().size());
3401
3402 auto output = permute(self, dims->vec());
3403 value_map.emplace(
3404 node->output()->unique(), ValueHolder(output, format));
3405 },
3406 [](const Node* node) -> bool {
3407 if (!isInputNonSizeZeroTensor(node)) {
3408 return false;
3409 }
3410 auto dims = constant_as<c10::List<int64_t>>(node->input(1));
3411 if (!dims.has_value()) {
3412 return false;
3413 }
3414
3415 return true;
3416 },
3417 nullptr);
3418 }
3419
3420 {
3421 auto ptr_op = getOperatorForLiteral(
3422 "prim::transpose_copy.int(Tensor(a) self, int dim0, int dim1) -> Tensor");
3423 REGISTER_PARSE_RULE(
3424 ptr_op,
3425 {
3426 MemoryFormat format;
3427 std::list<Val*> list_val;
3428 std::tie(format, list_val) = getConsistentValues(
3429 c10::nullopt, value_map[node->inputs()[0]->unique()]);
3430 auto self_t = list_val.front();
3431 list_val.pop_front();
3432 auto self = self_t->as<TensorView>();
3433
3434 auto dim0 = constant_as<int>(node->input(1));
3435 TORCH_INTERNAL_ASSERT(
3436 dim0.has_value(), "dim0 in transpose is not valid.");
3437
3438 auto dim1 = constant_as<int>(node->input(2));
3439 TORCH_INTERNAL_ASSERT(
3440 dim1.has_value(), "dim1 in transpose is not valid.");
3441
3442 auto output = transpose(self, dim0.value(), dim1.value());
3443 value_map.emplace(
3444 node->output()->unique(), ValueHolder(output, format));
3445 },
3446 [](const Node* node) -> bool {
3447 if (!isInputNonSizeZeroTensor(node)) {
3448 return false;
3449 }
3450 if (node->input(1)->node()->kind() != prim::Constant) {
3451 return false;
3452 }
3453 if (node->input(2)->node()->kind() != prim::Constant) {
3454 return false;
3455 }
3456 return true;
3457 },
3458 nullptr);
3459 }
3460
3461 {
3462 auto ptr_op =
3463 getOperatorForLiteral("prim::t_copy(Tensor(a) self) -> Tensor");
3464 REGISTER_PARSE_RULE(
3465 ptr_op,
3466 {
3467 MemoryFormat format;
3468 std::list<Val*> list_val;
3469 std::tie(format, list_val) = getConsistentValues(
3470 c10::nullopt, value_map[node->inputs()[0]->unique()]);
3471 auto self_t = list_val.front();
3472 list_val.pop_front();
3473 auto self = self_t->as<TensorView>();
3474
3475 TORCH_INTERNAL_ASSERT(self->getMaybeRFactorDomain().size() <= 2);
3476
3477 auto output = transpose(self);
3478 value_map.emplace(
3479 node->output()->unique(), ValueHolder(output, format));
3480 },
3481 [](const Node* node) -> bool {
3482 if (!isInputNonSizeZeroTensor(node)) {
3483 return false;
3484 }
3485
3486 return true;
3487 },
3488 nullptr);
3489 }
3490 }
3491
3492 void processJitNode(const JitOp* node) {
3493 if (node->kind() == prim::Constant) {
3494 // partition doesn't take constant node explicitly, but it does and copy
3495 // constant into subgraph. So we need to register constants in codegen IR;
3496 for (auto output : node->outputs()) {
3497 TORCH_INTERNAL_ASSERT(
3498 registerScalar(output),
3499 "registration of output failed at index ",
3500 output->offset(),
3501 " for node ",
3502 *node);
3503 }
3504 } else {
3505 auto reg_entry = lookupInRegistry(node);
3506 TORCH_INTERNAL_ASSERT(
3507 reg_entry != nullptr,
3508 "CudaFusionGroup Parser doesn't handle node: ",
3509 canonicalSchemaString(node->schema()));
3510 reg_entry->parse(node, value_map_);
3511 }
3512 }
3513
3514 bool registerValue(const JitValue* val) {
3515 return registerInputTensor(val) || registerScalar(val);
3516 }
3517
3518 bool registerScalar(const JitValue* val) {
3519 if (val->type()->isSubtypeOf(
3520 static_cast<c10::TypePtr>(ComplexType::get()))) {
3521 CgValue cg_val = nullptr;
3522 if (auto ival = constant_as<c10::complex<double>>(val)) {
3523 cg_val = IrBuilder::create<ComplexDouble>(ival.value());
3524 } else {
3525 cg_val = IrBuilder::create<ComplexDouble>();
3526 }
3527 value_map_.emplace(val->unique(), cg_val);
3528 return true;
3529 } else if (val->type()->isSubtypeOf(
3530 static_cast<c10::TypePtr>(FloatType::get()))) {
3531 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
3532 CgValue cg_val;
3533 if (auto ival = constant_as<double>(val)) {
3534 cg_val = IrBuilder::create<Double>(ival.value());
3535 } else {
3536 cg_val = IrBuilder::create<Double>();
3537 }
3538 value_map_.emplace(val->unique(), cg_val);
3539 return true;
3540 } else if (val->type()->isSubtypeOf(
3541 static_cast<c10::TypePtr>(IntType::get()))) {
3542 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
3543 CgValue cg_val;
3544 if (auto ival = constant_as<int64_t>(val)) {
3545 cg_val = IrBuilder::create<Int>(ival.value());
3546 } else {
3547 cg_val = IrBuilder::create<Int>();
3548 }
3549 value_map_.emplace(val->unique(), cg_val);
3550 return true;
3551 } else if (val->type()->isSubtypeOf(
3552 static_cast<c10::TypePtr>(BoolType::get()))) {
3553 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
3554 CgValue cg_val;
3555 if (auto ival = constant_as<bool>(val)) {
3556 cg_val = IrBuilder::create<Bool>(ival.value());
3557 } else {
3558 cg_val = IrBuilder::create<Bool>();
3559 }
3560 value_map_.emplace(val->unique(), cg_val);
3561 return true;
3562 } else if (
3563 val->type()->isSubtypeOf(
3564 static_cast<c10::TypePtr>(StringType::get())) ||
3565 val->type()->isSubtypeOf(
3566 static_cast<c10::TypePtr>(DeviceObjType::get())) ||
3567 val->type()->isSubtypeOf(static_cast<c10::TypePtr>(NoneType::get()))) {
3568 // TODO: should we consider adding support for NoneType;
3569 // Note: String/Device scalars are only used in parsing rules, do not
3570 // register string with codegen IR.
3571 return true;
3572 } else if (val->type()->cast<ListType>()) {
3573 // TODO: we don't support list type in codegen yet;
3574 // This is a WAR to allow axes of reduction to be passed as constant list;
3575 // We simply ignore conversion if the scalar value is a constant;
3576 auto ivalue = toIValue(val);
3577 TORCH_INTERNAL_ASSERT(
3578 ivalue.has_value(),
3579 "List[T] is not supported as an argument by NvFuser. Use a Constant List.");
3580 return true;
3581 }
3582 return false;
3583 }
3584
3585 bool registerInputTensor(const JitValue* val) {
3586 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
3587 CgValue cg_val;
3588 // Don't register if we don't support the type
3589 if (auto tensor_type = val->type()->cast<c10::TensorType>()) {
3590 if (!tensor_type->scalarType().has_value()) {
3591 return false;
3592 }
3593
3594 if (aten_to_data_type(tensor_type->scalarType().value()) ==
3595 DataType::Null) {
3596 return false;
3597 }
3598
3599 // check for NHWC contiguous tensor
3600 TORCH_CHECK(tensor_type->dim().has_value(), "rank missing");
3601 const auto n_dim = tensor_type->dim().value();
3602
3603 MemoryFormat format;
3604 std::vector<int> stride_index;
3605 for (const auto i : c10::irange(n_dim)) {
3606 const auto& stride_property_i = tensor_type->stride_properties()[i];
3607 if (stride_property_i->stride_index_.has_value()) {
3608 stride_index.emplace_back(stride_property_i->stride_index_.value());
3609 }
3610 }
3611
3612 // only set permutation when all stride_index are available
3613 if (stride_index.size() == n_dim) {
3614 format.setPermutation(stride_index);
3615 }
3616
3617 // construct permuted tensor_type
3618 if (format.hasPermutation()) {
3619 auto opt_s_vec = tensor_type->symbolic_sizes().sizes();
3620 TORCH_CHECK(opt_s_vec.has_value(), "missing rank of symbolic sizes");
3621 std::vector<c10::ShapeSymbol> s_vec = opt_s_vec.value();
3622 // apply permutation
3623 auto permutation = format.apply();
3624 for (auto new_axis : c10::irange(permutation.size())) {
3625 auto old_axis = permutation.at(new_axis);
3626 s_vec[new_axis] = opt_s_vec.value()[old_axis];
3627 }
3628
3629 // copying stride properties because we need to permute it
3630 auto opt_stride_vec = tensor_type->stride_properties().sizes();
3631 TORCH_CHECK(opt_stride_vec.has_value(), "missing stride properties");
3632 auto nhwc_stride_vec = opt_stride_vec.value();
3633 // Make tensor contiguous after permutation.
3634 // Note that we are only updating stride_properties.stride_index, since
3635 // contiguous_ and stride_ value should remain the same after
3636 // permutation
3637 for (const auto i : c10::irange(n_dim)) {
3638 nhwc_stride_vec[i]->stride_index_ = n_dim - i - 1;
3639 }
3640
3641 tensor_type = c10::TensorType::create(
3642 tensor_type->scalarType(),
3643 tensor_type->device(),
3644 s_vec,
3645 nhwc_stride_vec,
3646 tensor_type->requires_grad(),
3647 tensor_type->undefined());
3648 }
3649
3650 cg_val = IrBuilder::create<TensorView>(tensor_type);
3651 if (is_cpu_scalar(*tensor_type)) {
3652 cg_val->as<TensorView>()->setCpuScalar(true);
3653 }
3654 value_map_.emplace(val->unique(), ValueHolder(cg_val, format));
3655 return true;
3656 }
3657 return false;
3658 }
3659
3660 std::shared_ptr<Graph> graph_;
3661
3662 // maps from JitValue::unique() to fusion Val;
3663 std::unordered_map<size_t, ValueHolder> value_map_;
3664
3665 static std::unordered_set<Symbol> parser_symbol_set_;
3666 static std::unordered_set<Symbol> parser_skip_set_;
3667 static std::mutex parser_mutex_;
3668
3669 // parsing rule registry.
3670 static std::unordered_map<std::string, RegistrationEntry>
3671 jit_operator_registry_; // NOLINT
3672
3673 // pointing cached entry stored in `jit_operator_registry_`
3674 static std::unordered_map<const FunctionSchema*, const RegistrationEntry*>
3675 cached_registry_lookup_; // NOLINT
3676
3677 // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
3678 static c10::once_flag once_flag_;
3679};
3680std::unordered_set<Symbol> IrParser::parser_symbol_set_; // NOLINT
3681std::unordered_set<Symbol> IrParser::parser_skip_set_; // NOLINT
3682std::mutex IrParser::parser_mutex_;
3683std::unordered_map<std::string, IrParser::RegistrationEntry>
3684 IrParser::jit_operator_registry_; // NOLINT
3685std::unordered_map<const FunctionSchema*, const IrParser::RegistrationEntry*>
3686 IrParser::cached_registry_lookup_; // NOLINT
3687
3688// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
3689c10::once_flag IrParser::once_flag_;
3690
3691ProfileIValueOp* insertProfileIValueOp(
3692 Node* node,
3693 size_t offset,
3694 ProfilingRecord* pr) {
3695 auto in_val = node->input(offset);
3696 auto pn = pr->createProfileIValueNode(in_val);
3697 pn->insertBefore(node);
3698 node->replaceInput(offset, pn->output());
3699 return pn;
3700}
3701
3702void profileReductionSize(ProfilingRecord* pr, Node* node, size_t offset) {
3703 auto pn = insertProfileIValueOp(node, offset, pr);
3704
3705 const auto ivalue_profiler = [pr, pn](Stack& stack) {
3706 std::lock_guard<std::mutex> lock(pr->mutex_);
3707
3708 // TODO: we don't care about merging multiple profiling runs as we don't
3709 // support it at all;
3710 int64_t frame_id = 0;
3711 pop(stack, frame_id);
3712 IValue value;
3713 pop(stack, value);
3714
3715 std::vector<int64_t> size_vec;
3716 if (value.isIntList()) {
3717 size_vec = value.toIntVector();
3718 } else if (value.isNone()) {
3719 size_vec.clear();
3720 } else {
3721 TORCH_INTERNAL_ASSERT(
3722 false,
3723 "profileReductionSize does not support data type: ",
3724 value.tagKind());
3725 }
3726 // We stop profiling when it has failed
3727 if (!pn->hasAttribute(profileFailedAttr)) {
3728 if (!pn->hasAttribute(reductionSizeAttr)) {
3729 pn->is_(reductionSizeAttr, size_vec);
3730 } else {
3731 auto profiled_ints = pn->is(reductionSizeAttr);
3732 if (profiled_ints.size() != size_vec.size() ||
3733 !std::equal(
3734 profiled_ints.begin(), profiled_ints.end(), size_vec.begin())) {
3735 TORCH_WARN_ONCE(
3736 __FUNCTION__,
3737 " sees varying value in profiling, ignoring and this should be handled by GUARD logic");
3738 pn->s_(profileFailedAttr, "varying profile values");
3739 pn->removeAttribute(reductionSizeAttr);
3740 }
3741 }
3742 } else {
3743 TORCH_INTERNAL_ASSERT(
3744 !pn->hasAttribute(reductionSizeAttr),
3745 "profiled attribute should have been removed when profiling is marked as failed");
3746 }
3747 push(stack, value);
3748 };
3749 pn->setCallback(ivalue_profiler);
3750}
3751
3752void profileViewSize(ProfilingRecord* pr, Node* node, size_t offset) {
3753 auto pn = insertProfileIValueOp(node, offset, pr);
3754
3755 const auto ivalue_profiler = [pr, pn](Stack& stack) {
3756 std::lock_guard<std::mutex> lock(pr->mutex_);
3757
3758 // TODO: we don't care about merging multiple profiling runs as we don't
3759 // support it at all;
3760 int64_t frame_id = 0;
3761 pop(stack, frame_id);
3762 IValue value;
3763 pop(stack, value);
3764 TORCH_INTERNAL_ASSERT(
3765 value.isIntList(), "profiling seeing the wrong data type");
3766 if (!pn->hasAttribute(profileFailedAttr)) {
3767 if (!pn->hasAttribute(viewSizeAttr)) {
3768 pn->is_(viewSizeAttr, value.toIntVector());
3769 } else {
3770 auto profiled_ints = pn->is(viewSizeAttr);
3771 auto input_ints = value.toIntList();
3772 if (profiled_ints.size() != input_ints.size() ||
3773 !std::equal(
3774 profiled_ints.begin(),
3775 profiled_ints.end(),
3776 input_ints.begin())) {
3777 TORCH_WARN_ONCE(
3778 __FUNCTION__,
3779 " sees varying value in profiling, ignoring and this should be handled by GUARD logic");
3780 pn->s_(profileFailedAttr, "varying profile values");
3781 pn->removeAttribute(viewSizeAttr);
3782 }
3783 }
3784 } else {
3785 TORCH_INTERNAL_ASSERT(
3786 !pn->hasAttribute(viewSizeAttr),
3787 "profiled attribute should have been removed when profiling is marked as failed");
3788 }
3789 push(stack, value);
3790 };
3791
3792 pn->setCallback(ivalue_profiler);
3793}
3794
3795void profileIntList(ProfilingRecord* pr, Node* node, size_t offset) {
3796 auto pn = insertProfileIValueOp(node, offset, pr);
3797
3798 const auto ivalue_profiler = [pr, pn](Stack& stack) {
3799 std::lock_guard<std::mutex> lock(pr->mutex_);
3800
3801 // TODO: we don't care about merging multiple profiling runs as we don't
3802 // support it at all;
3803 int64_t frame_id = 0;
3804 pop(stack, frame_id);
3805 IValue value;
3806 pop(stack, value);
3807 TORCH_INTERNAL_ASSERT(
3808 value.isIntList(), "profiling seeing the wrong data type");
3809 if (!pn->hasAttribute(profileFailedAttr)) {
3810 if (!pn->hasAttribute(intListAttr)) {
3811 pn->is_(intListAttr, value.toIntVector());
3812 } else {
3813 auto profiled_ints = pn->is(intListAttr);
3814 auto input_ints = value.toIntList();
3815 if (profiled_ints.size() != input_ints.size() ||
3816 !std::equal(
3817 profiled_ints.begin(),
3818 profiled_ints.end(),
3819 input_ints.begin())) {
3820 TORCH_WARN_ONCE(
3821 __FUNCTION__,
3822 " sees varying value in profiling, ignoring and this should be handled by GUARD logic");
3823 pn->s_(profileFailedAttr, "varying profile values");
3824 pn->removeAttribute(intListAttr);
3825 }
3826 }
3827 } else {
3828 TORCH_INTERNAL_ASSERT(
3829 !pn->hasAttribute(intListAttr),
3830 "profiled attribute should have been removed when profiling is marked as failed");
3831 }
3832 push(stack, value);
3833 };
3834
3835 pn->setCallback(ivalue_profiler);
3836}
3837
3838void profileString(ProfilingRecord* pr, Node* node, size_t offset) {
3839 auto pn = insertProfileIValueOp(node, offset, pr);
3840
3841 const auto ivalue_profiler = [pr, pn](Stack& stack) {
3842 std::lock_guard<std::mutex> lock(pr->mutex_);
3843
3844 // TODO: we don't care about merging multiple profiling runs as we don't
3845 // support it at all;
3846 int64_t frame_id = 0;
3847 pop(stack, frame_id);
3848 IValue value;
3849 pop(stack, value);
3850 TORCH_INTERNAL_ASSERT(
3851 value.isString(), "profiling seeing the wrong data type");
3852 if (!pn->hasAttribute(profileFailedAttr)) {
3853 if (!pn->hasAttribute(strAttr)) {
3854 pn->s_(strAttr, value.toStringRef());
3855 } else {
3856 const auto& profiled_str = pn->s(strAttr);
3857 const auto& input_str = value.toStringRef();
3858 if (input_str != profiled_str) {
3859 TORCH_WARN_ONCE(
3860 __FUNCTION__,
3861 " sees varying value in profiling, ignoring and this should be handled by GUARD logic");
3862 pn->s_(profileFailedAttr, "varying profile values");
3863 pn->removeAttribute(strAttr);
3864 }
3865 }
3866 } else {
3867 TORCH_INTERNAL_ASSERT(
3868 !pn->hasAttribute(strAttr),
3869 "profiled attribute should have been removed when profiling is marked as failed");
3870 }
3871 push(stack, value);
3872 };
3873
3874 pn->setCallback(ivalue_profiler);
3875}
3876
3877void profileBool(ProfilingRecord* pr, Node* node, size_t offset) {
3878 auto pn = insertProfileIValueOp(node, offset, pr);
3879
3880 const auto ivalue_profiler = [pr, pn](Stack& stack) {
3881 std::lock_guard<std::mutex> lock(pr->mutex_);
3882
3883 // TODO: we don't care about merging multiple profiling runs as we don't
3884 // support it at all;
3885 int64_t frame_id = 0;
3886 pop(stack, frame_id);
3887 IValue value;
3888 pop(stack, value);
3889 TORCH_INTERNAL_ASSERT(
3890 value.isBool(), "profiling seeing the wrong data type");
3891 if (!pn->hasAttribute(profileFailedAttr)) {
3892 if (!pn->hasAttribute(boolAttr)) {
3893 pn->i_(boolAttr, value.toBool());
3894 } else {
3895 auto profiled_bool = pn->i(boolAttr);
3896 auto input_bool = value.toBool();
3897 if (input_bool != profiled_bool) {
3898 TORCH_WARN_ONCE(
3899 __FUNCTION__,
3900 " sees varying value in profiling, ignoring and this should be handled by GUARD logic");
3901 pn->s_(profileFailedAttr, "varying profile values");
3902 pn->removeAttribute(boolAttr);
3903 }
3904 }
3905 } else {
3906 TORCH_INTERNAL_ASSERT(
3907 !pn->hasAttribute(boolAttr),
3908 "profiled attribute should have been removed when profiling is marked as failed");
3909 }
3910 push(stack, value);
3911 };
3912
3913 pn->setCallback(ivalue_profiler);
3914}
3915
3916void profileInt(ProfilingRecord* pr, Node* node, size_t offset) {
3917 auto pn = insertProfileIValueOp(node, offset, pr);
3918
3919 const auto ivalue_profiler = [pr, pn](Stack& stack) {
3920 std::lock_guard<std::mutex> lock(pr->mutex_);
3921
3922 // TODO: we don't care about merging multiple profiling runs as we don't
3923 // support it at all;
3924 int64_t frame_id = 0;
3925 pop(stack, frame_id);
3926 IValue value;
3927 pop(stack, value);
3928 TORCH_INTERNAL_ASSERT(
3929 value.isInt(), "profiling seeing the wrong data type");
3930 if (!pn->hasAttribute(profileFailedAttr)) {
3931 if (!pn->hasAttribute(intAttr)) {
3932 pn->i_(intAttr, value.toInt());
3933 } else {
3934 auto profiled_int = pn->i(intAttr);
3935 auto input_int = value.toInt();
3936 if (input_int != profiled_int) {
3937 TORCH_WARN_ONCE(
3938 __FUNCTION__,
3939 " sees varying value in profiling, ignoring and this should be handled by GUARD logic");
3940 pn->s_(profileFailedAttr, "varying profile values");
3941 pn->removeAttribute(intAttr);
3942 }
3943 }
3944 } else {
3945 TORCH_INTERNAL_ASSERT(
3946 !pn->hasAttribute(intAttr),
3947 "profiled attribute should have been removed when profiling is marked as failed");
3948 }
3949 push(stack, value);
3950 };
3951
3952 pn->setCallback(ivalue_profiler);
3953}
3954
3955// profile ivalue, used for optional arguments
3956void profileIval(ProfilingRecord* pr, Node* node, size_t offset) {
3957 auto pn = insertProfileIValueOp(node, offset, pr);
3958
3959 const auto ivalue_profiler = [pr, pn](Stack& stack) {
3960 std::lock_guard<std::mutex> lock(pr->mutex_);
3961
3962 // TODO: we don't care about merging multiple profiling runs as we don't
3963 // support it at all;
3964 int64_t frame_id = 0;
3965 pop(stack, frame_id);
3966 IValue value;
3967 pop(stack, value);
3968 if (!pn->hasAttribute(profileFailedAttr)) {
3969 if (!pn->hasAttribute(ivalAttr)) {
3970 pn->ival_(ivalAttr, value);
3971 } else {
3972 auto profiled_ival = pn->ival(ivalAttr);
3973 if (value != profiled_ival) {
3974 TORCH_WARN_ONCE(
3975 __FUNCTION__,
3976 " sees varying value in profiling, ignoring and this should be handled by GUARD logic");
3977 pn->s_(profileFailedAttr, "varying profile values");
3978 pn->removeAttribute(ivalAttr);
3979 }
3980 }
3981 } else {
3982 TORCH_INTERNAL_ASSERT(
3983 !pn->hasAttribute(ivalAttr),
3984 "profiled attribute should have been removed when profiling is marked as failed");
3985 }
3986 push(stack, value);
3987 };
3988
3989 pn->setCallback(ivalue_profiler);
3990}
3991
3992void profileBoolList(ProfilingRecord* pr, Node* node, size_t offset) {
3993 auto pn = insertProfileIValueOp(node, offset, pr);
3994
3995 const auto ivalue_profiler = [pr, pn](Stack& stack) {
3996 std::lock_guard<std::mutex> lock(pr->mutex_);
3997
3998 // TODO: we don't care about merging multiple profiling runs as we don't
3999 // support it at all;
4000 int64_t frame_id = 0;
4001 pop(stack, frame_id);
4002 IValue value;
4003 pop(stack, value);
4004 TORCH_INTERNAL_ASSERT(
4005 value.isBoolList(), "profiling seeing the wrong data type");
4006 if (!pn->hasAttribute(profileFailedAttr)) {
4007 if (!pn->hasAttribute(boolListAttr)) {
4008 auto list = value.toBoolList();
4009 std::vector<int64_t> val(list.begin(), list.end());
4010 pn->is_(boolListAttr, val);
4011 } else {
4012 auto profiled_ints = pn->is(boolListAttr);
4013 auto input_bools = value.toBoolList();
4014 if (profiled_ints.size() != input_bools.size() ||
4015 !std::equal(
4016 input_bools.begin(),
4017 input_bools.end(),
4018 profiled_ints.begin())) {
4019 TORCH_WARN_ONCE(
4020 __FUNCTION__,
4021 " sees varying value in profiling, ignoring and this should be handled by GUARD logic");
4022 pn->s_(profileFailedAttr, "varying profile values");
4023 pn->removeAttribute(boolListAttr);
4024 }
4025 }
4026 } else {
4027 TORCH_INTERNAL_ASSERT(
4028 !pn->hasAttribute(boolListAttr),
4029 "profiled attribute should have been removed when profiling is marked as failed");
4030 }
4031 push(stack, value);
4032 };
4033
4034 pn->setCallback(ivalue_profiler);
4035}
4036
4037bool anyInBlock(
4038 const Block* block,
4039 const std::function<bool(const Node*)>& fn) {
4040 for (auto node : block->nodes()) {
4041 if (fn(node)) {
4042 return true;
4043 }
4044 for (auto block : node->blocks()) {
4045 if (anyInBlock(block, fn)) {
4046 return true;
4047 }
4048 }
4049 }
4050 return false;
4051}
4052
4053} // namespace
4054
4055bool hasReductionNode(const Block* block) {
4056 return anyInBlock(block, isReductionNode);
4057}
4058
4059bool isReductionNode(const Node* node) {
4060 return IrParser::isReductionNode(node);
4061}
4062
4063bool isReductionToSizeNode(const Node* node) {
4064 return IrParser::isReductionToSizeNode(node);
4065}
4066
4067bool hasNormalizationNode(const Block* block) {
4068 return anyInBlock(block, isNormalizationNode);
4069}
4070
4071bool isNormalizationNode(const Node* node) {
4072 return IrParser::isNormalizationNode(node);
4073}
4074
4075bool isElementWiseNode(const Node* node) {
4076 return IrParser::isElementWiseNode(node);
4077}
4078
4079bool isNodeParsible(const Node* node) {
4080 return IrParser::canParseNode(node);
4081}
4082
4083bool shouldProfileNode(const Node* node) {
4084 return IrParser::lookupInSymbolSet(node);
4085}
4086
4087bool skipNodeKind(const std::string& symbol_str, bool flip) {
4088 return IrParser::querySkipSymbolSet(
4089 c10::Symbol::fromQualString(symbol_str), flip);
4090}
4091
4092bool insertProfileIValue(ProfilingRecord* pr, Node* node, size_t offset) {
4093 // is skip constant necessary?
4094 if (node->input(offset)->node()->kind() == prim::Constant) {
4095 return false;
4096 }
4097
4098 static auto dropout_schema =
4099 getOperatorForLiteral(
4100 "aten::dropout(Tensor input, float p, bool train) -> Tensor")
4101 ->schema();
4102 static auto native_dropout_schema =
4103 getOperatorForLiteral(
4104 "aten::native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor)")
4105 ->schema();
4106 if (node->matches(dropout_schema) || node->matches(native_dropout_schema)) {
4107 switch (offset) {
4108 // argument 2: Is training?
4109 case 2:
4110 profileBool(pr, node, offset);
4111 break;
4112 default:
4113 return false;
4114 }
4115 return true;
4116 }
4117
4118 static auto amax_schema =
4119 getOperatorForLiteral(
4120 "aten::amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor")
4121 ->schema();
4122 static auto amin_schema =
4123 getOperatorForLiteral(
4124 "aten::amin(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor")
4125 ->schema();
4126 if (node->matches(amax_schema) || node->matches(amin_schema)) {
4127 switch (offset) {
4128 // argument 1: reduction axes;
4129 case 1:
4130 profileIntList(pr, node, offset);
4131 break;
4132 // argument 2: keepdim;
4133 case 2:
4134 profileBool(pr, node, offset);
4135 break;
4136 default:
4137 return false;
4138 }
4139 return true;
4140 }
4141
4142 static auto reduction_operator_schema =
4143 getOperatorForLiteral(
4144 "aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)")
4145 ->schema();
4146 if (node->matches(reduction_operator_schema)) {
4147 switch (offset) {
4148 // argument 1: reduction axes;
4149 case 1:
4150 profileIntList(pr, node, offset);
4151 break;
4152 // argument 2: keepdim;
4153 case 2:
4154 profileBool(pr, node, offset);
4155 break;
4156 default:
4157 return false;
4158 }
4159 return true;
4160 }
4161
4162 static auto sum_to_size_schema =
4163 getOperatorForLiteral(
4164 "aten::sum_to_size(Tensor self, int[] size) -> Tensor")
4165 ->schema();
4166 static auto grad_sum_to_size_schema =
4167 getOperatorForLiteral(
4168 "aten::_grad_sum_to_size(Tensor(a) self, int[]? size) -> Tensor(a)")
4169 ->schema();
4170 if (node->matches(sum_to_size_schema) ||
4171 node->matches(grad_sum_to_size_schema)) {
4172 switch (offset) {
4173 // argument 1: reduction sizes;
4174 case 1:
4175 // TODO(profile_size): double check optional[size]?
4176 profileReductionSize(pr, node, offset);
4177 break;
4178 default:
4179 return false;
4180 }
4181 return true;
4182 }
4183
4184 static auto reshape_schema =
4185 getOperatorForLiteral("aten::reshape(Tensor self, int[] shape) -> Tensor")
4186 ->schema();
4187 static auto reshape_copy_schema =
4188 getOperatorForLiteral(
4189 "prim::reshape_copy(Tensor self, int[] shape) -> Tensor")
4190 ->schema();
4191 static auto view_schema =
4192 getOperatorForLiteral("aten::view(Tensor self, int[] size) -> Tensor")
4193 ->schema();
4194 static auto view_copy_schema =
4195 getOperatorForLiteral(
4196 "prim::view_copy(Tensor self, int[] size) -> Tensor")
4197 ->schema();
4198 if (node->matches(reshape_schema) || node->matches(reshape_copy_schema) ||
4199 node->matches(view_schema) || node->matches(view_copy_schema)) {
4200 switch (offset) {
4201 // argument 1: new tensor size;
4202 case 1:
4203 profileViewSize(pr, node, offset);
4204 break;
4205 default:
4206 return false;
4207 }
4208 return true;
4209 }
4210
4211 static auto flatten_schema1 =
4212 getOperatorForLiteral(
4213 "aten::flatten.using_ints(Tensor self, int start_dim=0, int end_dim=-1) -> Tensor")
4214 ->schema();
4215 static auto flatten_schema2 =
4216 getOperatorForLiteral(
4217 "prim::flatten_copy(Tensor self, int start_dim, int end_dim) -> Tensor")
4218 ->schema();
4219 if (node->matches(flatten_schema1) || node->matches(flatten_schema2)) {
4220 switch (offset) {
4221 // argument 1: start_dim;
4222 // argument 2: end_dim;
4223 case 1:
4224 case 2:
4225 profileInt(pr, node, offset);
4226 break;
4227 default:
4228 return false;
4229 }
4230 return true;
4231 }
4232
4233 static auto squeeze_dim_schema =
4234 getOperatorForLiteral(
4235 "prim::squeeze_copy.dim(Tensor self, int dim) -> Tensor")
4236 ->schema();
4237 static auto unsqueeze_schema =
4238 getOperatorForLiteral(
4239 "prim::unsqueeze_copy(Tensor self, int dim) -> Tensor")
4240 ->schema();
4241 if (node->matches(squeeze_dim_schema) || node->matches(unsqueeze_schema)) {
4242 switch (offset) {
4243 // argument 1: unsqueeze dim;
4244 case 1:
4245 profileInt(pr, node, offset);
4246 break;
4247 default:
4248 return false;
4249 }
4250 return true;
4251 }
4252
4253 static auto permute_schema =
4254 getOperatorForLiteral(
4255 "aten::permute(Tensor(a) self, int[] dims) -> Tensor(a)")
4256 ->schema();
4257 static auto permute_copy_schema =
4258 getOperatorForLiteral(
4259 "prim::permute_copy(Tensor(a) self, int[] dims) -> Tensor")
4260 ->schema();
4261 if (node->matches(permute_schema) || node->matches(permute_copy_schema)) {
4262 switch (offset) {
4263 // argument 1: dims;
4264 case 1:
4265 profileIntList(pr, node, offset);
4266 break;
4267 default:
4268 return false;
4269 }
4270 return true;
4271 }
4272
4273 static auto transpose_int_copy_schema =
4274 getOperatorForLiteral(
4275 "aten::transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a)")
4276 ->schema();
4277 static auto transpose_int_schema =
4278 getOperatorForLiteral(
4279 "prim::transpose_copy.int(Tensor(a) self, int dim0, int dim1) -> Tensor")
4280 ->schema();
4281 if (node->matches(transpose_int_copy_schema) ||
4282 node->matches(transpose_int_schema)) {
4283 switch (offset) {
4284 // argument 1: dim0;
4285 // argument 2: dim1;
4286 case 1:
4287 case 2:
4288 profileInt(pr, node, offset);
4289 break;
4290 default:
4291 return false;
4292 }
4293 return true;
4294 }
4295
4296 static auto batch_norm_impl_index_schema =
4297 getOperatorForLiteral(
4298 "aten::_batch_norm_impl_index(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor, Tensor, Tensor, Tensor, int)")
4299 ->schema();
4300 static auto native_batch_norm_schema =
4301 getOperatorForLiteral(
4302 "aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)")
4303 ->schema();
4304 static auto batch_norm_schema =
4305 getOperatorForLiteral(
4306 "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor")
4307 ->schema();
4308 static auto instance_norm_schema =
4309 getOperatorForLiteral(
4310 "aten::instance_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool use_input_stats, float momentum, float eps, bool cudnn_enabled) -> Tensor")
4311 ->schema();
4312 if (node->matches(native_batch_norm_schema) ||
4313 node->matches(batch_norm_impl_index_schema) ||
4314 node->matches(batch_norm_schema) || node->matches(instance_norm_schema)) {
4315 switch (offset) {
4316 // argument 5: training;
4317 case 5:
4318 profileBool(pr, node, offset);
4319 break;
4320 default:
4321 return false;
4322 }
4323 return true;
4324 }
4325
4326 static auto gelu_schema =
4327 getOperatorForLiteral(
4328 "aten::gelu(Tensor self, *, str approximate='none') -> Tensor")
4329 ->schema();
4330 if (node->matches(gelu_schema)) {
4331 switch (offset) {
4332 // argument 1: approximate;
4333 case 1:
4334 profileString(pr, node, offset);
4335 break;
4336 default:
4337 return false;
4338 }
4339 return true;
4340 }
4341
4342 static auto gelu_backward_schema =
4343 getOperatorForLiteral(
4344 "aten::gelu_backward(Tensor grad_output, Tensor self, *, str approximate='none') -> Tensor")
4345 ->schema();
4346 if (node->matches(gelu_backward_schema)) {
4347 switch (offset) {
4348 // argument 2: approximate;
4349 case 2:
4350 profileString(pr, node, offset);
4351 break;
4352 default:
4353 return false;
4354 }
4355 return true;
4356 }
4357
4358 static auto native_layer_norm_schema =
4359 getOperatorForLiteral(
4360 "aten::native_layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor)")
4361 ->schema();
4362 static auto layer_norm_schema =
4363 getOperatorForLiteral(
4364 "aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor")
4365 ->schema();
4366 if (node->matches(native_layer_norm_schema) ||
4367 node->matches(layer_norm_schema)) {
4368 switch (offset) {
4369 case 1:
4370 profileIntList(pr, node, offset);
4371 break;
4372 default:
4373 return false;
4374 }
4375 return true;
4376 }
4377
4378 static auto batch_norm_impl_index_backward_schema =
4379 getOperatorForLiteral(
4380 "aten::_batch_norm_impl_index_backward(int impl_index, Tensor input, Tensor grad_output, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var_transform, bool train, float eps, bool[3] output_mask, Tensor reservedSpace) -> (Tensor, Tensor, Tensor)")
4381 ->schema();
4382 if (node->matches(batch_norm_impl_index_backward_schema)) {
4383 switch (offset) {
4384 // TODO: guard impl_index, but I think that's not needed;
4385 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4386 case 8: // argument 8: training;
4387 profileBool(pr, node, offset);
4388 break;
4389 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4390 case 10:
4391 profileBoolList(pr, node, offset);
4392 break;
4393 default:
4394 return false;
4395 }
4396 return true;
4397 }
4398
4399 static auto batch_norm_backward_schema =
4400 getOperatorForLiteral(
4401 "aten::native_batch_norm_backward(Tensor grad_out, Tensor input, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_invstd, bool train, float eps, bool[3] output_mask) -> (Tensor, Tensor, Tensor)")
4402 ->schema();
4403 if (node->matches(batch_norm_backward_schema)) {
4404 switch (offset) {
4405 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4406 case 7: // argument 8: training;
4407 profileBool(pr, node, offset);
4408 break;
4409 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4410 case 9:
4411 profileBoolList(pr, node, offset);
4412 break;
4413 default:
4414 return false;
4415 }
4416 return true;
4417 }
4418
4419 static auto native_layer_norm_backward_schema =
4420 getOperatorForLiteral(
4421 "aten::native_layer_norm_backward(Tensor grad_out, Tensor input, int[] normalized_shape, Tensor mean, Tensor rstd, Tensor? weight, Tensor? bias, bool[3] output_mask) -> (Tensor, Tensor, Tensor)")
4422 ->schema();
4423 if (node->matches(native_layer_norm_backward_schema)) {
4424 switch (offset) {
4425 case 2:
4426 profileIntList(pr, node, offset);
4427 break;
4428 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
4429 case 7:
4430 profileBoolList(pr, node, offset);
4431 break;
4432 default:
4433 return false;
4434 }
4435 return true;
4436 }
4437
4438 static auto to_copy_schema =
4439 getOperatorForLiteral(
4440 "aten::_to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor")
4441 ->schema();
4442 if (node->matches(to_copy_schema)) {
4443 switch (offset) {
4444 case 1:
4445 profileInt(pr, node, offset);
4446 return true;
4447 default:
4448 return false;
4449 }
4450 }
4451
4452 static auto to_dtype_schema =
4453 getOperatorForLiteral(
4454 "aten::to.dtype(Tensor self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor")
4455 ->schema();
4456 if (node->matches(to_dtype_schema)) {
4457 switch (offset) {
4458 case 1:
4459 profileInt(pr, node, offset);
4460 return true;
4461 default:
4462 return false;
4463 }
4464 }
4465
4466 static auto log_softmax_data_schema =
4467 getOperatorForLiteral(
4468 "aten::log_softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor")
4469 ->schema();
4470 static auto softmax_data_schema =
4471 getOperatorForLiteral(
4472 "aten::softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor")
4473 ->schema();
4474 if (node->matches(log_softmax_data_schema) ||
4475 node->matches(softmax_data_schema)) {
4476 switch (offset) {
4477 case 2:
4478 profileIval(pr, node, offset);
4479 return true;
4480 default:
4481 return false;
4482 }
4483 }
4484
4485 static auto log_softmax_backward_data_schema =
4486 getOperatorForLiteral(
4487 "aten::_log_softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor")
4488 ->schema();
4489 static auto softmax_backward_data_schema =
4490 getOperatorForLiteral(
4491 "aten::_softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor")
4492 ->schema();
4493 if (node->matches(log_softmax_backward_data_schema) ||
4494 node->matches(softmax_backward_data_schema)) {
4495 switch (offset) {
4496 case 2:
4497 profileInt(pr, node, offset);
4498 return true;
4499 case 3:
4500 profileInt(pr, node, offset);
4501 return true;
4502 default:
4503 return false;
4504 }
4505 }
4506
4507 static auto var_dim_schema =
4508 getOperatorForLiteral(
4509 "aten::var.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> Tensor")
4510 ->schema();
4511 static auto std_dim_schema =
4512 getOperatorForLiteral(
4513 "aten::std.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> Tensor")
4514 ->schema();
4515 if (node->matches(var_dim_schema) || node->matches(std_dim_schema)) {
4516 switch (offset) {
4517 case 1:
4518 profileIntList(pr, node, offset);
4519 return true;
4520 case 2:
4521 profileBool(pr, node, offset);
4522 return true;
4523 case 3:
4524 profileBool(pr, node, offset);
4525 return true;
4526 default:
4527 return false;
4528 }
4529 }
4530
4531 return false;
4532}
4533
4534void insertProfileNodesForCUDAFuser_(Block* block, ProfilingRecord* pr) {
4535 for (const auto& n : block->nodes()) {
4536 for (const auto offset : c10::irange(n->inputs().size())) {
4537 insertProfileIValue(pr, n, offset);
4538 }
4539
4540 for (auto ib : n->blocks()) {
4541 insertProfileNodesForCUDAFuser_(ib, pr);
4542 }
4543 }
4544}
4545
4546void InsertProfileNodes(ProfilingRecord* pr) {
4547 insertProfileNodesForCUDAFuser_(pr->profiled_graph_->block(), pr);
4548}
4549
4550std::unique_ptr<Fusion> parseJitIR(const std::shared_ptr<Graph>& graph) {
4551 FUSER_PERF_SCOPE("parseJitIR");
4552
4553 IrParser parser(graph);
4554 return parser.parse();
4555}
4556
4557} // namespace cuda
4558} // namespace fuser
4559} // namespace jit
4560} // namespace torch
4561