sklearn tsne + matplotlib scatter
import os
import torch
import clip
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.manifold import TSNE
from tqdm import tqdm
from collections import defaultdict
from random import shuffle
X = torch.cat(X, dim=0)
X_emb = TSNE(init="pca", perplexity=30.0).fit_transform(X.cpu())
labels = ["source", "z_star/foggy-1", "z_star/foggy-2", "z_star/foggy-3", "z_bar/foggy-1", "z_bar/foggy-2", "z_bar/foggy-3"]
fig, ax = plt.subplots(1)
group_len = len(zs)
for i in range(7):
ax.scatter(X_emb[i*group_len:(i+1)*group_len, 0], X_emb[i*group_len:(i+1)*group_len, 1], label=labels[i], s=4)
ax.legend()
fig.savefig("temp.png")
댓글
댓글 쓰기