load bert pretrained weight to detr encoder
# Example python script of loading BERT-base model to DETR
# Create DETR
import argparse
from models import build_model
from util.default_args import get_args_parser
parser = argparse.ArgumentParser(parents=[get_args_parser()])
args = parser.parse_known_args()[0]
args.model = "detr"
args.hidden_dim = 768
args.dim_feedforward = 3072
args.lr = 1e-4
args.lr_backbone = 1e-5
args.num_queries = 100
args.enc_layers = 12
args.nheads = 12
model, criterion, postprocessors = build_model(args)
# Load BERT
import torch
bert = torch.hub.load('huggingface/pytorch-transformers', 'model', 'bert-base-uncased')
bert_enc = bert.encoder.state_dict()
# Convert keys
dict_bert2detr = {}
for i in range(args.enc_layers):
key = "layers.{}.self_attn.in_proj_weight".format(i)
dict_bert2detr[key] = torch.cat([bert_enc["layer.{}.attention.self.query.weight".format(i)], [bert_enc["layer.{}.attention.self.key.weight".format(i)], [bert_enc["layer.{}.attention.self.value.weight".format(i)]], dim=0)
key = "layers.{}.self_attn.in_proj_bias".format(i)
dict_bert2detr[key] = torch.cat([bert_enc["layer.{}.attention.self.query.bias".format(i)], [bert_enc["layer.{}.attention.self.key.bias".format(i)], [bert_enc["layer.{}.attention.self.value.bias".format(i)]], dim=0)
key = "layers.{}.self_attn.out_proj.weight".format(i)
dict_bert2detr[key] = bert_enc["layer.{}.attention.output.dense.weight".format(i)]
key = "layers.{}.self_attn.out_proj.bias".format(i)
dict_bert2detr[key] = bert_enc["layer.{}.attention.output.dense.bias".format(i)]
key = "layers.{}.linear1.weight".format(i)
dict_bert2detr[key] = bert_enc["layer.{}.intermediate.dense.weight".format(i)]
key = "layers.{}.linear1.bias".format(i)
dict_bert2detr[key] = bert_enc["layer.{}.intermediate.dense.bias".format(i)]
key = "layers.{}.linear2.weight".format(i)
dict_bert2detr[key] = bert_enc["layer.{}.output.dense.weight".format(i)]
key = "layers.{}.linear2.bias".format(i)
dict_bert2detr[key] = bert_enc["layer.{}.output.dense.bias".format(i)]
key = "layers.{}.norm1.weight".format(i)
dict_bert2detr[key] = bert_enc["layer.{}.attention.output.LayerNorm.weight".format(i)]
key = "layers.{}.norm1.bias".format(i)
dict_bert2detr[key] = bert_enc["layer.{}.attention.output.LayerNorm.bias".format(i)]
key = "layers.{}.norm2.weight".format(i)
dict_bert2detr[key] = bert_enc["layer.{}.output.LayerNorm.weight".format(i)]
key = "layers.{}.norm2.bias".format(i)
dict_bert2detr[key] = bert_enc["layer.{}.output.LayerNorm.bias".format(i)]
# Load BERT weights into DETR
model.transformer.encoder.load_state_dict(dict_bert2detr)
댓글
댓글 쓰기