1 | #include <parser.h> |
2 | |
3 | #include <arith.h> |
4 | #include <instrumentation.h> |
5 | #include <ir_all_nodes.h> |
6 | #include <ir_builder.h> |
7 | #include <ir_iostream.h> |
8 | #include <ops/all_ops.h> |
9 | #include <type_inference.h> |
10 | #include <type_promotion.h> |
11 | #include <utils.h> |
12 | |
13 | #include <torch/csrc/jit/frontend/function_schema_parser.h> |
14 | #include <torch/csrc/jit/ir/constants.h> |
15 | |
16 | #include <ATen/native/Activation.h> |
17 | |
18 | #include <c10/util/CallOnce.h> |
19 | |
20 | #include <unordered_map> |
21 | #include <utility> |
22 | |
23 | namespace torch { |
24 | namespace jit { |
25 | |
26 | typedef Value JitValue; |
27 | typedef Node JitOp; |
28 | |
29 | namespace fuser { |
30 | namespace cuda { |
31 | |
32 | constexpr auto kNumUnaryOps = 10; |
33 | constexpr auto kNumUnaryFloatOps = 23; |
34 | constexpr auto kNumUnaryIsOps = 6; |
35 | |
36 | constexpr auto kNumBinaryFloatOps = 3; |
37 | constexpr auto kNumBinaryComparisonOps = 12; |
38 | constexpr auto kNumBinaryCastOps = 19; |
39 | |
40 | constexpr auto kNumBinaryOpsWithAlpha = 6; |
41 | constexpr auto kNumLerpOps = 2; |
42 | constexpr auto kNumLayernormFwd = 2; |
43 | constexpr auto kNumBatchnormFwd = 3; |
44 | constexpr auto kNumBatchnormBwd = 2; |
45 | constexpr auto kNumInstancenormFwd = 1; |
46 | constexpr auto kNumSumToSize = 2; |
47 | constexpr auto kNumAutocastOps = 2; |
48 | constexpr auto kNumAliasDimOps = 2; |
49 | constexpr auto kNumViewOps = 2; |
50 | constexpr auto kNumVarOps = 2; |
51 | constexpr auto kNumSoftmaxFwd = 2; |
52 | constexpr auto kNumSoftmaxBwd = 2; |
53 | constexpr auto kNumAminAmaxOps = 2; |
54 | |
55 | namespace { |
56 | |
57 | #define REGISTER_PARSE_RULE(op, func_body, ...) \ |
58 | registerParseRule( \ |
59 | op, \ |
60 | [](const Node* node, std::unordered_map<size_t, ValueHolder>& value_map) \ |
61 | -> void func_body, \ |
62 | __VA_ARGS__) |
63 | |
64 | const auto& reductionSizeAttr = Symbol::attr("profiled_reduction_size" ); |
65 | const auto& viewSizeAttr = Symbol::attr("profiled_view_size" ); |
66 | const auto& intListAttr = Symbol::attr("profiled_int_list" ); |
67 | const auto& intAttr = Symbol::attr("profiled_int" ); |
68 | const auto& boolListAttr = Symbol::attr("profiled_bool_list" ); |
69 | const auto& boolAttr = Symbol::attr("profiled_bool" ); |
70 | const auto& strAttr = Symbol::attr("profiled_str" ); |
71 | const auto& ivalAttr = Symbol::attr("profiled_ival" ); |
72 | const auto& profileFailedAttr = Symbol::attr("profile_failed" ); |
73 | |
74 | typedef Val* CgValue; |
75 | typedef Expr* CgOp; |
76 | |
77 | Val* castTensoToDtype(CgValue self, JitValue* cast_val) { |
78 | auto cast_ival = toIValue(cast_val); |
79 | // we need static type for cast |
80 | TORCH_INTERNAL_ASSERT(cast_ival.has_value()); |
81 | if (cast_ival->isInt()) { |
82 | auto dtype = cast_ival->toScalarType(); |
83 | |
84 | // We want to keep our internal fusion math in FP32 |
85 | // Shape Inference will continue to propagate the right |
86 | // type to outputs unchanged. |
87 | if (dtype == at::ScalarType::Half || dtype == at::ScalarType::BFloat16) { |
88 | dtype = at::ScalarType::Float; |
89 | } |
90 | |
91 | return castOp(aten_to_data_type(dtype), self); |
92 | } else { |
93 | TORCH_INTERNAL_ASSERT( |
94 | cast_ival->isNone(), |
95 | "unrecognized dtype option, expect 'int' but got: " , |
96 | cast_ival->tagKind()); |
97 | |
98 | // return a copy if dtype is `None` |
99 | return set(self); |
100 | } |
101 | } |
102 | |
103 | bool isReductionNonCompatibleTensor( |
104 | const std::shared_ptr<c10::TensorType>& tensor_type) { |
105 | return is_zero_dim_tensor(tensor_type) || is_zero_sized_tensor(tensor_type); |
106 | } |
107 | |
108 | bool isInputNonSizeZeroTensor(const Node* node) { |
109 | for (const auto& val : node->inputs()) { |
110 | auto tensor_type = val->type()->cast<TensorType>(); |
111 | if (tensor_type && is_zero_sized_tensor(tensor_type)) { |
112 | return false; |
113 | } |
114 | } |
115 | return true; |
116 | } |
117 | |
118 | bool isScalarTypeCompatible(const Node* node, size_t offset) { |
119 | auto val = node->input(offset); |
120 | // return true if it's not specified |
121 | if (val->type()->isSubtypeOf(static_cast<c10::TypePtr>(NoneType::get()))) { |
122 | return true; |
123 | } |
124 | // return false if it's runtime value |
125 | if (val->node()->kind() != prim::Constant) { |
126 | return false; |
127 | } |
128 | auto dtype = toIValue(val)->toScalarType(); |
129 | |
130 | // we do NOT support half math type yet |
131 | if (dtype == at::ScalarType::Half || dtype == at::ScalarType::BFloat16) { |
132 | return false; |
133 | } |
134 | return true; |
135 | } |
136 | |
137 | // Note [ Permutation Bookkeeping and Propagation in Parser ] |
138 | // |
139 | // The goal in supporting permutation propagation in parser is to: |
140 | // 1. resolves conflicts and propagate permutation; |
141 | // 2. bookkeeping of permutation on existing tensors; |
142 | // |
143 | // The requirement right now is that all parsing rules should support |
144 | // non-permuted inputs, some binary operations support inputs with arbitrary |
145 | // permutation, a few operations support special inputs. |
146 | // In case where "wrong" inputs are fed to an operation, we should transpose |
147 | // it to proper supported permutation. This allows us to progressively expand |
148 | // permutation support. |
149 | // Currently we bind all permuted codegen Val in `ValueHolder`. This saves |
150 | // unnecessary transpose (not sure if it actually helps) since we can reuse |
151 | // permuted tensors. |
152 | // |
153 | // Parsing rule pattern: |
154 | // a. ops that only support non-permuted inputs (e.g. sum) |
155 | // |
156 | // // Specifying `MemoryFormat::Contiguous` here to force all inputs to be in |
157 | // // `Contiguous` |
158 | // auto [format, self] = getConsistentValues( |
159 | // MemoryFormat::Contiguous, |
160 | // value_map[node->inputs()[0]->unique()]); |
161 | // // ... use self |
162 | // |
163 | // b. format agnostic ops (e.g. PW unary/binary op like aten::add) |
164 | // |
165 | // // getConsistentValues -> return target format and copies of operands in |
166 | // // the same format |
167 | // auto [format, lhs, rhs] = getConsistentValues( |
168 | // c10::nullopt, |
169 | // value_map[node->inputs()[0]->unique()], |
170 | // value_map[node->inputs()[1]->unique()]); |
171 | // |
172 | // // compute out |
173 | // auto out = binaryOp(op_mapping[node->kind()], lhs, rhs); |
174 | // // specify `format` for out when adding it to `value_map_` |
175 | // value_map.emplace(node->output()->unique(), ValueHolder(out, format)); |
176 | // |
177 | // c. ops that supports special permutation. e.g. aten::batch_norm with |
178 | // channels-last inputs. |
179 | |
180 | struct MemoryFormat { |
181 | // indices of dimensions with increasing stride. |
182 | std::vector<int> permuted_order_; |
183 | |
184 | // permutation_ encodes `permuted_order_` by concatenating all elements, with |
185 | // the exception for unpermuted tensor, where we special case permutation_ to |
186 | // be 0. |
187 | // |
188 | // e.g. for an channels-last tensor, permutation_ would be (n-1)123...(n-2); |
189 | // Note: we are omitting the leading '0' when applicable, and apparently this |
190 | // encoding only works with rank < 10 |
191 | // see [ Note: MemoryFormat and Stride Order ] |
192 | size_t permutation_ = 0; |
193 | |
194 | // default to non-permuted tensor |
195 | MemoryFormat() = default; |
196 | |
197 | // [ Note: MemoryFormat and Stride Order ] |
198 | // stride_order is extracted from |
199 | // `TensorType::stride_properties()::stride_index_`, it describes the |
200 | // index of axes from fastest to slowest. |
201 | // or a 4d tensor, if we have stride_order = {x0, x1, x2, x3}, The i-th |
202 | // fastest dimension would be stride_order[i]. |
203 | // |
204 | // Look at comment for c10::Stride in aten/src/ATen/core/jit_type.h |
205 | // |
206 | // eg0. for rank 4 non-permuted tensor, stride_order would be {3, 2, 1, 0}, it |
207 | // means the fastest dimension is axis-3. the next one would be 2, e.t.c.. So |
208 | // it's a non-permuted tensor. |
209 | // it should be encoded as permutation_ = 3210 (we special case it to 0) |
210 | // |
211 | // eg1. for rank 4 channels-last tensor, stride_order would be {1, 3, 2, 0}, |
212 | // it means the fastest dimension is axis-1. the next one would be 3, and then |
213 | // 2, and then 0. So this is a channels last tensor (NCHW). |
214 | // it will be encoded as permutation_ = 1320 |
215 | // |
216 | // eg2. for a rank 4 permuted tensor, stride_order can be {0, 3, 2, 1} |
217 | // it will be encoded as permutation_ = 321 (omitting leading '0') |
218 | void setPermutation(const std::vector<int>& stride_order) { |
219 | int rank = stride_order.size(); |
220 | TORCH_INTERNAL_ASSERT( |
221 | rank <= 10, "MemoryFormat for permutation only supports rank <= 10" ); |
222 | |
223 | // storing stride_order in `permuted_order` for a simpler life, so we don't |
224 | // have to decode `permutation_` when we want to apply/restore permutation_. |
225 | permuted_order_ = stride_order; |
226 | bool has_permutation = false; |
227 | permutation_ = 0; |
228 | for (const auto i : c10::irange(rank)) { |
229 | permutation_ = permutation_ * 10 + stride_order[i]; |
230 | if (!has_permutation && stride_order[i] != rank - 1 - i) { |
231 | has_permutation = true; |
232 | } |
233 | } |
234 | |
235 | // special case permutation_ to reflect non-permuted tensor |
236 | if (!has_permutation) { |
237 | permutation_ = 0; |
238 | } |
239 | } |
240 | |
241 | // returns the stride order for given MemoryFormat encoding permutation_ |
242 | // |
243 | // see details for encoding in [ Note: MemoryFormat and Stride Order ] |
244 | std::vector<int> toStrideOrder() const { |
245 | std::vector<int> stride_order; |
246 | // return empty vector for no permutation |
247 | if (hasPermutation()) { |
248 | // be generous with reserved space |
249 | stride_order.reserve(10); |
250 | bool encountered_zero = false; |
251 | size_t permutation = permutation_; |
252 | while (permutation != 0) { |
253 | int order = static_cast<int>(permutation % 10); |
254 | permutation /= 10; |
255 | if (order == 0) { |
256 | encountered_zero = true; |
257 | } |
258 | stride_order.push_back(order); |
259 | } |
260 | if (!encountered_zero) { |
261 | // in case leading '0' is omitted, push it back |
262 | stride_order.push_back(0); |
263 | } |
264 | // since we use push_back, our stride_order is reversed. |
265 | std::reverse(stride_order.begin(), stride_order.end()); |
266 | } |
267 | return stride_order; |
268 | } |
269 | |
270 | // returns c10::nullopt when it's not safe to broadcast current permutation to |
271 | // rank |
272 | c10::optional<MemoryFormat> broadcastToRank(size_t rank) const { |
273 | auto ret = Contiguous(); |
274 | if (hasPermutation()) { |
275 | auto stride_order = toStrideOrder(); |
276 | auto cur_rank = stride_order.size(); |
277 | // no op for (cur_rank == 0) || (cur_rank == rank) |
278 | if (cur_rank < rank) { |
279 | // broadcasting to hight rank can be done by: |
280 | // 1. incrementing all existing stride order by rank_diff; |
281 | // 2. push back decrementing elements starting with rank_diff; |
282 | // where rank_diff = rank - cur_rank |
283 | // |
284 | // see [ Note: MemoryFormat and Stride Order] |
285 | // e.g. |
286 | // taking broadcasted bias for channels last as an example |
287 | // stride_order = {0, 2, 1} broadcasted to rank == 4 would give us |
288 | // rank_diff = 4 - 3 = 1 |
289 | // take step 1 -> {1, 3, 2} |
290 | // take step 2 -> {1, 3, 2, 0} |
291 | int rank_diff = static_cast<int>(rank - cur_rank); |
292 | for (auto& val : stride_order) { |
293 | val += rank_diff; |
294 | } |
295 | for (int i = rank_diff - 1; i >= 0; i--) { |
296 | stride_order.push_back(i); |
297 | } |
298 | } else if (cur_rank > rank) { |
299 | // shrink permutation to lower rank. We can simply discard higher rank |
300 | // stride order when they are not permuted to lower rank bit, because in |
301 | // those instance we can't obey broadcasting semantics while preserving |
302 | // permutation. We check for stride order and ensure that the lower |
303 | // `rank` bits are all permuted within the lower rank. Afterwards, we |
304 | // update stride_order by decrement each entry by rank_diff to reflect |
305 | // correct stride order. |
306 | // |
307 | // see [ Note: MemoryFormat and Stride Order] |
308 | // e.g. for rank 4 channels last {1, 3, 2, 0}: |
309 | // 1. format can safely shrink to rank 3, since any@{1, 3, 2} >= |
310 | // (4-3); We ditch last (4-3) rank and decrement each element by (4-1) |
311 | // that gives us {0, 2, 1}; |
312 | // 2. but when we shrink it to rank 2, we have {1, 3} where 1 < (4-2) |
313 | // and it can't be handled, we return c10::nullopt. |
314 | int collapsed_ranks = static_cast<int>(cur_rank - rank); |
315 | for (size_t i = 0; i < rank; i++) { |
316 | if (stride_order[i] < collapsed_ranks) { |
317 | // illegal collapsing, return c10::nullopt |
318 | return c10::nullopt; |
319 | } |
320 | // update collapsed stride_order |
321 | stride_order[i] -= collapsed_ranks; |
322 | } |
323 | // discard higher rank stride order. |
324 | stride_order.resize(rank); |
325 | } |
326 | ret.setPermutation(stride_order); |
327 | } |
328 | return ret; |
329 | } |
330 | |
331 | // returns non-permuted format |
332 | static MemoryFormat Contiguous() { |
333 | return MemoryFormat(); |
334 | } |
335 | |
336 | bool hasPermutation() const { |
337 | return permutation_ != 0; |
338 | } |
339 | |
340 | bool isChannelsLast() const { |
341 | int rank = permuted_order_.size(); |
342 | |
343 | if (rank > 2 && permuted_order_[0] == 1 && permuted_order_[rank - 1] == 0) { |
344 | for (const auto i : c10::irange(rank - 2)) { |
345 | if (permuted_order_[i + 1] != rank - 1 - i) { |
346 | return false; |
347 | } |
348 | } |
349 | return true; |
350 | } |
351 | return false; |
352 | } |
353 | |
354 | // returns transpose map to achieve permutation on non-permuted tensor |
355 | // note: used for aten::permute API and codegen tranpose API |
356 | std::vector<int64_t> apply() const { |
357 | std::vector<int64_t> ret; |
358 | if (hasPermutation()) { |
359 | ret.resize(permuted_order_.size()); |
360 | std::copy(permuted_order_.rbegin(), permuted_order_.rend(), ret.begin()); |
361 | } |
362 | return ret; |
363 | } |
364 | |
365 | // returns transpose map to restore back to non-permuted tensor |
366 | // note: used for aten::permute API and codegen transpose API |
367 | std::vector<int64_t> restore() const { |
368 | std::vector<int64_t> ret; |
369 | if (hasPermutation()) { |
370 | int rank = permuted_order_.size(); |
371 | ret.resize(rank); |
372 | for (const auto i : c10::irange(rank)) { |
373 | ret[permuted_order_[i]] = rank - 1 - i; |
374 | } |
375 | } |
376 | return ret; |
377 | } |
378 | }; |
379 | |
380 | struct MemoryCompare { |
381 | bool operator()(const MemoryFormat& format0, const MemoryFormat& format1) |
382 | const { |
383 | return format0.permutation_ < format1.permutation_; |
384 | } |
385 | }; |
386 | |
387 | typedef std::map<MemoryFormat, CgValue, MemoryCompare> MemoryFormatMap; |
388 | |
389 | MemoryFormat operator+(const MemoryFormat& a, const MemoryFormat& b) { |
390 | // Note: TensorIterator logic uses first input to dominate output MemoryFormat |
391 | // so instead of `a.permutation_ >= b.permutation_ ? a : b;`, we use: |
392 | return a; |
393 | }; |
394 | |
395 | //! ValueHolder is holds multiple copies in different permutation `MemoryFormat` |
396 | //! of a tensor view. This mainly serves two purposes: |
397 | //! |
398 | //! 1. reuse permuted tensor views among consumers |
399 | //! 2. bookkeeping for permuted tensor views in input/output tensors |
400 | //! |
401 | //! refer to Note [ Permutation Bookkeeping and Propagation in Parser ] |
402 | class ValueHolder { |
403 | public: |
404 | // checks if given Val in target format exists. |
405 | bool hasValue(const MemoryFormat& format) const { |
406 | return vals_.count(format) != 0; |
407 | } |
408 | |
409 | // returns Val in target format. |
410 | CgValue value(const MemoryFormat& format) const { |
411 | auto iter_val = vals_.find(format); |
412 | TORCH_INTERNAL_ASSERT( |
413 | iter_val != vals_.end(), "accessing non existing c_last_value()" ); |
414 | return iter_val->second; |
415 | } |
416 | |
417 | // returns Val in target format if it exists, otherwise, transpose an existing |
418 | // copy and add that to bookkeeping. |
419 | CgValue maybeConvertValue(const MemoryFormat& format) { |
420 | auto cur_rank = rank(); |
421 | // scalar (tensor) where cur_rank == 0, memory format doesn't carry meaning |
422 | // and should just return the value as-is. same for non-tensor where |
423 | // cur_rank == -1 |
424 | if (cur_rank <= 0) { |
425 | return std::get<1>(getEntry()); |
426 | } |
427 | MemoryFormat format_s; |
428 | CgValue value_s = nullptr; |
429 | std::tie(format_s, value_s) = getEntry(); |
430 | |
431 | auto opt_format_d = format.broadcastToRank(static_cast<size_t>(cur_rank)); |
432 | TORCH_INTERNAL_ASSERT( |
433 | opt_format_d.has_value(), |
434 | "maybeConvertValue requested for illegal permutation" ); |
435 | MemoryFormat format_d = opt_format_d.value(); |
436 | |
437 | auto iter_val = vals_.find(format_d); |
438 | if (iter_val != vals_.end()) { |
439 | return iter_val->second; |
440 | } |
441 | auto val = convertValue(format_d, format_s, value_s); |
442 | vals_[format_d] = val; |
443 | return val; |
444 | } |
445 | |
446 | int rank() const { |
447 | if (!is_tensor_view_) { |
448 | return -1; |
449 | } else { |
450 | auto v = std::get<1>(getEntry()); |
451 | TORCH_INTERNAL_ASSERT( |
452 | v->isA<TensorView>(), "can only access rank of TensorView" ); |
453 | return static_cast<int>(v->as<TensorView>()->nDims()); |
454 | } |
455 | } |
456 | |
457 | // TODO: delete this and update accessor for value_map(_) |
458 | ValueHolder() { |
459 | TORCH_INTERNAL_ASSERT(false, "can't default constructor ValueHolder" ); |
460 | } |
461 | |
462 | ValueHolder(CgValue val, MemoryFormat format = MemoryFormat()) { |
463 | vals_[format] = val; |
464 | if (val->isA<TensorView>()) { |
465 | is_tensor_view_ = true; |
466 | } |
467 | } |
468 | |
469 | // returns the MemoryFormat and codegen Val with the highest precedence among |
470 | // existing copies. |
471 | std::tuple<MemoryFormat, CgValue> getEntry() const { |
472 | TORCH_CHECK(!vals_.empty(), "ValueHolder::getEntry() on empty vals_" ); |
473 | // return the last entry, this allows us to prioritize permuted (e.g. |
474 | // channels-last) tensor over non-permuted tensors |
475 | return *vals_.rbegin(); |
476 | } |
477 | |
478 | // TODO: code cleaning in parser so we don't need these. |
479 | // returns Val*, keeping them here just so we have less code change. |
480 | CgValue operator*() const { |
481 | return std::get<1>(getEntry()); |
482 | } |
483 | CgValue operator->() const { |
484 | return std::get<1>(getEntry()); |
485 | } |
486 | operator CgValue() const { |
487 | return std::get<1>(getEntry()); |
488 | } |
489 | |
490 | private: |
491 | // helper function to convert value_s @ format_s to format_d |
492 | CgValue convertValue( |
493 | MemoryFormat format_d, |
494 | MemoryFormat format_s, |
495 | CgValue value_s) { |
496 | TORCH_INTERNAL_ASSERT( |
497 | value_s->isA<TensorView>(), "cannot convert non-TensorView" ); |
498 | auto tv = value_s->as<TensorView>(); |
499 | // TODO: we could probably merge the two if it has perf impact on generated |
500 | // kernel |
501 | |
502 | // restore source permutation |
503 | if (format_s.hasPermutation()) { |
504 | tv = permute(tv, format_s.restore()); |
505 | } |
506 | // apply destination permutation |
507 | if (format_d.hasPermutation()) { |
508 | tv = permute(tv, format_d.apply()); |
509 | } |
510 | return tv; |
511 | } |
512 | |
513 | private: |
514 | // container to hold all copies of value in different MemoryFormat |
515 | // std::unordered_map<MemoryFormat, CgValue> vals_; |
516 | MemoryFormatMap vals_; |
517 | |
518 | // identify scalar Val |
519 | bool is_tensor_view_ = false; |
520 | }; |
521 | |
522 | template <class Func, class... Values> |
523 | auto iterate(Func f, ValueHolder& val) { |
524 | return f(val); |
525 | } |
526 | |
527 | template <class Func, class... Values> |
528 | auto iterate(Func f, ValueHolder& val, Values&... vals) { |
529 | return f(val, iterate(f, vals...)); |
530 | } |
531 | |
532 | // iterate through all vals and return the output MemoryFormat and copies of |
533 | // vals. |
534 | // 1. When `forced_format == c10::nullopt`, target MemoryFormat returns the |
535 | // format of the first val in `vals`, this is to achieve a coherent |
536 | // behavior as with eager TensorIterator; |
537 | // 2. The target can be overwritten vias specifying `forced_format`. |
538 | // |
539 | // Note: take `Values&` by reference, since `maybeConvertValue` needs to modify |
540 | // the entry and we want that to be updated in `value_map_` |
541 | template <class... Values> |
542 | std::pair<MemoryFormat, std::list<CgValue>> getConsistentValues( |
543 | c10::optional<MemoryFormat> forced_format, |
544 | Values&... vals) { |
545 | MemoryFormat format; |
546 | if (forced_format.has_value()) { |
547 | format = forced_format.value(); |
548 | } else { |
549 | // check for identical nDim on vals |
550 | auto rank_func = [](const ValueHolder& val, int rank = 0) { |
551 | int v_rank = val.rank(); |
552 | v_rank = std::max(0, v_rank); |
553 | if (rank == 0) { |
554 | return v_rank; |
555 | } else if (v_rank == 0) { |
556 | return rank; |
557 | } else if (rank == -1 || v_rank != rank) { |
558 | return -1; |
559 | } |
560 | return rank; |
561 | }; |
562 | int rank = iterate(rank_func, vals...); |
563 | |
564 | // TODO: this is not needed as we are only using the first val |
565 | // only apply permutation when all inputs are of identical rank, since |
566 | // permutation could have changed semantics among broadcasted tensors. |
567 | // Consider pointwise operation between two tensor [N, C, H, W] + [H, W] |
568 | if (rank > 0) { |
569 | auto format_func = [](const ValueHolder& val, |
570 | MemoryFormat f = MemoryFormat::Contiguous()) { |
571 | return std::get<0>(val.getEntry()) + f; |
572 | }; |
573 | format = iterate(format_func, vals...); |
574 | } else { |
575 | format = MemoryFormat::Contiguous(); |
576 | } |
577 | } |
578 | |
579 | auto convert_func = [format]( |
580 | ValueHolder& val, std::list<CgValue> list_val = {}) { |
581 | list_val.push_front(val.maybeConvertValue(format)); |
582 | return list_val; |
583 | }; |
584 | auto list_val = iterate(convert_func, vals...); |
585 | |
586 | return std::make_pair(format, list_val); |
587 | } |
588 | |
589 | // iterate through all vals and return the output MemoryFormat and copies of |
590 | // vals. |
591 | // 1. When `forced_format == c10::nullopt`, target MemoryFormat returns the |
592 | // format of the first val in `vals`, this is to achieve a coherent |
593 | // behavior as with eager TensorIterator; |
594 | // 2. The target can be overwritten vias specifying `forced_format`. |
595 | // |
596 | // Note: take `Values&` by reference, since `maybeConvertValue` needs to modify |
597 | // the entry and we want that to be updated in `value_map_` |
598 | template <class... Values> |
599 | std::pair<MemoryFormat, std::list<CgValue>> getPWFormatValues( |
600 | c10::optional<MemoryFormat> forced_format, |
601 | Values&... vals) { |
602 | MemoryFormat format; |
603 | if (forced_format.has_value()) { |
604 | format = forced_format.value(); |
605 | } else { |
606 | // get maximum rank on vals |
607 | std::vector<MemoryFormat> formats; |
608 | std::vector<int> ranks; |
609 | auto max_rank_func = [&ranks](const ValueHolder& val, int rank = 0) { |
610 | int v_rank = val.rank(); |
611 | ranks.push_back(v_rank); |
612 | return std::max(rank, v_rank); |
613 | }; |
614 | int max_rank = iterate(max_rank_func, vals...); |
615 | |
616 | // going through all permutation, keeping consistency with TensorIterator |
617 | // behavior and the first tensor with highest rank dictates output |
618 | // permutation |
619 | auto format_func = [&formats, &max_rank]( |
620 | const ValueHolder& val, |
621 | MemoryFormat f = MemoryFormat::Contiguous()) { |
622 | auto cur_format = std::get<0>(val.getEntry()); |
623 | formats.push_back(cur_format); |
624 | return val.rank() == max_rank ? cur_format : f; |
625 | }; |
626 | format = iterate(format_func, vals...); |
627 | |
628 | // we need to do pair-wise comparison to ensure that all permutation are |
629 | // compatible since permutation could have changed semantics among |
630 | // broadcasted tensors. Consider pointwise operation between three tensor |
631 | // [N, C, H, W] + [C, H, W] + [H, W] |
632 | for (size_t i = 0; i < formats.size() && format.hasPermutation(); i++) { |
633 | for (size_t j = 0; j < formats.size(); j++) { |
634 | // don't compare scalar tensor or scalar |
635 | if (ranks[i] <= 0 || ranks[j] <= 0 || i == j) { |
636 | continue; |
637 | } |
638 | size_t lower_rank = std::min(ranks[i], ranks[j]); |
639 | auto i_format = formats[i].broadcastToRank(lower_rank); |
640 | auto j_format = formats[j].broadcastToRank(lower_rank); |
641 | |
642 | // breaks permutation if any: |
643 | // 1. i_format can't be broadcasted to lower_rank; |
644 | // 2. j_format can't be broadcasted to lower_rank; |
645 | if (!i_format.has_value() || !j_format.has_value()) { |
646 | format = MemoryFormat::Contiguous(); |
647 | } |
648 | } |
649 | } |
650 | } |
651 | |
652 | auto convert_func = [format]( |
653 | ValueHolder& val, std::list<CgValue> list_val = {}) { |
654 | list_val.push_front(val.maybeConvertValue(format)); |
655 | return list_val; |
656 | }; |
657 | auto list_val = iterate(convert_func, vals...); |
658 | |
659 | return std::make_pair(format, list_val); |
660 | } |
661 | |
662 | typedef void ( |
663 | *ParseFuncPtr)(const Node*, std::unordered_map<size_t, ValueHolder>&); |
664 | typedef bool (*MergeQueryFuncPtr)(const Node*); |
665 | |
666 | // TODO: add a mutex to make it thread safe. |
667 | class IrParser { |
668 | enum class OperatorType { |
669 | ElementWise, |
670 | Reduction, |
671 | ReductionToSize, |
672 | Normalization |
673 | }; |
674 | typedef OperatorType (*OperatorTypeFuncPtr)(const Node*); |
675 | |
676 | class RegistrationEntry { |
677 | public: |
678 | RegistrationEntry( |
679 | ParseFuncPtr parse_f, |
680 | MergeQueryFuncPtr merge_f = nullptr, |
681 | OperatorTypeFuncPtr type_f = nullptr) |
682 | : parse_f_(parse_f), merge_f_(merge_f), type_f_(type_f) {} |
683 | |
684 | void parse( |
685 | const Node* node, |
686 | std::unordered_map<size_t, ValueHolder>& values) const { |
687 | parse_f_(node, values); |
688 | } |
689 | |
690 | bool isCompatible(const Node* node) const { |
691 | if (merge_f_ == nullptr) { |
692 | return true; |
693 | } |
694 | return merge_f_(node); |
695 | } |
696 | |
697 | bool isType(const Node* node, OperatorType type) const { |
698 | auto n_type = |
699 | type_f_ == nullptr ? OperatorType::ElementWise : type_f_(node); |
700 | return n_type == type; |
701 | } |
702 | |
703 | private: |
704 | ParseFuncPtr parse_f_; |
705 | MergeQueryFuncPtr merge_f_; |
706 | OperatorTypeFuncPtr type_f_; |
707 | }; |
708 | |
709 | public: |
710 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) |
711 | IrParser(std::shared_ptr<Graph> graph) : graph_(std::move(graph)) { |
712 | initRegistry(); |
713 | } |
714 | |
715 | std::unique_ptr<Fusion> parse() { |
716 | auto fusion = std::make_unique<Fusion>(); |
717 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
718 | FusionGuard fg(fusion.get()); |
719 | auto block = graph_->block(); |
720 | |
721 | std::unordered_map<Val*, MemoryFormat> permuted_tensors; |
722 | // register all inputs; |
723 | for (auto val : block->inputs()) { |
724 | TORCH_INTERNAL_ASSERT( |
725 | registerValue(val), |
726 | "Failure when register value: " , |
727 | *(val->node()), |
728 | " with type: " , |
729 | val->type()->repr_str()); |
730 | MemoryFormat format; |
731 | Val* operand = nullptr; |
732 | std::tie(format, operand) = value_map_[val->unique()].getEntry(); |
733 | fusion->addInput(operand); |
734 | |
735 | // mark input tensor as permuted; |
736 | if (format.hasPermutation()) { |
737 | permuted_tensors.insert({operand, format}); |
738 | } |
739 | |
740 | auto opt_dtype = operand->getDataType(); |
741 | // computation promotion, we cast fp16 or bf16 inputs to fp32 and use |
742 | // promoted type in the computation. |
743 | if (opt_dtype.has_value() && |
744 | (opt_dtype.value() == DataType::Half || |
745 | opt_dtype.value() == DataType::BFloat16)) { |
746 | Val* promoted_val = castOp(DataType::Float, operand); |
747 | value_map_[val->unique()] = ValueHolder(promoted_val, format); |
748 | } |
749 | } |
750 | |
751 | // compose nodes in topo order; |
752 | for (const JitOp* node : block->nodes()) { |
753 | processJitNode(node); |
754 | } |
755 | |
756 | // mark output; |
757 | for (auto jit_output : block->outputs()) { |
758 | MemoryFormat format; |
759 | Val* operand = nullptr; |
760 | std::tie(format, operand) = value_map_[jit_output->unique()].getEntry(); |
761 | TensorView* out = operand->as<TensorView>(); |
762 | // demote output dtype to be match PyTorch JIT graph. |
763 | auto tensor_type = jit_output->type()->cast<TensorType>(); |
764 | TORCH_INTERNAL_ASSERT( |
765 | tensor_type, "output of fusion group is not TensorType." ); |
766 | if (tensor_type->scalarType().has_value()) { |
767 | out = optionalCastStrict( |
768 | aten_to_data_type(*tensor_type->scalarType()), out) |
769 | ->as<TensorView>(); |
770 | } |
771 | |
772 | if (out->isFusionOutput()) { |
773 | // TODO: This is wasted memory bandwidth, we need to copy since we can't |
774 | // output a tensor twice. |
775 | out = set(out); |
776 | } |
777 | |
778 | fusion->addOutput(out); |
779 | |
780 | // mark output tensor as permuted; |
781 | if (format.hasPermutation()) { |
782 | permuted_tensors.insert({out, format}); |
783 | } |
784 | } |
785 | |
786 | for (const auto& i : c10::irange(fusion->inputs().size())) { |
787 | const auto& entry = permuted_tensors.find(fusion->inputs()[i]); |
788 | if (entry != permuted_tensors.end()) { |
789 | fusion->setPermutationOnInput(i, entry->second.apply()); |
790 | } |
791 | } |
792 | for (const auto& i : c10::irange(fusion->outputs().size())) { |
793 | const auto& entry = permuted_tensors.find(fusion->outputs()[i]); |
794 | if (entry != permuted_tensors.end()) { |
795 | fusion->setPermutationOnOutput(i, entry->second.restore()); |
796 | } |
797 | } |
798 | return fusion; |
799 | } |
800 | |
801 | static bool lookupInSymbolSet(const Node* node) { |
802 | initRegistry(); |
803 | |
804 | std::lock_guard<std::mutex> lock(parser_mutex_); |
805 | return parser_symbol_set_.count(node->kind()) != 0; |
806 | } |
807 | |
808 | // return nullptr if entry does not exist |
809 | static const RegistrationEntry* lookupInRegistry(const Node* node) { |
810 | std::lock_guard<std::mutex> lock(parser_mutex_); |
811 | |
812 | if (parser_skip_set_.count(node->kind()) != 0) { |
813 | return nullptr; |
814 | } |
815 | // we need to use maybeSchema for nodes like prim::Constant, which doesn't |
816 | // have a schema |
817 | auto schema_ptr = node->maybeSchema(); |
818 | if (schema_ptr != nullptr) { |
819 | // search cached entry first |
820 | auto cache_it = cached_registry_lookup_.find(schema_ptr); |
821 | if (cache_it != cached_registry_lookup_.end()) { |
822 | return cache_it->second; |
823 | } else { |
824 | // match signature |
825 | auto schema_str = canonicalSchemaString(*schema_ptr); |
826 | |
827 | auto iter = jit_operator_registry_.find(schema_str); |
828 | if (iter != jit_operator_registry_.end()) { |
829 | // update cache entry |
830 | cached_registry_lookup_.insert(cache_it, {schema_ptr, &iter->second}); |
831 | return &iter->second; |
832 | } |
833 | } |
834 | } |
835 | return nullptr; |
836 | } |
837 | |
838 | static bool querySkipSymbolSet(c10::Symbol symbol, bool flip) { |
839 | initRegistry(); |
840 | |
841 | std::lock_guard<std::mutex> lock(parser_mutex_); |
842 | // no need to init registry here (unlike `lookupInSymbolSet`, as |
843 | // `parser_skip_set_` is not initialized via initialization |
844 | bool ret = parser_skip_set_.count(symbol) != 0; |
845 | if (flip) { |
846 | if (ret) { |
847 | parser_skip_set_.erase(symbol); |
848 | } else { |
849 | parser_skip_set_.insert(symbol); |
850 | } |
851 | } |
852 | return ret; |
853 | } |
854 | |
855 | static void initRegistry() { |
856 | c10::call_once(once_flag_, []() { |
857 | std::lock_guard<std::mutex> lock(parser_mutex_); |
858 | registerJitOperator(); |
859 | }); |
860 | } |
861 | |
862 | static bool canParseNode(const Node* node) { |
863 | initRegistry(); |
864 | |
865 | // match signature. |
866 | auto schema_ptr = node->maybeSchema(); |
867 | if (schema_ptr == nullptr) { |
868 | return false; |
869 | } |
870 | auto reg_entry = lookupInRegistry(node); |
871 | return reg_entry != nullptr && reg_entry->isCompatible(node); |
872 | } |
873 | |
874 | static bool isReductionToSizeNode(const Node* node) { |
875 | initRegistry(); |
876 | |
877 | auto reg_entry = lookupInRegistry(node); |
878 | return reg_entry != nullptr && |
879 | reg_entry->isType(node, OperatorType::ReductionToSize); |
880 | } |
881 | |
882 | static bool isReductionNode(const Node* node) { |
883 | initRegistry(); |
884 | |
885 | auto reg_entry = lookupInRegistry(node); |
886 | return reg_entry != nullptr && |
887 | (reg_entry->isType(node, OperatorType::Reduction) || |
888 | reg_entry->isType(node, OperatorType::ReductionToSize)); |
889 | } |
890 | |
891 | static bool isNormalizationNode(const Node* node) { |
892 | initRegistry(); |
893 | |
894 | auto reg_entry = lookupInRegistry(node); |
895 | return reg_entry != nullptr && |
896 | reg_entry->isType(node, OperatorType::Normalization); |
897 | } |
898 | |
899 | static bool isElementWiseNode(const Node* node) { |
900 | initRegistry(); |
901 | |
902 | auto reg_entry = lookupInRegistry(node); |
903 | return reg_entry != nullptr && |
904 | reg_entry->isType(node, OperatorType::ElementWise); |
905 | } |
906 | |
907 | // TODO: is_reduction is too hacky here. we should categorize operation types |
908 | // based on their memory accessing pattern, which would affect fusion |
909 | // strategy and partition logic. |
910 | static void registerParseRule( |
911 | std::shared_ptr<Operator>& op, |
912 | ParseFuncPtr parse_fn, |
913 | MergeQueryFuncPtr merge_query_fn = nullptr, |
914 | OperatorTypeFuncPtr type_fn = nullptr) { |
915 | auto op_name = op->schema().name(); |
916 | parser_symbol_set_.insert(c10::Symbol::fromQualString(op_name)); |
917 | // We blindly attempt to profile the inplace version of supported op, this |
918 | // is to ensure that in-place removal in fusion partition would have the |
919 | // profile information for them readily available after the pass. |
920 | parser_symbol_set_.insert(c10::Symbol::fromQualString(op_name + '_')); |
921 | jit_operator_registry_.emplace( |
922 | std::piecewise_construct, |
923 | std::forward_as_tuple(canonicalSchemaString(op->schema())), |
924 | std::forward_as_tuple(parse_fn, merge_query_fn, type_fn)); |
925 | } |
926 | |
927 | private: |
928 | static void registerJitOperator() { |
929 | // Register parse-function for each JIT operator; |
930 | // This is a one-time look up, our hash registry indexes on the pointer in |
931 | // OperatorRegistry. |
932 | |
933 | std::array<const char*, kNumBinaryOpsWithAlpha> BinaryOpWithAlpha = { |
934 | "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor" , |
935 | "aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor" , |
936 | "aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor" , |
937 | "aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor" , |
938 | "aten::rsub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor" , |
939 | "aten::rsub(Tensor self, Scalar other, Scalar alpha) -> Tensor" }; |
940 | for (auto signature : BinaryOpWithAlpha) { |
941 | auto ptr_op = getOperatorForLiteral(signature); |
942 | REGISTER_PARSE_RULE( |
943 | ptr_op, |
944 | { |
945 | using BinaryOpWithAlphaType = Val* (*)(Val*, Val*, Val*); |
946 | static std::unordered_map< |
947 | Symbol, |
948 | std::pair<BinaryOpType, BinaryOpWithAlphaType>> |
949 | op_mapping( |
950 | {{aten::add, |
951 | std::make_pair( |
952 | BinaryOpType::Add, |
953 | static_cast<BinaryOpWithAlphaType>(&add_alpha))}, |
954 | {aten::sub, |
955 | std::make_pair( |
956 | BinaryOpType::Sub, |
957 | static_cast<BinaryOpWithAlphaType>(&sub_alpha))}, |
958 | {aten::rsub, |
959 | std::make_pair( |
960 | BinaryOpType::Sub, |
961 | static_cast<BinaryOpWithAlphaType>(&sub_alpha))}}); |
962 | // TODO: handle scaling factor when it's not constant 1; |
963 | MemoryFormat format; |
964 | std::list<Val*> list_val; |
965 | std::tie(format, list_val) = getPWFormatValues( |
966 | c10::nullopt, |
967 | value_map[node->inputs()[0]->unique()], |
968 | value_map[node->inputs()[1]->unique()]); |
969 | auto lhs = list_val.front(); |
970 | list_val.pop_front(); |
971 | auto rhs = list_val.front(); |
972 | list_val.pop_front(); |
973 | Val* alpha = value_map[node->inputs()[2]->unique()]; |
974 | |
975 | auto out = alpha->isOneInt() |
976 | ? binaryOp( |
977 | op_mapping[node->kind()].first, |
978 | node->kind() == aten::rsub ? rhs : lhs, |
979 | node->kind() == aten::rsub ? lhs : rhs, |
980 | TypePromotion::default_op_config) |
981 | : (node->kind() == aten::rsub |
982 | ? op_mapping[node->kind()].second(rhs, lhs, alpha) |
983 | : op_mapping[node->kind()].second(lhs, rhs, alpha)); |
984 | value_map.emplace( |
985 | node->output()->unique(), ValueHolder(out, format)); |
986 | }, |
987 | isInputNonSizeZeroTensor, |
988 | nullptr); |
989 | } |
990 | |
991 | std::array<const char*, kNumBinaryFloatOps> BinaryFloatOp = { |
992 | "aten::div(Tensor self, Tensor other) -> Tensor" , |
993 | "aten::div(Tensor self, Scalar other) -> Tensor" , |
994 | "aten::atan2(Tensor self, Tensor other) -> Tensor" }; |
995 | for (auto signature : BinaryFloatOp) { |
996 | auto ptr_op = getOperatorForLiteral(signature); |
997 | REGISTER_PARSE_RULE( |
998 | ptr_op, |
999 | { |
1000 | static std::unordered_map<Symbol, BinaryOpType> op_mapping( |
1001 | {{aten::div, BinaryOpType::Div}, |
1002 | {aten::atan2, BinaryOpType::Atan2}}); |
1003 | |
1004 | MemoryFormat format; |
1005 | std::list<Val*> list_val; |
1006 | std::tie(format, list_val) = getPWFormatValues( |
1007 | c10::nullopt, |
1008 | value_map[node->inputs()[0]->unique()], |
1009 | value_map[node->inputs()[1]->unique()]); |
1010 | auto lhs = list_val.front(); |
1011 | list_val.pop_front(); |
1012 | auto rhs = list_val.front(); |
1013 | list_val.pop_front(); |
1014 | |
1015 | auto out = binaryOp( |
1016 | op_mapping[node->kind()], |
1017 | lhs, |
1018 | rhs, |
1019 | TypePromotion::float_op_config); |
1020 | value_map.emplace( |
1021 | node->output()->unique(), ValueHolder(out, format)); |
1022 | }, |
1023 | isInputNonSizeZeroTensor, |
1024 | nullptr); |
1025 | } |
1026 | |
1027 | std::array<const char*, kNumBinaryCastOps> BinaryCastOp = { |
1028 | "aten::mul(Tensor self, Tensor other) -> Tensor" , |
1029 | "aten::mul(Tensor self, Scalar other) -> Tensor" , |
1030 | "aten::max(Tensor self, Tensor other) -> Tensor" , |
1031 | "aten::min(Tensor self, Tensor other) -> Tensor" , |
1032 | "aten::pow(Tensor self, Tensor exponent) -> Tensor" , |
1033 | "aten::pow(Tensor self, Scalar exponent) -> Tensor" , |
1034 | "aten::pow(Scalar self, Tensor exponent) -> Tensor" , |
1035 | "aten::remainder(Tensor self, Tensor other) -> Tensor" , |
1036 | "aten::fmod(Tensor self, Tensor other) -> Tensor" , |
1037 | "aten::bitwise_and(Tensor self, Tensor other) -> Tensor" , |
1038 | "aten::__and__(Tensor self, Tensor other) -> Tensor" , |
1039 | "aten::bitwise_or(Tensor self, Tensor other) -> Tensor" , |
1040 | "aten::__or__(Tensor self, Tensor other) -> Tensor" , |
1041 | "aten::bitwise_xor(Tensor self, Tensor other) -> Tensor" , |
1042 | "aten::__xor__(Tensor self, Tensor other) -> Tensor" , |
1043 | "aten::bitwise_left_shift(Tensor self, Tensor other) -> Tensor" , |
1044 | "aten::__lshift__(Tensor self, Tensor other) -> Tensor" , |
1045 | "aten::bitwise_right_shift(Tensor self, Tensor other) -> Tensor" , |
1046 | "aten::__rshift__(Tensor self, Tensor other) -> Tensor" }; |
1047 | for (auto signature : BinaryCastOp) { |
1048 | auto ptr_op = getOperatorForLiteral(signature); |
1049 | REGISTER_PARSE_RULE( |
1050 | ptr_op, |
1051 | { |
1052 | static std::unordered_map<Symbol, BinaryOpType> op_mapping( |
1053 | {{aten::mul, BinaryOpType::Mul}, |
1054 | {aten::min, BinaryOpType::Min}, |
1055 | {aten::max, BinaryOpType::Max}, |
1056 | {aten::pow, BinaryOpType::Pow}, |
1057 | {aten::remainder, BinaryOpType::Remainder}, |
1058 | {aten::fmod, BinaryOpType::Fmod}, |
1059 | {aten::bitwise_and, BinaryOpType::And}, |
1060 | {aten::__and__, BinaryOpType::And}, |
1061 | {aten::bitwise_or, BinaryOpType::Or}, |
1062 | {aten::__or__, BinaryOpType::Or}, |
1063 | {aten::bitwise_xor, BinaryOpType::Xor}, |
1064 | {aten::__xor__, BinaryOpType::Xor}, |
1065 | {aten::bitwise_left_shift, BinaryOpType::Lshift}, |
1066 | {aten::__lshift__, BinaryOpType::Lshift}, |
1067 | {aten::bitwise_right_shift, BinaryOpType::Rshift}, |
1068 | {aten::__rshift__, BinaryOpType::Rshift}}); |
1069 | |
1070 | MemoryFormat format; |
1071 | std::list<Val*> list_val; |
1072 | std::tie(format, list_val) = getPWFormatValues( |
1073 | c10::nullopt, |
1074 | value_map[node->inputs()[0]->unique()], |
1075 | value_map[node->inputs()[1]->unique()]); |
1076 | auto lhs = list_val.front(); |
1077 | list_val.pop_front(); |
1078 | auto rhs = list_val.front(); |
1079 | list_val.pop_front(); |
1080 | |
1081 | auto out = binaryOp( |
1082 | op_mapping[node->kind()], |
1083 | lhs, |
1084 | rhs, |
1085 | TypePromotion::default_op_config); |
1086 | value_map.emplace( |
1087 | node->output()->unique(), ValueHolder(out, format)); |
1088 | }, |
1089 | isInputNonSizeZeroTensor, |
1090 | nullptr); |
1091 | } |
1092 | |
1093 | std::array<const char*, kNumBinaryComparisonOps> BinaryOp = { |
1094 | "aten::eq(Tensor self, Tensor other) -> Tensor" , |
1095 | "aten::eq(Tensor self, Scalar other) -> Tensor" , |
1096 | "aten::ne(Tensor self, Tensor other) -> Tensor" , |
1097 | "aten::ne(Tensor self, Scalar other) -> Tensor" , |
1098 | "aten::ge(Tensor self, Tensor other) -> Tensor" , |
1099 | "aten::ge(Tensor self, Scalar other) -> Tensor" , |
1100 | "aten::gt(Tensor self, Tensor other) -> Tensor" , |
1101 | "aten::gt(Tensor self, Scalar other) -> Tensor" , |
1102 | "aten::le(Tensor self, Tensor other) -> Tensor" , |
1103 | "aten::le(Tensor self, Scalar other) -> Tensor" , |
1104 | "aten::lt(Tensor self, Tensor other) -> Tensor" , |
1105 | "aten::lt(Tensor self, Scalar other) -> Tensor" }; |
1106 | for (auto signature : BinaryOp) { |
1107 | auto ptr_op = getOperatorForLiteral(signature); |
1108 | REGISTER_PARSE_RULE( |
1109 | ptr_op, |
1110 | { |
1111 | static std::unordered_map<Symbol, BinaryOpType> op_mapping( |
1112 | {{aten::lt, BinaryOpType::LT}, |
1113 | {aten::le, BinaryOpType::LE}, |
1114 | {aten::gt, BinaryOpType::GT}, |
1115 | {aten::ge, BinaryOpType::GE}, |
1116 | {aten::ne, BinaryOpType::NE}, |
1117 | {aten::eq, BinaryOpType::Eq}}); |
1118 | |
1119 | MemoryFormat format; |
1120 | std::list<Val*> list_val; |
1121 | std::tie(format, list_val) = getPWFormatValues( |
1122 | c10::nullopt, |
1123 | value_map[node->inputs()[0]->unique()], |
1124 | value_map[node->inputs()[1]->unique()]); |
1125 | auto lhs = list_val.front(); |
1126 | list_val.pop_front(); |
1127 | auto rhs = list_val.front(); |
1128 | list_val.pop_front(); |
1129 | |
1130 | auto out = binaryOp( |
1131 | op_mapping[node->kind()], |
1132 | lhs, |
1133 | rhs, |
1134 | TypePromotion::comparison_op_config); |
1135 | value_map.emplace( |
1136 | node->output()->unique(), ValueHolder(out, format)); |
1137 | }, |
1138 | isInputNonSizeZeroTensor, |
1139 | nullptr); |
1140 | } |
1141 | |
1142 | std::array<const char*, kNumUnaryOps> UnaryOp = { |
1143 | "aten::abs(Tensor self) -> Tensor" , |
1144 | "aten::bitwise_not(Tensor self) -> Tensor" , |
1145 | "aten::ceil(Tensor self) -> Tensor" , |
1146 | "aten::floor(Tensor self) -> Tensor" , |
1147 | "aten::frac(Tensor self) -> Tensor" , |
1148 | "aten::neg(Tensor self) -> Tensor" , |
1149 | "aten::relu(Tensor self) -> Tensor" , |
1150 | "aten::round(Tensor self) -> Tensor" , |
1151 | "aten::silu(Tensor self) -> Tensor" , |
1152 | "aten::trunc(Tensor self) -> Tensor" , |
1153 | }; |
1154 | for (auto signature : UnaryOp) { |
1155 | auto ptr_op = getOperatorForLiteral(signature); |
1156 | REGISTER_PARSE_RULE( |
1157 | ptr_op, |
1158 | { |
1159 | static std::unordered_map<Symbol, UnaryOpType> op_mapping({ |
1160 | {aten::abs, UnaryOpType::Abs}, |
1161 | {aten::bitwise_not, UnaryOpType::Not}, |
1162 | {aten::ceil, UnaryOpType::Ceil}, |
1163 | {aten::floor, UnaryOpType::Floor}, |
1164 | {aten::frac, UnaryOpType::Frac}, |
1165 | {aten::neg, UnaryOpType::Neg}, |
1166 | {aten::relu, UnaryOpType::Relu}, |
1167 | {aten::round, UnaryOpType::Round}, |
1168 | {aten::silu, UnaryOpType::Silu}, |
1169 | {aten::trunc, UnaryOpType::Trunc}, |
1170 | }); |
1171 | MemoryFormat format; |
1172 | std::list<Val*> list_val; |
1173 | std::tie(format, list_val) = getConsistentValues( |
1174 | c10::nullopt, value_map[node->inputs()[0]->unique()]); |
1175 | auto operand = list_val.front(); |
1176 | list_val.pop_front(); |
1177 | auto out = unaryOp(op_mapping[node->kind()], operand); |
1178 | value_map.emplace( |
1179 | node->output()->unique(), ValueHolder(out, format)); |
1180 | }, |
1181 | isInputNonSizeZeroTensor, |
1182 | nullptr); |
1183 | } |
1184 | |
1185 | std::array<const char*, kNumUnaryFloatOps> UnaryFloatOp = { |
1186 | "aten::log(Tensor self) -> Tensor" , |
1187 | "aten::log10(Tensor self) -> Tensor" , |
1188 | "aten::log1p(Tensor self) -> Tensor" , |
1189 | "aten::log2(Tensor self) -> Tensor" , |
1190 | "aten::lgamma(Tensor self) -> Tensor" , |
1191 | "aten::exp(Tensor self) -> Tensor" , |
1192 | "aten::expm1(Tensor self) -> Tensor" , |
1193 | "aten::erf(Tensor self) -> Tensor" , |
1194 | "aten::erfc(Tensor self) -> Tensor" , |
1195 | "aten::cos(Tensor self) -> Tensor" , |
1196 | "aten::acos(Tensor self) -> Tensor" , |
1197 | "aten::cosh(Tensor self) -> Tensor" , |
1198 | "aten::sin(Tensor self) -> Tensor" , |
1199 | "aten::asin(Tensor self) -> Tensor" , |
1200 | "aten::sinh(Tensor self) -> Tensor" , |
1201 | "aten::tan(Tensor self) -> Tensor" , |
1202 | "aten::atan(Tensor self) -> Tensor" , |
1203 | "aten::tanh(Tensor self) -> Tensor" , |
1204 | "aten::atanh(Tensor self) -> Tensor" , |
1205 | "aten::sqrt(Tensor self) -> Tensor" , |
1206 | "aten::rsqrt(Tensor self) -> Tensor" , |
1207 | "aten::reciprocal(Tensor self) -> Tensor" , |
1208 | "aten::sigmoid(Tensor self) -> Tensor" }; |
1209 | for (auto signature : UnaryFloatOp) { |
1210 | auto ptr_op = getOperatorForLiteral(signature); |
1211 | REGISTER_PARSE_RULE( |
1212 | ptr_op, |
1213 | { |
1214 | static std::unordered_map<Symbol, UnaryOpType> op_mapping({ |
1215 | {aten::log, UnaryOpType::Log}, |
1216 | {aten::log10, UnaryOpType::Log10}, |
1217 | {aten::log1p, UnaryOpType::Log1p}, |
1218 | {aten::log2, UnaryOpType::Log2}, |
1219 | {aten::lgamma, UnaryOpType::Lgamma}, |
1220 | {aten::exp, UnaryOpType::Exp}, |
1221 | {aten::expm1, UnaryOpType::Expm1}, |
1222 | {aten::erf, UnaryOpType::Erf}, |
1223 | {aten::erfc, UnaryOpType::Erfc}, |
1224 | {aten::cos, UnaryOpType::Cos}, |
1225 | {aten::acos, UnaryOpType::Acos}, |
1226 | {aten::cosh, UnaryOpType::Cosh}, |
1227 | {aten::sin, UnaryOpType::Sin}, |
1228 | {aten::asin, UnaryOpType::Asin}, |
1229 | {aten::sinh, UnaryOpType::Sinh}, |
1230 | {aten::tan, UnaryOpType::Tan}, |
1231 | {aten::tanh, UnaryOpType::Tanh}, |
1232 | {aten::atan, UnaryOpType::Atan}, |
1233 | {aten::atanh, UnaryOpType::Atanh}, |
1234 | {aten::sqrt, UnaryOpType::Sqrt}, |
1235 | {aten::rsqrt, UnaryOpType::Rsqrt}, |
1236 | {aten::reciprocal, UnaryOpType::Reciprocal}, |
1237 | {aten::sigmoid, UnaryOpType::Sigmoid}, |
1238 | }); |
1239 | MemoryFormat format; |
1240 | std::list<Val*> list_val; |
1241 | std::tie(format, list_val) = getConsistentValues( |
1242 | c10::nullopt, value_map[node->inputs()[0]->unique()]); |
1243 | auto operand = list_val.front(); |
1244 | list_val.pop_front(); |
1245 | auto out = unaryOp( |
1246 | op_mapping[node->kind()], |
1247 | operand, |
1248 | TypePromotion::float_op_config); |
1249 | value_map.emplace( |
1250 | node->output()->unique(), ValueHolder(out, format)); |
1251 | }, |
1252 | isInputNonSizeZeroTensor, |
1253 | nullptr); |
1254 | } |
1255 | |
1256 | std::array<const char*, kNumUnaryIsOps> UnaryIsOp = { |
1257 | "aten::isfinite(Tensor self) -> Tensor" , |
1258 | "aten::isinf(Tensor self) -> Tensor" , |
1259 | "aten::isnan(Tensor self) -> Tensor" , |
1260 | "aten::isneginf(Tensor self) -> Tensor" , |
1261 | "aten::isposinf(Tensor self) -> Tensor" , |
1262 | "aten::isreal(Tensor self) -> Tensor" }; |
1263 | for (auto signature : UnaryIsOp) { |
1264 | auto ptr_op = getOperatorForLiteral(signature); |
1265 | REGISTER_PARSE_RULE( |
1266 | ptr_op, |
1267 | { |
1268 | static std::unordered_map<Symbol, UnaryOpType> op_mapping({ |
1269 | {aten::isfinite, UnaryOpType::IsFinite}, |
1270 | {aten::isinf, UnaryOpType::IsInf}, |
1271 | {aten::isnan, UnaryOpType::IsNan}, |
1272 | {aten::isneginf, UnaryOpType::IsNegInf}, |
1273 | {aten::isposinf, UnaryOpType::IsPosInf}, |
1274 | {aten::isreal, UnaryOpType::IsReal}, |
1275 | }); |
1276 | MemoryFormat format; |
1277 | std::list<Val*> list_val; |
1278 | std::tie(format, list_val) = getConsistentValues( |
1279 | c10::nullopt, value_map[node->inputs()[0]->unique()]); |
1280 | auto operand = list_val.front(); |
1281 | list_val.pop_front(); |
1282 | auto out = unaryIsOp(op_mapping[node->kind()], operand); |
1283 | value_map.emplace( |
1284 | node->output()->unique(), ValueHolder(out, format)); |
1285 | }, |
1286 | isInputNonSizeZeroTensor, |
1287 | nullptr); |
1288 | } |
1289 | |
1290 | { |
1291 | auto ptr_op = getOperatorForLiteral( |
1292 | "aten::rand_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor" ); |
1293 | REGISTER_PARSE_RULE( |
1294 | ptr_op, |
1295 | { |
1296 | MemoryFormat format; |
1297 | std::list<Val*> list_val; |
1298 | std::tie(format, list_val) = getConsistentValues( |
1299 | c10::nullopt, value_map[node->inputs()[0]->unique()]); |
1300 | auto operand = list_val.front(); |
1301 | list_val.pop_front(); |
1302 | |
1303 | if (!node->input(3)->type()->isSubtypeOf( |
1304 | static_cast<c10::TypePtr>(NoneType::get()))) { |
1305 | auto device = constant_as<c10::Device>(node->input(3)); |
1306 | TORCH_INTERNAL_ASSERT( |
1307 | device.has_value() && device->is_cuda(), |
1308 | "rand_like in nvfuser is not on cuda device" ); |
1309 | auto input_tensor_type = |
1310 | node->input(0)->type()->cast<TensorType>(); |
1311 | // device->index() == -1 indicating that we don't change device |
1312 | // index |
1313 | if (device->index() != -1 && input_tensor_type) { |
1314 | auto input_device = input_tensor_type->device(); |
1315 | // we expect device index to be consistent with input and it |
1316 | // should have already been handled by partition |
1317 | TORCH_INTERNAL_ASSERT( |
1318 | !input_device.has_value() || |
1319 | input_device->index() == device->index(), |
1320 | "rand_like in nvfuser is not on cuda device" ); |
1321 | } |
1322 | } |
1323 | |
1324 | auto out = rand_like(operand); |
1325 | value_map.emplace( |
1326 | node->output()->unique(), ValueHolder(out, format)); |
1327 | }, |
1328 | [](const Node* node) -> bool { |
1329 | if (!isInputNonSizeZeroTensor(node)) { |
1330 | return false; |
1331 | } |
1332 | if (!node->input(1)->type()->isSubtypeOf( |
1333 | static_cast<c10::TypePtr>(NoneType::get())) || |
1334 | !node->input(2)->type()->isSubtypeOf( |
1335 | static_cast<c10::TypePtr>(NoneType::get())) || |
1336 | !node->input(5)->type()->isSubtypeOf( |
1337 | static_cast<c10::TypePtr>(NoneType::get()))) { |
1338 | return false; |
1339 | } |
1340 | return true; |
1341 | }, |
1342 | nullptr); |
1343 | } |
1344 | |
1345 | { |
1346 | auto ptr_op = getOperatorForLiteral( |
1347 | "aten::softplus(Tensor self, Scalar beta, Scalar threshold) -> Tensor" ); |
1348 | REGISTER_PARSE_RULE( |
1349 | ptr_op, |
1350 | { |
1351 | MemoryFormat format; |
1352 | std::list<Val*> list_val; |
1353 | std::tie(format, list_val) = getConsistentValues( |
1354 | c10::nullopt, value_map[node->inputs()[0]->unique()]); |
1355 | auto operand = list_val.front()->as<TensorView>(); |
1356 | list_val.pop_front(); |
1357 | auto& beta = value_map[node->inputs()[1]->unique()]; |
1358 | auto& threshold = value_map[node->inputs()[2]->unique()]; |
1359 | auto out = softplus(operand, beta, threshold); |
1360 | value_map.emplace( |
1361 | node->output()->unique(), ValueHolder(out, format)); |
1362 | }, |
1363 | isInputNonSizeZeroTensor, |
1364 | nullptr); |
1365 | } |
1366 | |
1367 | { |
1368 | auto ptr_op = getOperatorForLiteral( |
1369 | "aten::threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor" ); |
1370 | REGISTER_PARSE_RULE( |
1371 | ptr_op, |
1372 | { |
1373 | MemoryFormat format; |
1374 | std::list<Val*> list_val; |
1375 | std::tie(format, list_val) = getConsistentValues( |
1376 | c10::nullopt, value_map[node->inputs()[0]->unique()]); |
1377 | auto operand = list_val.front(); |
1378 | list_val.pop_front(); |
1379 | auto& th = value_map[node->inputs()[1]->unique()]; |
1380 | auto& value = value_map[node->inputs()[2]->unique()]; |
1381 | |
1382 | auto out = threshold(operand, th, value); |
1383 | value_map.emplace( |
1384 | node->output()->unique(), ValueHolder(out, format)); |
1385 | }, |
1386 | isInputNonSizeZeroTensor, |
1387 | nullptr); |
1388 | } |
1389 | |
1390 | { // LTC uses threshold_backward for relu_backward |
1391 | auto ptr_op = getOperatorForLiteral( |
1392 | "aten::threshold_backward(Tensor grad_output, Tensor self, Scalar threshold) -> Tensor" ); |
1393 | REGISTER_PARSE_RULE( |
1394 | ptr_op, |
1395 | { |
1396 | MemoryFormat format; |
1397 | std::list<Val*> list_val; |
1398 | std::tie(format, list_val) = getPWFormatValues( |
1399 | c10::nullopt, |
1400 | value_map[node->inputs()[0]->unique()], |
1401 | value_map[node->inputs()[1]->unique()]); |
1402 | auto grad_output = list_val.front(); |
1403 | list_val.pop_front(); |
1404 | auto input = list_val.front(); |
1405 | auto& threshold = value_map[node->inputs()[2]->unique()]; |
1406 | |
1407 | auto comparison = binaryOp( |
1408 | BinaryOpType::GT, |
1409 | input, |
1410 | threshold, |
1411 | TypePromotion::comparison_op_config); |
1412 | auto mask = castOp(input->getDataType().value(), comparison); |
1413 | auto out = mul(grad_output, mask); |
1414 | |
1415 | value_map.emplace( |
1416 | node->output()->unique(), ValueHolder(out, format)); |
1417 | }, |
1418 | isInputNonSizeZeroTensor, |
1419 | nullptr); |
1420 | } |
1421 | |
1422 | { |
1423 | auto ptr_op = getOperatorForLiteral( |
1424 | "aten::clamp(Tensor self, Scalar? min, Scalar? max) -> Tensor" ); |
1425 | REGISTER_PARSE_RULE( |
1426 | ptr_op, |
1427 | { |
1428 | MemoryFormat format; |
1429 | std::list<Val*> list_val; |
1430 | std::tie(format, list_val) = getConsistentValues( |
1431 | c10::nullopt, value_map[node->inputs()[0]->unique()]); |
1432 | auto operand = list_val.front(); |
1433 | list_val.pop_front(); |
1434 | Val* min = value_map.count(node->inputs()[1]->unique()) != 0 |
1435 | ? *value_map[node->inputs()[1]->unique()] |
1436 | : nullptr; |
1437 | Val* max = value_map.count(node->inputs()[2]->unique()) != 0 |
1438 | ? *value_map[node->inputs()[2]->unique()] |
1439 | : nullptr; |
1440 | |
1441 | Val* out = clamp(operand, min, max); |
1442 | value_map.emplace( |
1443 | node->output()->unique(), ValueHolder(out, format)); |
1444 | }, |
1445 | isInputNonSizeZeroTensor, |
1446 | nullptr); |
1447 | } |
1448 | |
1449 | { |
1450 | auto ptr_op = getOperatorForLiteral( |
1451 | "aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor" ); |
1452 | REGISTER_PARSE_RULE( |
1453 | ptr_op, |
1454 | { |
1455 | MemoryFormat format; |
1456 | std::list<Val*> list_val; |
1457 | std::tie(format, list_val) = getPWFormatValues( |
1458 | c10::nullopt, |
1459 | value_map[node->inputs()[0]->unique()], |
1460 | value_map[node->inputs()[1]->unique()], |
1461 | value_map[node->inputs()[2]->unique()]); |
1462 | auto condition = list_val.front(); |
1463 | list_val.pop_front(); |
1464 | auto x = list_val.front(); |
1465 | list_val.pop_front(); |
1466 | auto y = list_val.front(); |
1467 | list_val.pop_front(); |
1468 | |
1469 | auto out = where(condition, x, y); |
1470 | value_map.emplace( |
1471 | node->output()->unique(), ValueHolder(out, format)); |
1472 | }, |
1473 | isInputNonSizeZeroTensor, |
1474 | nullptr); |
1475 | } |
1476 | |
1477 | { |
1478 | std::array<const char*, kNumLerpOps> LerpOp = { |
1479 | "aten::lerp(Tensor self, Tensor end, Scalar weight) -> Tensor" , |
1480 | "aten::lerp(Tensor self, Tensor end, Tensor weight) -> Tensor" }; |
1481 | for (auto signature : LerpOp) { |
1482 | auto ptr_op = getOperatorForLiteral(signature); |
1483 | REGISTER_PARSE_RULE( |
1484 | ptr_op, |
1485 | { |
1486 | MemoryFormat format; |
1487 | std::list<Val*> list_val; |
1488 | std::tie(format, list_val) = getPWFormatValues( |
1489 | c10::nullopt, |
1490 | value_map[node->inputs()[0]->unique()], |
1491 | value_map[node->inputs()[1]->unique()], |
1492 | value_map[node->inputs()[2]->unique()]); |
1493 | auto self = list_val.front(); |
1494 | list_val.pop_front(); |
1495 | auto end = list_val.front(); |
1496 | list_val.pop_front(); |
1497 | auto weight = list_val.front(); |
1498 | list_val.pop_front(); |
1499 | |
1500 | auto out = lerp(self, end, weight); |
1501 | value_map.emplace( |
1502 | node->output()->unique(), ValueHolder(out, format)); |
1503 | }, |
1504 | isInputNonSizeZeroTensor, |
1505 | nullptr); |
1506 | } |
1507 | } |
1508 | |
1509 | { |
1510 | auto ptr_op = getOperatorForLiteral( |
1511 | "aten::addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor" ); |
1512 | REGISTER_PARSE_RULE( |
1513 | ptr_op, |
1514 | { |
1515 | MemoryFormat format; |
1516 | std::list<Val*> list_val; |
1517 | std::tie(format, list_val) = getPWFormatValues( |
1518 | c10::nullopt, |
1519 | value_map[node->inputs()[0]->unique()], |
1520 | value_map[node->inputs()[1]->unique()], |
1521 | value_map[node->inputs()[2]->unique()], |
1522 | value_map[node->inputs()[3]->unique()]); |
1523 | auto self = list_val.front(); |
1524 | list_val.pop_front(); |
1525 | auto tensor1 = list_val.front(); |
1526 | list_val.pop_front(); |
1527 | auto tensor2 = list_val.front(); |
1528 | list_val.pop_front(); |
1529 | auto value = list_val.front(); |
1530 | list_val.pop_front(); |
1531 | |
1532 | auto out = addcmul(self, tensor1, tensor2, value); |
1533 | value_map.emplace( |
1534 | node->output()->unique(), ValueHolder(out, format)); |
1535 | }, |
1536 | isInputNonSizeZeroTensor, |
1537 | nullptr); |
1538 | } |
1539 | |
1540 | { |
1541 | auto ptr_op = getOperatorForLiteral( |
1542 | "aten::native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor)" ); |
1543 | REGISTER_PARSE_RULE( |
1544 | ptr_op, |
1545 | { |
1546 | MemoryFormat format; |
1547 | std::list<Val*> list_val; |
1548 | std::tie(format, list_val) = getConsistentValues( |
1549 | c10::nullopt, |
1550 | value_map[node->inputs()[0]->unique()], |
1551 | value_map[node->inputs()[1]->unique()]); |
1552 | auto input = list_val.front(); |
1553 | list_val.pop_front(); |
1554 | auto prob = list_val.front(); |
1555 | list_val.pop_front(); |
1556 | auto train = constant_as<bool>(node->input(2)); |
1557 | |
1558 | TORCH_INTERNAL_ASSERT( |
1559 | train.has_value(), "dropout needs constant `train` flag" ); |
1560 | |
1561 | if (train.value()) { |
1562 | auto result = dropout(input->as<TensorView>(), prob); |
1563 | |
1564 | value_map.emplace( |
1565 | node->output(0)->unique(), |
1566 | ValueHolder(result.output, format)); |
1567 | value_map.emplace( |
1568 | node->output(1)->unique(), ValueHolder(result.mask, format)); |
1569 | } else { |
1570 | value_map.emplace(node->output(0)->unique(), input); |
1571 | value_map.emplace( |
1572 | node->output(1)->unique(), |
1573 | ValueHolder(TensorViewBuilder().build(), format)); |
1574 | } |
1575 | }, |
1576 | [](const Node* node) -> bool { |
1577 | if (!isInputNonSizeZeroTensor(node)) { |
1578 | return false; |
1579 | } |
1580 | if (node->inputs()[2]->node()->kind() != prim::Constant) { |
1581 | return false; |
1582 | } |
1583 | return true; |
1584 | }, |
1585 | nullptr); |
1586 | } |
1587 | |
1588 | { |
1589 | auto ptr_op = getOperatorForLiteral( |
1590 | "aten::dropout(Tensor input, float p, bool train) -> Tensor" ); |
1591 | REGISTER_PARSE_RULE( |
1592 | ptr_op, |
1593 | { |
1594 | MemoryFormat format; |
1595 | std::list<Val*> list_val; |
1596 | std::tie(format, list_val) = getConsistentValues( |
1597 | c10::nullopt, |
1598 | value_map[node->inputs()[0]->unique()], |
1599 | value_map[node->inputs()[1]->unique()]); |
1600 | auto input = list_val.front(); |
1601 | list_val.pop_front(); |
1602 | auto prob = list_val.front(); |
1603 | list_val.pop_front(); |
1604 | |
1605 | auto train = constant_as<bool>(node->input(2)); |
1606 | TORCH_INTERNAL_ASSERT( |
1607 | train.has_value(), "dropout needs constant `train` flag" ); |
1608 | |
1609 | if (train.value()) { |
1610 | auto result = dropout(input->as<TensorView>(), prob); |
1611 | |
1612 | value_map.emplace( |
1613 | node->output()->unique(), ValueHolder(result.output, format)); |
1614 | } else { |
1615 | value_map.emplace( |
1616 | node->output()->unique(), ValueHolder(input, format)); |
1617 | } |
1618 | }, |
1619 | [](const Node* node) -> bool { |
1620 | if (!isInputNonSizeZeroTensor(node)) { |
1621 | return false; |
1622 | } |
1623 | if (node->inputs()[2]->node()->kind() != prim::Constant) { |
1624 | return false; |
1625 | } |
1626 | return true; |
1627 | }, |
1628 | nullptr); |
1629 | } |
1630 | |
1631 | { |
1632 | auto ptr_op = getOperatorForLiteral( |
1633 | "aten::native_dropout_backward(Tensor grad_output, Tensor mask, float scale) -> Tensor" ); |
1634 | REGISTER_PARSE_RULE( |
1635 | ptr_op, |
1636 | { |
1637 | MemoryFormat format; |
1638 | std::list<Val*> list_val; |
1639 | std::tie(format, list_val) = getPWFormatValues( |
1640 | c10::nullopt, |
1641 | value_map[node->inputs()[0]->unique()], |
1642 | value_map[node->inputs()[1]->unique()], |
1643 | value_map[node->inputs()[2]->unique()]); |
1644 | auto grad = list_val.front(); |
1645 | list_val.pop_front(); |
1646 | auto mask = list_val.front(); |
1647 | list_val.pop_front(); |
1648 | auto scale = list_val.front(); |
1649 | list_val.pop_front(); |
1650 | |
1651 | auto output = dropout_backward( |
1652 | grad->as<TensorView>(), mask->as<TensorView>(), scale); |
1653 | value_map.emplace( |
1654 | node->output()->unique(), ValueHolder(output, format)); |
1655 | }, |
1656 | isInputNonSizeZeroTensor, |
1657 | nullptr); |
1658 | } |
1659 | |
1660 | { |
1661 | std::array<const char*, kNumInstancenormFwd> InstanceNormFwd = { |
1662 | "aten::instance_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool use_input_stats, float momentum, float eps, bool cudnn_enabled) -> Tensor" }; |
1663 | for (auto signature : InstanceNormFwd) { |
1664 | auto ptr_op = getOperatorForLiteral(signature); |
1665 | REGISTER_PARSE_RULE( |
1666 | ptr_op, |
1667 | { |
1668 | // TODO: handle channels last |
1669 | MemoryFormat format; |
1670 | std::list<Val*> list_val; |
1671 | std::tie(format, list_val) = getConsistentValues( |
1672 | MemoryFormat::Contiguous(), |
1673 | value_map[node->inputs()[0]->unique()]); |
1674 | auto input_t = list_val.front(); |
1675 | list_val.pop_front(); |
1676 | auto input = input_t->as<TensorView>(); |
1677 | |
1678 | TensorView* weight = nullptr; |
1679 | if (!node->input(1)->type()->isSubtypeOf( |
1680 | static_cast<c10::TypePtr>(NoneType::get()))) { |
1681 | weight = value_map[node->input(1)->unique()]->as<TensorView>(); |
1682 | } |
1683 | |
1684 | TensorView* bias = nullptr; |
1685 | if (!node->input(2)->type()->isSubtypeOf( |
1686 | static_cast<c10::TypePtr>(NoneType::get()))) { |
1687 | bias = value_map[node->input(2)->unique()]->as<TensorView>(); |
1688 | } |
1689 | |
1690 | TensorView* running_mean = nullptr; |
1691 | if (!node->input(3)->type()->isSubtypeOf( |
1692 | static_cast<c10::TypePtr>(NoneType::get()))) { |
1693 | running_mean = |
1694 | value_map[node->input(3)->unique()]->as<TensorView>(); |
1695 | } |
1696 | |
1697 | TensorView* running_var = nullptr; |
1698 | if (!node->input(4)->type()->isSubtypeOf( |
1699 | static_cast<c10::TypePtr>(NoneType::get()))) { |
1700 | running_var = |
1701 | value_map[node->input(4)->unique()]->as<TensorView>(); |
1702 | } |
1703 | |
1704 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
1705 | auto use_input_stats = constant_as<bool>(node->input(5)); |
1706 | TORCH_INTERNAL_ASSERT( |
1707 | use_input_stats.has_value(), |
1708 | "The use_input_stats (bool) parameter is required." ); |
1709 | const bool kUseInputStats = use_input_stats.value(); |
1710 | |
1711 | Val* momentum_ptr = nullptr; |
1712 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
1713 | if (auto momentum = constant_as<float>(node->input(6))) { |
1714 | momentum_ptr = IrBuilder::create<Double>(momentum.value()); |
1715 | } else { |
1716 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
1717 | momentum_ptr = value_map[node->input(6)->unique()]; |
1718 | } |
1719 | |
1720 | Val* eps_ptr = nullptr; |
1721 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
1722 | if (auto eps = constant_as<float>(node->input(7))) { |
1723 | eps_ptr = IrBuilder::create<Double>(eps.value()); |
1724 | } else { |
1725 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
1726 | eps_ptr = value_map[node->input(7)->unique()]; |
1727 | } |
1728 | |
1729 | auto result = instance_norm( |
1730 | input, |
1731 | weight, |
1732 | bias, |
1733 | running_mean, |
1734 | running_var, |
1735 | kUseInputStats, |
1736 | momentum_ptr, |
1737 | eps_ptr); |
1738 | |
1739 | if (node->kind() == |
1740 | c10::Symbol::fromQualString("aten::instance_norm" )) { |
1741 | value_map.emplace(node->output()->unique(), result.output); |
1742 | } |
1743 | }, |
1744 | [](const Node* node) -> bool { |
1745 | if (isReductionNonCompatibleTensor( |
1746 | node->input(0)->type()->cast<TensorType>())) { |
1747 | return false; |
1748 | } |
1749 | return true; |
1750 | }, |
1751 | [](const Node* node) -> OperatorType { |
1752 | return OperatorType::Normalization; |
1753 | }); |
1754 | } |
1755 | } |
1756 | |
1757 | { |
1758 | std::array<const char*, kNumBatchnormFwd> BatchNormFwd = { |
1759 | "aten::_batch_norm_impl_index(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor, Tensor, Tensor, Tensor, int)" , |
1760 | "aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)" , |
1761 | "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor" }; |
1762 | for (auto signature : BatchNormFwd) { |
1763 | auto ptr_op = getOperatorForLiteral(signature); |
1764 | REGISTER_PARSE_RULE( |
1765 | ptr_op, |
1766 | { |
1767 | MemoryFormat format; |
1768 | Val* operand = nullptr; |
1769 | std::tie(format, operand) = |
1770 | value_map[node->input(0)->unique()].getEntry(); |
1771 | if (format.hasPermutation() && !format.isChannelsLast()) { |
1772 | format = MemoryFormat::Contiguous(); |
1773 | operand = value_map[node->input(0)->unique()].maybeConvertValue( |
1774 | format); |
1775 | } |
1776 | auto input = operand->as<TensorView>(); |
1777 | |
1778 | TensorView* weight = nullptr; |
1779 | if (!node->input(1)->type()->isSubtypeOf( |
1780 | static_cast<c10::TypePtr>(NoneType::get()))) { |
1781 | weight = value_map[node->input(1)->unique()]->as<TensorView>(); |
1782 | } |
1783 | |
1784 | TensorView* bias = nullptr; |
1785 | if (!node->input(2)->type()->isSubtypeOf( |
1786 | static_cast<c10::TypePtr>(NoneType::get()))) { |
1787 | bias = value_map[node->input(2)->unique()]->as<TensorView>(); |
1788 | } |
1789 | |
1790 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
1791 | auto training = constant_as<bool>(node->input(5)); |
1792 | TORCH_INTERNAL_ASSERT( |
1793 | training.has_value(), |
1794 | "The training (bool) parameter is required." ); |
1795 | const bool kTraining = training.value(); |
1796 | |
1797 | TensorView* running_mean = nullptr; |
1798 | if (!node->input(3)->type()->isSubtypeOf( |
1799 | static_cast<c10::TypePtr>(NoneType::get()))) { |
1800 | running_mean = |
1801 | value_map[node->input(3)->unique()]->as<TensorView>(); |
1802 | } |
1803 | |
1804 | TensorView* running_var = nullptr; |
1805 | if (!node->input(4)->type()->isSubtypeOf( |
1806 | static_cast<c10::TypePtr>(NoneType::get()))) { |
1807 | running_var = |
1808 | value_map[node->input(4)->unique()]->as<TensorView>(); |
1809 | } |
1810 | |
1811 | Val* momentum_ptr = nullptr; |
1812 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
1813 | if (auto momentum = constant_as<float>(node->input(6))) { |
1814 | momentum_ptr = IrBuilder::create<Double>(momentum.value()); |
1815 | } else { |
1816 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
1817 | momentum_ptr = value_map[node->input(6)->unique()]; |
1818 | } |
1819 | |
1820 | Val* eps_ptr = nullptr; |
1821 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
1822 | if (auto eps = constant_as<float>(node->input(7))) { |
1823 | eps_ptr = IrBuilder::create<Double>(eps.value()); |
1824 | } else { |
1825 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
1826 | eps_ptr = value_map[node->input(7)->unique()]; |
1827 | } |
1828 | |
1829 | auto result = batch_norm( |
1830 | input, |
1831 | weight, |
1832 | bias, |
1833 | running_mean, |
1834 | running_var, |
1835 | kTraining, |
1836 | momentum_ptr, |
1837 | eps_ptr, |
1838 | format.isChannelsLast()); |
1839 | |
1840 | if (node->kind() == |
1841 | c10::Symbol::fromQualString("aten::native_batch_norm" ) || |
1842 | node->kind() == |
1843 | c10::Symbol::fromQualString( |
1844 | "aten::_batch_norm_impl_index" )) { |
1845 | // TODO: output 3 & 4 are not created |
1846 | // we are not creating these outputs because codegen |
1847 | // currently lacks the support. |
1848 | value_map.emplace( |
1849 | node->output(0)->unique(), |
1850 | ValueHolder(result.output, format)); |
1851 | value_map.emplace(node->output(1)->unique(), result.mean); |
1852 | value_map.emplace(node->output(2)->unique(), result.invstd); |
1853 | } else if ( |
1854 | node->kind() == |
1855 | c10::Symbol::fromQualString("aten::batch_norm" )) { |
1856 | value_map.emplace( |
1857 | node->output()->unique(), |
1858 | ValueHolder(result.output, format)); |
1859 | } |
1860 | }, |
1861 | [](const Node* node) -> bool { |
1862 | if (isReductionNonCompatibleTensor( |
1863 | node->input(0)->type()->cast<TensorType>())) { |
1864 | return false; |
1865 | } |
1866 | if (node->input(5)->node()->kind() != prim::Constant) { |
1867 | return false; |
1868 | } |
1869 | return true; |
1870 | }, |
1871 | [](const Node* node) -> OperatorType { |
1872 | return OperatorType::Normalization; |
1873 | }); |
1874 | } |
1875 | } |
1876 | |
1877 | { |
1878 | std::array<const char*, kNumBatchnormBwd> BatchNormBwd = { |
1879 | "aten::_batch_norm_impl_index_backward(int impl_index, Tensor input, Tensor grad_output, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var_transform, bool train, float eps, bool[3] output_mask, Tensor reservedSpace) -> (Tensor, Tensor, Tensor)" , |
1880 | "aten::native_batch_norm_backward(Tensor grad_out, Tensor input, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_invstd, bool train, float eps, bool[3] output_mask) -> (Tensor, Tensor, Tensor)" }; |
1881 | for (auto signature : BatchNormBwd) { |
1882 | auto ptr_op = getOperatorForLiteral(signature); |
1883 | REGISTER_PARSE_RULE( |
1884 | ptr_op, |
1885 | { |
1886 | JitValue* ts_input = nullptr; |
1887 | JitValue* ts_grad_output; |
1888 | JitValue* ts_weight = nullptr; |
1889 | JitValue* ts_r_mean = nullptr; |
1890 | JitValue* ts_r_var = nullptr; |
1891 | JitValue* ts_save_mean = nullptr; |
1892 | JitValue* ts_save_invstd = nullptr; |
1893 | JitValue* ts_train = nullptr; |
1894 | JitValue* ts_eps = nullptr; |
1895 | JitValue* ts_mask = nullptr; |
1896 | if (node->kind() == |
1897 | c10::Symbol::fromQualString( |
1898 | "aten::_batch_norm_impl_index_backward" )) { |
1899 | ts_input = node->input(1); |
1900 | ts_grad_output = node->input(2); |
1901 | ts_weight = node->input(3); |
1902 | ts_r_mean = node->input(4); |
1903 | ts_r_var = node->input(5); |
1904 | ts_save_mean = node->input(6); |
1905 | ts_save_invstd = node->input(7); |
1906 | ts_train = node->input(8); |
1907 | ts_eps = node->input(9); |
1908 | ts_mask = node->input(10); |
1909 | } else if ( |
1910 | node->kind() == |
1911 | c10::Symbol::fromQualString( |
1912 | "aten::native_batch_norm_backward" )) { |
1913 | ts_grad_output = node->input(0); |
1914 | ts_input = node->input(1); |
1915 | ts_weight = node->input(2); |
1916 | ts_r_mean = node->input(3); |
1917 | ts_r_var = node->input(4); |
1918 | ts_save_mean = node->input(5); |
1919 | ts_save_invstd = node->input(6); |
1920 | ts_train = node->input(7); |
1921 | ts_eps = node->input(8); |
1922 | ts_mask = node->input(9); |
1923 | } else { |
1924 | TORCH_INTERNAL_ASSERT( |
1925 | false, |
1926 | "Forgot to register the key for BN variation: " , |
1927 | node->kind().toDisplayString()); |
1928 | } |
1929 | |
1930 | // discard impl_index and reservedSpace since we don't use them |
1931 | MemoryFormat format; |
1932 | std::list<Val*> list_val; |
1933 | std::tie(format, list_val) = getConsistentValues( |
1934 | c10::nullopt, |
1935 | value_map[ts_input->unique()], |
1936 | value_map[ts_grad_output->unique()]); |
1937 | if (format.hasPermutation() && !format.isChannelsLast()) { |
1938 | std::tie(format, list_val) = getConsistentValues( |
1939 | MemoryFormat::Contiguous(), |
1940 | value_map[ts_input->unique()], |
1941 | value_map[ts_grad_output->unique()]); |
1942 | } |
1943 | auto operand0 = list_val.front(); |
1944 | list_val.pop_front(); |
1945 | auto operand1 = list_val.front(); |
1946 | list_val.pop_front(); |
1947 | auto input = operand0->as<TensorView>(); |
1948 | auto grad_out = operand1->as<TensorView>(); |
1949 | |
1950 | TensorView* weight = nullptr; |
1951 | if (!ts_weight->type()->isSubtypeOf( |
1952 | static_cast<c10::TypePtr>(NoneType::get()))) { |
1953 | weight = value_map[ts_weight->unique()]->as<TensorView>(); |
1954 | } |
1955 | |
1956 | TensorView* running_mean = nullptr; |
1957 | if (!ts_r_mean->type()->isSubtypeOf( |
1958 | static_cast<c10::TypePtr>(NoneType::get()))) { |
1959 | running_mean = value_map[ts_r_mean->unique()]->as<TensorView>(); |
1960 | } |
1961 | |
1962 | TensorView* running_var = nullptr; |
1963 | if (!ts_r_var->type()->isSubtypeOf( |
1964 | static_cast<c10::TypePtr>(NoneType::get()))) { |
1965 | running_var = value_map[ts_r_var->unique()]->as<TensorView>(); |
1966 | } |
1967 | |
1968 | TensorView* save_mean = nullptr; |
1969 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
1970 | if (!ts_save_mean->type()->isSubtypeOf( |
1971 | static_cast<c10::TypePtr>(NoneType::get()))) { |
1972 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
1973 | save_mean = value_map[ts_save_mean->unique()]->as<TensorView>(); |
1974 | } |
1975 | |
1976 | TensorView* save_invstd = nullptr; |
1977 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
1978 | if (!ts_save_invstd->type()->isSubtypeOf( |
1979 | static_cast<c10::TypePtr>(NoneType::get()))) { |
1980 | save_invstd = |
1981 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
1982 | value_map[ts_save_invstd->unique()]->as<TensorView>(); |
1983 | } |
1984 | |
1985 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
1986 | auto training = constant_as<bool>(ts_train); |
1987 | TORCH_INTERNAL_ASSERT( |
1988 | training.has_value(), |
1989 | "The training (bool) parameter is required." ); |
1990 | const bool kTraining = training.value(); |
1991 | |
1992 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
1993 | Val* eps_ptr = nullptr; |
1994 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
1995 | if (auto eps = constant_as<float>(ts_eps)) { |
1996 | eps_ptr = IrBuilder::create<Double>(eps.value()); |
1997 | } else { |
1998 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
1999 | eps_ptr = value_map[ts_eps->unique()]; |
2000 | } |
2001 | |
2002 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
2003 | auto out_mask_list = constant_as<c10::List<bool>>(ts_mask); |
2004 | TORCH_INTERNAL_ASSERT( |
2005 | out_mask_list.has_value(), |
2006 | "output mask for batch_norm_backward" ); |
2007 | std::vector<bool> output_mask; |
2008 | for (const auto value : out_mask_list->vec()) { |
2009 | output_mask.emplace_back(static_cast<bool>(value)); |
2010 | } |
2011 | |
2012 | // TODO: merge this loop below. |
2013 | if (kTraining) { |
2014 | TORCH_INTERNAL_ASSERT( |
2015 | save_mean != nullptr && save_invstd != nullptr, |
2016 | "When training=True, save_mean and save_invstd are required." ); |
2017 | } else { |
2018 | // TODO: this is not a legit assumption? Can't we run with |
2019 | // track_running_stats == false && training == false |
2020 | // which should just run through the case above. |
2021 | TORCH_INTERNAL_ASSERT( |
2022 | running_mean != nullptr && running_var != nullptr, |
2023 | "When training=False, running_mean and running_invstd are required." ); |
2024 | } |
2025 | |
2026 | auto grads = batch_norm_backward( |
2027 | input, |
2028 | grad_out, |
2029 | weight, |
2030 | running_mean, |
2031 | running_var, |
2032 | save_mean, |
2033 | save_invstd, |
2034 | kTraining, |
2035 | eps_ptr, |
2036 | output_mask, |
2037 | format.isChannelsLast()); |
2038 | |
2039 | if (output_mask[0]) { |
2040 | TORCH_INTERNAL_ASSERT(grads.grad_input != nullptr); |
2041 | value_map.emplace( |
2042 | node->output(0)->unique(), |
2043 | ValueHolder(grads.grad_input, format)); |
2044 | } else { |
2045 | TORCH_INTERNAL_ASSERT(grads.grad_input == nullptr); |
2046 | value_map.emplace( |
2047 | node->output(0)->unique(), |
2048 | ValueHolder(TensorViewBuilder().build(), format)); |
2049 | } |
2050 | |
2051 | if (output_mask[1]) { |
2052 | TORCH_INTERNAL_ASSERT(grads.grad_weight != nullptr); |
2053 | value_map.emplace(node->output(1)->unique(), grads.grad_weight); |
2054 | } else { |
2055 | TORCH_INTERNAL_ASSERT(grads.grad_weight == nullptr); |
2056 | value_map.emplace( |
2057 | node->output(1)->unique(), TensorViewBuilder().build()); |
2058 | } |
2059 | |
2060 | if (output_mask[2]) { |
2061 | TORCH_INTERNAL_ASSERT(grads.grad_bias != nullptr); |
2062 | value_map.emplace(node->output(2)->unique(), grads.grad_bias); |
2063 | } else { |
2064 | TORCH_INTERNAL_ASSERT(grads.grad_bias == nullptr); |
2065 | value_map.emplace( |
2066 | node->output(2)->unique(), TensorViewBuilder().build()); |
2067 | } |
2068 | }, |
2069 | [](const Node* node) -> bool { |
2070 | if (isReductionNonCompatibleTensor( |
2071 | node->input(1)->type()->cast<TensorType>())) { |
2072 | return false; |
2073 | } |
2074 | if (node->kind() == |
2075 | c10::Symbol::fromQualString( |
2076 | "aten::_batch_norm_impl_index_backward" )) { |
2077 | if (node->inputs()[8]->node()->kind() != prim::Constant) { |
2078 | return false; |
2079 | } |
2080 | if (node->inputs()[10]->node()->kind() != prim::Constant) { |
2081 | return false; |
2082 | } |
2083 | } else if ( |
2084 | node->kind() == |
2085 | c10::Symbol::fromQualString( |
2086 | "aten::native_batch_norm_backward" )) { |
2087 | if (node->inputs()[7]->node()->kind() != prim::Constant) { |
2088 | return false; |
2089 | } |
2090 | if (node->inputs()[9]->node()->kind() != prim::Constant) { |
2091 | return false; |
2092 | } |
2093 | } else { |
2094 | TORCH_INTERNAL_ASSERT( |
2095 | false, |
2096 | "Forgot to update profiled constant check for" , |
2097 | node->kind().toDisplayString()); |
2098 | } |
2099 | return true; |
2100 | }, |
2101 | [](const Node* node) -> OperatorType { |
2102 | return OperatorType::Normalization; |
2103 | }); |
2104 | } |
2105 | } |
2106 | |
2107 | { |
2108 | std::array<const char*, kNumLayernormFwd> LayerNormFwd = { |
2109 | "aten::native_layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor)" , |
2110 | "aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor" }; |
2111 | for (auto signature : LayerNormFwd) { |
2112 | auto ptr_op = getOperatorForLiteral(signature); |
2113 | REGISTER_PARSE_RULE( |
2114 | ptr_op, |
2115 | { |
2116 | MemoryFormat format; |
2117 | std::list<Val*> list_val; |
2118 | std::tie(format, list_val) = getConsistentValues( |
2119 | MemoryFormat::Contiguous(), |
2120 | value_map[node->inputs()[0]->unique()]); |
2121 | auto input_t = list_val.front(); |
2122 | list_val.pop_front(); |
2123 | auto input = input_t->as<TensorView>(); |
2124 | |
2125 | auto norm_shape_optional = |
2126 | constant_as<c10::List<int64_t>>(node->input(1)); |
2127 | TORCH_INTERNAL_ASSERT( |
2128 | norm_shape_optional.has_value(), |
2129 | "The Normalized_Shape list is required." ); |
2130 | auto norm_shape = norm_shape_optional->vec(); |
2131 | |
2132 | TensorView* weight = nullptr; |
2133 | if (!node->input(2)->type()->isSubtypeOf( |
2134 | static_cast<c10::TypePtr>(NoneType::get()))) { |
2135 | weight = value_map[node->input(2)->unique()]->as<TensorView>(); |
2136 | } |
2137 | |
2138 | TensorView* bias = nullptr; |
2139 | if (!node->input(3)->type()->isSubtypeOf( |
2140 | static_cast<c10::TypePtr>(NoneType::get()))) { |
2141 | bias = value_map[node->input(3)->unique()]->as<TensorView>(); |
2142 | } |
2143 | |
2144 | Val* eps_ptr = nullptr; |
2145 | if (auto eps = constant_as<float>(node->input(4))) { |
2146 | eps_ptr = IrBuilder::create<Double>(eps.value()); |
2147 | } else { |
2148 | eps_ptr = value_map[node->input(4)->unique()]; |
2149 | } |
2150 | |
2151 | auto result = |
2152 | layer_norm(input, norm_shape, weight, bias, eps_ptr); |
2153 | |
2154 | if (node->kind() == |
2155 | c10::Symbol::fromQualString("aten::native_layer_norm" )) { |
2156 | value_map.emplace(node->output(0)->unique(), result.output); |
2157 | value_map.emplace(node->output(1)->unique(), result.mean); |
2158 | value_map.emplace(node->output(2)->unique(), result.invstd); |
2159 | } else if ( |
2160 | node->kind() == |
2161 | c10::Symbol::fromQualString("aten::layer_norm" )) { |
2162 | value_map.emplace(node->output()->unique(), result.output); |
2163 | } |
2164 | }, |
2165 | // TODO: #ProfileIValue List should update this |
2166 | [](const Node* node) -> bool { |
2167 | if (isReductionNonCompatibleTensor( |
2168 | node->input(0)->type()->cast<TensorType>())) { |
2169 | return false; |
2170 | } |
2171 | if (node->inputs()[1]->node()->kind() != prim::Constant) { |
2172 | return false; |
2173 | } |
2174 | return true; |
2175 | }, |
2176 | [](const Node* node) -> OperatorType { |
2177 | return OperatorType::Normalization; |
2178 | }); |
2179 | } |
2180 | } |
2181 | |
2182 | { |
2183 | auto ptr_op = getOperatorForLiteral( |
2184 | "aten::native_layer_norm_backward(Tensor grad_out, Tensor input, int[] normalized_shape, Tensor mean, Tensor rstd, Tensor? weight, Tensor? bias, bool[3] output_mask) -> (Tensor, Tensor, Tensor)" ); |
2185 | REGISTER_PARSE_RULE( |
2186 | ptr_op, |
2187 | { |
2188 | MemoryFormat format; |
2189 | std::list<Val*> list_val; |
2190 | std::tie(format, list_val) = getConsistentValues( |
2191 | MemoryFormat::Contiguous(), |
2192 | value_map[node->inputs()[0]->unique()], |
2193 | value_map[node->inputs()[1]->unique()]); |
2194 | auto grad_out_t = list_val.front(); |
2195 | list_val.pop_front(); |
2196 | auto input_t = list_val.front(); |
2197 | list_val.pop_front(); |
2198 | auto grad_out = grad_out_t->as<TensorView>(); |
2199 | auto input = input_t->as<TensorView>(); |
2200 | |
2201 | auto norm_shape_optional = |
2202 | constant_as<c10::List<int64_t>>(node->input(2)); |
2203 | TORCH_INTERNAL_ASSERT( |
2204 | norm_shape_optional.has_value(), |
2205 | "The Normalized_Shape list is required." ); |
2206 | auto norm_shape = norm_shape_optional->vec(); |
2207 | |
2208 | auto mean = value_map[node->input(3)->unique()]->as<TensorView>(); |
2209 | auto rstd = value_map[node->input(4)->unique()]->as<TensorView>(); |
2210 | |
2211 | TensorView* weight = nullptr; |
2212 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
2213 | if (!node->input(5)->type()->isSubtypeOf( |
2214 | static_cast<c10::TypePtr>(NoneType::get()))) { |
2215 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
2216 | weight = value_map[node->input(5)->unique()]->as<TensorView>(); |
2217 | } |
2218 | |
2219 | TensorView* bias = nullptr; |
2220 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
2221 | if (!node->input(6)->type()->isSubtypeOf( |
2222 | static_cast<c10::TypePtr>(NoneType::get()))) { |
2223 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
2224 | bias = value_map[node->input(6)->unique()]->as<TensorView>(); |
2225 | } |
2226 | |
2227 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
2228 | auto output_mask_optional = |
2229 | constant_as<c10::List<bool>>(node->input(7)); |
2230 | TORCH_INTERNAL_ASSERT( |
2231 | output_mask_optional.has_value(), |
2232 | "output mask for layer_norm_backward" ); |
2233 | std::vector<bool> output_mask = output_mask_optional->vec(); |
2234 | |
2235 | auto grad = layer_norm_backward( |
2236 | grad_out, |
2237 | input, |
2238 | norm_shape, |
2239 | mean, |
2240 | rstd, |
2241 | weight, |
2242 | bias, |
2243 | output_mask); |
2244 | |
2245 | if (output_mask[0]) { |
2246 | TORCH_INTERNAL_ASSERT(grad.grad_input != nullptr); |
2247 | value_map.emplace(node->output(0)->unique(), grad.grad_input); |
2248 | } else { |
2249 | TORCH_INTERNAL_ASSERT(grad.grad_input == nullptr); |
2250 | value_map.emplace( |
2251 | node->output(0)->unique(), TensorViewBuilder().build()); |
2252 | } |
2253 | |
2254 | if (output_mask[1] && weight != nullptr) { |
2255 | TORCH_INTERNAL_ASSERT(grad.grad_weight != nullptr); |
2256 | value_map.emplace(node->output(1)->unique(), grad.grad_weight); |
2257 | } else { |
2258 | TORCH_INTERNAL_ASSERT(grad.grad_weight == nullptr); |
2259 | value_map.emplace( |
2260 | node->output(1)->unique(), TensorViewBuilder().build()); |
2261 | } |
2262 | |
2263 | if (output_mask[2] && bias != nullptr) { |
2264 | TORCH_INTERNAL_ASSERT(grad.grad_bias != nullptr); |
2265 | value_map.emplace(node->output(2)->unique(), grad.grad_bias); |
2266 | } else { |
2267 | TORCH_INTERNAL_ASSERT(grad.grad_bias == nullptr); |
2268 | value_map.emplace( |
2269 | node->output(2)->unique(), TensorViewBuilder().build()); |
2270 | } |
2271 | }, |
2272 | // TODO: #ProfileIValue List should update this |
2273 | [](const Node* node) -> bool { |
2274 | if (isReductionNonCompatibleTensor( |
2275 | node->input(0)->type()->cast<TensorType>())) { |
2276 | return false; |
2277 | } |
2278 | if (node->inputs()[2]->node()->kind() != prim::Constant) { |
2279 | return false; |
2280 | } |
2281 | if (node->inputs()[7]->node()->kind() != prim::Constant) { |
2282 | return false; |
2283 | } |
2284 | return true; |
2285 | }, |
2286 | [](const Node* node) -> OperatorType { |
2287 | return OperatorType::Normalization; |
2288 | }); |
2289 | } |
2290 | |
2291 | { |
2292 | std::array<const char*, kNumSoftmaxFwd> SoftmaxFwd = { |
2293 | "aten::softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor" , |
2294 | "aten::log_softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor" }; |
2295 | for (auto signature : SoftmaxFwd) { |
2296 | auto ptr_op = getOperatorForLiteral(signature); |
2297 | REGISTER_PARSE_RULE( |
2298 | ptr_op, |
2299 | { |
2300 | MemoryFormat format; |
2301 | std::list<Val*> list_val; |
2302 | std::tie(format, list_val) = getConsistentValues( |
2303 | MemoryFormat::Contiguous(), |
2304 | value_map[node->inputs()[0]->unique()]); |
2305 | auto input_t = list_val.front(); |
2306 | list_val.pop_front(); |
2307 | auto input = input_t->as<TensorView>(); |
2308 | |
2309 | auto dim_value = constant_as<int>(node->input(1)); |
2310 | TORCH_INTERNAL_ASSERT( |
2311 | dim_value.has_value(), "dim in softmax is not valid" ); |
2312 | |
2313 | auto data_type = DataType::Null; |
2314 | if (const auto opt_ivalue = toIValue(node->input(2))) { |
2315 | if (!opt_ivalue->isNone()) { |
2316 | data_type = aten_to_data_type(opt_ivalue->toScalarType()); |
2317 | } |
2318 | } |
2319 | |
2320 | input = (data_type != DataType::Null) |
2321 | ? optionalCastStrict(data_type, input)->as<TensorView>() |
2322 | : input; |
2323 | |
2324 | bool is_log_softmax = node->kind() == |
2325 | c10::Symbol::fromQualString("aten::log_softmax" ); |
2326 | |
2327 | auto output = (is_log_softmax) |
2328 | ? log_softmax(input, dim_value.value()) |
2329 | : softmax(input, dim_value.value()); |
2330 | |
2331 | value_map.emplace(node->output()->unique(), output); |
2332 | }, |
2333 | [](const Node* node) -> bool { |
2334 | if (isReductionNonCompatibleTensor( |
2335 | node->input(0)->type()->cast<TensorType>())) { |
2336 | return false; |
2337 | } |
2338 | if (node->inputs()[1]->node()->kind() != prim::Constant) { |
2339 | return false; |
2340 | } |
2341 | if (!isScalarTypeCompatible(node, 2)) { |
2342 | return false; |
2343 | } |
2344 | return true; |
2345 | }, |
2346 | [](const Node* node) -> OperatorType { |
2347 | return OperatorType::Normalization; |
2348 | }); |
2349 | } |
2350 | } |
2351 | |
2352 | { // LTC uses this op for softmax |
2353 | auto ptr_op = getOperatorForLiteral( |
2354 | "aten::_softmax(Tensor self, int dim, bool half_to_float) -> Tensor" ); |
2355 | REGISTER_PARSE_RULE( |
2356 | ptr_op, |
2357 | { |
2358 | MemoryFormat format; |
2359 | std::list<Val*> list_val; |
2360 | std::tie(format, list_val) = getConsistentValues( |
2361 | MemoryFormat::Contiguous(), |
2362 | value_map[node->inputs()[0]->unique()]); |
2363 | auto input_t = list_val.front(); |
2364 | list_val.pop_front(); |
2365 | auto input = input_t->as<TensorView>(); |
2366 | |
2367 | auto dim_value = constant_as<int>(node->input(1)); |
2368 | TORCH_INTERNAL_ASSERT( |
2369 | dim_value.has_value(), "dim in softmax is not valid" ); |
2370 | |
2371 | auto output = softmax(input, dim_value.value()); |
2372 | value_map.emplace(node->output()->unique(), output); |
2373 | }, |
2374 | [](const Node* node) -> bool { |
2375 | if (isReductionNonCompatibleTensor( |
2376 | node->input(0)->type()->cast<TensorType>())) { |
2377 | return false; |
2378 | } |
2379 | if (node->inputs()[1]->node()->kind() != prim::Constant) { |
2380 | return false; |
2381 | } |
2382 | if (node->inputs()[2]->node()->kind() != prim::Constant) { |
2383 | return false; |
2384 | } else { |
2385 | const auto half_to_float = constant_as<bool>(node->input(2)); |
2386 | TORCH_INTERNAL_ASSERT( |
2387 | half_to_float.has_value(), "Bool half_to_float is not valid" ); |
2388 | auto input_tensor_type = |
2389 | node->input(0)->type()->cast<TensorType>(); |
2390 | if (half_to_float.value() && |
2391 | input_tensor_type->scalarType() != at::ScalarType::Half) { |
2392 | return false; |
2393 | } |
2394 | } |
2395 | return true; |
2396 | }, |
2397 | [](const Node* node) -> OperatorType { |
2398 | return OperatorType::Normalization; |
2399 | }); |
2400 | } |
2401 | |
2402 | { |
2403 | std::array<const char*, kNumSoftmaxBwd> SoftmaxBwd = { |
2404 | "aten::_log_softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor" , |
2405 | "aten::_softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor" }; |
2406 | for (auto signature : SoftmaxBwd) { |
2407 | auto ptr_op = getOperatorForLiteral(signature); |
2408 | REGISTER_PARSE_RULE( |
2409 | ptr_op, |
2410 | { |
2411 | MemoryFormat format; |
2412 | std::list<Val*> list_val; |
2413 | std::tie(format, list_val) = getConsistentValues( |
2414 | MemoryFormat::Contiguous(), |
2415 | value_map[node->inputs()[0]->unique()], |
2416 | value_map[node->inputs()[1]->unique()]); |
2417 | auto grad_output_t = list_val.front(); |
2418 | list_val.pop_front(); |
2419 | auto grad_output = grad_output_t->as<TensorView>(); |
2420 | |
2421 | auto output_t = list_val.front(); |
2422 | list_val.pop_front(); |
2423 | auto output = output_t->as<TensorView>(); |
2424 | |
2425 | auto dim_value = constant_as<int>(node->input(2)); |
2426 | TORCH_INTERNAL_ASSERT( |
2427 | dim_value.has_value(), "dim in softmax is not valid" ); |
2428 | |
2429 | // input_dtype here is ignored! type_inference handles it |
2430 | bool is_log_softmax = node->kind() == |
2431 | c10::Symbol::fromQualString( |
2432 | "aten::_log_softmax_backward_data" ); |
2433 | auto grad_input = (is_log_softmax) |
2434 | ? log_softmax_backward(grad_output, output, dim_value.value()) |
2435 | : softmax_backward(grad_output, output, dim_value.value()); |
2436 | |
2437 | value_map.emplace(node->output()->unique(), grad_input); |
2438 | }, |
2439 | [](const Node* node) -> bool { |
2440 | if (isReductionNonCompatibleTensor( |
2441 | node->input(0)->type()->cast<TensorType>())) { |
2442 | return false; |
2443 | } |
2444 | if (node->inputs()[2]->node()->kind() != prim::Constant) { |
2445 | return false; |
2446 | } |
2447 | if (node->inputs()[3]->node()->kind() != prim::Constant) { |
2448 | return false; |
2449 | } |
2450 | return true; |
2451 | }, |
2452 | [](const Node* node) -> OperatorType { |
2453 | return OperatorType::Normalization; |
2454 | }); |
2455 | } |
2456 | } |
2457 | |
2458 | { |
2459 | std::array<const char*, kNumVarOps> Variance = { |
2460 | "aten::var.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> Tensor" , |
2461 | "aten::std.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> Tensor" }; |
2462 | for (auto signature : Variance) { |
2463 | auto ptr_op = getOperatorForLiteral(signature); |
2464 | REGISTER_PARSE_RULE( |
2465 | ptr_op, |
2466 | { |
2467 | MemoryFormat format; |
2468 | std::list<Val*> list_val; |
2469 | std::tie(format, list_val) = getConsistentValues( |
2470 | MemoryFormat::Contiguous(), |
2471 | value_map[node->inputs()[0]->unique()]); |
2472 | auto input_t = list_val.front(); |
2473 | list_val.pop_front(); |
2474 | auto input = input_t->as<TensorView>(); |
2475 | |
2476 | bool is_variance = |
2477 | node->kind() == c10::Symbol::fromQualString("aten::var" ); |
2478 | |
2479 | auto dims_list = constant_as<c10::List<int64_t>>(node->input(1)); |
2480 | TORCH_INTERNAL_ASSERT( |
2481 | dims_list.has_value(), "Cannot fuse with dynamic axes" ); |
2482 | std::vector<int> dims; |
2483 | if (!dims_list->empty()) { |
2484 | for (const auto dim : dims_list->vec()) { |
2485 | dims.emplace_back(static_cast<int>(dim)); |
2486 | } |
2487 | } else { |
2488 | dims.resize(input->as<TensorView>()->nDims()); |
2489 | std::iota(dims.begin(), dims.end(), 0); |
2490 | } |
2491 | |
2492 | auto unbiased = constant_as<bool>(node->input(2)); |
2493 | TORCH_INTERNAL_ASSERT( |
2494 | unbiased.has_value(), "Cannot fuse with dynamic unbiased" ); |
2495 | |
2496 | auto keepdim = constant_as<bool>(node->input(3)); |
2497 | TORCH_INTERNAL_ASSERT( |
2498 | keepdim.has_value(), "Cannot fuse with dynamic keepdim" ); |
2499 | |
2500 | auto output = (is_variance) |
2501 | ? variance(input, dims, unbiased.value(), keepdim.value()) |
2502 | : standard_deviation( |
2503 | input, dims, unbiased.value(), keepdim.value()); |
2504 | value_map.emplace(node->output()->unique(), output); |
2505 | }, |
2506 | [](const Node* node) -> bool { |
2507 | if (isReductionNonCompatibleTensor( |
2508 | node->input(0)->type()->cast<TensorType>())) { |
2509 | return false; |
2510 | } |
2511 | return true; |
2512 | }, |
2513 | [](const Node* node) -> OperatorType { |
2514 | return OperatorType::Normalization; |
2515 | }); |
2516 | } |
2517 | } |
2518 | |
2519 | { |
2520 | auto ptr_op = getOperatorForLiteral( |
2521 | "aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)" ); |
2522 | REGISTER_PARSE_RULE( |
2523 | ptr_op, |
2524 | { |
2525 | // TODO: support channels last in sum |
2526 | MemoryFormat format; |
2527 | std::list<Val*> list_val; |
2528 | std::tie(format, list_val) = getConsistentValues( |
2529 | MemoryFormat::Contiguous(), |
2530 | value_map[node->inputs()[0]->unique()]); |
2531 | auto self = list_val.front(); |
2532 | list_val.pop_front(); |
2533 | auto dims_list = constant_as<c10::List<int64_t>>(node->input(1)); |
2534 | TORCH_INTERNAL_ASSERT( |
2535 | dims_list.has_value(), |
2536 | "aten::sum cannot be fused with dynamic axes" ); |
2537 | std::vector<int> dims; |
2538 | if (!dims_list->empty()) { |
2539 | for (const auto dim : dims_list->vec()) { |
2540 | dims.emplace_back(static_cast<int>(dim)); |
2541 | } |
2542 | } else { |
2543 | dims.resize(self->as<TensorView>()->nDims()); |
2544 | std::iota(dims.begin(), dims.end(), 0); |
2545 | } |
2546 | auto keepdim = constant_as<bool>(node->input(2)); |
2547 | TORCH_INTERNAL_ASSERT( |
2548 | keepdim.has_value(), |
2549 | "aten::sum cannot be fused with dynamic keepdim" ); |
2550 | auto out = sum(self->as<TensorView>(), dims, keepdim.value()); |
2551 | value_map.emplace(node->output()->unique(), out); |
2552 | }, |
2553 | [](const Node* node) -> bool { |
2554 | if (isReductionNonCompatibleTensor( |
2555 | node->input(0)->type()->cast<TensorType>())) { |
2556 | return false; |
2557 | } |
2558 | // TODO: support cast of output types |
2559 | if (!node->inputs()[3]->type()->isSubtypeOf( |
2560 | static_cast<c10::TypePtr>(NoneType::get()))) { |
2561 | // We can only handle output as half, float, and double; |
2562 | if (const auto opt_ivalue = toIValue(node->input(3))) { |
2563 | const auto scalar_type = opt_ivalue->toScalarType(); |
2564 | if (!at::isFloatingType(scalar_type)) { |
2565 | return false; |
2566 | } |
2567 | } |
2568 | } |
2569 | // we don't support dynamic reduction axes; |
2570 | if (node->inputs()[1]->node()->kind() != prim::Constant) { |
2571 | return false; |
2572 | } |
2573 | // we don't support dynamic keepdim yet; |
2574 | if (node->inputs()[2]->node()->kind() != prim::Constant) { |
2575 | return false; |
2576 | } |
2577 | return true; |
2578 | }, |
2579 | [](const Node* node) -> OperatorType { |
2580 | return OperatorType::Reduction; |
2581 | }); |
2582 | } |
2583 | |
2584 | { |
2585 | auto ptr_op = getOperatorForLiteral( |
2586 | "aten::mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor" ); |
2587 | REGISTER_PARSE_RULE( |
2588 | ptr_op, |
2589 | { |
2590 | MemoryFormat format; |
2591 | std::list<Val*> list_val; |
2592 | std::tie(format, list_val) = getConsistentValues( |
2593 | MemoryFormat::Contiguous(), |
2594 | value_map[node->inputs()[0]->unique()]); |
2595 | auto operand = list_val.front(); |
2596 | list_val.pop_front(); |
2597 | auto self = operand->as<TensorView>(); |
2598 | auto dims_list = constant_as<c10::List<int64_t>>(node->input(1)); |
2599 | TORCH_INTERNAL_ASSERT( |
2600 | dims_list.has_value(), |
2601 | "aten::mean cannot be fused with dynamic axes" ); |
2602 | std::vector<int> dims; |
2603 | if (!dims_list->empty()) { |
2604 | for (const auto dim : dims_list->vec()) { |
2605 | dims.emplace_back(static_cast<int>(dim)); |
2606 | } |
2607 | } else { |
2608 | dims.resize(self->as<TensorView>()->nDims()); |
2609 | std::iota(dims.begin(), dims.end(), 0); |
2610 | } |
2611 | auto keepdim = constant_as<bool>(node->input(2)); |
2612 | TORCH_INTERNAL_ASSERT( |
2613 | keepdim.has_value(), |
2614 | "aten::mean cannot be fused with dynamic keepdim" ); |
2615 | auto o_sum = sum(self, dims, keepdim.value()); |
2616 | Val* num_features = IrBuilder::create<Double>(1); |
2617 | for (auto axis : dims) { |
2618 | if (axis < 0) { |
2619 | axis += int(self->nDims()); |
2620 | } |
2621 | num_features = |
2622 | mul(num_features, self->domain()->domain()[axis]->extent()); |
2623 | } |
2624 | auto out = div(o_sum, num_features); |
2625 | value_map.emplace(node->output()->unique(), out); |
2626 | }, |
2627 | [](const Node* node) -> bool { |
2628 | if (isReductionNonCompatibleTensor( |
2629 | node->input(0)->type()->cast<TensorType>())) { |
2630 | return false; |
2631 | } |
2632 | // TODO: support cast of output types |
2633 | if (!node->inputs()[3]->type()->isSubtypeOf( |
2634 | static_cast<c10::TypePtr>(NoneType::get()))) { |
2635 | // We can only handle output as half, float, and double; |
2636 | if (const auto opt_ivalue = toIValue(node->input(3))) { |
2637 | const auto scalar_type = opt_ivalue->toScalarType(); |
2638 | if (!at::isFloatingType(scalar_type)) { |
2639 | return false; |
2640 | } |
2641 | } |
2642 | } |
2643 | // we don't support dynamic reduction axes; |
2644 | if (node->inputs()[1]->node()->kind() != prim::Constant) { |
2645 | return false; |
2646 | } |
2647 | // we don't support dynamic keepdim yet; |
2648 | if (node->inputs()[2]->node()->kind() != prim::Constant) { |
2649 | return false; |
2650 | } |
2651 | return true; |
2652 | }, |
2653 | [](const Node* node) -> OperatorType { |
2654 | return OperatorType::Reduction; |
2655 | }); |
2656 | } |
2657 | { |
2658 | std::array<const char*, kNumSumToSize> SumToSize = { |
2659 | "aten::_grad_sum_to_size(Tensor(a) self, int[]? size) -> Tensor(a)" , |
2660 | "aten::sum_to_size(Tensor self, int[] size) -> Tensor" }; |
2661 | for (auto signature : SumToSize) { |
2662 | auto ptr_op = getOperatorForLiteral(signature); |
2663 | REGISTER_PARSE_RULE( |
2664 | ptr_op, |
2665 | { |
2666 | MemoryFormat format; |
2667 | std::list<Val*> list_val; |
2668 | std::tie(format, list_val) = getConsistentValues( |
2669 | MemoryFormat::Contiguous(), |
2670 | value_map[node->inputs()[0]->unique()]); |
2671 | auto self = list_val.front(); |
2672 | list_val.pop_front(); |
2673 | auto size_to = constant_as<c10::List<int64_t>>(node->input(1)); |
2674 | TORCH_INTERNAL_ASSERT( |
2675 | size_to.has_value(), |
2676 | "aten::sum cannot be fused with dynamic axes" ); |
2677 | if (!size_to->empty()) { |
2678 | auto input = self->as<TensorView>(); |
2679 | auto out = sum_to(input, size_to->vec()); |
2680 | // this copy is not necessary, but making copy avoids tricky |
2681 | // computational graph where no-op could be challenging. |
2682 | if (out == input) { |
2683 | out = set(input); |
2684 | } |
2685 | value_map.emplace(node->output()->unique(), out); |
2686 | } else { |
2687 | // We are introducing alias here! |
2688 | value_map.emplace(node->output()->unique(), self); |
2689 | } |
2690 | }, |
2691 | [](const Node* node) -> bool { |
2692 | if (isReductionNonCompatibleTensor( |
2693 | node->input(0)->type()->cast<TensorType>())) { |
2694 | return false; |
2695 | } |
2696 | // we don't support dynamic reduction axes; |
2697 | if (node->inputs()[1]->node()->kind() != prim::Constant) { |
2698 | return false; |
2699 | } |
2700 | return true; |
2701 | }, |
2702 | [](const Node* node) -> OperatorType { |
2703 | auto size_to = constant_as<c10::List<int64_t>>(node->input(1)); |
2704 | // technically size_to->empty() should never occur, as specialized |
2705 | // _grad_sum_to_size should have been removed by optimization pass |
2706 | if (size_to->empty()) { |
2707 | return OperatorType::ElementWise; |
2708 | } else { |
2709 | return OperatorType::ReductionToSize; |
2710 | } |
2711 | }); |
2712 | } |
2713 | } |
2714 | |
2715 | { |
2716 | std::array<const char*, kNumAutocastOps> AutocastOps = { |
2717 | "aten::_autocast_to_reduced_precision(Tensor(a) self, bool cuda_enabled, bool cpu_enabled, ScalarType cuda_dtype, ScalarType cpu_dtype) -> Tensor(a)" , |
2718 | "aten::_autocast_to_full_precision(Tensor(a) self, bool cuda_enabled, bool cpu_enabled) -> Tensor(a)" }; |
2719 | for (auto signature : AutocastOps) { |
2720 | auto ptr_op = getOperatorForLiteral(signature); |
2721 | REGISTER_PARSE_RULE( |
2722 | ptr_op, |
2723 | { |
2724 | MemoryFormat format; |
2725 | std::list<Val*> list_val; |
2726 | std::tie(format, list_val) = getConsistentValues( |
2727 | c10::nullopt, value_map[node->inputs()[0]->unique()]); |
2728 | auto self = list_val.front(); |
2729 | list_val.pop_front(); |
2730 | |
2731 | auto out = set(self); |
2732 | value_map.emplace( |
2733 | node->output()->unique(), ValueHolder(out, format)); |
2734 | }, |
2735 | isInputNonSizeZeroTensor, |
2736 | nullptr); |
2737 | } |
2738 | } |
2739 | |
2740 | { |
2741 | auto ptr_op = getOperatorForLiteral( |
2742 | "aten::_to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor" ); |
2743 | REGISTER_PARSE_RULE( |
2744 | ptr_op, |
2745 | { |
2746 | MemoryFormat format; |
2747 | std::list<Val*> list_val; |
2748 | std::tie(format, list_val) = getConsistentValues( |
2749 | c10::nullopt, value_map[node->inputs()[0]->unique()]); |
2750 | auto self = list_val.front(); |
2751 | list_val.pop_front(); |
2752 | |
2753 | auto out = castTensoToDtype(self, node->input(1)); |
2754 | |
2755 | value_map.emplace( |
2756 | node->output()->unique(), ValueHolder(out, format)); |
2757 | }, |
2758 | [](const Node* node) -> bool { |
2759 | if (!isInputNonSizeZeroTensor(node)) { |
2760 | return false; |
2761 | } |
2762 | if (node->inputs()[1]->node()->kind() != prim::Constant) { |
2763 | return false; |
2764 | } |
2765 | // we do not support explicit memory_format on output |
2766 | if (!node->inputs()[2]->type()->isSubtypeOf( |
2767 | static_cast<c10::TypePtr>(NoneType::get()))) { |
2768 | return false; |
2769 | } |
2770 | // we do not support explicit memory_format on output |
2771 | if (!node->inputs()[3]->type()->isSubtypeOf( |
2772 | static_cast<c10::TypePtr>(NoneType::get()))) { |
2773 | return false; |
2774 | } |
2775 | // we do not support explicit memory_format on output |
2776 | if (!node->inputs()[4]->type()->isSubtypeOf( |
2777 | static_cast<c10::TypePtr>(NoneType::get()))) { |
2778 | return false; |
2779 | } |
2780 | // we do not support explicit memory_format on output |
2781 | if (!node->inputs()[6]->type()->isSubtypeOf( |
2782 | static_cast<c10::TypePtr>(NoneType::get()))) { |
2783 | return false; |
2784 | } |
2785 | return true; |
2786 | }, |
2787 | nullptr); |
2788 | } |
2789 | |
2790 | // Limiting aten::to implementation to only change the dtype of a tensor |
2791 | { |
2792 | auto ptr_op = getOperatorForLiteral( |
2793 | "aten::to.dtype(Tensor self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor" ); |
2794 | REGISTER_PARSE_RULE( |
2795 | ptr_op, |
2796 | { |
2797 | MemoryFormat format; |
2798 | std::list<Val*> list_val; |
2799 | std::tie(format, list_val) = getConsistentValues( |
2800 | c10::nullopt, value_map[node->inputs()[0]->unique()]); |
2801 | auto self = list_val.front(); |
2802 | list_val.pop_front(); |
2803 | |
2804 | auto out = castTensoToDtype(self, node->input(1)); |
2805 | |
2806 | value_map.emplace( |
2807 | node->output()->unique(), ValueHolder(out, format)); |
2808 | }, |
2809 | [](const Node* node) -> bool { |
2810 | if (!isInputNonSizeZeroTensor(node)) { |
2811 | return false; |
2812 | } |
2813 | if (node->inputs()[1]->node()->kind() != prim::Constant) { |
2814 | return false; |
2815 | } |
2816 | // we do not support explicit memory_format on output |
2817 | if (!node->inputs()[4]->type()->isSubtypeOf( |
2818 | static_cast<c10::TypePtr>(NoneType::get()))) { |
2819 | return false; |
2820 | } |
2821 | return true; |
2822 | }, |
2823 | nullptr); |
2824 | } |
2825 | |
2826 | { |
2827 | auto ptr_op = getOperatorForLiteral( |
2828 | "aten::type_as(Tensor self, Tensor other) -> Tensor" ); |
2829 | REGISTER_PARSE_RULE( |
2830 | ptr_op, |
2831 | { |
2832 | MemoryFormat format; |
2833 | std::list<Val*> list_val; |
2834 | std::tie(format, list_val) = getConsistentValues( |
2835 | c10::nullopt, value_map[node->inputs()[0]->unique()]); |
2836 | auto self = list_val.front(); |
2837 | list_val.pop_front(); |
2838 | |
2839 | // TODO: switch to PyTorch dtype as it's closer to truth. |
2840 | // For now, reality is that PyTorch IR profiling information could |
2841 | // be missing even with profiling executor, due to upstream |
2842 | // transformations between profiling runs to fusion pass. |
2843 | auto opt_dtype = |
2844 | value_map[node->inputs()[1]->unique()]->getDataType(); |
2845 | TORCH_INTERNAL_ASSERT(opt_dtype.has_value()); |
2846 | |
2847 | auto out = castOp(opt_dtype.value(), self); |
2848 | value_map.emplace( |
2849 | node->output()->unique(), ValueHolder(out, format)); |
2850 | }, |
2851 | isInputNonSizeZeroTensor, |
2852 | nullptr); |
2853 | } |
2854 | |
2855 | { |
2856 | // We are not fusing `linear` yet, because we can't codegen efficient gemm |
2857 | // However, we still need this here, so PE would insert profile node for |
2858 | // this node. |
2859 | // During fusion pass, We decompose linear into gemm + elementwise. |
2860 | auto ptr_op = getOperatorForLiteral( |
2861 | "aten::linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor" ); |
2862 | REGISTER_PARSE_RULE( |
2863 | ptr_op, |
2864 | { |
2865 | // this entry is created so we do profile input tensors; |
2866 | TORCH_INTERNAL_ASSERT(false, "not implemented yet" ); |
2867 | }, |
2868 | [](const Node* node) -> bool { |
2869 | // We only profile `linear` layer but not fusing it. |
2870 | return false; |
2871 | }); |
2872 | } |
2873 | |
2874 | { |
2875 | auto ptr_op = getOperatorForLiteral( |
2876 | "prim::add_optional(Tensor(a) input, Tensor? bias) -> Tensor(a)" ); |
2877 | REGISTER_PARSE_RULE( |
2878 | ptr_op, |
2879 | { |
2880 | // this entry is created so we do profile input tensors; |
2881 | if (node->input(1)->type()->isSubtypeOf( |
2882 | static_cast<c10::TypePtr>(NoneType::get()))) { |
2883 | // forwarding the value; |
2884 | value_map.emplace( |
2885 | node->output()->unique(), |
2886 | value_map[node->inputs()[0]->unique()]); |
2887 | } else { |
2888 | MemoryFormat format; |
2889 | std::list<Val*> list_val; |
2890 | std::tie(format, list_val) = getPWFormatValues( |
2891 | c10::nullopt, |
2892 | value_map[node->inputs()[0]->unique()], |
2893 | value_map[node->inputs()[1]->unique()]); |
2894 | auto lhs = list_val.front(); |
2895 | list_val.pop_front(); |
2896 | auto rhs = list_val.front(); |
2897 | list_val.pop_front(); |
2898 | |
2899 | auto out = binaryOp( |
2900 | BinaryOpType::Add, |
2901 | lhs, |
2902 | rhs, |
2903 | TypePromotion::default_op_config); |
2904 | value_map.emplace( |
2905 | node->output()->unique(), ValueHolder(out, format)); |
2906 | } |
2907 | }, |
2908 | isInputNonSizeZeroTensor, |
2909 | nullptr); |
2910 | } |
2911 | |
2912 | { |
2913 | auto ptr_op = getOperatorForLiteral( |
2914 | "aten::leaky_relu(Tensor self, Scalar negative_slope=0.01) -> Tensor" ); |
2915 | REGISTER_PARSE_RULE( |
2916 | ptr_op, |
2917 | { |
2918 | MemoryFormat format; |
2919 | std::list<Val*> list_val; |
2920 | std::tie(format, list_val) = getConsistentValues( |
2921 | c10::nullopt, value_map[node->inputs()[0]->unique()]); |
2922 | auto self = list_val.front()->as<TensorView>(); |
2923 | list_val.pop_front(); |
2924 | |
2925 | Val* negative_slope = value_map[node->inputs()[1]->unique()]; |
2926 | |
2927 | auto out = leaky_relu(self, negative_slope); |
2928 | value_map.emplace( |
2929 | node->output()->unique(), ValueHolder(out, format)); |
2930 | }, |
2931 | [](const Node* node) -> bool { |
2932 | if (!isInputNonSizeZeroTensor(node)) { |
2933 | return false; |
2934 | } |
2935 | return true; |
2936 | }, |
2937 | nullptr); |
2938 | } |
2939 | |
2940 | { |
2941 | auto ptr_op = getOperatorForLiteral( |
2942 | "aten::gelu(Tensor self, *, str approximate='none') -> Tensor" ); |
2943 | REGISTER_PARSE_RULE( |
2944 | ptr_op, |
2945 | { |
2946 | MemoryFormat format; |
2947 | std::list<Val*> list_val; |
2948 | std::tie(format, list_val) = getConsistentValues( |
2949 | c10::nullopt, value_map[node->inputs()[0]->unique()]); |
2950 | auto self = list_val.front()->as<TensorView>(); |
2951 | list_val.pop_front(); |
2952 | |
2953 | auto approximate = constant_as<std::string>(node->input(1)); |
2954 | TORCH_INTERNAL_ASSERT( |
2955 | approximate.has_value(), |
2956 | "The approximate parameter is required." ); |
2957 | const auto kTanhGelu = |
2958 | at::native::get_gelutype_enum(approximate.value()) == |
2959 | at::native::GeluType::Tanh; |
2960 | |
2961 | auto out = (kTanhGelu) ? tanh_gelu(self) : gelu(self); |
2962 | value_map.emplace( |
2963 | node->output()->unique(), ValueHolder(out, format)); |
2964 | }, |
2965 | [](const Node* node) -> bool { |
2966 | if (!isInputNonSizeZeroTensor(node)) { |
2967 | return false; |
2968 | } |
2969 | if (node->input(1)->node()->kind() != prim::Constant) { |
2970 | return false; |
2971 | } |
2972 | return true; |
2973 | }, |
2974 | nullptr); |
2975 | } |
2976 | |
2977 | { |
2978 | auto ptr_op = getOperatorForLiteral( |
2979 | "aten::gelu_backward(Tensor grad_output, Tensor self, *, str approximate='none') -> Tensor" ); |
2980 | REGISTER_PARSE_RULE( |
2981 | ptr_op, |
2982 | { |
2983 | MemoryFormat format; |
2984 | std::list<Val*> list_val; |
2985 | std::tie(format, list_val) = getPWFormatValues( |
2986 | c10::nullopt, |
2987 | value_map[node->inputs()[0]->unique()], |
2988 | value_map[node->inputs()[1]->unique()]); |
2989 | auto grad_out = list_val.front()->as<TensorView>(); |
2990 | list_val.pop_front(); |
2991 | auto self = list_val.front()->as<TensorView>(); |
2992 | list_val.pop_front(); |
2993 | |
2994 | auto approximate = constant_as<std::string>(node->input(2)); |
2995 | TORCH_INTERNAL_ASSERT( |
2996 | approximate.has_value(), |
2997 | "The approximate parameter is required." ); |
2998 | const auto kTanhGelu = |
2999 | at::native::get_gelutype_enum(approximate.value()) == |
3000 | at::native::GeluType::Tanh; |
3001 | |
3002 | auto grad_in = (kTanhGelu) ? tanh_gelu_backward(grad_out, self) |
3003 | : gelu_backward(grad_out, self); |
3004 | value_map.emplace( |
3005 | node->output()->unique(), ValueHolder(grad_in, format)); |
3006 | }, |
3007 | [](const Node* node) -> bool { |
3008 | if (!isInputNonSizeZeroTensor(node)) { |
3009 | return false; |
3010 | } |
3011 | if (node->input(2)->node()->kind() != prim::Constant) { |
3012 | return false; |
3013 | } |
3014 | return true; |
3015 | }, |
3016 | nullptr); |
3017 | } |
3018 | |
3019 | { |
3020 | auto ptr_op = getOperatorForLiteral( |
3021 | "aten::tanh_backward(Tensor grad_output, Tensor output) -> Tensor" ); |
3022 | REGISTER_PARSE_RULE( |
3023 | ptr_op, |
3024 | { |
3025 | MemoryFormat format; |
3026 | std::list<Val*> list_val; |
3027 | std::tie(format, list_val) = getPWFormatValues( |
3028 | c10::nullopt, |
3029 | value_map[node->inputs()[0]->unique()], |
3030 | value_map[node->inputs()[1]->unique()]); |
3031 | auto grad_out = list_val.front()->as<TensorView>(); |
3032 | list_val.pop_front(); |
3033 | auto self = list_val.front()->as<TensorView>(); |
3034 | list_val.pop_front(); |
3035 | |
3036 | auto grad_in = tanh_backward(grad_out, self); |
3037 | value_map.emplace( |
3038 | node->output()->unique(), ValueHolder(grad_in, format)); |
3039 | }, |
3040 | isInputNonSizeZeroTensor, |
3041 | nullptr); |
3042 | } |
3043 | |
3044 | { |
3045 | std::array<const char*, kNumAminAmaxOps> BinaryFloatOp = { |
3046 | "aten::amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor" , |
3047 | "aten::amin(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor" }; |
3048 | for (auto signature : BinaryFloatOp) { |
3049 | auto ptr_op = getOperatorForLiteral(signature); |
3050 | REGISTER_PARSE_RULE( |
3051 | ptr_op, |
3052 | { |
3053 | MemoryFormat format; |
3054 | std::list<Val*> list_val; |
3055 | std::tie(format, list_val) = getConsistentValues( |
3056 | MemoryFormat::Contiguous(), |
3057 | value_map[node->inputs()[0]->unique()]); |
3058 | auto self = list_val.front(); |
3059 | list_val.pop_front(); |
3060 | auto dims_list = constant_as<c10::List<int64_t>>(node->input(1)); |
3061 | TORCH_INTERNAL_ASSERT( |
3062 | dims_list.has_value(), |
3063 | "aten::amax/amin cannot be fused with dynamic axes" ); |
3064 | std::vector<int> dims; |
3065 | if (!dims_list->empty()) { |
3066 | for (const auto dim : dims_list->vec()) { |
3067 | dims.emplace_back(static_cast<int>(dim)); |
3068 | } |
3069 | } else { |
3070 | dims.resize(self->as<TensorView>()->nDims()); |
3071 | std::iota(dims.begin(), dims.end(), 0); |
3072 | } |
3073 | auto keepdim = constant_as<bool>(node->input(2)); |
3074 | TORCH_INTERNAL_ASSERT( |
3075 | keepdim.has_value(), |
3076 | "aten::amax/amin cannot be fused with dynamic keepdim" ); |
3077 | |
3078 | TensorView* out = nullptr; |
3079 | if (node->kind() == c10::Symbol::fromQualString("aten::amax" )) { |
3080 | out = max(self->as<TensorView>(), dims, keepdim.value()); |
3081 | } else if ( |
3082 | node->kind() == c10::Symbol::fromQualString("aten::amin" )) { |
3083 | out = min(self->as<TensorView>(), dims, keepdim.value()); |
3084 | } else { |
3085 | TORCH_INTERNAL_ASSERT( |
3086 | false, "unrecognized operation in aten::amax/amin" ); |
3087 | } |
3088 | value_map.emplace(node->output()->unique(), out); |
3089 | }, |
3090 | [](const Node* node) -> bool { |
3091 | if (isReductionNonCompatibleTensor( |
3092 | node->input(0)->type()->cast<TensorType>())) { |
3093 | return false; |
3094 | } |
3095 | // we don't support dynamic reduction axes; |
3096 | if (node->inputs()[1]->node()->kind() != prim::Constant) { |
3097 | return false; |
3098 | } |
3099 | // we don't support dynamic keepdim yet; |
3100 | if (node->inputs()[2]->node()->kind() != prim::Constant) { |
3101 | return false; |
3102 | } |
3103 | return true; |
3104 | }, |
3105 | [](const Node* node) -> OperatorType { |
3106 | return OperatorType::Reduction; |
3107 | }); |
3108 | } |
3109 | } |
3110 | |
3111 | { |
3112 | std::array<const char*, kNumViewOps> ViewOps = { |
3113 | "prim::reshape_copy(Tensor self, int[] shape) -> Tensor" , |
3114 | "prim::view_copy(Tensor self, int[] size) -> Tensor" }; |
3115 | for (auto signature : ViewOps) { |
3116 | auto ptr_op = getOperatorForLiteral(signature); |
3117 | REGISTER_PARSE_RULE( |
3118 | ptr_op, |
3119 | { |
3120 | auto self_value = node->inputs()[0]; |
3121 | MemoryFormat format; |
3122 | std::list<Val*> list_val; |
3123 | std::tie(format, list_val) = getConsistentValues( |
3124 | MemoryFormat::Contiguous(), value_map[self_value->unique()]); |
3125 | auto self = list_val.front()->as<TensorView>(); |
3126 | list_val.pop_front(); |
3127 | |
3128 | auto self_type = self_value->type()->cast<c10::TensorType>(); |
3129 | TORCH_INTERNAL_ASSERT(self_type != nullptr); |
3130 | auto self_sizes = getTensorSizes(self_type); |
3131 | |
3132 | auto view_sizes = constant_as<c10::List<int64_t>>(node->input(1)); |
3133 | TORCH_INTERNAL_ASSERT( |
3134 | view_sizes.has_value(), "The size parameter is required." ); |
3135 | |
3136 | auto output = view(self, self_sizes, view_sizes->vec()); |
3137 | value_map.emplace(node->output()->unique(), output); |
3138 | }, |
3139 | [](const Node* node) -> bool { |
3140 | auto self_value = node->inputs()[0]; |
3141 | auto tensor_type = self_value->type()->cast<c10::TensorType>(); |
3142 | if (tensor_type == nullptr) { |
3143 | return false; |
3144 | } |
3145 | if (!tensor_type->sizes().concrete_sizes().has_value()) { |
3146 | // Shape information for input tensor is required. |
3147 | return false; |
3148 | } |
3149 | |
3150 | if (!isInputNonSizeZeroTensor(node)) { |
3151 | return false; |
3152 | } |
3153 | // Reject fusing node if view_sizes contains an inferred dimension |
3154 | auto view_sizes = constant_as<c10::List<int64_t>>(node->input(1)); |
3155 | if (!view_sizes.has_value()) { |
3156 | // The size parameter is required. |
3157 | return false; |
3158 | } |
3159 | |
3160 | for (auto axis_size : view_sizes->vec()) { |
3161 | if (axis_size == -1) { |
3162 | return false; |
3163 | } |
3164 | } |
3165 | return true; |
3166 | }, |
3167 | nullptr); |
3168 | } |
3169 | } |
3170 | |
3171 | { |
3172 | auto flatten_op = getOperatorForLiteral( |
3173 | "prim::flatten_copy(Tensor self, int start_dim, int end_dim) -> Tensor" ); |
3174 | REGISTER_PARSE_RULE( |
3175 | flatten_op, |
3176 | { |
3177 | auto self_value = node->inputs()[0]; |
3178 | MemoryFormat format; |
3179 | std::list<Val*> list_val; |
3180 | std::tie(format, list_val) = getConsistentValues( |
3181 | MemoryFormat::Contiguous(), value_map[self_value->unique()]); |
3182 | auto self = list_val.front()->as<TensorView>(); |
3183 | list_val.pop_front(); |
3184 | |
3185 | auto start_dim_value = constant_as<int>(node->input(1)); |
3186 | TORCH_INTERNAL_ASSERT( |
3187 | start_dim_value.has_value(), "start_dim is not valid" ); |
3188 | auto end_dim_value = constant_as<int>(node->input(2)); |
3189 | TORCH_INTERNAL_ASSERT( |
3190 | end_dim_value.has_value(), "end_dim is not valid" ); |
3191 | |
3192 | TensorView* output = |
3193 | flatten(self, start_dim_value.value(), end_dim_value.value()); |
3194 | value_map.emplace(node->output()->unique(), output); |
3195 | }, |
3196 | [](const Node* node) -> bool { |
3197 | // we don't support dynamic start_dim; |
3198 | if (node->inputs()[1]->node()->kind() != prim::Constant) { |
3199 | return false; |
3200 | } |
3201 | // we don't support dynamic end_dim yet; |
3202 | if (node->inputs()[2]->node()->kind() != prim::Constant) { |
3203 | return false; |
3204 | } |
3205 | return true; |
3206 | }, |
3207 | nullptr); |
3208 | } |
3209 | |
3210 | { |
3211 | auto ptr_op = |
3212 | getOperatorForLiteral("prim::squeeze_copy(Tensor self) -> Tensor" ); |
3213 | REGISTER_PARSE_RULE( |
3214 | ptr_op, |
3215 | { |
3216 | auto self_value = node->inputs()[0]; |
3217 | MemoryFormat format; |
3218 | std::list<Val*> list_val; |
3219 | std::tie(format, list_val) = getConsistentValues( |
3220 | MemoryFormat::Contiguous(), value_map[self_value->unique()]); |
3221 | auto self = list_val.front()->as<TensorView>(); |
3222 | list_val.pop_front(); |
3223 | |
3224 | auto self_type = self_value->type()->cast<c10::TensorType>(); |
3225 | TORCH_INTERNAL_ASSERT(self_type != nullptr); |
3226 | auto self_sizes = getTensorSizes(self_type); |
3227 | |
3228 | TensorView* output = nullptr; |
3229 | if (self_sizes.empty()) { |
3230 | // squeeze on scalar tensor should just return itself; |
3231 | output = set(self); |
3232 | } else { |
3233 | output = squeeze(self, self_sizes); |
3234 | } |
3235 | value_map.emplace(node->output()->unique(), output); |
3236 | }, |
3237 | [](const Node* node) -> bool { |
3238 | // Shape information for input tensor is required. |
3239 | auto self_value = node->inputs()[0]; |
3240 | auto tensor_type = self_value->type()->cast<c10::TensorType>(); |
3241 | if (tensor_type == nullptr) { |
3242 | return false; |
3243 | } |
3244 | if (!isInputNonSizeZeroTensor(node)) { |
3245 | return false; |
3246 | } |
3247 | return tensor_type->sizes().concrete_sizes().has_value(); |
3248 | }, |
3249 | nullptr); |
3250 | } |
3251 | |
3252 | { |
3253 | std::array<const char*, kNumAliasDimOps> AliasOpWithDim = { |
3254 | "prim::squeeze_copy.dim(Tensor self, int dim) -> Tensor" , |
3255 | "prim::unsqueeze_copy(Tensor self, int dim) -> Tensor" }; |
3256 | for (auto signature : AliasOpWithDim) { |
3257 | auto ptr_op = getOperatorForLiteral(signature); |
3258 | REGISTER_PARSE_RULE( |
3259 | ptr_op, |
3260 | { |
3261 | auto self_value = node->inputs()[0]; |
3262 | MemoryFormat format; |
3263 | std::list<Val*> list_val; |
3264 | std::tie(format, list_val) = getConsistentValues( |
3265 | MemoryFormat::Contiguous(), |
3266 | value_map[node->inputs()[0]->unique()]); |
3267 | auto self = list_val.front()->as<TensorView>(); |
3268 | list_val.pop_front(); |
3269 | |
3270 | auto dim_value = constant_as<int>(node->input(1)); |
3271 | TORCH_INTERNAL_ASSERT(dim_value.has_value(), "dim is not valid" ); |
3272 | |
3273 | TensorView* output = nullptr; |
3274 | if (node->kind() == prim::unsqueeze_copy) { |
3275 | output = unsqueeze(self, dim_value.value()); |
3276 | } else { |
3277 | auto self_type = self_value->type()->cast<c10::TensorType>(); |
3278 | TORCH_INTERNAL_ASSERT(self_type != nullptr); |
3279 | auto self_sizes = getTensorSizes(self_type); |
3280 | if (self_sizes.empty()) { |
3281 | // squeeze on scalar tensor should just return itself; |
3282 | output = set(self); |
3283 | } else { |
3284 | output = squeeze(self, self_sizes, dim_value.value()); |
3285 | } |
3286 | } |
3287 | value_map.emplace(node->output()->unique(), output); |
3288 | }, |
3289 | [](const Node* node) -> bool { |
3290 | // Shape information for input tensor is required. |
3291 | auto self_value = node->inputs()[0]; |
3292 | auto tensor_type = self_value->type()->cast<c10::TensorType>(); |
3293 | if (tensor_type == nullptr) { |
3294 | return false; |
3295 | } |
3296 | if (!isInputNonSizeZeroTensor(node)) { |
3297 | return false; |
3298 | } |
3299 | if (node->input(1)->node()->kind() != prim::Constant) { |
3300 | return false; |
3301 | } |
3302 | auto optional_sizes = tensor_type->sizes().concrete_sizes(); |
3303 | return tensor_type->sizes().concrete_sizes().has_value(); |
3304 | }, |
3305 | nullptr); |
3306 | } |
3307 | } |
3308 | |
3309 | { |
3310 | auto ptr_op = getOperatorForLiteral( |
3311 | "prim::expand_as_copy(Tensor self, Tensor other) -> Tensor" ); |
3312 | REGISTER_PARSE_RULE( |
3313 | ptr_op, |
3314 | { |
3315 | MemoryFormat format; |
3316 | std::list<Val*> list_val; |
3317 | std::tie(format, list_val) = getPWFormatValues( |
3318 | c10::nullopt, |
3319 | value_map[node->inputs()[0]->unique()], |
3320 | value_map[node->inputs()[1]->unique()]); |
3321 | auto self = list_val.front()->as<TensorView>(); |
3322 | list_val.pop_front(); |
3323 | auto other = list_val.front()->as<TensorView>(); |
3324 | list_val.pop_front(); |
3325 | |
3326 | auto output = expand_as(self, other); |
3327 | value_map.emplace( |
3328 | node->output()->unique(), ValueHolder(output, format)); |
3329 | }, |
3330 | [](const Node* node) -> bool { |
3331 | if (!isInputNonSizeZeroTensor(node)) { |
3332 | return false; |
3333 | } |
3334 | |
3335 | return true; |
3336 | }, |
3337 | nullptr); |
3338 | } |
3339 | |
3340 | { |
3341 | auto ptr_op = getOperatorForLiteral( |
3342 | "prim::expand_copy(Tensor self, int[] size, *, bool implicit=False) -> Tensor" ); |
3343 | REGISTER_PARSE_RULE( |
3344 | ptr_op, |
3345 | { |
3346 | auto self_value = node->inputs()[0]; |
3347 | MemoryFormat format; |
3348 | std::list<Val*> list_val; |
3349 | std::tie(format, list_val) = getConsistentValues( |
3350 | MemoryFormat::Contiguous(), value_map[self_value->unique()]); |
3351 | auto self = list_val.front()->as<TensorView>(); |
3352 | list_val.pop_front(); |
3353 | |
3354 | auto expand_sizes = constant_as<c10::List<int64_t>>(node->input(1)); |
3355 | TORCH_INTERNAL_ASSERT( |
3356 | expand_sizes.has_value(), "The size parameter is required." ); |
3357 | |
3358 | std::vector<CgValue> expand_sizes_vec; |
3359 | for (const int64_t& size : expand_sizes.value()) { |
3360 | expand_sizes_vec.push_back(IrBuilder::create<Int>(size)); |
3361 | } |
3362 | |
3363 | // TODO: we should be able to support dynamic expand values |
3364 | auto output = expand(self, expand_sizes_vec); |
3365 | value_map.emplace(node->output()->unique(), output); |
3366 | }, |
3367 | [](const Node* node) -> bool { |
3368 | if (!isInputNonSizeZeroTensor(node)) { |
3369 | return false; |
3370 | } |
3371 | // expand_sizes needs to be constant |
3372 | auto expand_sizes = constant_as<c10::List<int64_t>>(node->input(1)); |
3373 | if (!expand_sizes.has_value()) { |
3374 | return false; |
3375 | } |
3376 | |
3377 | return true; |
3378 | }, |
3379 | nullptr); |
3380 | } |
3381 | |
3382 | { |
3383 | auto ptr_op = getOperatorForLiteral( |
3384 | "prim::permute_copy.int(Tensor(a) self, int[] dims) -> Tensor" ); |
3385 | REGISTER_PARSE_RULE( |
3386 | ptr_op, |
3387 | { |
3388 | MemoryFormat format; |
3389 | std::list<Val*> list_val; |
3390 | std::tie(format, list_val) = getConsistentValues( |
3391 | c10::nullopt, value_map[node->inputs()[0]->unique()]); |
3392 | auto self_t = list_val.front(); |
3393 | list_val.pop_front(); |
3394 | auto self = self_t->as<TensorView>(); |
3395 | |
3396 | auto dims = constant_as<c10::List<int64_t>>(node->input(1)); |
3397 | TORCH_INTERNAL_ASSERT( |
3398 | dims.has_value(), "The dims parameter is required." ); |
3399 | TORCH_INTERNAL_ASSERT( |
3400 | dims.value().size() == self->getMaybeRFactorDomain().size()); |
3401 | |
3402 | auto output = permute(self, dims->vec()); |
3403 | value_map.emplace( |
3404 | node->output()->unique(), ValueHolder(output, format)); |
3405 | }, |
3406 | [](const Node* node) -> bool { |
3407 | if (!isInputNonSizeZeroTensor(node)) { |
3408 | return false; |
3409 | } |
3410 | auto dims = constant_as<c10::List<int64_t>>(node->input(1)); |
3411 | if (!dims.has_value()) { |
3412 | return false; |
3413 | } |
3414 | |
3415 | return true; |
3416 | }, |
3417 | nullptr); |
3418 | } |
3419 | |
3420 | { |
3421 | auto ptr_op = getOperatorForLiteral( |
3422 | "prim::transpose_copy.int(Tensor(a) self, int dim0, int dim1) -> Tensor" ); |
3423 | REGISTER_PARSE_RULE( |
3424 | ptr_op, |
3425 | { |
3426 | MemoryFormat format; |
3427 | std::list<Val*> list_val; |
3428 | std::tie(format, list_val) = getConsistentValues( |
3429 | c10::nullopt, value_map[node->inputs()[0]->unique()]); |
3430 | auto self_t = list_val.front(); |
3431 | list_val.pop_front(); |
3432 | auto self = self_t->as<TensorView>(); |
3433 | |
3434 | auto dim0 = constant_as<int>(node->input(1)); |
3435 | TORCH_INTERNAL_ASSERT( |
3436 | dim0.has_value(), "dim0 in transpose is not valid." ); |
3437 | |
3438 | auto dim1 = constant_as<int>(node->input(2)); |
3439 | TORCH_INTERNAL_ASSERT( |
3440 | dim1.has_value(), "dim1 in transpose is not valid." ); |
3441 | |
3442 | auto output = transpose(self, dim0.value(), dim1.value()); |
3443 | value_map.emplace( |
3444 | node->output()->unique(), ValueHolder(output, format)); |
3445 | }, |
3446 | [](const Node* node) -> bool { |
3447 | if (!isInputNonSizeZeroTensor(node)) { |
3448 | return false; |
3449 | } |
3450 | if (node->input(1)->node()->kind() != prim::Constant) { |
3451 | return false; |
3452 | } |
3453 | if (node->input(2)->node()->kind() != prim::Constant) { |
3454 | return false; |
3455 | } |
3456 | return true; |
3457 | }, |
3458 | nullptr); |
3459 | } |
3460 | |
3461 | { |
3462 | auto ptr_op = |
3463 | getOperatorForLiteral("prim::t_copy(Tensor(a) self) -> Tensor" ); |
3464 | REGISTER_PARSE_RULE( |
3465 | ptr_op, |
3466 | { |
3467 | MemoryFormat format; |
3468 | std::list<Val*> list_val; |
3469 | std::tie(format, list_val) = getConsistentValues( |
3470 | c10::nullopt, value_map[node->inputs()[0]->unique()]); |
3471 | auto self_t = list_val.front(); |
3472 | list_val.pop_front(); |
3473 | auto self = self_t->as<TensorView>(); |
3474 | |
3475 | TORCH_INTERNAL_ASSERT(self->getMaybeRFactorDomain().size() <= 2); |
3476 | |
3477 | auto output = transpose(self); |
3478 | value_map.emplace( |
3479 | node->output()->unique(), ValueHolder(output, format)); |
3480 | }, |
3481 | [](const Node* node) -> bool { |
3482 | if (!isInputNonSizeZeroTensor(node)) { |
3483 | return false; |
3484 | } |
3485 | |
3486 | return true; |
3487 | }, |
3488 | nullptr); |
3489 | } |
3490 | } |
3491 | |
3492 | void processJitNode(const JitOp* node) { |
3493 | if (node->kind() == prim::Constant) { |
3494 | // partition doesn't take constant node explicitly, but it does and copy |
3495 | // constant into subgraph. So we need to register constants in codegen IR; |
3496 | for (auto output : node->outputs()) { |
3497 | TORCH_INTERNAL_ASSERT( |
3498 | registerScalar(output), |
3499 | "registration of output failed at index " , |
3500 | output->offset(), |
3501 | " for node " , |
3502 | *node); |
3503 | } |
3504 | } else { |
3505 | auto reg_entry = lookupInRegistry(node); |
3506 | TORCH_INTERNAL_ASSERT( |
3507 | reg_entry != nullptr, |
3508 | "CudaFusionGroup Parser doesn't handle node: " , |
3509 | canonicalSchemaString(node->schema())); |
3510 | reg_entry->parse(node, value_map_); |
3511 | } |
3512 | } |
3513 | |
3514 | bool registerValue(const JitValue* val) { |
3515 | return registerInputTensor(val) || registerScalar(val); |
3516 | } |
3517 | |
3518 | bool registerScalar(const JitValue* val) { |
3519 | if (val->type()->isSubtypeOf( |
3520 | static_cast<c10::TypePtr>(ComplexType::get()))) { |
3521 | CgValue cg_val = nullptr; |
3522 | if (auto ival = constant_as<c10::complex<double>>(val)) { |
3523 | cg_val = IrBuilder::create<ComplexDouble>(ival.value()); |
3524 | } else { |
3525 | cg_val = IrBuilder::create<ComplexDouble>(); |
3526 | } |
3527 | value_map_.emplace(val->unique(), cg_val); |
3528 | return true; |
3529 | } else if (val->type()->isSubtypeOf( |
3530 | static_cast<c10::TypePtr>(FloatType::get()))) { |
3531 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
3532 | CgValue cg_val; |
3533 | if (auto ival = constant_as<double>(val)) { |
3534 | cg_val = IrBuilder::create<Double>(ival.value()); |
3535 | } else { |
3536 | cg_val = IrBuilder::create<Double>(); |
3537 | } |
3538 | value_map_.emplace(val->unique(), cg_val); |
3539 | return true; |
3540 | } else if (val->type()->isSubtypeOf( |
3541 | static_cast<c10::TypePtr>(IntType::get()))) { |
3542 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
3543 | CgValue cg_val; |
3544 | if (auto ival = constant_as<int64_t>(val)) { |
3545 | cg_val = IrBuilder::create<Int>(ival.value()); |
3546 | } else { |
3547 | cg_val = IrBuilder::create<Int>(); |
3548 | } |
3549 | value_map_.emplace(val->unique(), cg_val); |
3550 | return true; |
3551 | } else if (val->type()->isSubtypeOf( |
3552 | static_cast<c10::TypePtr>(BoolType::get()))) { |
3553 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
3554 | CgValue cg_val; |
3555 | if (auto ival = constant_as<bool>(val)) { |
3556 | cg_val = IrBuilder::create<Bool>(ival.value()); |
3557 | } else { |
3558 | cg_val = IrBuilder::create<Bool>(); |
3559 | } |
3560 | value_map_.emplace(val->unique(), cg_val); |
3561 | return true; |
3562 | } else if ( |
3563 | val->type()->isSubtypeOf( |
3564 | static_cast<c10::TypePtr>(StringType::get())) || |
3565 | val->type()->isSubtypeOf( |
3566 | static_cast<c10::TypePtr>(DeviceObjType::get())) || |
3567 | val->type()->isSubtypeOf(static_cast<c10::TypePtr>(NoneType::get()))) { |
3568 | // TODO: should we consider adding support for NoneType; |
3569 | // Note: String/Device scalars are only used in parsing rules, do not |
3570 | // register string with codegen IR. |
3571 | return true; |
3572 | } else if (val->type()->cast<ListType>()) { |
3573 | // TODO: we don't support list type in codegen yet; |
3574 | // This is a WAR to allow axes of reduction to be passed as constant list; |
3575 | // We simply ignore conversion if the scalar value is a constant; |
3576 | auto ivalue = toIValue(val); |
3577 | TORCH_INTERNAL_ASSERT( |
3578 | ivalue.has_value(), |
3579 | "List[T] is not supported as an argument by NvFuser. Use a Constant List." ); |
3580 | return true; |
3581 | } |
3582 | return false; |
3583 | } |
3584 | |
3585 | bool registerInputTensor(const JitValue* val) { |
3586 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
3587 | CgValue cg_val; |
3588 | // Don't register if we don't support the type |
3589 | if (auto tensor_type = val->type()->cast<c10::TensorType>()) { |
3590 | if (!tensor_type->scalarType().has_value()) { |
3591 | return false; |
3592 | } |
3593 | |
3594 | if (aten_to_data_type(tensor_type->scalarType().value()) == |
3595 | DataType::Null) { |
3596 | return false; |
3597 | } |
3598 | |
3599 | // check for NHWC contiguous tensor |
3600 | TORCH_CHECK(tensor_type->dim().has_value(), "rank missing" ); |
3601 | const auto n_dim = tensor_type->dim().value(); |
3602 | |
3603 | MemoryFormat format; |
3604 | std::vector<int> stride_index; |
3605 | for (const auto i : c10::irange(n_dim)) { |
3606 | const auto& stride_property_i = tensor_type->stride_properties()[i]; |
3607 | if (stride_property_i->stride_index_.has_value()) { |
3608 | stride_index.emplace_back(stride_property_i->stride_index_.value()); |
3609 | } |
3610 | } |
3611 | |
3612 | // only set permutation when all stride_index are available |
3613 | if (stride_index.size() == n_dim) { |
3614 | format.setPermutation(stride_index); |
3615 | } |
3616 | |
3617 | // construct permuted tensor_type |
3618 | if (format.hasPermutation()) { |
3619 | auto opt_s_vec = tensor_type->symbolic_sizes().sizes(); |
3620 | TORCH_CHECK(opt_s_vec.has_value(), "missing rank of symbolic sizes" ); |
3621 | std::vector<c10::ShapeSymbol> s_vec = opt_s_vec.value(); |
3622 | // apply permutation |
3623 | auto permutation = format.apply(); |
3624 | for (auto new_axis : c10::irange(permutation.size())) { |
3625 | auto old_axis = permutation.at(new_axis); |
3626 | s_vec[new_axis] = opt_s_vec.value()[old_axis]; |
3627 | } |
3628 | |
3629 | // copying stride properties because we need to permute it |
3630 | auto opt_stride_vec = tensor_type->stride_properties().sizes(); |
3631 | TORCH_CHECK(opt_stride_vec.has_value(), "missing stride properties" ); |
3632 | auto nhwc_stride_vec = opt_stride_vec.value(); |
3633 | // Make tensor contiguous after permutation. |
3634 | // Note that we are only updating stride_properties.stride_index, since |
3635 | // contiguous_ and stride_ value should remain the same after |
3636 | // permutation |
3637 | for (const auto i : c10::irange(n_dim)) { |
3638 | nhwc_stride_vec[i]->stride_index_ = n_dim - i - 1; |
3639 | } |
3640 | |
3641 | tensor_type = c10::TensorType::create( |
3642 | tensor_type->scalarType(), |
3643 | tensor_type->device(), |
3644 | s_vec, |
3645 | nhwc_stride_vec, |
3646 | tensor_type->requires_grad(), |
3647 | tensor_type->undefined()); |
3648 | } |
3649 | |
3650 | cg_val = IrBuilder::create<TensorView>(tensor_type); |
3651 | if (is_cpu_scalar(*tensor_type)) { |
3652 | cg_val->as<TensorView>()->setCpuScalar(true); |
3653 | } |
3654 | value_map_.emplace(val->unique(), ValueHolder(cg_val, format)); |
3655 | return true; |
3656 | } |
3657 | return false; |
3658 | } |
3659 | |
3660 | std::shared_ptr<Graph> graph_; |
3661 | |
3662 | // maps from JitValue::unique() to fusion Val; |
3663 | std::unordered_map<size_t, ValueHolder> value_map_; |
3664 | |
3665 | static std::unordered_set<Symbol> parser_symbol_set_; |
3666 | static std::unordered_set<Symbol> parser_skip_set_; |
3667 | static std::mutex parser_mutex_; |
3668 | |
3669 | // parsing rule registry. |
3670 | static std::unordered_map<std::string, RegistrationEntry> |
3671 | jit_operator_registry_; // NOLINT |
3672 | |
3673 | // pointing cached entry stored in `jit_operator_registry_` |
3674 | static std::unordered_map<const FunctionSchema*, const RegistrationEntry*> |
3675 | cached_registry_lookup_; // NOLINT |
3676 | |
3677 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) |
3678 | static c10::once_flag once_flag_; |
3679 | }; |
3680 | std::unordered_set<Symbol> IrParser::parser_symbol_set_; // NOLINT |
3681 | std::unordered_set<Symbol> IrParser::parser_skip_set_; // NOLINT |
3682 | std::mutex IrParser::parser_mutex_; |
3683 | std::unordered_map<std::string, IrParser::RegistrationEntry> |
3684 | IrParser::jit_operator_registry_; // NOLINT |
3685 | std::unordered_map<const FunctionSchema*, const IrParser::RegistrationEntry*> |
3686 | IrParser::cached_registry_lookup_; // NOLINT |
3687 | |
3688 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) |
3689 | c10::once_flag IrParser::once_flag_; |
3690 | |
3691 | ProfileIValueOp* insertProfileIValueOp( |
3692 | Node* node, |
3693 | size_t offset, |
3694 | ProfilingRecord* pr) { |
3695 | auto in_val = node->input(offset); |
3696 | auto pn = pr->createProfileIValueNode(in_val); |
3697 | pn->insertBefore(node); |
3698 | node->replaceInput(offset, pn->output()); |
3699 | return pn; |
3700 | } |
3701 | |
3702 | void profileReductionSize(ProfilingRecord* pr, Node* node, size_t offset) { |
3703 | auto pn = insertProfileIValueOp(node, offset, pr); |
3704 | |
3705 | const auto ivalue_profiler = [pr, pn](Stack& stack) { |
3706 | std::lock_guard<std::mutex> lock(pr->mutex_); |
3707 | |
3708 | // TODO: we don't care about merging multiple profiling runs as we don't |
3709 | // support it at all; |
3710 | int64_t frame_id = 0; |
3711 | pop(stack, frame_id); |
3712 | IValue value; |
3713 | pop(stack, value); |
3714 | |
3715 | std::vector<int64_t> size_vec; |
3716 | if (value.isIntList()) { |
3717 | size_vec = value.toIntVector(); |
3718 | } else if (value.isNone()) { |
3719 | size_vec.clear(); |
3720 | } else { |
3721 | TORCH_INTERNAL_ASSERT( |
3722 | false, |
3723 | "profileReductionSize does not support data type: " , |
3724 | value.tagKind()); |
3725 | } |
3726 | // We stop profiling when it has failed |
3727 | if (!pn->hasAttribute(profileFailedAttr)) { |
3728 | if (!pn->hasAttribute(reductionSizeAttr)) { |
3729 | pn->is_(reductionSizeAttr, size_vec); |
3730 | } else { |
3731 | auto profiled_ints = pn->is(reductionSizeAttr); |
3732 | if (profiled_ints.size() != size_vec.size() || |
3733 | !std::equal( |
3734 | profiled_ints.begin(), profiled_ints.end(), size_vec.begin())) { |
3735 | TORCH_WARN_ONCE( |
3736 | __FUNCTION__, |
3737 | " sees varying value in profiling, ignoring and this should be handled by GUARD logic" ); |
3738 | pn->s_(profileFailedAttr, "varying profile values" ); |
3739 | pn->removeAttribute(reductionSizeAttr); |
3740 | } |
3741 | } |
3742 | } else { |
3743 | TORCH_INTERNAL_ASSERT( |
3744 | !pn->hasAttribute(reductionSizeAttr), |
3745 | "profiled attribute should have been removed when profiling is marked as failed" ); |
3746 | } |
3747 | push(stack, value); |
3748 | }; |
3749 | pn->setCallback(ivalue_profiler); |
3750 | } |
3751 | |
3752 | void profileViewSize(ProfilingRecord* pr, Node* node, size_t offset) { |
3753 | auto pn = insertProfileIValueOp(node, offset, pr); |
3754 | |
3755 | const auto ivalue_profiler = [pr, pn](Stack& stack) { |
3756 | std::lock_guard<std::mutex> lock(pr->mutex_); |
3757 | |
3758 | // TODO: we don't care about merging multiple profiling runs as we don't |
3759 | // support it at all; |
3760 | int64_t frame_id = 0; |
3761 | pop(stack, frame_id); |
3762 | IValue value; |
3763 | pop(stack, value); |
3764 | TORCH_INTERNAL_ASSERT( |
3765 | value.isIntList(), "profiling seeing the wrong data type" ); |
3766 | if (!pn->hasAttribute(profileFailedAttr)) { |
3767 | if (!pn->hasAttribute(viewSizeAttr)) { |
3768 | pn->is_(viewSizeAttr, value.toIntVector()); |
3769 | } else { |
3770 | auto profiled_ints = pn->is(viewSizeAttr); |
3771 | auto input_ints = value.toIntList(); |
3772 | if (profiled_ints.size() != input_ints.size() || |
3773 | !std::equal( |
3774 | profiled_ints.begin(), |
3775 | profiled_ints.end(), |
3776 | input_ints.begin())) { |
3777 | TORCH_WARN_ONCE( |
3778 | __FUNCTION__, |
3779 | " sees varying value in profiling, ignoring and this should be handled by GUARD logic" ); |
3780 | pn->s_(profileFailedAttr, "varying profile values" ); |
3781 | pn->removeAttribute(viewSizeAttr); |
3782 | } |
3783 | } |
3784 | } else { |
3785 | TORCH_INTERNAL_ASSERT( |
3786 | !pn->hasAttribute(viewSizeAttr), |
3787 | "profiled attribute should have been removed when profiling is marked as failed" ); |
3788 | } |
3789 | push(stack, value); |
3790 | }; |
3791 | |
3792 | pn->setCallback(ivalue_profiler); |
3793 | } |
3794 | |
3795 | void profileIntList(ProfilingRecord* pr, Node* node, size_t offset) { |
3796 | auto pn = insertProfileIValueOp(node, offset, pr); |
3797 | |
3798 | const auto ivalue_profiler = [pr, pn](Stack& stack) { |
3799 | std::lock_guard<std::mutex> lock(pr->mutex_); |
3800 | |
3801 | // TODO: we don't care about merging multiple profiling runs as we don't |
3802 | // support it at all; |
3803 | int64_t frame_id = 0; |
3804 | pop(stack, frame_id); |
3805 | IValue value; |
3806 | pop(stack, value); |
3807 | TORCH_INTERNAL_ASSERT( |
3808 | value.isIntList(), "profiling seeing the wrong data type" ); |
3809 | if (!pn->hasAttribute(profileFailedAttr)) { |
3810 | if (!pn->hasAttribute(intListAttr)) { |
3811 | pn->is_(intListAttr, value.toIntVector()); |
3812 | } else { |
3813 | auto profiled_ints = pn->is(intListAttr); |
3814 | auto input_ints = value.toIntList(); |
3815 | if (profiled_ints.size() != input_ints.size() || |
3816 | !std::equal( |
3817 | profiled_ints.begin(), |
3818 | profiled_ints.end(), |
3819 | input_ints.begin())) { |
3820 | TORCH_WARN_ONCE( |
3821 | __FUNCTION__, |
3822 | " sees varying value in profiling, ignoring and this should be handled by GUARD logic" ); |
3823 | pn->s_(profileFailedAttr, "varying profile values" ); |
3824 | pn->removeAttribute(intListAttr); |
3825 | } |
3826 | } |
3827 | } else { |
3828 | TORCH_INTERNAL_ASSERT( |
3829 | !pn->hasAttribute(intListAttr), |
3830 | "profiled attribute should have been removed when profiling is marked as failed" ); |
3831 | } |
3832 | push(stack, value); |
3833 | }; |
3834 | |
3835 | pn->setCallback(ivalue_profiler); |
3836 | } |
3837 | |
3838 | void profileString(ProfilingRecord* pr, Node* node, size_t offset) { |
3839 | auto pn = insertProfileIValueOp(node, offset, pr); |
3840 | |
3841 | const auto ivalue_profiler = [pr, pn](Stack& stack) { |
3842 | std::lock_guard<std::mutex> lock(pr->mutex_); |
3843 | |
3844 | // TODO: we don't care about merging multiple profiling runs as we don't |
3845 | // support it at all; |
3846 | int64_t frame_id = 0; |
3847 | pop(stack, frame_id); |
3848 | IValue value; |
3849 | pop(stack, value); |
3850 | TORCH_INTERNAL_ASSERT( |
3851 | value.isString(), "profiling seeing the wrong data type" ); |
3852 | if (!pn->hasAttribute(profileFailedAttr)) { |
3853 | if (!pn->hasAttribute(strAttr)) { |
3854 | pn->s_(strAttr, value.toStringRef()); |
3855 | } else { |
3856 | const auto& profiled_str = pn->s(strAttr); |
3857 | const auto& input_str = value.toStringRef(); |
3858 | if (input_str != profiled_str) { |
3859 | TORCH_WARN_ONCE( |
3860 | __FUNCTION__, |
3861 | " sees varying value in profiling, ignoring and this should be handled by GUARD logic" ); |
3862 | pn->s_(profileFailedAttr, "varying profile values" ); |
3863 | pn->removeAttribute(strAttr); |
3864 | } |
3865 | } |
3866 | } else { |
3867 | TORCH_INTERNAL_ASSERT( |
3868 | !pn->hasAttribute(strAttr), |
3869 | "profiled attribute should have been removed when profiling is marked as failed" ); |
3870 | } |
3871 | push(stack, value); |
3872 | }; |
3873 | |
3874 | pn->setCallback(ivalue_profiler); |
3875 | } |
3876 | |
3877 | void profileBool(ProfilingRecord* pr, Node* node, size_t offset) { |
3878 | auto pn = insertProfileIValueOp(node, offset, pr); |
3879 | |
3880 | const auto ivalue_profiler = [pr, pn](Stack& stack) { |
3881 | std::lock_guard<std::mutex> lock(pr->mutex_); |
3882 | |
3883 | // TODO: we don't care about merging multiple profiling runs as we don't |
3884 | // support it at all; |
3885 | int64_t frame_id = 0; |
3886 | pop(stack, frame_id); |
3887 | IValue value; |
3888 | pop(stack, value); |
3889 | TORCH_INTERNAL_ASSERT( |
3890 | value.isBool(), "profiling seeing the wrong data type" ); |
3891 | if (!pn->hasAttribute(profileFailedAttr)) { |
3892 | if (!pn->hasAttribute(boolAttr)) { |
3893 | pn->i_(boolAttr, value.toBool()); |
3894 | } else { |
3895 | auto profiled_bool = pn->i(boolAttr); |
3896 | auto input_bool = value.toBool(); |
3897 | if (input_bool != profiled_bool) { |
3898 | TORCH_WARN_ONCE( |
3899 | __FUNCTION__, |
3900 | " sees varying value in profiling, ignoring and this should be handled by GUARD logic" ); |
3901 | pn->s_(profileFailedAttr, "varying profile values" ); |
3902 | pn->removeAttribute(boolAttr); |
3903 | } |
3904 | } |
3905 | } else { |
3906 | TORCH_INTERNAL_ASSERT( |
3907 | !pn->hasAttribute(boolAttr), |
3908 | "profiled attribute should have been removed when profiling is marked as failed" ); |
3909 | } |
3910 | push(stack, value); |
3911 | }; |
3912 | |
3913 | pn->setCallback(ivalue_profiler); |
3914 | } |
3915 | |
3916 | void profileInt(ProfilingRecord* pr, Node* node, size_t offset) { |
3917 | auto pn = insertProfileIValueOp(node, offset, pr); |
3918 | |
3919 | const auto ivalue_profiler = [pr, pn](Stack& stack) { |
3920 | std::lock_guard<std::mutex> lock(pr->mutex_); |
3921 | |
3922 | // TODO: we don't care about merging multiple profiling runs as we don't |
3923 | // support it at all; |
3924 | int64_t frame_id = 0; |
3925 | pop(stack, frame_id); |
3926 | IValue value; |
3927 | pop(stack, value); |
3928 | TORCH_INTERNAL_ASSERT( |
3929 | value.isInt(), "profiling seeing the wrong data type" ); |
3930 | if (!pn->hasAttribute(profileFailedAttr)) { |
3931 | if (!pn->hasAttribute(intAttr)) { |
3932 | pn->i_(intAttr, value.toInt()); |
3933 | } else { |
3934 | auto profiled_int = pn->i(intAttr); |
3935 | auto input_int = value.toInt(); |
3936 | if (input_int != profiled_int) { |
3937 | TORCH_WARN_ONCE( |
3938 | __FUNCTION__, |
3939 | " sees varying value in profiling, ignoring and this should be handled by GUARD logic" ); |
3940 | pn->s_(profileFailedAttr, "varying profile values" ); |
3941 | pn->removeAttribute(intAttr); |
3942 | } |
3943 | } |
3944 | } else { |
3945 | TORCH_INTERNAL_ASSERT( |
3946 | !pn->hasAttribute(intAttr), |
3947 | "profiled attribute should have been removed when profiling is marked as failed" ); |
3948 | } |
3949 | push(stack, value); |
3950 | }; |
3951 | |
3952 | pn->setCallback(ivalue_profiler); |
3953 | } |
3954 | |
3955 | // profile ivalue, used for optional arguments |
3956 | void profileIval(ProfilingRecord* pr, Node* node, size_t offset) { |
3957 | auto pn = insertProfileIValueOp(node, offset, pr); |
3958 | |
3959 | const auto ivalue_profiler = [pr, pn](Stack& stack) { |
3960 | std::lock_guard<std::mutex> lock(pr->mutex_); |
3961 | |
3962 | // TODO: we don't care about merging multiple profiling runs as we don't |
3963 | // support it at all; |
3964 | int64_t frame_id = 0; |
3965 | pop(stack, frame_id); |
3966 | IValue value; |
3967 | pop(stack, value); |
3968 | if (!pn->hasAttribute(profileFailedAttr)) { |
3969 | if (!pn->hasAttribute(ivalAttr)) { |
3970 | pn->ival_(ivalAttr, value); |
3971 | } else { |
3972 | auto profiled_ival = pn->ival(ivalAttr); |
3973 | if (value != profiled_ival) { |
3974 | TORCH_WARN_ONCE( |
3975 | __FUNCTION__, |
3976 | " sees varying value in profiling, ignoring and this should be handled by GUARD logic" ); |
3977 | pn->s_(profileFailedAttr, "varying profile values" ); |
3978 | pn->removeAttribute(ivalAttr); |
3979 | } |
3980 | } |
3981 | } else { |
3982 | TORCH_INTERNAL_ASSERT( |
3983 | !pn->hasAttribute(ivalAttr), |
3984 | "profiled attribute should have been removed when profiling is marked as failed" ); |
3985 | } |
3986 | push(stack, value); |
3987 | }; |
3988 | |
3989 | pn->setCallback(ivalue_profiler); |
3990 | } |
3991 | |
3992 | void profileBoolList(ProfilingRecord* pr, Node* node, size_t offset) { |
3993 | auto pn = insertProfileIValueOp(node, offset, pr); |
3994 | |
3995 | const auto ivalue_profiler = [pr, pn](Stack& stack) { |
3996 | std::lock_guard<std::mutex> lock(pr->mutex_); |
3997 | |
3998 | // TODO: we don't care about merging multiple profiling runs as we don't |
3999 | // support it at all; |
4000 | int64_t frame_id = 0; |
4001 | pop(stack, frame_id); |
4002 | IValue value; |
4003 | pop(stack, value); |
4004 | TORCH_INTERNAL_ASSERT( |
4005 | value.isBoolList(), "profiling seeing the wrong data type" ); |
4006 | if (!pn->hasAttribute(profileFailedAttr)) { |
4007 | if (!pn->hasAttribute(boolListAttr)) { |
4008 | auto list = value.toBoolList(); |
4009 | std::vector<int64_t> val(list.begin(), list.end()); |
4010 | pn->is_(boolListAttr, val); |
4011 | } else { |
4012 | auto profiled_ints = pn->is(boolListAttr); |
4013 | auto input_bools = value.toBoolList(); |
4014 | if (profiled_ints.size() != input_bools.size() || |
4015 | !std::equal( |
4016 | input_bools.begin(), |
4017 | input_bools.end(), |
4018 | profiled_ints.begin())) { |
4019 | TORCH_WARN_ONCE( |
4020 | __FUNCTION__, |
4021 | " sees varying value in profiling, ignoring and this should be handled by GUARD logic" ); |
4022 | pn->s_(profileFailedAttr, "varying profile values" ); |
4023 | pn->removeAttribute(boolListAttr); |
4024 | } |
4025 | } |
4026 | } else { |
4027 | TORCH_INTERNAL_ASSERT( |
4028 | !pn->hasAttribute(boolListAttr), |
4029 | "profiled attribute should have been removed when profiling is marked as failed" ); |
4030 | } |
4031 | push(stack, value); |
4032 | }; |
4033 | |
4034 | pn->setCallback(ivalue_profiler); |
4035 | } |
4036 | |
4037 | bool anyInBlock( |
4038 | const Block* block, |
4039 | const std::function<bool(const Node*)>& fn) { |
4040 | for (auto node : block->nodes()) { |
4041 | if (fn(node)) { |
4042 | return true; |
4043 | } |
4044 | for (auto block : node->blocks()) { |
4045 | if (anyInBlock(block, fn)) { |
4046 | return true; |
4047 | } |
4048 | } |
4049 | } |
4050 | return false; |
4051 | } |
4052 | |
4053 | } // namespace |
4054 | |
4055 | bool hasReductionNode(const Block* block) { |
4056 | return anyInBlock(block, isReductionNode); |
4057 | } |
4058 | |
4059 | bool isReductionNode(const Node* node) { |
4060 | return IrParser::isReductionNode(node); |
4061 | } |
4062 | |
4063 | bool isReductionToSizeNode(const Node* node) { |
4064 | return IrParser::isReductionToSizeNode(node); |
4065 | } |
4066 | |
4067 | bool hasNormalizationNode(const Block* block) { |
4068 | return anyInBlock(block, isNormalizationNode); |
4069 | } |
4070 | |
4071 | bool isNormalizationNode(const Node* node) { |
4072 | return IrParser::isNormalizationNode(node); |
4073 | } |
4074 | |
4075 | bool isElementWiseNode(const Node* node) { |
4076 | return IrParser::isElementWiseNode(node); |
4077 | } |
4078 | |
4079 | bool isNodeParsible(const Node* node) { |
4080 | return IrParser::canParseNode(node); |
4081 | } |
4082 | |
4083 | bool shouldProfileNode(const Node* node) { |
4084 | return IrParser::lookupInSymbolSet(node); |
4085 | } |
4086 | |
4087 | bool skipNodeKind(const std::string& symbol_str, bool flip) { |
4088 | return IrParser::querySkipSymbolSet( |
4089 | c10::Symbol::fromQualString(symbol_str), flip); |
4090 | } |
4091 | |
4092 | bool insertProfileIValue(ProfilingRecord* pr, Node* node, size_t offset) { |
4093 | // is skip constant necessary? |
4094 | if (node->input(offset)->node()->kind() == prim::Constant) { |
4095 | return false; |
4096 | } |
4097 | |
4098 | static auto dropout_schema = |
4099 | getOperatorForLiteral( |
4100 | "aten::dropout(Tensor input, float p, bool train) -> Tensor" ) |
4101 | ->schema(); |
4102 | static auto native_dropout_schema = |
4103 | getOperatorForLiteral( |
4104 | "aten::native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor)" ) |
4105 | ->schema(); |
4106 | if (node->matches(dropout_schema) || node->matches(native_dropout_schema)) { |
4107 | switch (offset) { |
4108 | // argument 2: Is training? |
4109 | case 2: |
4110 | profileBool(pr, node, offset); |
4111 | break; |
4112 | default: |
4113 | return false; |
4114 | } |
4115 | return true; |
4116 | } |
4117 | |
4118 | static auto amax_schema = |
4119 | getOperatorForLiteral( |
4120 | "aten::amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor" ) |
4121 | ->schema(); |
4122 | static auto amin_schema = |
4123 | getOperatorForLiteral( |
4124 | "aten::amin(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor" ) |
4125 | ->schema(); |
4126 | if (node->matches(amax_schema) || node->matches(amin_schema)) { |
4127 | switch (offset) { |
4128 | // argument 1: reduction axes; |
4129 | case 1: |
4130 | profileIntList(pr, node, offset); |
4131 | break; |
4132 | // argument 2: keepdim; |
4133 | case 2: |
4134 | profileBool(pr, node, offset); |
4135 | break; |
4136 | default: |
4137 | return false; |
4138 | } |
4139 | return true; |
4140 | } |
4141 | |
4142 | static auto reduction_operator_schema = |
4143 | getOperatorForLiteral( |
4144 | "aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)" ) |
4145 | ->schema(); |
4146 | if (node->matches(reduction_operator_schema)) { |
4147 | switch (offset) { |
4148 | // argument 1: reduction axes; |
4149 | case 1: |
4150 | profileIntList(pr, node, offset); |
4151 | break; |
4152 | // argument 2: keepdim; |
4153 | case 2: |
4154 | profileBool(pr, node, offset); |
4155 | break; |
4156 | default: |
4157 | return false; |
4158 | } |
4159 | return true; |
4160 | } |
4161 | |
4162 | static auto sum_to_size_schema = |
4163 | getOperatorForLiteral( |
4164 | "aten::sum_to_size(Tensor self, int[] size) -> Tensor" ) |
4165 | ->schema(); |
4166 | static auto grad_sum_to_size_schema = |
4167 | getOperatorForLiteral( |
4168 | "aten::_grad_sum_to_size(Tensor(a) self, int[]? size) -> Tensor(a)" ) |
4169 | ->schema(); |
4170 | if (node->matches(sum_to_size_schema) || |
4171 | node->matches(grad_sum_to_size_schema)) { |
4172 | switch (offset) { |
4173 | // argument 1: reduction sizes; |
4174 | case 1: |
4175 | // TODO(profile_size): double check optional[size]? |
4176 | profileReductionSize(pr, node, offset); |
4177 | break; |
4178 | default: |
4179 | return false; |
4180 | } |
4181 | return true; |
4182 | } |
4183 | |
4184 | static auto reshape_schema = |
4185 | getOperatorForLiteral("aten::reshape(Tensor self, int[] shape) -> Tensor" ) |
4186 | ->schema(); |
4187 | static auto reshape_copy_schema = |
4188 | getOperatorForLiteral( |
4189 | "prim::reshape_copy(Tensor self, int[] shape) -> Tensor" ) |
4190 | ->schema(); |
4191 | static auto view_schema = |
4192 | getOperatorForLiteral("aten::view(Tensor self, int[] size) -> Tensor" ) |
4193 | ->schema(); |
4194 | static auto view_copy_schema = |
4195 | getOperatorForLiteral( |
4196 | "prim::view_copy(Tensor self, int[] size) -> Tensor" ) |
4197 | ->schema(); |
4198 | if (node->matches(reshape_schema) || node->matches(reshape_copy_schema) || |
4199 | node->matches(view_schema) || node->matches(view_copy_schema)) { |
4200 | switch (offset) { |
4201 | // argument 1: new tensor size; |
4202 | case 1: |
4203 | profileViewSize(pr, node, offset); |
4204 | break; |
4205 | default: |
4206 | return false; |
4207 | } |
4208 | return true; |
4209 | } |
4210 | |
4211 | static auto flatten_schema1 = |
4212 | getOperatorForLiteral( |
4213 | "aten::flatten.using_ints(Tensor self, int start_dim=0, int end_dim=-1) -> Tensor" ) |
4214 | ->schema(); |
4215 | static auto flatten_schema2 = |
4216 | getOperatorForLiteral( |
4217 | "prim::flatten_copy(Tensor self, int start_dim, int end_dim) -> Tensor" ) |
4218 | ->schema(); |
4219 | if (node->matches(flatten_schema1) || node->matches(flatten_schema2)) { |
4220 | switch (offset) { |
4221 | // argument 1: start_dim; |
4222 | // argument 2: end_dim; |
4223 | case 1: |
4224 | case 2: |
4225 | profileInt(pr, node, offset); |
4226 | break; |
4227 | default: |
4228 | return false; |
4229 | } |
4230 | return true; |
4231 | } |
4232 | |
4233 | static auto squeeze_dim_schema = |
4234 | getOperatorForLiteral( |
4235 | "prim::squeeze_copy.dim(Tensor self, int dim) -> Tensor" ) |
4236 | ->schema(); |
4237 | static auto unsqueeze_schema = |
4238 | getOperatorForLiteral( |
4239 | "prim::unsqueeze_copy(Tensor self, int dim) -> Tensor" ) |
4240 | ->schema(); |
4241 | if (node->matches(squeeze_dim_schema) || node->matches(unsqueeze_schema)) { |
4242 | switch (offset) { |
4243 | // argument 1: unsqueeze dim; |
4244 | case 1: |
4245 | profileInt(pr, node, offset); |
4246 | break; |
4247 | default: |
4248 | return false; |
4249 | } |
4250 | return true; |
4251 | } |
4252 | |
4253 | static auto permute_schema = |
4254 | getOperatorForLiteral( |
4255 | "aten::permute(Tensor(a) self, int[] dims) -> Tensor(a)" ) |
4256 | ->schema(); |
4257 | static auto permute_copy_schema = |
4258 | getOperatorForLiteral( |
4259 | "prim::permute_copy(Tensor(a) self, int[] dims) -> Tensor" ) |
4260 | ->schema(); |
4261 | if (node->matches(permute_schema) || node->matches(permute_copy_schema)) { |
4262 | switch (offset) { |
4263 | // argument 1: dims; |
4264 | case 1: |
4265 | profileIntList(pr, node, offset); |
4266 | break; |
4267 | default: |
4268 | return false; |
4269 | } |
4270 | return true; |
4271 | } |
4272 | |
4273 | static auto transpose_int_copy_schema = |
4274 | getOperatorForLiteral( |
4275 | "aten::transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a)" ) |
4276 | ->schema(); |
4277 | static auto transpose_int_schema = |
4278 | getOperatorForLiteral( |
4279 | "prim::transpose_copy.int(Tensor(a) self, int dim0, int dim1) -> Tensor" ) |
4280 | ->schema(); |
4281 | if (node->matches(transpose_int_copy_schema) || |
4282 | node->matches(transpose_int_schema)) { |
4283 | switch (offset) { |
4284 | // argument 1: dim0; |
4285 | // argument 2: dim1; |
4286 | case 1: |
4287 | case 2: |
4288 | profileInt(pr, node, offset); |
4289 | break; |
4290 | default: |
4291 | return false; |
4292 | } |
4293 | return true; |
4294 | } |
4295 | |
4296 | static auto batch_norm_impl_index_schema = |
4297 | getOperatorForLiteral( |
4298 | "aten::_batch_norm_impl_index(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor, Tensor, Tensor, Tensor, int)" ) |
4299 | ->schema(); |
4300 | static auto native_batch_norm_schema = |
4301 | getOperatorForLiteral( |
4302 | "aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)" ) |
4303 | ->schema(); |
4304 | static auto batch_norm_schema = |
4305 | getOperatorForLiteral( |
4306 | "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor" ) |
4307 | ->schema(); |
4308 | static auto instance_norm_schema = |
4309 | getOperatorForLiteral( |
4310 | "aten::instance_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool use_input_stats, float momentum, float eps, bool cudnn_enabled) -> Tensor" ) |
4311 | ->schema(); |
4312 | if (node->matches(native_batch_norm_schema) || |
4313 | node->matches(batch_norm_impl_index_schema) || |
4314 | node->matches(batch_norm_schema) || node->matches(instance_norm_schema)) { |
4315 | switch (offset) { |
4316 | // argument 5: training; |
4317 | case 5: |
4318 | profileBool(pr, node, offset); |
4319 | break; |
4320 | default: |
4321 | return false; |
4322 | } |
4323 | return true; |
4324 | } |
4325 | |
4326 | static auto gelu_schema = |
4327 | getOperatorForLiteral( |
4328 | "aten::gelu(Tensor self, *, str approximate='none') -> Tensor" ) |
4329 | ->schema(); |
4330 | if (node->matches(gelu_schema)) { |
4331 | switch (offset) { |
4332 | // argument 1: approximate; |
4333 | case 1: |
4334 | profileString(pr, node, offset); |
4335 | break; |
4336 | default: |
4337 | return false; |
4338 | } |
4339 | return true; |
4340 | } |
4341 | |
4342 | static auto gelu_backward_schema = |
4343 | getOperatorForLiteral( |
4344 | "aten::gelu_backward(Tensor grad_output, Tensor self, *, str approximate='none') -> Tensor" ) |
4345 | ->schema(); |
4346 | if (node->matches(gelu_backward_schema)) { |
4347 | switch (offset) { |
4348 | // argument 2: approximate; |
4349 | case 2: |
4350 | profileString(pr, node, offset); |
4351 | break; |
4352 | default: |
4353 | return false; |
4354 | } |
4355 | return true; |
4356 | } |
4357 | |
4358 | static auto native_layer_norm_schema = |
4359 | getOperatorForLiteral( |
4360 | "aten::native_layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor)" ) |
4361 | ->schema(); |
4362 | static auto layer_norm_schema = |
4363 | getOperatorForLiteral( |
4364 | "aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor" ) |
4365 | ->schema(); |
4366 | if (node->matches(native_layer_norm_schema) || |
4367 | node->matches(layer_norm_schema)) { |
4368 | switch (offset) { |
4369 | case 1: |
4370 | profileIntList(pr, node, offset); |
4371 | break; |
4372 | default: |
4373 | return false; |
4374 | } |
4375 | return true; |
4376 | } |
4377 | |
4378 | static auto batch_norm_impl_index_backward_schema = |
4379 | getOperatorForLiteral( |
4380 | "aten::_batch_norm_impl_index_backward(int impl_index, Tensor input, Tensor grad_output, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var_transform, bool train, float eps, bool[3] output_mask, Tensor reservedSpace) -> (Tensor, Tensor, Tensor)" ) |
4381 | ->schema(); |
4382 | if (node->matches(batch_norm_impl_index_backward_schema)) { |
4383 | switch (offset) { |
4384 | // TODO: guard impl_index, but I think that's not needed; |
4385 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4386 | case 8: // argument 8: training; |
4387 | profileBool(pr, node, offset); |
4388 | break; |
4389 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4390 | case 10: |
4391 | profileBoolList(pr, node, offset); |
4392 | break; |
4393 | default: |
4394 | return false; |
4395 | } |
4396 | return true; |
4397 | } |
4398 | |
4399 | static auto batch_norm_backward_schema = |
4400 | getOperatorForLiteral( |
4401 | "aten::native_batch_norm_backward(Tensor grad_out, Tensor input, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_invstd, bool train, float eps, bool[3] output_mask) -> (Tensor, Tensor, Tensor)" ) |
4402 | ->schema(); |
4403 | if (node->matches(batch_norm_backward_schema)) { |
4404 | switch (offset) { |
4405 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4406 | case 7: // argument 8: training; |
4407 | profileBool(pr, node, offset); |
4408 | break; |
4409 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4410 | case 9: |
4411 | profileBoolList(pr, node, offset); |
4412 | break; |
4413 | default: |
4414 | return false; |
4415 | } |
4416 | return true; |
4417 | } |
4418 | |
4419 | static auto native_layer_norm_backward_schema = |
4420 | getOperatorForLiteral( |
4421 | "aten::native_layer_norm_backward(Tensor grad_out, Tensor input, int[] normalized_shape, Tensor mean, Tensor rstd, Tensor? weight, Tensor? bias, bool[3] output_mask) -> (Tensor, Tensor, Tensor)" ) |
4422 | ->schema(); |
4423 | if (node->matches(native_layer_norm_backward_schema)) { |
4424 | switch (offset) { |
4425 | case 2: |
4426 | profileIntList(pr, node, offset); |
4427 | break; |
4428 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
4429 | case 7: |
4430 | profileBoolList(pr, node, offset); |
4431 | break; |
4432 | default: |
4433 | return false; |
4434 | } |
4435 | return true; |
4436 | } |
4437 | |
4438 | static auto to_copy_schema = |
4439 | getOperatorForLiteral( |
4440 | "aten::_to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor" ) |
4441 | ->schema(); |
4442 | if (node->matches(to_copy_schema)) { |
4443 | switch (offset) { |
4444 | case 1: |
4445 | profileInt(pr, node, offset); |
4446 | return true; |
4447 | default: |
4448 | return false; |
4449 | } |
4450 | } |
4451 | |
4452 | static auto to_dtype_schema = |
4453 | getOperatorForLiteral( |
4454 | "aten::to.dtype(Tensor self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor" ) |
4455 | ->schema(); |
4456 | if (node->matches(to_dtype_schema)) { |
4457 | switch (offset) { |
4458 | case 1: |
4459 | profileInt(pr, node, offset); |
4460 | return true; |
4461 | default: |
4462 | return false; |
4463 | } |
4464 | } |
4465 | |
4466 | static auto log_softmax_data_schema = |
4467 | getOperatorForLiteral( |
4468 | "aten::log_softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor" ) |
4469 | ->schema(); |
4470 | static auto softmax_data_schema = |
4471 | getOperatorForLiteral( |
4472 | "aten::softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor" ) |
4473 | ->schema(); |
4474 | if (node->matches(log_softmax_data_schema) || |
4475 | node->matches(softmax_data_schema)) { |
4476 | switch (offset) { |
4477 | case 2: |
4478 | profileIval(pr, node, offset); |
4479 | return true; |
4480 | default: |
4481 | return false; |
4482 | } |
4483 | } |
4484 | |
4485 | static auto log_softmax_backward_data_schema = |
4486 | getOperatorForLiteral( |
4487 | "aten::_log_softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor" ) |
4488 | ->schema(); |
4489 | static auto softmax_backward_data_schema = |
4490 | getOperatorForLiteral( |
4491 | "aten::_softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor" ) |
4492 | ->schema(); |
4493 | if (node->matches(log_softmax_backward_data_schema) || |
4494 | node->matches(softmax_backward_data_schema)) { |
4495 | switch (offset) { |
4496 | case 2: |
4497 | profileInt(pr, node, offset); |
4498 | return true; |
4499 | case 3: |
4500 | profileInt(pr, node, offset); |
4501 | return true; |
4502 | default: |
4503 | return false; |
4504 | } |
4505 | } |
4506 | |
4507 | static auto var_dim_schema = |
4508 | getOperatorForLiteral( |
4509 | "aten::var.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> Tensor" ) |
4510 | ->schema(); |
4511 | static auto std_dim_schema = |
4512 | getOperatorForLiteral( |
4513 | "aten::std.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> Tensor" ) |
4514 | ->schema(); |
4515 | if (node->matches(var_dim_schema) || node->matches(std_dim_schema)) { |
4516 | switch (offset) { |
4517 | case 1: |
4518 | profileIntList(pr, node, offset); |
4519 | return true; |
4520 | case 2: |
4521 | profileBool(pr, node, offset); |
4522 | return true; |
4523 | case 3: |
4524 | profileBool(pr, node, offset); |
4525 | return true; |
4526 | default: |
4527 | return false; |
4528 | } |
4529 | } |
4530 | |
4531 | return false; |
4532 | } |
4533 | |
4534 | void insertProfileNodesForCUDAFuser_(Block* block, ProfilingRecord* pr) { |
4535 | for (const auto& n : block->nodes()) { |
4536 | for (const auto offset : c10::irange(n->inputs().size())) { |
4537 | insertProfileIValue(pr, n, offset); |
4538 | } |
4539 | |
4540 | for (auto ib : n->blocks()) { |
4541 | insertProfileNodesForCUDAFuser_(ib, pr); |
4542 | } |
4543 | } |
4544 | } |
4545 | |
4546 | void InsertProfileNodes(ProfilingRecord* pr) { |
4547 | insertProfileNodesForCUDAFuser_(pr->profiled_graph_->block(), pr); |
4548 | } |
4549 | |
4550 | std::unique_ptr<Fusion> parseJitIR(const std::shared_ptr<Graph>& graph) { |
4551 | FUSER_PERF_SCOPE("parseJitIR" ); |
4552 | |
4553 | IrParser parser(graph); |
4554 | return parser.parse(); |
4555 | } |
4556 | |
4557 | } // namespace cuda |
4558 | } // namespace fuser |
4559 | } // namespace jit |
4560 | } // namespace torch |
4561 | |