load bert pretrained weight to detr encoder

# Example python script of loading BERT-base model to DETR



# Create DETR

import argparse

from models import build_model

from util.default_args import get_args_parser

parser = argparse.ArgumentParser(parents=[get_args_parser()])

args = parser.parse_known_args()[0]

args.model = "detr"

args.hidden_dim = 768

args.dim_feedforward = 3072

args.lr = 1e-4

args.lr_backbone = 1e-5

args.num_queries = 100

args.enc_layers = 12

args.nheads = 12

model, criterion, postprocessors = build_model(args)


# Load BERT

import torch

bert = torch.hub.load('huggingface/pytorch-transformers', 'model', 'bert-base-uncased')

bert_enc = bert.encoder.state_dict()


# Convert keys

dict_bert2detr = {}

for i in range(args.enc_layers):

    key = "layers.{}.self_attn.in_proj_weight".format(i)

    dict_bert2detr[key] = torch.cat([bert_enc["layer.{}.attention.self.query.weight".format(i)], [bert_enc["layer.{}.attention.self.key.weight".format(i)], [bert_enc["layer.{}.attention.self.value.weight".format(i)]], dim=0)

    key = "layers.{}.self_attn.in_proj_bias".format(i)

    dict_bert2detr[key] = torch.cat([bert_enc["layer.{}.attention.self.query.bias".format(i)], [bert_enc["layer.{}.attention.self.key.bias".format(i)], [bert_enc["layer.{}.attention.self.value.bias".format(i)]], dim=0)

    key = "layers.{}.self_attn.out_proj.weight".format(i)

    dict_bert2detr[key] = bert_enc["layer.{}.attention.output.dense.weight".format(i)]

    key = "layers.{}.self_attn.out_proj.bias".format(i)

    dict_bert2detr[key] = bert_enc["layer.{}.attention.output.dense.bias".format(i)]

    key = "layers.{}.linear1.weight".format(i)

    dict_bert2detr[key] = bert_enc["layer.{}.intermediate.dense.weight".format(i)]

    key = "layers.{}.linear1.bias".format(i)

    dict_bert2detr[key] = bert_enc["layer.{}.intermediate.dense.bias".format(i)]

    key = "layers.{}.linear2.weight".format(i)

    dict_bert2detr[key] = bert_enc["layer.{}.output.dense.weight".format(i)]

    key = "layers.{}.linear2.bias".format(i)

    dict_bert2detr[key] = bert_enc["layer.{}.output.dense.bias".format(i)]

    key = "layers.{}.norm1.weight".format(i)

    dict_bert2detr[key] = bert_enc["layer.{}.attention.output.LayerNorm.weight".format(i)]

    key = "layers.{}.norm1.bias".format(i)

    dict_bert2detr[key] = bert_enc["layer.{}.attention.output.LayerNorm.bias".format(i)]

    key = "layers.{}.norm2.weight".format(i)

    dict_bert2detr[key] = bert_enc["layer.{}.output.LayerNorm.weight".format(i)]

    key = "layers.{}.norm2.bias".format(i)

    dict_bert2detr[key] = bert_enc["layer.{}.output.LayerNorm.bias".format(i)]


# Load BERT weights into DETR

model.transformer.encoder.load_state_dict(dict_bert2detr)

댓글

이 블로그의 인기 게시물

sklearn tsne + matplotlib scatter

Implementation of Focal Loss using Pytorch