```python
import torch
import math
import argparse
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm

parser = argparse.ArgumentParser(description='LR_Decay Figure')
parser.add_argument('--epochs', default=90, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('--init_lr', '--init_learning-rate', default=0.1, type=float,
                    metavar='LR', help='initial (base) learning rate', dest='init_lr')
parser.add_argument('--lr_decay_rate', '--Exp-learning-rate-decay-rate', default=0.9, type=float,
                    metavar='LRDR', help='learning rate decay rate for Exp-decay', dest='lr_decay_rate')
parser.add_argument('--lr_decay_step', '--Exp-learning-rate-decay-step', default=10, type=float,
                    metavar='LRDS', help='initial learning rate decay step', dest='lr_decay_step')

args = parser.parse_args()


def adjust_learning_rate_cosine(epoch, args):
    """Decay the learning rate based on schedule"""
    cur_lr = args.init_lr * 0.5 * (1. + math.cos(math.pi * epoch / args.epochs))
    return cur_lr

def adjust_learning_rate_exp(epoch, args):
    """Decay the learning rate based on schedule"""
    cur_lr = args.init_lr * args.lr_decay_rate ** (epoch / args.lr_decay_step)
    return cur_lr

def main():
    args = parser.parse_args()
    x_data = {}
    LR_cosine = {}
    LR_exp = {}
    print(args.epochs)
    for epoch in range(args.epochs):
        print(epoch)
        x_data[epoch] = str(epoch)
        print(x_data[epoch])
        LR_cosine[epoch] = str(adjust_learning_rate_cosine(epoch, args))
        LR_exp[epoch] = str(adjust_learning_rate_exp(epoch, args))

    x_data = list(x_data.values())
    LR_cosine = list(LR_cosine.values())
    LR_exp = list(LR_exp.values())

    ln1, = plt.plot(x_data,LR_cosine,color='red',linewidth=2.0,linestyle='--')
    ln2, = plt.plot(x_data,LR_exp,color='blue',linewidth=3.0,linestyle='-.')

    plt.title("Comparison of different LR_decay methods.")
    plt.legend(handles=[ln1, ln2], labels=['LR_cosine_decay', 'LR_exp_decay'])
    ax = plt.gca()
    ax.spines['right'].set_color('none')  # right边框属性设置为none 不显示
    ax.spines['top'].set_color('none')  # top边框属性设置为none 不显示
    plt.show()
    print(1)

plt.show()
if __name__ == '__main__':
    main()

Logo

CSDN联合极客时间,共同打造面向开发者的精品内容学习社区,助力成长!

更多推荐