vllm.compilation.passes.utility.scatter_split_replace ¶
Replace slice_scatter and split_with_sizes nodes with a single assignment if there are no users for the inplace tensor written to by the slice_scatter call.
The inplace rotary_embedding custom op takes in mutable query and key inputs that are split+getitem outputs of a single qkv tensor. When functionalized, we fetch the rotated query and key from the functionalized op using getitem calls. However, we also write to the qkv tensor inplace using a slice_scatter, then split the inplace tensor to get the output tensors again. Instead, if the inplace tensor has no subsequent users, we can just replace the slice_scatter and split_with_sizes nodes with the getitem calls.
This is already done in fix_functionalization::FixFunctionalizationPass, but writing a custom pass for it before defunctionalization allows matching against the qkv split+rotary_embedding subpattern as part of e.g. the RoPE+KVCache fusion pass.
ScatterSplitReplacementPass ¶
Bases: VllmInductorPass
Replace getitem+slice_scatter+split nodes with a single getitem when the inplace subtensor written to by the slice_scatter has no other users.
Here's an example graph with q_size = 512, kv_size = 64: split_with_sizes_1 = torch.ops.aten.split_with_sizes.default(qkv, (512, 64, 64), -1) at = auto_functionalized(torch.ops._C.rotary_embedding.default(positions, q, k)) q = operator.getitem(at, 1) k = operator.getitem(at, 2) torch.ops.aten.slice_scatter.default(qkv, q, [0, 512], -1) torch.ops.aten.slice_scatter.default(qkv, k, [512, 512 + 64], -1) split_with_sizes_2 = torch.ops.aten.split_with_sizes.default(qkv, (512, 64, 64), -1) q = operator.getitem(split_with_sizes_2, 0) k = operator.getitem(split_with_sizes_2, 1) v = operator.getitem(split_with_sizes_2, 2)
After this pass, this sequence of nodes is replaced with: split_with_sizes_1 = torch.ops.aten.split_with_sizes.default(qkv, (512, 64, 64), -1) at = auto_functionalized(torch.ops._C.rotary_embedding.default(positions, q, k)) q = operator.getitem(at, 1) k = operator.getitem(at, 2) v = operator.getitem(split_with_sizes_1, 2)
Source code in vllm/compilation/passes/utility/scatter_split_replace.py
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 | |