1// *** Tensor Expressions ***
2//
3// This tutorial covers basics of NNC's tensor expressions, shows basic APIs to
4// work with them, and outlines how they are used in the overall TorchScript
5// compilation pipeline. This doc is permanently a "work in progress" since NNC
6// is under active development and things change fast.
7//
8// This Tutorial's code is compiled in the standard pytorch build, and the
9// executable can be found in `build/bin/tutorial_tensorexpr`.
10//
11// *** What is NNC ***
12//
13// NNC stands for Neural Net Compiler. It is a component of TorchScript JIT
14// and it performs on-the-fly code generation for kernels, which are often a
15// combination of multiple aten (torch) operators.
16//
17// When the JIT interpreter executes a torchscript model, it automatically
18// extracts subgraphs from the torchscript IR graph for which specialized code
19// can be JIT generated. This usually improves performance as the 'combined'
20// kernel created from the subgraph could avoid unnecessary memory traffic that
21// is unavoidable when the subgraph is interpreted as-is, operator by operator.
22// This optimization is often referred to as 'fusion'. Relatedly, the process of
23// finding and extracting subgraphs suitable for NNC code generation is done by
24// a JIT pass called 'fuser'.
25//
26// *** What is TE ***
27//
28// TE stands for Tensor Expressions. TE is a commonly used approach for
29// compiling kernels performing tensor (~matrix) computation. The idea behind it
30// is that operators are represented as a mathematical formula describing what
31// computation they do (as TEs) and then the TE engine can perform mathematical
32// simplification and other optimizations using those formulas and eventually
33// generate executable code that would produce the same results as the original
34// sequence of operators, but more efficiently.
35//
36// NNC's design and implementation of TE was heavily inspired by Halide and TVM
37// projects.
38#include <iostream>
39#include <string>
40
41#include <c10/util/irange.h>
42#include <torch/csrc/jit/ir/ir.h>
43#include <torch/csrc/jit/ir/irparser.h>
44#include <torch/csrc/jit/tensorexpr/eval.h>
45#include <torch/csrc/jit/tensorexpr/expr.h>
46#include <torch/csrc/jit/tensorexpr/ir.h>
47#include <torch/csrc/jit/tensorexpr/ir_printer.h>
48#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
49#include <torch/csrc/jit/tensorexpr/kernel.h>
50#include <torch/csrc/jit/tensorexpr/loopnest.h>
51#include <torch/csrc/jit/tensorexpr/stmt.h>
52#include <torch/csrc/jit/tensorexpr/tensor.h>
53#include <torch/torch.h>
54
55using namespace torch::jit::tensorexpr;
56
57#ifdef TORCH_ENABLE_LLVM
58
59// Helper function to print a snippet from a big multi-line string
60static void printLinesToFrom(const std::string& input_str, int from, int to);
61
62#endif
63
64int main(int argc, char* argv[]) {
65 std::cout << "*** Structure of tensor expressions and statements ***"
66 << std::endl;
67 {
68 // A tensor expression is a tree of expressions. Each expression has a type,
69 // and that type defines what sub-expressions the current expression has.
70 // For instance, an expression of type 'Mul' would have a type 'kMul' and
71 // two subexpressions: LHS and RHS. Each of these two sub-expressions could
72 // also be a 'Mul' or some other expression.
73 //
74 // Let's construct a simple TE:
75 ExprPtr lhs = alloc<IntImm>(5);
76 ExprPtr rhs = alloc<Var>("x", kInt);
77 ExprPtr mul = alloc<Mul>(lhs, rhs);
78 std::cout << "Tensor expression: " << *mul << std::endl;
79 // Prints: Tensor expression: 5 * x
80
81 // Here we created an expression representing a 5*x computation, where x is
82 // an int variable.
83
84 // Another, probably a more convenient, way to construct tensor expressions
85 // is to use so called expression handles (as opposed to raw expressions
86 // like we did in the previous example). Expression handles overload common
87 // operations and allow us to express the same semantics in a more natural
88 // way:
89 ExprHandle l = 5;
90 ExprHandle r = Var::make("x", kInt);
91 ExprHandle m = l * r;
92 std::cout << "Tensor expression: " << *m.node() << std::endl;
93 // Prints: Tensor expression: 5 * x
94
95 // Converting from handles to raw expressions and back is easy:
96 ExprHandle handle = Var::make("x", kInt);
97 ExprPtr raw_expr_from_handle = handle.node();
98 ExprPtr raw_expr = alloc<Var>("x", kInt);
99 ExprHandle handle_from_raw_expr = ExprHandle(raw_expr);
100
101 // We could construct arbitrarily complex expressions using mathematical
102 // and logical operations, casts between various data types, and a bunch of
103 // intrinsics.
104 ExprHandle a = Var::make("a", kInt);
105 ExprHandle b = Var::make("b", kFloat);
106 ExprHandle c = Var::make("c", kFloat);
107 ExprHandle x = ExprHandle(5) * a + b / (sigmoid(c) - 3.0f);
108 std::cout << "Tensor expression: " << *x.node() << std::endl;
109 // Prints: Tensor expression: float(5 * a) + b / ((sigmoid(c)) - 3.f)
110
111 // An ultimate purpose of tensor expressions is to optimize tensor
112 // computations, and in order to represent accesses to tensors data, there
113 // is a special kind of expression - a load.
114 // To construct a load we need two pieces: the base and the indices. The
115 // base of a load is a Buf expression, which could be thought of as a
116 // placeholder similar to Var, but with dimensions info.
117 //
118 // Let's construct a simple load:
119 BufHandle A("A", {64, 32}, kInt);
120 VarPtr i_var = alloc<Var>("i", kInt), j_var = alloc<Var>("j", kInt);
121 ExprHandle i(i_var), j(j_var);
122 ExprHandle load = Load::make(A.dtype(), A, {i, j});
123 std::cout << "Tensor expression: " << *load.node() << std::endl;
124 // Prints: Tensor expression: A[i, j]
125
126 // Tensor Expressions constitute Tensor Statements, which are used to
127 // represent computation of a given operator or a group of operators from a
128 // fusion group.
129 //
130 // There are three main kinds of tensor statements:
131 // - block
132 // - store
133 // - loop
134 //
135 // A Store represents a store to a single element of a tensor (or to a
136 // group of elements if it's a vectorized store). Store statements,
137 // similarly to Load expressions, have a base and indices, but on top of
138 // that they also include a value - an expression representing what needs
139 // to be stored at the given memory location. Let's create a Store stmt:
140 StmtPtr store_a = Store::make(A, {i, j}, i + j);
141 std::cout << "Store statement: " << *store_a << std::endl;
142 // Prints: Store statement: A[i, j] = i + j;
143
144 // An operator fills the entire tensor, not just a single element, and to
145 // represent this we need to use For stmt: let's wrap our store stmt with
146 // two nested loops to represent that variables i and j need to iterate
147 // over some ranges.
148 ForPtr loop_j_a = For::make(VarHandle(j_var), 0, 32, store_a);
149 ForPtr loop_i_a = For::make(VarHandle(i_var), 0, 64, loop_j_a);
150
151 std::cout << "Nested for loops: " << std::endl << *loop_i_a << std::endl;
152 // Prints:
153 // Nested for loops:
154 // for (const auto i : c10::irange(64)) {
155 // for (const auto j : c10::irange(32)) {
156 // A[i, j] = i + j;
157 // }
158 // }
159
160 // A Block statement is used when we need a sequence of other statements.
161 // E.g. if a fusion group contains several operators, we initially define
162 // separate loopnest for each of them and put them all into a common block:
163 BufHandle B("B", {64, 32}, kInt);
164 StmtPtr store_b = Store::make(B, {i, j}, A.load(i, j));
165 ForPtr loop_j_b = For::make(VarHandle(j_var), 0, 32, store_b);
166 ForPtr loop_i_b = For::make(VarHandle(i_var), 0, 64, loop_j_b);
167
168 BlockPtr block = Block::make({loop_i_a, loop_i_b});
169 std::cout << "Compound Block statement: " << std::endl
170 << *block << std::endl;
171 // Prints:
172 // Compound Block statement:
173 // {
174 // for (const auto i : c10::irange(64)) {
175 // for (const auto j : c10::irange(32)) {
176 // A[i, j] = i + j;
177 // }
178 // }
179 // for (const auto i : c10::irange(64)) {
180 // for (const auto j : c10::irange(32)) {
181 // B[i, j] = A[i, j];
182 // }
183 // }
184 // }
185
186 // Manually constructing nested loops and blocks to represent a computation
187 // might be laborious, and instead we can use a 'Compute' API. This API
188 // requires us to specify dimensions and a lambda to compute a single
189 // element of the resulting tensor and returns a `Tensor` structure. This
190 // structure is simply a pair of a buffer that was created to represent the
191 // result of the computation (BufPtr) and a statement representing the
192 // computation itself (StmtPtr).
193 Tensor C =
194 Compute("C", {64, 32}, [&](const VarHandle& i, const VarHandle& j) {
195 return i * j;
196 });
197 std::cout << "Stmt produced by 'Compute' API: " << std::endl
198 << *C.stmt() << std::endl;
199 // Prints:
200 // Stmt produced by 'Compute' API:
201 // for (const auto i : c10::irange(64)) {
202 // for (const auto j : c10::irange(32)) {
203 // C[i, j] = i * j;
204 // }
205 // }
206
207 // To construct statements to represent computations with reductions, we
208 // can use a 'Reduce' API - it is similar to 'Compute' but takes a couple
209 // of extra arguments defining how to perform the reduction. Let's define a
210 // simple 2D sum of C using that:
211 Tensor D = Reduce(
212 "D",
213 {},
214 Sum(),
215 [&](const VarHandle& i, const VarHandle& j) { return C.load(i, j); },
216 {64, 32});
217 std::cout << "Stmt produced by 'Reduce' API: " << std::endl
218 << *D.stmt() << std::endl;
219 }
220
221 std::cout << "*** Loopnests transformations ***" << std::endl;
222 {
223 // When a statement for the computation is generated, we might want to
224 // apply some optimizations to it. These transformations allow us to end up
225 // with a statement producing the same results, but more efficiently.
226 //
227 // Let's look at a couple of transformations that are used in NNC. We will
228 // begin with constructing a Block statement like we did before.
229
230 Tensor C =
231 Compute("C", {64, 32}, [&](const VarHandle& i, const VarHandle& j) {
232 return i * (j + 1);
233 });
234 BufHandle c_buf(C.buf());
235 Tensor D =
236 Compute("D", {64, 32}, [&](const VarHandle& i, const VarHandle& j) {
237 return c_buf.load(i, j) - i;
238 });
239 StmtPtr block = Block::make({C.stmt(), D.stmt()});
240 std::cout << "Stmt produced by 'Compute' API: " << std::endl
241 << *block << std::endl;
242 // Prints:
243 // Stmt produced by 'Compute' API:
244 // {
245 // for (const auto i : c10::irange(64)) {
246 // for (const auto j : c10::irange(32)) {
247 // C[i, j] = i * (j + 1);
248 // }
249 // }
250 // for (const auto i_1 : c10::irange(64)) {
251 // for (const auto j_1 : c10::irange(32)) {
252 // D[i_1, j_1] = (C[i_1, j_1]) - i_1;
253 // }
254 // }
255 // }
256
257 // One transformation we can apply to this computation is inlining: i.e.
258 // taking the expression that defines values of C and substituting a load
259 // from C with it.
260 // To do that, we first need to create a special object called LoopNest -
261 // all transformations are methods of this class. To create a loopnest we
262 // need to provide a list of output buffers and the root statement:
263 LoopNest nest(block, {D.buf()});
264
265 // We can always retrieve the Stmt back from LoopNest:
266 std::cout << "LoopNest root stmt: " << std::endl
267 << *nest.root_stmt() << std::endl;
268 // Prints:
269 // LoopNest root stmt:
270 // {
271 // for (const auto i : c10::irange(64)) {
272 // for (const auto j : c10::irange(32)) {
273 // C[i, j] = i * (j + 1);
274 // }
275 // }
276 // for (const auto i_1 : c10::irange(64)) {
277 // for (const auto j_1 : c10::irange(32)) {
278 // D[i_1, j_1] = (C[i_1, j_1]) - i_1;
279 // }
280 // }
281 // }
282
283 // Now we can apply the inlining transformation:
284 nest.computeInline(C.buf());
285 std::cout << "Stmt after inlining:" << std::endl
286 << *nest.root_stmt() << std::endl;
287 // Prints:
288 // Stmt after inlining:
289 // {
290 // for (const auto i : c10::irange(64)) {
291 // for (const auto j : c10::irange(32)) {
292 // D[i, j] = i * (j + 1) - i;
293 // }
294 // }
295 // }
296
297 // We can also apply algebraic simplification to a statement:
298 StmtPtr simplified = IRSimplifier::simplify(nest.root_stmt());
299 std::cout << "Stmt after simplification:" << std::endl
300 << *simplified << std::endl;
301 // Prints:
302 // Stmt after simplification:
303 // {
304 // for (const auto i : c10::irange(64)) {
305 // for (const auto j : c10::irange(32)) {
306 // D[i, j] = i * j;
307 // }
308 // }
309 // }
310
311 // Many loopnest transformations are stateless and can be applied without
312 // creating a LoopNest object. In fact, we plan to make all transformations
313 // stateless.
314 // splitWithTail is one such transformation: it splits an iteration space
315 // of a given loop into two with a given factor.
316 ForPtr outer_loop = to<For>(to<Block>(simplified)->stmts().front());
317 LoopNest::splitWithTail(outer_loop, 13);
318 // Call simplifier once more to fold some arithmetic.
319 simplified = IRSimplifier::simplify(simplified);
320 std::cout << "Stmt after splitWithTail:" << std::endl
321 << *simplified << std::endl;
322 // Prints:
323 // Stmt after splitWithTail:
324 // {
325 // for (const auto i_outer : c10::irange(4)) {
326 // for (const auto i_inner : c10::irange(13)) {
327 // for (const auto j : c10::irange(32)) {
328 // D[i_inner + 13 * i_outer, j] = i_inner * j + 13 * (i_outer * j);
329 // }
330 // }
331 // }
332 // for (const auto i_tail : c10::irange(12)) {
333 // for (const auto j : c10::irange(32)) {
334 // D[i_tail + 52, j] = i_tail * j + 52 * j;
335 // }
336 // }
337 // }
338
339 // NNC supports a wide range of loop nest transformations, which we are not
340 // listing here. Please refer to documentation in
341 // https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/tensorexpr/loopnest.h
342 // for more details.
343 }
344
345 std::cout << "*** Codegen ***" << std::endl;
346 {
347 // An ultimate goal of tensor expressions is to be provide a mechanism to
348 // execute a given computation in the fastest possible way. So far we've
349 // looked at how we could describe what computation we're interested in, but
350 // we haven't looked at how to actually execute it.
351 //
352 // All we've been dealing with was just symbols with no actual data
353 // associated, in this section we would look at how we can bridge that gap.
354
355 // Let's start by constructing a simple computation for us to work with:
356 BufHandle A("A", {64, 32}, kInt);
357 BufHandle B("B", {64, 32}, kInt);
358 Tensor X =
359 Compute("X", {64, 32}, [&](const VarHandle& i, const VarHandle& j) {
360 return A.load(i, j) + B.load(i, j);
361 });
362
363 // And let's lower it to a loop nest, as we did in the previous section. We
364 // can pass Tensor object directly:
365 LoopNest loopnest({X});
366 std::cout << *loopnest.root_stmt() << std::endl;
367 // Prints:
368 // {
369 // for (const auto i : c10::irange(64)) {
370 // for (const auto j : c10::irange(32)) {
371 // X[i, j] = (A[i, j]) + (B[i, j]);
372 // }
373 // }
374
375 // Now imagine that we have two actual tensors 64x32 that we want sum
376 // together, how do we pass those tensors to the computation and how do we
377 // carry it out?
378 //
379 // Codegen object is aimed at providing exactly that functionality. Codegen
380 // is an abstract class and concrete codegens are derived from it.
381 // Currently, we have three codegens:
382 // 1) Simple Evaluator,
383 // 2) LLVM Codegen for CPU,
384 // 3) CUDA Codegen.
385 // In this example we will be using Simple Evaluator, since it's available
386 // everywhere.
387
388 // To create a codegen, we need to provide the statement - it specifies the
389 // computation we want to perform - and a list of placeholders and tensors
390 // used in the computation. The latter part is crucial since that's the only
391 // way the codegen could use to correlate symbols in the statement to actual
392 // data arrays that we will be passing when we will actually be performing
393 // the computation.
394 //
395 // Let's create a Simple IR Evaluator codegen for our computation:
396 SimpleIREvaluator ir_eval(loopnest.root_stmt(), {A, B, X});
397
398 // We are using the simplest codegen and in it almost no work is done at the
399 // construction step. Real codegens such as CUDA and LLVM perform
400 // compilation during that stage so that when we're about to run the
401 // computation everything is ready.
402
403 // Let's now create some inputs and run our computation with them:
404 std::vector<int> data_A(64 * 32, 3); // This will be the input A
405 std::vector<int> data_B(64 * 32, 5); // This will be the input B
406 std::vector<int> data_X(64 * 32, 0); // This will be used for the result
407
408 // Now let's invoke our codegen to perform the computation on our data. We
409 // need to provide as many arguments as how many placeholders and tensors we
410 // passed at the codegen construction time. A position in these lists would
411 // define how real data arrays from the latter call (these arguments are
412 // referred to as 'CallArg's in our codebase) correspond to symbols
413 // (placeholders and tensors) used in the tensor expressions we constructed
414 // (these are referred to as 'BufferArg').
415 // Thus, we will provide three arguments: data_A, data_B, and data_X. data_A
416 // contains data for the placeholder A, data_B - for the placeholder B, and
417 // data_X would be used for contents of tensor X.
418 ir_eval(data_A, data_B, data_X);
419
420 // Let's print one of the elements from each array to verify that the
421 // computation did happen:
422 std::cout << "A[10] = " << data_A[10] << std::endl
423 << "B[10] = " << data_B[10] << std::endl
424 << "X[10] = A[10] + B[10] = " << data_X[10] << std::endl;
425 // Prints:
426 // A[10] = 3
427 // B[10] = 5
428 // X[10] = A[10] + B[10] = 8
429 }
430
431 std::cout << "*** Lowering TorchScript IR to TensorExpr IR ***" << std::endl;
432 {
433 // This section requires a LLVM-enabled PyTorch build, so we have to use a
434 // guard:
435#ifdef TORCH_ENABLE_LLVM
436
437 // Often we would like to convert a TorchScript IR to TE rather than
438 // construct TE IR from scratch. NNC provides an API to perform such
439 // lowering: it takes a TorchScript graph and returns an object that can be
440 // used to invoke the generated kernel.
441 // This API is currently used by the TorchScript JIT fuser and can also be
442 // used ahead of time to pre-compile parts of a model.
443 //
444 // To get familiar with this API let's first start with defining a simple
445 // TorchScript graph:
446 const auto graph_string = R"IR(
447 graph(%A : Float(5, 3, strides=[3, 1], device=cpu),
448 %B : Float(5, 3, strides=[3, 1], device=cpu)):
449 %AB : Float(5, 3, strides=[3, 1]) = aten::mul(%A, %B)
450 %one : int = prim::Constant[value=1]()
451 %AAB : Float(5, 3, strides=[3, 1]) = aten::mul(%A, %AB)
452 %AAB_plus_B: Float(5, 3, strides=[3, 1]) = aten::add(%AAB, %B, %one)
453 return (%AAB_plus_B))IR";
454 auto graph = std::make_shared<torch::jit::Graph>();
455 parseIR(graph_string, &*graph);
456
457 // This graph defines a simple computation of A*A*B + B where A and B are
458 // input 5x3 tensors.
459
460 // To lower this TorchScript graph to TE, we just need to create a
461 // TensorExprKernel object. In its constructor it constructs the
462 // corresponding TE IR and compiles it for the given backend (in this
463 // example for CPU using LLVM compiler).
464 TensorExprKernel kernel(graph);
465
466 // We can retrieve the generated TE stmt from the kernel object:
467 StmtPtr kernel_stmt = kernel.getCodeGenStmt();
468 std::cout << "TE Stmt constructed from TorchScript: " << std::endl
469 << *kernel_stmt << std::endl;
470 // Prints:
471 // TE Stmt constructed from TorchScript:
472 // {
473 // for (const auto v : c10::irange(5)) {
474 // for (const auto _tail_tail : c10::irange(3)) {
475 // aten_add[_tail_tail + 3 * v] = (tA[_tail_tail + 3 * v]) *
476 // ((tA[_tail_tail + 3 * v]) * (tB[_tail_tail + 3 * v])) +
477 // (tB[_tail_tail + 3 * v]);
478 // }
479 // }
480 // }
481
482 // We can also examine generated LLVM IR and assembly code:
483 std::cout << "Generated LLVM IR: " << std::endl;
484 auto ir_str = kernel.getCodeText("ir");
485 printLinesToFrom(ir_str, 15, 20);
486 // Prints:
487 // Generated LLVM IR:
488 // %9 = bitcast float* %2 to <8 x float>*
489 // %10 = load <8 x float>, <8 x float>* %9 ...
490 // %11 = bitcast float* %5 to <8 x float>*
491 // %12 = load <8 x float>, <8 x float>* %11 ...
492 // %13 = fmul <8 x float> %10, %12
493 // %14 = fmul <8 x float> %10, %13
494
495 std::cout << "Generated assembly: " << std::endl;
496 auto asm_str = kernel.getCodeText("asm");
497 printLinesToFrom(asm_str, 10, 15);
498 // Prints:
499 // Generated assembly:
500 // vmulps %ymm1, %ymm0, %ymm2
501 // vfmadd213ps %ymm1, %ymm0, %ymm2
502 // vmovups %ymm2, (%rax)
503 // vmovss 32(%rcx), %xmm0
504 // vmovss 32(%rdx), %xmm1
505 // vmulss %xmm1, %xmm0, %xmm2
506
507 // We can also execute the generated kernel:
508 auto A =
509 at::ones({5, 3}, torch::TensorOptions(torch::kCPU).dtype(at::kFloat)) *
510 2.0;
511 auto B =
512 at::ones({5, 3}, torch::TensorOptions(torch::kCPU).dtype(at::kFloat)) *
513 3.0;
514 std::vector<at::Tensor> inputs = {A, B};
515 std::vector<torch::IValue> stack = torch::fmap<torch::IValue>(inputs);
516 kernel.run(stack);
517 auto R = stack[0].toTensor();
518
519 // Let's print one of the elements from the result tensor to verify that the
520 // computation did happen and was correct:
521 std::cout << "R[2][2] = " << R[2][2] << std::endl;
522 // Prints:
523 // R[2][2] = 15
524 // [ CPUFloatType{} ]
525#endif
526 }
527 return 0;
528}
529
530void printLinesToFrom(const std::string& input_str, int from, int to) {
531 std::istringstream f(input_str);
532 std::string s;
533 int idx = 0;
534 while (getline(f, s)) {
535 if (idx > from) {
536 std::cout << s << "\n";
537 }
538 if (idx++ > to) {
539 break;
540 }
541 }
542}
543