import os import torch import clip import matplotlib . pyplot as plt from PIL import Image from sklearn . manifold import TSNE from tqdm import tqdm from collections import defaultdict from random import shuffle X = torch . cat ( X , dim = 0 ) X_emb = TSNE ( init = "pca" , perplexity = 30.0 ). fit_transform ( X . cpu ()) labels = [ "source" , "z_star/foggy-1" , "z_star/foggy-2" , "z_star/foggy-3" , "z_bar/foggy-1" , "z_bar/foggy-2" , "z_bar/foggy-3" ] fig , ax = plt . subplots ( 1 ) group_len = len ( zs ) for i in range ( 7 ): ax . scatter ( X_emb [ i * group_len :( i + 1 )* group_len , 0 ], X_emb [ i * group_len :( i + 1 )* group_len , 1 ], label = labels [ i ], s = 4 ) ax . legend () fig .savefig( "temp.png" )
댓글
댓글 쓰기