sklearn tsne + matplotlib scatter

import os
import torch
import clip
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.manifold import TSNE
from tqdm import tqdm
from collections import defaultdict
from random import shuffle

 X = torch.cat(X, dim=0)

X_emb = TSNE(init="pca", perplexity=30.0).fit_transform(X.cpu())

labels = ["source", "z_star/foggy-1", "z_star/foggy-2", "z_star/foggy-3", "z_bar/foggy-1", "z_bar/foggy-2", "z_bar/foggy-3"]
fig, ax = plt.subplots(1)
group_len = len(zs)
for i in range(7):
ax.scatter(X_emb[i*group_len:(i+1)*group_len, 0], X_emb[i*group_len:(i+1)*group_len, 1], label=labels[i], s=4)
ax.legend()
fig.savefig("temp.png")

댓글

이 블로그의 인기 게시물

Implementation of Focal Loss using Pytorch