load bert pretrained weight to detr encoder
# Example python script of loading BERT-base model to DETR # Create DETR import argparse from models import build_model from util.default_args import get_args_parser parser = argparse.ArgumentParser(parents=[get_args_parser()]) args = parser.parse_known_args()[0] args.model = "detr" args.hidden_dim = 768 args.dim_feedforward = 3072 args.lr = 1e-4 args.lr_backbone = 1e-5 args.num_queries = 100 args.enc_layers = 12 args.nheads = 12 model, criterion, postprocessors = build_model(args) # Load BERT import torch bert = torch.hub.load('huggingface/pytorch-transformers', 'model', 'bert-base-uncased') bert_enc = bert.encoder.state_dict() # Convert keys dict_bert2detr = {} for i in range(args.enc_layers): key = "layers.{}.self_attn.in_proj_weight".format(i) dict_bert2detr[key] = torch.cat([bert_enc["layer.{}.attention.self.query.weight".format(i)], [bert_enc["layer.{}.attention.self.key.weight".format(i)], [bert_enc["la...