咱们直接进入正题!

def train(model, loss1, loss2, train_dataloader, optimizer_loss1, optimizer_loss2, epoch, writer, device_num):
    model.train()
    device = torch.device("cuda:"+str(device_num))
    correct = 0
    value_loss1 = 0
    value_loss2 = 0
    result_loss = 0
    for data_nnl in train_dataloader:
        data, target = data_nnl
        target = target.long()
        if torch.cuda.is_available():
            data = data.to(device)
            target = target.to(device)

        optimizer_loss1.zero_grad()
        optimizer_loss2.zero_grad()
        output = model(data)
        classifier_output = F.log_softmax(output[1], dim=1)
        value_loss1_batch = loss1(classifier_output, target) //第一个损失项
        value_loss2_batch = loss2(output[0], target) //第二个损失项

        weight_loss2 = 0.005

        result_loss_batch = value_loss1_batch + weight_loss2 * value_loss2_batch

        result_loss_batch.backward()
        optimizer_loss1.step()
        for param in loss2.parameters():
            param.grad.data *= (1. / weight_loss2)
        optimizer_loss2.step()

我这里采用的是两项损失,loss1用于优化网络权重,loss2用于优化中心矢量,二者均是可训练的超参,因此包含两个优化器,如果多个损失项均用于优化网络权重,那么只采用一个优化器即可,如下所示

def train(model, loss1, loss2, train_dataloader, optimizer, epoch, writer, device_num):
    model.train()
    device = torch.device("cuda:"+str(device_num))
    correct = 0
    value_loss1 = 0
    value_loss2 = 0
    result_loss = 0
    for data_nnl in train_dataloader:
        data, target = data_nnl
        target = target.long()
        if torch.cuda.is_available():
            data = data.to(device)
            target = target.to(device)

        optimizer.zero_grad()
        output = model(data)
        classifier_output = F.log_softmax(output[1], dim=1)
        value_loss1_batch = loss1(classifier_output, target) //第一个损失项
        value_loss2_batch = loss2(output[0], target) //第二个损失项

        weight_loss2 = 0.005

        result_loss_batch = value_loss1_batch + weight_loss2 * value_loss2_batch

        result_loss_batch.backward()
        optimizer.step()

详细代码,请翻阅我们的论文,代码已开源,开源链接可查论文摘要。

若该经验贴对您科研、学习有所帮助,欢迎您引用我们的论文。

[1] X. Fu et al., "Semi-Supervised Specific Emitter Identification Method Using Metric-Adversarial Training," in IEEE Internet of Things Journal, vol. 10, no. 12, pp. 10778-10789, 15 June15, 2023, doi: 10.1109/JIOT.2023.3240242.

[2] X. Fu et al., "Semi-Supervised Specific Emitter Identification via Dual Consistency Regularization," in IEEE Internet of Things Journal, doi: 10.1109/JIOT.2023.3281668.

Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐