Tensorflow笔记__使用mnist数据集并测试自己的手写图片_py如何安装 mnist_backward-程序员宅基地

技术标签: Tensorflow笔记  

内容源于曹建老师的tensorflow笔记课程

源码链接:https://github.com/cj0012/AI-Practice-Tensorflow-Notes

测试图片下载:https://github.com/cj0012/AI-Practice-Tensorflow-Notes/blob/master/num.zip

主要包含四个文件,主要是mnist_forward.py,mnist_backward.py,mnist_test.py,mnist_app.py

定义前向传播过程 mnist_forward.py:

 

import tensorflow as tf

INPUT_NODE = 784
OUTPUT_NODE = 10
LAYER_NODE = 500

# 定义神经网络参数,传入两个参数,一个是shape一个是正则化参数大小
def get_weight(shape,regularizer):
    # tf.truncated_normal截断的正态分布函数,超过标准差的重新生成
   w = tf.Variable(tf.truncated_normal(shape,stddev=0.1))
   if regularizer != None:
        # 将正则化结果存入losses中
      tf.add_to_collection("losses",tf.contrib.layers.l2_regularizer(regularizer)(w))
   return w

# 定义偏置b,传入shape参数
def get_bias(shape):
    # 初始化为0
   b = tf.Variable(tf.zeros(shape))
   return b

# 定义前向传播过程,两个参数,一个是输入数据,一个是正则化参数
def forward(x,regularizer):
    # w1的维度就是[输入神经元大小,第一层隐含层神经元大小]
   w1 = get_weight([INPUT_NODE,LAYER_NODE],regularizer)
    # 偏置b参数,与w的后一个参数相同
   b1 = get_bias(LAYER_NODE)
    # 激活函数
   y1 = tf.nn.relu(tf.matmul(x,w1)+b1)

   w2 = get_weight([LAYER_NODE,OUTPUT_NODE],regularizer)
   b2 = get_bias(OUTPUT_NODE)
   y = tf.matmul(y1,w2)+b2

   return y

   

定义反向传播过程 mnist_backward.py:

 

#coding:utf-8
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import mnist_forward
import os

BATCH_SIZE = 200
#学习率衰减的原始值
LEARNING_RATE_BASE = 0.1
# 学习率衰减率
LEARNING_RATE_DECAY = 0.99
# 正则化参数
REGULARIZER = 0.0001
# 训练轮数
STEPS = 50000
#这个使用滑动平均的衰减率
MOVING_AVERAGE_DECAY = 0.99
MODEL_SAVE_PATH = "./model/"
MODEL_NAME = "mnist_model"

def backward(mnist):
   #一共有多少个特征,784行,一列
   x = tf.placeholder(tf.float32,[None,mnist_forward.INPUT_NODE])
   y_ = tf.placeholder(tf.float32,[None,mnist_forward.OUTPUT_NODE])
   # 给前向传播传入参数x和正则化参数计算出y的值
   y = mnist_forward.forward(x,REGULARIZER)
   # 初始化global—step,它会随着训练轮数增加
   global_step = tf.Variable(0,trainable=False)

   # softmax和交叉商一起运算的函数,logits传入是x*w,也就是y
   ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y,labels=tf.argmax(y_,1))
   cem = tf.reduce_mean(ce)
   loss = cem + tf.add_n(tf.get_collection("losses"))

   learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE,
                                               global_step,
                                               mnist.train.num_examples/BATCH_SIZE,
                                               LEARNING_RATE_DECAY,
                                               staircase = True)

   train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss,global_step = global_step)

    # 滑动平均处理,可以提高泛华能力
   ema = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY,global_step)
   ema_op = ema.apply(tf.trainable_variables())
   # 将train_step和滑动平均计算ema_op放在同一个节点
   with tf.control_dependencies([train_step,ema_op]):
      train_op = tf.no_op(name="train")
        
   saver = tf.train.Saver()

   with tf.Session() as sess:
        
      init_op = tf.global_variables_initializer()
      sess.run(init_op)

      for i in range(STEPS):
         # mnist.train.next_batch()函数包含一个参数BATCH_SIZE,表示随机从训练集中抽取BATCH_SIZE个样本输入到神经网络
         # next_batch函数返回的是image的像素和标签label
         xs,ys = mnist.train.next_batch(BATCH_SIZE)
         # _,表示后面不使用这个变量
         _,loss_value,step = sess.run([train_op,loss,global_step],feed_dict={x:xs,y_:ys})
            
         if i % 1000 == 0:
            print("Ater {} training step(s),loss on training batch is {} ".format(step,loss_value))
            saver.save(sess,os.path.join(MODEL_SAVE_PATH,MODEL_NAME),global_step=global_step)

def main():
    
   mnist = input_data.read_data_sets("./data",one_hot = True)
   backward(mnist)

if __name__ == "__main__":
   main()

更新如下,增加tensorboard和断点续训内容:

# coding:utf-8
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import mnist_forward
import os
import time

BATCH_SIZE = 200
# 学习率衰减的原始值
LEARNING_RATE_BASE = 0.1
# 学习率衰减率
LEARNING_RATE_DECAY = 0.99
# 正则化参数
REGULARIZER = 0.0001
# 训练轮数
STEPS = 50000
# 这个使用滑动平均的衰减率
MOVING_AVERAGE_DECAY = 0.99
MODEL_SAVE_PATH = "model/"
MODEL_NAME = "mnist_model"


def checkpoint_load(sess,saver,path):
    print('Reading Checkpoints... .. .\n')
    ckpt = tf.train.get_checkpoint_state(path)
    if ckpt and ckpt.model_checkpoint_path:
        ckpt_path = ckpt.model_checkpoint_path
        saver.restore(sess,os.path.join(os.getcwd(),ckpt_path))
        step = int(os.path.basename(ckpt_path).split('-')[-1])
    else:
        step = 0
        print("\nCheckpoint Loading Failed! \n")
    return step
def backward(mnist):
    # 一共有多少个特征,784行,一列
    x = tf.placeholder(tf.float32, [None, mnist_forward.INPUT_NODE])
    y_ = tf.placeholder(tf.float32, [None, mnist_forward.OUTPUT_NODE])
    # 给前向传播传入参数x和正则化参数计算出y的值
    y = mnist_forward.forward(x, REGULARIZER)
    # 初始化global—step,它会随着训练轮数增加
    global_step = tf.Variable(0, trainable=False)

    # softmax和交叉商一起运算的函数,logits传入是x*w,也就是y
    ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))
    cem = tf.reduce_mean(ce)
    loss = cem + tf.add_n(tf.get_collection("losses"))
    learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE,
                                               global_step,
                                               mnist.train.num_examples / BATCH_SIZE,
                                               LEARNING_RATE_DECAY,
                                               staircase=True)

    train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)

    # 滑动平均处理,可以提高泛华能力
    ema = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
    ema_op = ema.apply(tf.trainable_variables())
    # 将train_step和滑动平均计算ema_op放在同一个节点
    with tf.control_dependencies([train_step, ema_op]):
        train_op = tf.no_op(name="train")

    tf.summary.scalar('loss',loss)
    merge_summary = tf.summary.merge_all()
    # 没有的话会自动创建
    summary_path = os.path.join(MODEL_SAVE_PATH,'summary')
    train_writer = tf.summary.FileWriter(summary_path)
    saver = tf.train.Saver()
    time_ = time.time()
    with tf.Session() as sess:
        init_op = tf.global_variables_initializer()
        sess.run(init_op)
        # 加载模型并返回当前训练的step,加载模型,实现断点续训,如果模型存在返回step,否则返回0
        counter = checkpoint_load(sess,saver, MODEL_SAVE_PATH)
        for i in range(STEPS):
            # mnist.train.next_batch()函数包含一个参数BATCH_SIZE,表示随机从训练集中抽取BATCH_SIZE个样本输入到神经网络
            # next_batch函数返回的是image的像素和标签label
            xs, ys = mnist.train.next_batch(BATCH_SIZE)
            # _,表示后面不使用这个变量
            _, loss_value = sess.run([train_op, loss], feed_dict={x: xs, y_: ys})
            counter += 1
            if i % 1000 == 0:
                print("Ater {} training step(s),loss on training batch is {:.7f},time consuming is {:.4f} ".format(counter, loss_value,time.time()-time_))
                saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=counter)
                train_summary = sess.run(merge_summary,feed_dict={x: xs, y_: ys})
                train_writer.add_summary(train_summary,counter)

def main():
    mnist = input_data.read_data_sets("./data", one_hot=True)
    backward(mnist)


if __name__ == "__main__":
    main()
Ater 168008 training step(s),loss on training batch is 0.1153447,time consuming is 1.0684 
Ater 169008 training step(s),loss on training batch is 0.1144902,time consuming is 3.5929 
Ater 170008 training step(s),loss on training batch is 0.1135995,time consuming is 6.2114 
Ater 171008 training step(s),loss on training batch is 0.1192827,time consuming is 8.6844 
Ater 172008 training step(s),loss on training batch is 0.1177420,time consuming is 11.2969 
Ater 173008 training step(s),loss on training batch is 0.1164965,time consuming is 13.8670 

定义测试部分 mnist_test.py:

#coding:utf-8
import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import mnist_forward
import mnist_backward

TEST_INTERVAL_SECS = 5

def test(mnist):
    with tf.Graph().as_default() as g:
        # 占位符,第一个参数是tf.float32数据类型,第二个参数是shape,shape[0]=None表示输入维度任意,shpe[1]表示输入数据特征数
        x = tf.placeholder(tf.float32,shape = [None,mnist_forward.INPUT_NODE])
        y_ = tf.placeholder(tf.float32,shape = [None,mnist_forward.OUTPUT_NODE])
        """注意这里没有传入正则化参数,需要明确的是,在测试的时候不要正则化,不要dropout"""
        y = mnist_forward.forward(x,None)

        # 实例化可还原的滑动平均模型
        ema = tf.train.ExponentialMovingAverage(mnist_backward.MOVING_AVERAGE_DECAY)
        ema_restore = ema.variables_to_restore()
        saver = tf.train.Saver(ema_restore)

        # y计算的过程:x是mnist.test.images是10000×784的,最后输出的y仕10000×10的,y_:mnist.test.labels也是10000×10的
        correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(y_,1))
        # tf.cast可以师兄数据类型的转换,tf.equal返回的只有True和False
        accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

        while True:
            with tf.Session() as sess:
                # 加载训练好的模型
                ckpt = tf.train.get_checkpoint_state(mnist_backward.MODEL_SAVE_PATH)
                if ckpt and ckpt.model_checkpoint_path:
                    # 恢复模型到当前会话
                    saver.restore(sess,ckpt.model_checkpoint_path)
                    # 恢复轮数
                    global_step = ckpt.model_checkpoint_path.split("/")[-1].split("-")[-1]
                    # 计算准确率
                    accuracy_score = sess.run(accuracy,feed_dict={x:mnist.test.images,y_:mnist.test.labels})
                    print("After {} training step(s),test accuracy is: {} ".format(global_step,accuracy_score))
                else:
                    print("No chekpoint file found")
                    print(sess.run(y,feed_dict={x:mnist.test.images}))
                    return
            time.sleep(TEST_INTERVAL_SECS)

def main():

    mnist = input_data.read_data_sets("./data",one_hot=True)
    test(mnist)

if __name__== "__main__":
    main()

定义使用手写图片部分mnist_app.py:

 

import tensorflow as tf
import numpy as np
from PIL import Image
import mnist_forward
import mnist_backward

# 定义加载使用模型进行预测的函数
def restore_model(testPicArr):

    with tf.Graph().as_default() as tg:
        
        x = tf.placeholder(tf.float32,[None,mnist_forward.INPUT_NODE])
        y = mnist_forward.forward(x,None)
        preValue = tf.argmax(y,1)
        # 加载滑动平均模型
        variable_averages = tf.train.ExponentialMovingAverage(mnist_backward.MOVING_AVERAGE_DECAY)
        variables_to_restore = variable_averages.variables_to_restore()
        saver = tf.train.Saver(variables_to_restore)

        with tf.Session() as sess:
            
            ckpt = tf.train.get_checkpoint_state(mnist_backward.MODEL_SAVE_PATH)
            if ckpt and ckpt.model_checkpoint_path:
                # 恢复当前会话,将ckpt中的值赋值给w和b
                saver.restore(sess,ckpt.model_checkpoint_path)
                # 执行图计算
                preValue = sess.run(preValue,feed_dict={x:testPicArr})
                return preValue
            else:
                print("No checkpoint file found")
                return -1
# 图片预处理函数
def pre_pic(picName):
    # 先打开传入的原始图片
    img = Image.open(picName)
    # 使用消除锯齿的方法resize图片
    reIm = img.resize((28,28),Image.ANTIALIAS)
    # 变成灰度图,转换成矩阵
    im_arr = np.array(reIm.convert("L"))
    threshold = 50#对图像进行二值化处理,设置合理的阈值,可以过滤掉噪声,让他只有纯白色的点和纯黑色点
    for i in range(28):
        for j in range(28):
            im_arr[i][j] = 255-im_arr[i][j]
            if (im_arr[i][j]<threshold):
                im_arr[i][j] = 0
            else:
                im_arr[i][j] = 255
    # 将图像矩阵拉成1行784列,并将值变成浮点型(像素要求的仕0-1的浮点型输入)
    nm_arr = im_arr.reshape([1,784])
    nm_arr = nm_arr.astype(np.float32)
    img_ready = np.multiply(nm_arr,1.0/255.0)

    return img_ready

def application():
    # input函数可以从控制台接受数字
    testNum = int(input("input the number of test images:"))
    # 使用循环来历遍需要测试的图片才结束
    for i in range(testNum):
        # input可以实现从控制台接收字符格式,图片存储路径
        testPic = input("the path of test picture:")
        # 将图片路径传入图像预处理函数中
        testPicArr = pre_pic(testPic)
        # 将处理后的结果输入到预测函数最后返回预测结果
        preValue = restore_model(testPicArr)
        print("The prediction number is :",preValue)

def main():
    application()

if __name__ == "__main__":
    main()

output:

The end.

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/Li_haiyu/article/details/80846657

智能推荐

oracle 12c 集群安装后的检查_12c查看crs状态-程序员宅基地

文章浏览阅读1.6k次。安装配置gi、安装数据库软件、dbca建库见下:http://blog.csdn.net/kadwf123/article/details/784299611、检查集群节点及状态:[root@rac2 ~]# olsnodes -srac1 Activerac2 Activerac3 Activerac4 Active[root@rac2 ~]_12c查看crs状态

解决jupyter notebook无法找到虚拟环境的问题_jupyter没有pytorch环境-程序员宅基地

文章浏览阅读1.3w次,点赞45次,收藏99次。我个人用的是anaconda3的一个python集成环境,自带jupyter notebook,但在我打开jupyter notebook界面后,却找不到对应的虚拟环境,原来是jupyter notebook只是通用于下载anaconda时自带的环境,其他环境要想使用必须手动下载一些库:1.首先进入到自己创建的虚拟环境(pytorch是虚拟环境的名字)activate pytorch2.在该环境下下载这个库conda install ipykernelconda install nb__jupyter没有pytorch环境

国内安装scoop的保姆教程_scoop-cn-程序员宅基地

文章浏览阅读5.2k次,点赞19次,收藏28次。选择scoop纯属意外,也是无奈,因为电脑用户被锁了管理员权限,所有exe安装程序都无法安装,只可以用绿色软件,最后被我发现scoop,省去了到处下载XXX绿色版的烦恼,当然scoop里需要管理员权限的软件也跟我无缘了(譬如everything)。推荐添加dorado这个bucket镜像,里面很多中文软件,但是部分国外的软件下载地址在github,可能无法下载。以上两个是官方bucket的国内镜像,所有软件建议优先从这里下载。上面可以看到很多bucket以及软件数。如果官网登陆不了可以试一下以下方式。_scoop-cn

Element ui colorpicker在Vue中的使用_vue el-color-picker-程序员宅基地

文章浏览阅读4.5k次,点赞2次,收藏3次。首先要有一个color-picker组件 <el-color-picker v-model="headcolor"></el-color-picker>在data里面data() { return {headcolor: ’ #278add ’ //这里可以选择一个默认的颜色} }然后在你想要改变颜色的地方用v-bind绑定就好了,例如:这里的:sty..._vue el-color-picker

迅为iTOP-4412精英版之烧写内核移植后的镜像_exynos 4412 刷机-程序员宅基地

文章浏览阅读640次。基于芯片日益增长的问题,所以内核开发者们引入了新的方法,就是在内核中只保留函数,而数据则不包含,由用户(应用程序员)自己把数据按照规定的格式编写,并放在约定的地方,为了不占用过多的内存,还要求数据以根精简的方式编写。boot启动时,传参给内核,告诉内核设备树文件和kernel的位置,内核启动时根据地址去找到设备树文件,再利用专用的编译器去反编译dtb文件,将dtb还原成数据结构,以供驱动的函数去调用。firmware是三星的一个固件的设备信息,因为找不到固件,所以内核启动不成功。_exynos 4412 刷机

Linux系统配置jdk_linux配置jdk-程序员宅基地

文章浏览阅读2w次,点赞24次,收藏42次。Linux系统配置jdkLinux学习教程,Linux入门教程(超详细)_linux配置jdk

随便推点

matlab(4):特殊符号的输入_matlab微米怎么输入-程序员宅基地

文章浏览阅读3.3k次,点赞5次,收藏19次。xlabel('\delta');ylabel('AUC');具体符号的对照表参照下图:_matlab微米怎么输入

C语言程序设计-文件(打开与关闭、顺序、二进制读写)-程序员宅基地

文章浏览阅读119次。顺序读写指的是按照文件中数据的顺序进行读取或写入。对于文本文件,可以使用fgets、fputs、fscanf、fprintf等函数进行顺序读写。在C语言中,对文件的操作通常涉及文件的打开、读写以及关闭。文件的打开使用fopen函数,而关闭则使用fclose函数。在C语言中,可以使用fread和fwrite函数进行二进制读写。‍ Biaoge 于2024-03-09 23:51发布 阅读量:7 ️文章类型:【 C语言程序设计 】在C语言中,用于打开文件的函数是____,用于关闭文件的函数是____。

Touchdesigner自学笔记之三_touchdesigner怎么让一个模型跟着鼠标移动-程序员宅基地

文章浏览阅读3.4k次,点赞2次,收藏13次。跟随鼠标移动的粒子以grid(SOP)为partical(SOP)的资源模板,调整后连接【Geo组合+point spirit(MAT)】,在连接【feedback组合】适当调整。影响粒子动态的节点【metaball(SOP)+force(SOP)】添加mouse in(CHOP)鼠标位置到metaball的坐标,实现鼠标影响。..._touchdesigner怎么让一个模型跟着鼠标移动

【附源码】基于java的校园停车场管理系统的设计与实现61m0e9计算机毕设SSM_基于java技术的停车场管理系统实现与设计-程序员宅基地

文章浏览阅读178次。项目运行环境配置:Jdk1.8 + Tomcat7.0 + Mysql + HBuilderX(Webstorm也行)+ Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。项目技术:Springboot + mybatis + Maven +mysql5.7或8.0+html+css+js等等组成,B/S模式 + Maven管理等等。环境需要1.运行环境:最好是java jdk 1.8,我们在这个平台上运行的。其他版本理论上也可以。_基于java技术的停车场管理系统实现与设计

Android系统播放器MediaPlayer源码分析_android多媒体播放源码分析 时序图-程序员宅基地

文章浏览阅读3.5k次。前言对于MediaPlayer播放器的源码分析内容相对来说比较多,会从Java-&amp;amp;gt;Jni-&amp;amp;gt;C/C++慢慢分析,后面会慢慢更新。另外,博客只作为自己学习记录的一种方式,对于其他的不过多的评论。MediaPlayerDemopublic class MainActivity extends AppCompatActivity implements SurfaceHolder.Cal..._android多媒体播放源码分析 时序图

java 数据结构与算法 ——快速排序法-程序员宅基地

文章浏览阅读2.4k次,点赞41次,收藏13次。java 数据结构与算法 ——快速排序法_快速排序法