TensorFlow中accuracy.eval函数,softmax回归_eval('softmax')-程序员宅基地

技术标签: 深度学习  

下面是用TensorFlow实现Logistic Regression,步骤都做了标注,不详细说了。

#encoding:utf-8

import tensorflow as tf
# 装在MNIST数据
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_Data/data/", one_hot=True)

# 一些参数
learning_rate = 0.01
training_epochs = 25
batch_size = 100
display_step = 1

# tf Graph Input
x = tf.placeholder(tf.float32, [None, 784]) # mnist图像数据 28*28=784
y = tf.placeholder(tf.float32, [None, 10]) # 图像类别,总共10类

# 设置模型参数变量w和b
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))

# 构建softmax模型
pred = tf.nn.softmax(tf.matmul(x, W) + b)


# 损失函数用cross entropy
cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred), reduction_indices=1))
# 梯度下降优化
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

# 初始化所有变量
init = tf.initialize_all_variables()

# Launch the graph
with tf.Session() as sess:
    sess.run(init)
    # Training cycle
    for epoch in range(training_epochs):
        avg_cost = 0.
        total_batch = int(mnist.train.num_examples/batch_size)
        # 每一轮迭代total_batches
        for i in range(total_batch):
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)
            # 使用batch data训练数据
            _, c = sess.run([optimizer, cost], feed_dict={x: batch_xs,
                                                          y: batch_ys})
            # 将每个batch的损失相加求平均
            avg_cost += c / total_batch
        # 每一轮打印损失
        if (epoch+1) % display_step == 0:
            print "Epoch:", '%04d' % (epoch+1), "cost=", "{:.9f}".format(avg_cost)

    print "Optimization Finished!"

    # 模型预测
    # tf.argmax(pred,axis=1)是预测值每一行最大值的索引,这里最大值是概率最大
    # tf.argmax(y,axis=1)是真实值的每一行最大值得索引,这里最大值就是1
    correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))

    # 对3000个数据预测准确率
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    print "Accuracy:", accuracy.eval({x: mnist.test.images[:3000], y: mnist.test.labels[:3000]})

tf.argmax(input,axis)

根据axis取值的不同返回每行或者每列最大值的索引。

参考:https://blog.csdn.net/u012300744/article/details/81240580

f.cast()函数

是执行 tensorflow 中张量数据类型转换,比如读入的图片如果是int8类型的,一般在要在训练前把图像的数据格式转换为float32。

参考:https://blog.csdn.net/dcrmg/article/details/79747814

accuracy.eval()函数的作用:

f.Tensor.eval(feed_dict=None, session=None):

作用:  
在一个Seesion里面“评估”tensor的值(其实就是计算),首先执行之前的所有必要的操作来产生这个计算这个tensor需要的输入,然后通过这些输入产生这个tensor。在激发tensor.eval()这个函数之前,tensor的图必须已经投入到session里面,或者一个默认的session是有效的,或者显式指定session.  
参数:  
feed_dict:一个字典,用来表示tensor被feed的值(联系placeholder一起看)  
session:(可选) 用来计算(evaluate)这个tensor的session.要是没有指定的话,那么就会使用默认的session。  
返回:  
表示“计算”结果值的numpy ndarray

转自原文:https://blog.csdn.net/Yaphat/article/details/53349551 

 

 


但是注意这个测试集和训练集的X,Y的x_i,y_i都是以行的,而吴恩达教授的深度学习课程正好相反,这点需要注意。吴深度学习第二课编程作业3中计算cost和上述代码不同。

    def compute_cost(Z3, Y):
    """
    Computes the cost
    
    Arguments:
    Z3 -- output of forward propagation (output of the last LINEAR unit), of shape (6, number of examples)
    Y -- "true" labels vector placeholder, same shape as Z3
    
    Returns:
    cost - Tensor of the cost function
    """
    
    # to fit the tensorflow requirement for tf.nn.softmax_cross_entropy_with_logits(...,...)
    logits = tf.transpose(Z3)#选择默认参数,意为矩阵转置
    labels = tf.transpose(Y)
    
    ### START CODE HERE ### (1 line of code)
    cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits = logits, labels = labels))
    ### END CODE HERE ###
    
    return cost

 

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/Msjiangmei/article/details/90296090

智能推荐

python调用linux软键盘_在Linux中使用Python模拟键盘按键-程序员宅基地

文章浏览阅读692次。之前在做自动化测试中需要用Python在linux中模拟Ctrl+V 进行路径粘贴,试了很多种方法,起初用了xerox和pyperclip这两个python库,但打开对话框后调用粘贴API无法进行粘贴,不知道为什么。然后发现了virtkey这个库,这个库可以在linux 中模拟键盘按键,但网上资料甚少。这个库主要有两个API1、press_keysym/release_keysym2、press_..._linux python模拟按键输入

aec一pc_AEC_PC_DLL.dll下载|AEC_PC_DLL.dll下载官方版【32位|64位】-太平洋下载中心-程序员宅基地

文章浏览阅读622次。AEC_PC_DLL.dll使用方法:方法一一、如果在运行某软件或编译程序时提示缺少、找不到AEC_PC_DLL.dll等类似提示,您可将从太平洋下载中心下载来的AEC_PC_DLL.dll拷贝到指定目录即可(一般是system系统目录或放到软件同级目录里面),或者重新添加文件引用。二、将软件包下载下来后,先将其解压(一般都是rar压缩包), 然后根据您系统的情况选择X86/X64,X86为32位..._aec_pc_dll.dll下载

【Docker | 4】Docker安装和配置指南_sudo yum install docker-程序员宅基地

文章浏览阅读1.4w次,点赞4次,收藏3次。Docker作为一种领先的容器化技术,极大地改变了现代应用部署和管理的方式。在开始使用Docker之前,首先需要在不同操作系统上安装和配置Docker。本文将提供详细的Docker安装和配置步骤,涵盖Linux、Windows和macOS平台,并介绍如何配置Docker Daemon和Docker客户端,轻松上手Docker容器化技术。_sudo yum install docker

omnet++中tictoc实例(中文注释) 1-6_"@display(\"bgb=216,207\")"-程序员宅基地

文章浏览阅读1.3k次。文章目录具体效果请自行复制运行tictoc1tictoc2tictoc3tictoc4tictoc5 4、5相差不大tictoc6具体效果请自行复制运行tictoc1tictoc1.nedsimple Txc1{ gates: input in; output out;}network Tictoc1{ @display("bgb=171,129"); submodules: tic: Txc1; _"@display(\"bgb=216,207\")"

element ui组件的自定义类名样式不生效_electron 自定义样式不起作用-程序员宅基地

文章浏览阅读618次。需要注意,样式不能写在标签中,会被vue自动加上data-v属性,导致样式失效。element ui中,类似描述列表这种组件。必须写在<style>标签里。会提供自定义类名属性。_electron 自定义样式不起作用

Java中char、String与int的运算结果_java int类型+string类型最后输出什么类型-程序员宅基地

文章浏览阅读1.1k次。学习地址:https://www.cnblogs.com/sekihin/archive/2007/06/11/779047.html1.Java中有char值参与的计算System.out.println('0'+'A');// 48+65 = 113 未指定类型--输出intSystem.out.println((char)('0'+'A'));// q 强转char --输出charSystem.out.println(10+'A'); //75 未指定类型--输出intSystem.o_java int类型+string类型最后输出什么类型

随便推点

基于python超市仓库管理系统的设计与实现-计算机毕业设计源码96723_编写一个程序,模拟库存管理系统,主要有商品入库,商品出库,输出仓库中商品信息的功-程序员宅基地

文章浏览阅读2.3k次,点赞13次,收藏34次。超市仓库管理系统从实际应用角度来说可以分成用户管理模块、系统模块、主要功能模块三大模块。1.系统用户管理模块系统用户管理模块可以分成用户管理、个人信息管理和权限管理模块。用户管理是对用户的相关信息进行查阅、修改,删除等操作。个人信息管理可以对个人信息的情况进行添加、修改信息删除、个人信息修改和个人信息查询。2.主要功能管理模块包括用户管理、商品分类、商品信息、商品出库、商品入库、通知公告管理。3.系统管理模块系统管理模块分为数据备份。_编写一个程序,模拟库存管理系统,主要有商品入库,商品出库,输出仓库中商品信息的功

Linux Kernel 4.16 系列停止维护,用户应升级至 4.17-程序员宅基地

文章浏览阅读88次。2019独角兽企业重金招聘Python工程师标准>>> ..._redhat系统内核升级4.17

第3.2章:StarRocks数据导入--Stream Load_starrocks stream load-程序员宅基地

文章浏览阅读1.1w次,点赞9次,收藏26次。一、环境准备Stream Load可以说是StarRocks最为核心的导入方式,StarRocks的主要导入方式例如Routine Load、Flink Connector、DataX StarRocksWriter等,底层实现都是基于Stream Load的思想,所以我们着重介绍。Stream Load是由用户发送HTTP请求将本地文件或数据流导入至StarRocks中的导入方式,其本身不依赖其他组件。Stream Load支持导入本地数据文件(csv、txt等)和json文件,建议单次导入的数据_starrocks stream load

Pytorch学习——22种transforms数据预处理方法_transforms 预处理-程序员宅基地

文章浏览阅读998次。一、图像增强数据增强又称为数据增广,数据扩增,它是对训练集进行变换,使训练集更丰富,从而让模型更具泛化能力二、transforms——裁剪1. transforms.CenterCrop 功能:从图像中心裁剪图片 size:所需裁剪图片尺寸2. transforms.RandomCrop (size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant')功能:从图片中随机裁剪出尺寸为size的图._transforms 预处理

element-ui message 显示重叠问题_element ui message 多个重叠-程序员宅基地

文章浏览阅读7.6k次,点赞3次,收藏7次。问题描述:在同一个方法中,触发了多个 message 组件提示信息时,会出现消息重叠的问题。解决方案:将方法定义为async异步函数,然后使用await等待执行。async checkLogin () { if (this.username === '') { await this.$message.warning('请输入用户名') } if (this.password === '') { await this.$message.warning('请输入用户..._element ui message 多个重叠

@程序员,代码清理有必要吗-程序员宅基地

文章浏览阅读136次。本文首发在CSDN 微信(ID:CSDNNews)。以下为译文:你的项目截止时间就要到了,你有一个紧急的 bug 需要修复,你的项目需要快速迭代输出产品。虽然你很忙,但..._不使用的代码段有必要清理吗

推荐文章

热门文章

相关标签