#needs pytorch and matplotlib import matplotlib.pyplot as plt import numpy as np import torch def savehm2(img,hm, outname, q=100): #usage: # hm is your heatmap in shape (1,3,h,w), NOTE: it is expected to have 3 channels, # if it does not have, bcs it is a gray heatmap, then create an axis at dim=0 and repeat the heatmap 3 times along the new dim=0 #see thr line " hm = hm.cpu().squeeze().sum(dim=0).numpy()" on how it gets processed #img is (1,3,h,w) torch tensor containing the image as it comes out from the dataloader with standard processing # see invert_normalize with undoes the standardizatiom #outname is the name with path where the heatmap is saved to . should be a .jpg or png as ending ts=invert_normalize(img.cpu().squeeze()) a=ts.data.numpy().transpose((1, 2, 0)) plt.imshow(a, cmap='gray') hm = hm.cpu().squeeze().sum(dim=0).numpy() clim = np.percentile(np.abs(hm), q) hm = hm / clim #hm = gregoire_black_firered(hm) #axs[1].imshow(hm) plt.imshow(hm, cmap="seismic", clim=(-1, 1),alpha=0.5) plt.axis('off') plt.savefig(outname,bbox_inches='tight') def invert_normalize(ten, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): print(ten.shape) s=torch.tensor(np.asarray(std,dtype=np.float32)).unsqueeze(1).unsqueeze(2) m=torch.tensor(np.asarray(mean,dtype=np.float32)).unsqueeze(1).unsqueeze(2) res=ten*s+m return res