QuadmasterXLII t1_ja7yo0f wrote
Reply to comment by QuadmasterXLII in [D] Training a UNet-like architecture for semantic segmentation with 200 outcome classes. by Scared_Employer6992
Your problem is the U-Net backbone, not the loss function. Assuming that you're married to a batch size of 4, the final convolution to get to 4 x 200 x 500 x 500, crossentropy, and the backpropagation should only take maybe 10 GB, so cram your architecture into the remaining 30GB
import torch
x = torch.randn([4, 128, 500, 500]).cuda()
z = torch.nn.Conv2d(128, 200, 3)
z.cuda()
q = torch.randint(0, 200, (4, 498, 498)).cuda()
torch.nn.CrossEntropyLoss()(z(x), q).backward()
for example, takes 7.5 GB.
Viewing a single comment thread. View all comments