Submitted by super_deap t3_11tmpc5 in MachineLearning
Dependent_Ad5120 t1_je5qfmp wrote
Reply to comment by oathbreakerkeeper in [D] PyTorch 2.0 Native Flash Attention 32k Context Window by super_deap
I don't have a github repo for this, but it is pretty simple:
```
model = nn.Transformer().cuda().half
input = torch.rand(..).cuda().half
with sdp_kernel(...enable only flash attn):
output = model(input)
```
These 4 lines should be enough.
Viewing a single comment thread. View all comments