from transformers import DistilBertForSequenceClassification, Trainer, TrainingArguments

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }


# 加入evaluation_strategy
training_args = TrainingArguments(
    output_dir='./classification_results',          # output directory
    num_train_epochs=5,              # total number of training epochs
    per_device_train_batch_size=16,  # batch size per device during training
    per_device_eval_batch_size=64,   # batch size for evaluation
    warmup_steps=5000,                # number of warmup steps for learning rate scheduler
    weight_decay=0.01,               # strength of weight decay
    logging_dir='./classification_logs',            # directory for storing logs
    evaluation_strategy='steps',      # "no": No evaluation is done during training.
                                      # "steps": Evaluation is done (and logged) every steps
                                      # "epoch": Evaluation is done at the end of each epoch.
    logging_steps=100,
)

model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased")

# 加入compute_metrics,并定义compute_metrics函数
trainer = Trainer(
    model=model,                         # the instantiated 🤗 Transformers model to be trained
    args=training_args,                  # training arguments, defined above
    train_dataset=train_dataset,         # training dataset
    eval_dataset=val_dataset,             # evaluation dataset
    compute_metrics=compute_metrics
)

trainer.train()
trainer.evaluate()

Logo

腾讯云面向开发者汇聚海量精品云计算使用和开发经验,营造开放的云计算技术生态圈。

更多推荐