1#include <torch/csrc/autograd/function.h>
2#include <torch/csrc/profiler/kineto_shim.h>
3#include <torch/csrc/profiler/util.h>
4
5#include <c10/util/ArrayRef.h>
6#include <c10/util/irange.h>
7#include <fmt/format.h>
8
9#ifdef USE_KINETO
10#include <libkineto.h>
11#endif
12
13namespace torch {
14namespace profiler {
15namespace impl {
16
17ApproximateClockToUnixTimeConverter::ApproximateClockToUnixTimeConverter()
18 : start_times_(measurePairs()) {}
19
20ApproximateClockToUnixTimeConverter::UnixAndApproximateTimePair
21ApproximateClockToUnixTimeConverter::measurePair() {
22 // Take a measurement on either side to avoid an ordering bias.
23 auto fast_0 = getApproximateTime();
24 auto wall = std::chrono::system_clock::now();
25 auto fast_1 = getApproximateTime();
26
27 TORCH_INTERNAL_ASSERT(fast_1 >= fast_0, "getCount is non-monotonic.");
28 auto t = std::chrono::duration_cast<std::chrono::nanoseconds>(
29 wall.time_since_epoch());
30
31 // `x + (y - x) / 2` is a more numerically stable average than `(x + y) / 2`.
32 return {t.count(), fast_0 + (fast_1 - fast_0) / 2};
33}
34
35ApproximateClockToUnixTimeConverter::time_pairs
36ApproximateClockToUnixTimeConverter::measurePairs() {
37 static constexpr auto n_warmup = 5;
38 for (C10_UNUSED const auto _ : c10::irange(n_warmup)) {
39 getApproximateTime();
40 steady_clock_t::now();
41 }
42
43 time_pairs out;
44 for (const auto i : c10::irange(out.size())) {
45 out[i] = measurePair();
46 }
47 return out;
48}
49
50std::function<time_t(approx_time_t)> ApproximateClockToUnixTimeConverter::
51 makeConverter() {
52 auto end_times = measurePairs();
53
54 // Compute the real time that passes for each tick of the approximate clock.
55 std::array<long double, replicates> scale_factors{};
56 for (const auto i : c10::irange(replicates)) {
57 auto delta_ns = end_times[i].t_ - start_times_[i].t_;
58 auto delta_approx = end_times[i].approx_t_ - start_times_[i].approx_t_;
59 scale_factors[i] = (double)delta_ns / (double)delta_approx;
60 }
61 std::sort(scale_factors.begin(), scale_factors.end());
62 long double scale_factor = scale_factors[replicates / 2 + 1];
63
64 // We shift all times by `t0` for better numerics. Double precision only has
65 // 16 decimal digits of accuracy, so if we blindly multiply times by
66 // `scale_factor` we may suffer from precision loss. The choice of `t0` is
67 // mostly arbitrary; we just need a factor that is the correct order of
68 // magnitude to bring the intermediate values closer to zero. We are not,
69 // however, guaranteed that `t0_approx` is *exactly* the getApproximateTime
70 // equivilent of `t0`; it is only an estimate that we have to fine tune.
71 auto t0 = start_times_[0].t_;
72 auto t0_approx = start_times_[0].approx_t_;
73 std::array<double, replicates> t0_correction{};
74 for (const auto i : c10::irange(replicates)) {
75 auto dt = start_times_[i].t_ - t0;
76 auto dt_approx =
77 (double)(start_times_[i].approx_t_ - t0_approx) * scale_factor;
78 t0_correction[i] = dt - (time_t)dt_approx;
79 }
80 t0 += t0_correction[t0_correction.size() / 2 + 1];
81
82 return [=](approx_time_t t_approx) {
83 // See above for why this is more stable than `A * t_approx + B`.
84 return (time_t)((double)(t_approx - t0_approx) * scale_factor) + t0;
85 };
86}
87
88namespace {
89c10::optional<bool> soft_assert_raises_;
90} // namespace
91
92void setSoftAssertRaises(c10::optional<bool> value) {
93 soft_assert_raises_ = value;
94}
95
96bool softAssertRaises() {
97 return soft_assert_raises_.value_or(false);
98}
99
100void logSoftAssert(
101 const char* func,
102 const char* file,
103 uint32_t line,
104 const char* cond,
105 const char* args) {
106#ifdef USE_KINETO
107 std::string error;
108 error = fmt::format(
109 "{} SOFT ASSERT FAILED at {}:{}, func: {}, args: {}",
110 cond,
111 file,
112 line,
113 func,
114 args);
115 // TODO: Implement profile_id and group_profile_id as 3rd/4th arguments.
116 kineto::logInvariantViolation(cond, error, "", "");
117#endif
118}
119
120void logSoftAssert(
121 const char* func,
122 const char* file,
123 uint32_t line,
124 const char* cond,
125 const std::string& args) {
126#ifdef USE_KINETO
127 std::string error;
128 error = fmt::format(
129 "{} SOFT ASSERT FAILED at {}:{}, func: {}, args: {}",
130 cond,
131 file,
132 line,
133 func,
134 args);
135 // TODO: Implement profile_id and group_profile_id as 3rd/4th arguments.
136 kineto::logInvariantViolation(cond, error, "", "");
137#endif
138}
139
140// ----------------------------------------------------------------------------
141// -- NVTX --------------------------------------------------------------------
142// ----------------------------------------------------------------------------
143std::string getNvtxStr(
144 const char* name,
145 int64_t sequence_nr,
146 const std::vector<std::vector<int64_t>>& shapes,
147 at::RecordFunctionHandle op_id,
148 const std::list<std::pair<at::RecordFunctionHandle, int>>& input_op_ids) {
149 if (sequence_nr >= -1 || !shapes.empty()) {
150 std::string str;
151 if (sequence_nr >= 0) {
152 str = fmt::format("{}, seq = {}", name, sequence_nr);
153 } else if (sequence_nr == -1) {
154 str = name;
155 } else {
156#if defined(USE_ROCM)
157 // Only ROCM supports < -1 sequence_nr
158 str = name;
159#endif
160 }
161 if (op_id > 0) {
162 str = fmt::format("{}, op_id = {}", str, op_id);
163 }
164 if (!shapes.empty()) {
165 str = fmt::format("{}, sizes = {}", str, shapesToStr(shapes));
166 }
167 // Include the op ids of the input edges so
168 // you can build the network graph
169 if (!input_op_ids.empty()) {
170 str = fmt::format(
171 "{}, input_op_ids = {}", str, inputOpIdsToStr(input_op_ids));
172 }
173 return str;
174 } else {
175 return name;
176 }
177}
178
179// ----------------------------------------------------------------------------
180// -- Op context (shapes, call stack) -----------------------------------------
181// ----------------------------------------------------------------------------
182std::vector<FileLineFunc> prepareCallstack(
183 const std::vector<jit::StackEntry>& cs) {
184 std::vector<FileLineFunc> entries;
185 entries.reserve(cs.size());
186 for (const auto& entry : cs) {
187 auto& range = entry.range;
188 if (range.source()) {
189 auto& src = range.source();
190 if (src && src->filename()) {
191 auto line =
192 src->starting_line_no() + src->lineno_for_offset(range.start());
193 entries.emplace_back(
194 FileLineFunc{*(src->filename()), line, entry.filename});
195 }
196 }
197 }
198 return entries;
199}
200
201std::vector<std::string> callstackStr(const std::vector<FileLineFunc>& cs) {
202 std::vector<std::string> cs_str;
203 cs_str.reserve(cs.size());
204 for (const auto& entry : cs) {
205 std::stringstream loc;
206 loc << entry.filename << "(" << entry.line << "): " << entry.funcname;
207 cs_str.push_back(loc.str());
208 }
209 return cs_str;
210}
211
212std::string stacksToStr(
213 const std::vector<std::string>& stacks,
214 const char* delim) {
215 std::ostringstream oss;
216 std::transform(
217 stacks.begin(),
218 stacks.end(),
219 std::ostream_iterator<std::string>(oss, delim),
220 [](std::string s) -> std::string {
221#ifdef _WIN32
222 // replace the windows backslash with forward slash
223 std::replace(s.begin(), s.end(), '\\', '/');
224#endif
225 return s;
226 });
227 auto rc = oss.str();
228 return "\"" + rc + "\"";
229}
230
231std::vector<std::vector<int64_t>> flattenList(
232 c10::List<c10::IValue> list,
233 std::string fn_name) {
234 std::vector<std::vector<int64_t>> tensor_dims;
235 for (const c10::IValue& input : list) {
236 if (input.isTensor()) {
237 const at::Tensor& tensor = input.toTensor();
238 if (tensor.defined()) {
239 tensor_dims.push_back(input.toTensor().sizes().vec());
240 }
241 }
242 }
243 return tensor_dims;
244}
245
246std::vector<std::vector<int64_t>> inputSizes(
247 const at::RecordFunction& fn,
248 bool flatten_list_enabled) {
249 std::vector<std::vector<int64_t>> sizes;
250 sizes.reserve(fn.inputs().size());
251 for (const c10::IValue& input : fn.inputs()) {
252 if (input.isTensor()) {
253 const at::Tensor& tensor = input.toTensor();
254 if (tensor.defined()) {
255 sizes.push_back(input.toTensor().sizes().vec());
256 } else {
257 sizes.emplace_back();
258 }
259 } else if (input.isList()) {
260 std::vector<std::vector<int64_t>> tmp_sizes;
261 if (flatten_list_enabled) {
262 tmp_sizes = flattenList(input.toList(), std::string(fn.name()));
263 }
264 // Extend the current sizes array by the array returned from input sizes
265 if (!tmp_sizes.empty()) {
266 sizes.insert(sizes.end(), tmp_sizes.begin(), tmp_sizes.end());
267 } else {
268 sizes.emplace_back();
269 }
270 } else {
271 sizes.emplace_back();
272 }
273 }
274 return sizes;
275}
276
277std::string shapesToStr(const std::vector<std::vector<int64_t>>& shapes) {
278 std::string str("[");
279 for (const auto t_idx : c10::irange(shapes.size())) {
280 if (t_idx > 0) {
281 str = fmt::format("{}, ", str);
282 }
283 str = fmt::format("{}[", str);
284 for (const auto s_idx : c10::irange(shapes[t_idx].size())) {
285 if (s_idx > 0) {
286 str = fmt::format("{}, ", str);
287 }
288 str = fmt::format("{}{}", str, shapes[t_idx][s_idx]);
289 }
290 str = fmt::format("{}]", str);
291 }
292 str = fmt::format("{}]", str);
293 return str;
294}
295
296std::string inputOpIdsToStr(
297 const std::list<std::pair<at::RecordFunctionHandle, int>>& input_op_ids) {
298 std::string str("[");
299 int idx = 0;
300
301 for (const auto& op_id_info_pair : input_op_ids) {
302 if (idx++ > 0) {
303 str = fmt::format("{}, ", str);
304 }
305 // (OpId,OutputNr)
306 str = fmt::format(
307 "{}({},{})", str, op_id_info_pair.first, op_id_info_pair.second);
308 }
309 str = fmt::format("{}]", str);
310 return str;
311}
312
313std::string dtypesToStr(const std::vector<std::string>& types) {
314 if (types.empty()) {
315 return "[]";
316 } else {
317 std::ostringstream oss;
318 std::transform(
319 types.begin(),
320 types.end(),
321 std::ostream_iterator<std::string>(oss, ", "),
322 [](std::string s) -> std::string { return "\"" + s + "\""; });
323 auto rc = oss.str();
324 rc.erase(rc.length() - 2); // remove last ", "
325 return "[" + rc + "]";
326 }
327}
328
329std::vector<std::string> inputTypes(const at::RecordFunction& fn) {
330 std::vector<std::string> types;
331 types.reserve(fn.inputs().size());
332 for (const c10::IValue& input : fn.inputs()) {
333 if (input.isTensor()) {
334 const at::Tensor& tensor = input.toTensor();
335 if (tensor.defined()) {
336 types.push_back(
337 static_cast<std::string>(input.toTensor().dtype().name()));
338 } else {
339 types.emplace_back();
340 }
341 } else if (input.isScalar() || input.isList()) {
342 types.push_back(input.tagKind());
343 } else {
344 types.emplace_back();
345 }
346 }
347 return types;
348}
349
350// ----------------------------------------------------------------------------
351// -- FLOPS -------------------------------------------------------------------
352// ----------------------------------------------------------------------------
353static constexpr auto kConv2dStride = 3;
354static constexpr auto kConv2dPadding = 4;
355static constexpr auto kConv2dDilation = 5;
356static constexpr auto kConv2dGroups = 6;
357
358// List of supported operators
359static constexpr auto kConv2dOp = "aten::conv2d";
360static constexpr auto kMMOp = "aten::mm";
361static constexpr auto kAddMMOp = "aten::addmm";
362static constexpr auto kMulOp = "aten::mul";
363static constexpr auto kAddOp = "aten::add";
364static constexpr auto kBMMOp = "aten::bmm";
365static constexpr auto kBAddBMMOp = "aten::baddbmm";
366
367static constexpr auto kInputSize = "input_size";
368static constexpr auto kWeightSize = "weight_size";
369static constexpr auto kGroups = "groups";
370static constexpr auto kPadding = "padding";
371static constexpr auto kStride = "stride";
372static constexpr auto kDilation = "dilation";
373static constexpr auto kMatSize = "mat_size";
374static constexpr auto kMat1Size = "mat1_size";
375static constexpr auto kMat2Size = "mat2_size";
376
377static bool validateInput(
378 const std::string& op_name,
379 size_t min_size,
380 c10::ArrayRef<const c10::IValue> inputs,
381 const c10::ArrayRef<int>& should_be_tensor) {
382 std::stringstream ss;
383 if (inputs.size() < min_size) {
384 ss << "Failed to save extra arguments for flops computation of op "
385 << op_name << ", min size: " << min_size
386 << ", actual size: " << inputs.size();
387 TORCH_WARN(ss.str());
388 return false;
389 }
390 for (auto index : should_be_tensor) {
391 if (!inputs[index].isTensor()) {
392 ss << "Failed to save extra arguments for flops computation of op "
393 << op_name << ", input[" << index << "] must be a tensor.";
394 TORCH_WARN(ss.str());
395 return false;
396 }
397 }
398 return true;
399}
400
401std::unordered_map<std::string, c10::IValue> saveExtraArgs(
402 const at::RecordFunction& fn) {
403 // for specific types of fn, return the saved extra args for computing flops
404 std::unordered_map<std::string, c10::IValue> map;
405 auto inputs = fn.inputs();
406 std::string fname(fn.name());
407
408 if (inputs.empty()) {
409 // Input shape is unavailable, return empty map
410 return map;
411 }
412
413 if (fname == kConv2dOp) {
414 bool check = validateInput(fname, kConv2dGroups + 1, inputs, {0, 1});
415 if (!check) {
416 return map;
417 }
418
419 at::Tensor input = inputs[0].toTensor();
420 at::Tensor weight = inputs[1].toTensor();
421 if (weight.sizes().size() != 4) {
422 TORCH_WARN(
423 "Failed to compute flops for op aten::conv2d because it requires a 4D kernel tensor.");
424 return map;
425 }
426 map[kInputSize] = at::IValue(input.sizes());
427 map[kWeightSize] = at::IValue(weight.sizes());
428 map[kStride] = inputs[kConv2dStride];
429 map[kPadding] = inputs[kConv2dPadding];
430 map[kDilation] = inputs[kConv2dDilation];
431 map[kGroups] = inputs[kConv2dGroups];
432 } else if (fname == kMMOp) {
433 bool check = validateInput(fname, 2, inputs, {0, 1});
434 if (!check) {
435 return map;
436 }
437
438 at::Tensor left = inputs[0].toTensor();
439 at::Tensor right = inputs[1].toTensor();
440 map[kMat1Size] = at::IValue(left.sizes());
441 map[kMat2Size] = at::IValue(right.sizes());
442 } else if (fname == kAddMMOp) {
443 bool check = validateInput(fname, 3, inputs, {0, 1, 2});
444 if (!check) {
445 return map;
446 }
447
448 // Exact FLOP count depends on scaling factors alpha and beta but
449 // just assume these are +=1.
450 // (similar to http://www.netlib.org/lapack/lawnspdf/lawn41.pdf,
451 // "Operations Count for the BLAS and LAPACK", Table 3, SGEMM)
452 at::Tensor left = inputs[1].toTensor();
453 at::Tensor right = inputs[2].toTensor();
454 map[kMat1Size] = at::IValue(left.sizes());
455 map[kMat2Size] = at::IValue(right.sizes());
456 } else if (fname == kMulOp) {
457 bool check = validateInput(fname, 1, inputs, {0});
458 if (!check) {
459 return map;
460 }
461
462 at::Tensor mat = inputs[0].toTensor();
463 map[kMatSize] = at::IValue(mat.sizes());
464 } else if (fname == kAddOp) {
465 bool check = validateInput(fname, 1, inputs, {0});
466 if (!check) {
467 return map;
468 }
469
470 at::Tensor mat = inputs[0].toTensor();
471 map[kMatSize] = at::IValue(mat.sizes());
472 } else if (fname == kBMMOp) {
473 bool check = validateInput(fname, 2, inputs, {0, 1});
474 if (!check) {
475 return map;
476 }
477
478 at::Tensor left = inputs[0].toTensor();
479 at::Tensor right = inputs[1].toTensor();
480 map[kMat1Size] = at::IValue(left.sizes());
481 map[kMat2Size] = at::IValue(right.sizes());
482 } else if (fname == kBAddBMMOp) {
483 bool check = validateInput(fname, 3, inputs, {0, 1, 2});
484 if (!check) {
485 return map;
486 }
487
488 // Exact FLOP count depends on scaling factors alpha and beta but
489 // just assume these are +=1.
490 // (similar to http://www.netlib.org/lapack/lawnspdf/lawn41.pdf,
491 // "Operations Count for the BLAS and LAPACK", Table 3, SGEMM)
492 at::Tensor left = inputs[1].toTensor();
493 at::Tensor right = inputs[2].toTensor();
494 map[kMat1Size] = at::IValue(left.sizes());
495 map[kMat2Size] = at::IValue(right.sizes());
496 }
497
498 return map;
499}
500
501uint64_t computeFlops(
502 const std::string& op_name,
503 const std::unordered_map<std::string, c10::IValue>& extra_args) {
504 if (op_name == kConv2dOp) {
505 if (extra_args.find(kInputSize) == extra_args.end() ||
506 extra_args.find(kWeightSize) == extra_args.end() ||
507 extra_args.find(kGroups) == extra_args.end() ||
508 extra_args.find(kPadding) == extra_args.end() ||
509 extra_args.find(kStride) == extra_args.end() ||
510 extra_args.find(kDilation) == extra_args.end()) {
511 TORCH_WARN(
512 "Calculating flops for aten::conv2d requires groups, padding, stride, dilation, input_size, and weight_size in saved arguments.");
513 return 0;
514 }
515 auto input_sizes_ref = extra_args.at(kInputSize);
516 auto kernel_sizes_ref = extra_args.at(kWeightSize);
517 auto groups_ref = extra_args.at(kGroups);
518 auto padding_ref = extra_args.at(kPadding);
519 auto stride_ref = extra_args.at(kStride);
520 auto dilation_ref = extra_args.at(kDilation);
521 if (!input_sizes_ref.isIntList() || !kernel_sizes_ref.isIntList()) {
522 TORCH_WARN(
523 "Failed to compute flops for op aten::conv2d because it requires input and weight tensor sizes.");
524 return 0;
525 }
526 if (!padding_ref.isIntList() || !stride_ref.isIntList() ||
527 !dilation_ref.isIntList()) {
528 TORCH_WARN(
529 "Failed to compute flops for op aten::conv2d because it requires padding, stride, and dilation values.");
530 return 0;
531 }
532
533 const auto input_sizes = input_sizes_ref.toDimVector();
534 const auto kernel_sizes = kernel_sizes_ref.toDimVector();
535 const uint64_t groups = groups_ref.toInt();
536 const std::vector<int64_t> padding = padding_ref.toIntVector();
537 const std::vector<int64_t> stride = stride_ref.toIntVector();
538 const std::vector<int64_t> dilation = dilation_ref.toIntVector();
539 if (input_sizes.size() != 4 || kernel_sizes.size() != 4) {
540 TORCH_WARN(
541 "Failed to compute flops for op aten::conv2d because both input and weight must be size 4.");
542 return 0;
543 }
544 if (!groups) {
545 TORCH_WARN(
546 "Failed to compute flops for op aten::conv2d because group size must not be 0.");
547 return 0;
548 }
549 if (padding.size() != 2 || dilation.size() != 2) {
550 TORCH_WARN(
551 "Failed to compute flops for op aten::conv2d because both padding and dilation must be size 2.");
552 return 0;
553 }
554 if (stride.size() != 2 || (stride[0] * stride[1] == 0)) {
555 TORCH_WARN(
556 "Failed to compute flops for op aten::conv2d because stride must be size 2 and cannot be 0.");
557 return 0;
558 }
559 // format of the input is defined in
560 // torch.ao.nn.quantized.functional.conv2d()
561 uint64_t minibatch = 0, in_channels = 0, input_h = 0, input_w = 0;
562 uint64_t out_channels = 0, kernel_h = 0, kernel_w = 0;
563 const uint64_t conv2d_multiply_factor = 2;
564 std::tie(minibatch, in_channels, input_h, input_w) = std::make_tuple(
565 input_sizes[0], input_sizes[1], input_sizes[2], input_sizes[3]);
566 std::tie(out_channels, std::ignore, kernel_h, kernel_w) = std::make_tuple(
567 kernel_sizes[0], kernel_sizes[1], kernel_sizes[2], kernel_sizes[3]);
568 uint64_t output_h =
569 (input_h + 2 * padding[0] - dilation[0] * (kernel_h - 1) - 1) /
570 stride[0] +
571 1;
572 uint64_t output_w =
573 (input_w + 2 * padding[1] - dilation[1] * (kernel_w - 1) - 1) /
574 stride[1] +
575 1;
576
577 return conv2d_multiply_factor * minibatch * output_h * output_w * kernel_h *
578 kernel_w * in_channels * out_channels / groups;
579 } else if (op_name == kMMOp || op_name == kAddMMOp) {
580 if (extra_args.find(kMat1Size) == extra_args.end() ||
581 extra_args.find(kMat2Size) == extra_args.end()) {
582 TORCH_WARN(
583 "Calculating flops for ",
584 op_name,
585 " requires mat1_size and mat2_size in saved arguments.");
586 return 0;
587 }
588 auto mat1_sizes_ref = extra_args.at(kMat1Size);
589 auto mat2_sizes_ref = extra_args.at(kMat2Size);
590 if (!mat1_sizes_ref.isIntList() || !mat2_sizes_ref.isIntList()) {
591 TORCH_WARN(
592 "Failed to compute flops for op ",
593 op_name,
594 " because it requires mat1_size and mat2_size to be IntList.");
595 return 0;
596 }
597
598 const auto mat1_size = mat1_sizes_ref.toDimVector();
599 const auto mat2_size = mat2_sizes_ref.toDimVector();
600 if (mat1_size.empty()) {
601 return 0;
602 }
603
604 int64_t overlap_dim = mat1_size.back();
605 if (overlap_dim == 0) {
606 return 0;
607 }
608
609 const uint64_t gemm_multiply_factor = 2;
610 uint64_t flops = 1;
611 for (int64_t dim : mat1_size) {
612 flops *= dim;
613 }
614 flops /= overlap_dim;
615 for (int64_t dim : mat2_size) {
616 flops *= dim;
617 }
618 flops *= gemm_multiply_factor;
619 return flops;
620 } else if (op_name == kBMMOp || op_name == kBAddBMMOp) {
621 if (extra_args.find(kMat1Size) == extra_args.end() ||
622 extra_args.find(kMat2Size) == extra_args.end()) {
623 TORCH_WARN(
624 "Calculating flops for ",
625 op_name,
626 " requires mat1_size and mat2_size in saved arguments.");
627 return 0;
628 }
629 auto mat1_sizes_ref = extra_args.at(kMat1Size);
630 auto mat2_sizes_ref = extra_args.at(kMat2Size);
631 if (!mat1_sizes_ref.isIntList() || !mat2_sizes_ref.isIntList()) {
632 TORCH_WARN(
633 "Failed to compute flops for op ",
634 op_name,
635 " because it requires mat1_size and mat2_size to be IntList.");
636 return 0;
637 }
638
639 const auto mat1_size = mat1_sizes_ref.toDimVector();
640 const auto mat2_size = mat2_sizes_ref.toDimVector();
641 if (mat1_size.empty()) {
642 return 0;
643 }
644
645 int64_t batch_size = mat1_size.front();
646 if (batch_size == 0) {
647 return 0;
648 }
649
650 int64_t overlap_dim = mat1_size.back();
651 if (overlap_dim == 0) {
652 return 0;
653 }
654
655 const uint64_t gemm_multiply_factor = 2;
656 uint64_t flops = 1;
657 for (int64_t dim : mat1_size) {
658 flops *= dim;
659 }
660 flops /= overlap_dim;
661 flops /= batch_size;
662 for (int64_t dim : mat2_size) {
663 flops *= dim;
664 }
665 flops *= gemm_multiply_factor;
666 return flops;
667 } else if (op_name == kMulOp) {
668 if (extra_args.find(kMatSize) == extra_args.end()) {
669 TORCH_WARN(
670 "Calculating flops for aten::mul.Tensor requires mat_size in saved arguments.");
671 return 0;
672 }
673 auto mat_sizes = extra_args.at(kMatSize);
674 if (!mat_sizes.isIntList()) {
675 TORCH_WARN(
676 "Failed to compute flops for op aten::mul because it requires mat_size to be IntList.");
677 return 0;
678 }
679
680 const auto mat_size = mat_sizes.toDimVector();
681 uint64_t flops = 1;
682 for (int64_t dim : mat_size) {
683 flops *= dim;
684 }
685 return flops;
686 } else if (op_name == kAddOp) {
687 if (extra_args.find(kMatSize) == extra_args.end()) {
688 TORCH_WARN(
689 "Calculating flops for aten::add.Tensor requires mat_size in saved arguments.");
690 return 0;
691 }
692 auto mat_sizes = extra_args.at(kMatSize);
693 if (!mat_sizes.isIntList()) {
694 TORCH_WARN(
695 "Failed to compute flops for op aten::add because it requires mat_size to be IntList.");
696 return 0;
697 }
698
699 const auto mat_size = mat_sizes.toDimVector();
700 uint64_t flops = 1;
701 for (int64_t dim : mat_size) {
702 flops *= dim;
703 }
704 return flops;
705 }
706 return 0;
707}
708
709} // namespace impl
710} // namespace profiler
711} // namespace torch
712