Bert是年google发布的新模型,打破了11项纪录,关于模型基础部分就不在这篇文章里多说了。
这次想和大家一起读的是huggingface的pytorch-pretrained-BERT代码examples里的文本分类任务run_classifier。
关于源代码可以在huggingface的github中找到。
在上一篇文章老周带你读Bert文本分类代码(pytorch篇一)中我介绍完了数据预处理部分。
接上一篇文章,在这篇文章中我会和大家一起读模型部分。继续接着看主函数部分:train_examples=Nonenum_train_steps=Noneifargs.do_train:train_examples=processor.get_train_examples(args.data_dir)num_train_steps=int(len(train_examples)/args.train_batch_size/args.gradient_accumulation_steps*args.num_train_epochs)#Preparemodelmodel=BertForSequenceClassification.from_pretrained(args.bert_model,cache_dir=PYTORCH_PRETRAINED_BERT_CACHE/distributed_{}.format(args.local_rank),num_labels=num_labels)
这段代码先是将训练数据导入train_example中,然后根据训练数据的总数算出需要多少个steps。
我们打开pytorch_pretrained_bert.modeling.py找到BertForSequenceClassification类。我先整理了BertForSequenceClassification类中调用关系,如下图所示。
本篇文章中,我会和大家一起读BertForSequenceClassification类,PreTrainedBertModel类和BertForSequenceClassification类中调用的BertModel的代码。
BertForSequenceClassification类,代码如下:
classBertForSequenceClassification(PreTrainedBertModel):"""参数:config:指定的bert模型的预训练参数num_labels:分类的类别数量输入:input_ids:训练集,torch.LongTensor类型,shape是[batch_size,sequence_length]token_type_ids:可选项,当训练集是两句话时才有的。attention_mask:可选项,当使用mask才有,可参考原论文。labels:数据标签,torch.LongTensor类型,shape是[batch_size]输出:如果labels不是None(训练时):输出的是分类的交叉熵如果labels是None(评价时):输出的是shape为[batch_size,num_labels]估计值"""#AlreadybeenconvertedintoWordPiecetokenidsinput_ids=torch.LongTensor([[31,51,99],[15,5,0]])input_mask=torch.LongTensor([[1,1,1],[1,1,0]])token_type_ids=torch.LongTensor([[0,0,1],[0,1,0]])config=BertConfig(vocab_size_or_config_json_file=,hidden_size=,num_hidden_layers=12,num_attention_heads=12,intermediate_size=)num_labels=2model=BertForSequenceClassification(config,num_labels)logits=model(input_ids,token_type_ids,input_mask)```"""def__init__(self,config,num_labels=2):super(BertForSequenceClassification,self).__init__(config)self.num_labels=num_labelsself.bert=BertModel(config)self.dropout=nn.Dropout(config.hidden_dropout_prob)self.classifier=nn.Linear(config.hidden_size,num_labels)self.apply(self.init_bert_weights)defforward(self,input_ids,token_type_ids=None,attention_mask=None,labels=None):_,pooled_output=self.bert(input_ids,token_type_ids,attention_mask,output_all_encoded_layers=False)pooled_output=self.dropout(pooled_output)logits=self.classifier(pooled_output)iflabelsisnotNone:loss_fct=CrossEntropyLoss()loss=loss_fct(logits.view(-1,self.num_labels),labels.view(-1))returnlosselse:returnlogits
我们从forward函数看,数据先输入BertModel中,然后进行dropout,之后是一个用作分类的Linear层。也就是说分类任务只是在bert的模型基础上加了一个线形层。
我们可以看到上面的BertForSequenceClassification类是继承于PreTrainedBertModel的子类,我们再来看看PreTrainedBertModel类的代码:
classPreTrainedBertModel(nn.Module):"""Anabstractclasstohandleweightsinitializationandasimpleinterfacefordowloadingandloadingpretrainedmodels."""def__init__(self,config,*inputs,**kwargs):super(PreTrainedBertModel,self).__init__()ifnotisinstance(config,BertConfig):raiseValueError("Parameterconfigin`{}(config)`shouldbeaninstanceofclass`BertConfig`.""TocreateamodelfromaGooglepretrainedmodeluse""`model={}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(self.__class__.__name__,self.__class__.__name__))self.config=configdefinit_bert_weights(self,module):"""Initializetheweights."""ifisinstance(module,(nn.Linear,nn.Embedding)):#SlightlydifferentfromtheTFversionwhichusestruncated_normalforinitialization#cf