源代码地址: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/tutorials/mnist


mnist.py

主要功能: 构建一个完全连接(fully connected)的MINST模型

推理 interface()

def inference(images, hidden1_units, hidden2_units):
  """
  功能: Build the MNIST model up to where it may be used for inference
  返回: 包含了预测结果的Tensor --- logits
  输入: 
      图像占位符
      第一个和第二个隐层的大小
  实现方式: 
      借助ReLu(Rectified Linear Units)激活函数,
      构建一对完全连接层(layers), 以及一个有着十个节点(node), 指明了输出logtis模型的线性层.
  """
  # Hidden 1 (隐层1)
  with tf.name_scope('hidden1'):
    # 每个变量在构建时, 都会获得初始化操作
    weights = tf.Variable(
        tf.truncated_normal([IMAGE_PIXELS, hidden1_units],
                            stddev=1.0 / math.sqrt(float(IMAGE_PIXELS))),
        name='weights')
        # 权重变量的名称其实是"hidden1/weights"
        # tf.truncated_normal()用来初始化weights: 根据所得到的均值和标准差, 生成一个随机分布
            # shape是一个二维的tensor: [IMAGE_PIXELS, hidden1_units]
                # 第一个维度代表该层中权重变量所连接(connect from)的单元数量
                # 第二个维度代表该层中权重变量所连接到的(connect to)单元数量
    biases = tf.Variable(tf.zeros([hidden1_units]),
                         name='biases')
        # tf.zero()用来初始化biases
            # shape: [hidden1_units], 指的是该层中所接到的(connect to)单元数量
    hidden1 = tf.nn.relu(tf.matmul(images, weights) + biases)
        # tf.nn.relu()是个激活函数
  # Hidden 2 (隐层2)
  with tf.name_scope('hidden2'):
    weights = tf.Variable(
        tf.truncated_normal([hidden1_units, hidden2_units],
                            stddev=1.0 / math.sqrt(float(hidden1_units))),
        name='weights')
    biases = tf.Variable(tf.zeros([hidden2_units]),
                         name='biases')
    hidden2 = tf.nn.relu(tf.matmul(hidden1, weights) + biases)
  # Linear
  with tf.name_scope('softmax_linear'):
    weights = tf.Variable(
        tf.truncated_normal([hidden2_units, NUM_CLASSES],
                            stddev=1.0 / math.sqrt(float(hidden2_units))),
        name='weights')
    biases = tf.Variable(tf.zeros([NUM_CLASSES]),
                         name='biases')
    logits = tf.matmul(hidden2, weights) + biases
        # 生成预测结果
  return logits

损失 lose()

def loss(logits, labels):
  """
  功能: 计算logits和labels之间的损失值(loss)
  输入: logits和labels
  返回: 包含了损失值(loss)的tensor
  """
  labels = tf.to_int64(labels)
  return tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)

训练 training()

def training(loss, learning_rate):
  """
  功能: 通过梯度下降将损失最小化
  输入:
    损失loss
    梯度下降的学习速率
  返回: 包含了训练操作(training op)输出结果的Tensor
  """
  # Add a scalar summary for the snapshot loss.
  tf.summary.scalar('loss', loss)
  # 根据给定的学习速率生成一个梯度下降算法op
  optimizer = tf.train.GradientDescentOptimizer(learning_rate)
  # 声明一个变量用于保存全局步骤进行到哪一步了
  global_step = tf.Variable(0, name='global_step', trainable=False)
  # 应用梯度下降算法去最小化损失(loss)
  train_op = optimizer.minimize(loss, global_step=global_step)
  return train_op

y_connected_feed.py

主要功能: 利用下载的数据集训练构建好的MNIST模型, 以数据反馈字典(feed dictionary)的形式作为输入

run_training()

  • 这是核心function()
  • 功能: 通过一系列步骤训练MNIST
def run_training():
  """通过一系列的步骤训练MNIST"""
  # 确保你下载了正确的数据
  # 解压这些输入并返回一个含有DataSet实例的字典
  data_sets = input_data.read_data_sets(FLAGS.input_data_dir, FLAGS.fake_data)

  # 模型将会建立在默认图上
  with tf.Graph().as_default():

    # 为数据创建占位符
    images_placeholder, labels_placeholder = placeholder_inputs(
        FLAGS.batch_size)

    # 推理预测节点(op)
    logits = mnist.inference(images_placeholder,
                             FLAGS.hidden1,
                             FLAGS.hidden2)

    # 损失节点(op)
    loss = mnist.loss(logits, labels_placeholder)

    # 训练节点(op)
    train_op = mnist.training(loss, FLAGS.learning_rate)

    # Add the Op to compare the logits to the labels during evaluation.
    eval_correct = mnist.evaluation(logits, labels_placeholder)

    # Build the summary Tensor based on the TF collection of Summaries.
    summary = tf.summary.merge_all()

    # 变量初始化节点(op)
    init = tf.global_variables_initializer()

    # Create a saver for writing training checkpoints.
    saver = tf.train.Saver()

    # 创建会话
    sess = tf.Session()

    # Instantiate a SummaryWriter to output summaries and the Graph.
    summary_writer = tf.summary.FileWriter(FLAGS.log_dir, sess.graph)

    # And then after everything is built:

    # 执行变量的初始化操作
    sess.run(init)

    # 开始循环训练
    for step in xrange(FLAGS.max_steps):
      start_time = time.time()

      # 向图表提供反馈
      # 为每一个训练步骤提供一个反馈字典(包括图像和标签)
      # fill_feed_dict()
          # 查询给定的DataSet
          # 索要下一批次batch_size的图像和标签
          # 与占位符相匹配的Tensor则会包含下一批次的图像和标签
      # feed_dict之后会传入sess.run()中, 为其训练提供输入样例
      feed_dict = fill_feed_dict(data_sets.train,
                                 images_placeholder,
                                 labels_placeholder)

      # Run one step of the model.  The return values are the activations
      # from the `train_op` (which is discarded) and the `loss` Op.  To
      # inspect the values of your Ops or variables, you may include them
      # in the list passed to sess.run() and the value tensors will be
      # returned in the tuple from the call.
      _, loss_value = sess.run([train_op, loss],
                               feed_dict=feed_dict)

      duration = time.time() - start_time

      # 定时输出当前的状态
      if step % 100 == 0:
        # 输出状态
        print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration))
        # Update the events file.
        summary_str = sess.run(summary, feed_dict=feed_dict)
        summary_writer.add_summary(summary_str, step)
        summary_writer.flush()

      # 保存检查点并定期评估模型
      if (step + 1) % 1000 == 0 or (step + 1) == FLAGS.max_steps:
        checkpoint_file = os.path.join(FLAGS.log_dir, 'model.ckpt')
        saver.save(sess, checkpoint_file, global_step=step)
        # Evaluate against the training set.
        print('Training Data Eval:')
        do_eval(sess,
                eval_correct,
                images_placeholder,
                labels_placeholder,
                data_sets.train)
        # Evaluate against the validation set.
        print('Validation Data Eval:')
        do_eval(sess,
                eval_correct,
                images_placeholder,
                labels_placeholder,
                data_sets.validation)
        # Evaluate against the test set.
        print('Test Data Eval:')
        do_eval(sess,
                eval_correct,
                images_placeholder,
                labels_placeholder,
                data_sets.test)
Logo

瓜分20万奖金 获得内推名额 丰厚实物奖励 易参与易上手

更多推荐