[NLP] TorchText 使用指南_torchtext使用指南-csdn-程序员宅基地

技术标签: deep learning  PyTorch  NLP  

TorchText 是 PyTorch 的一个功能包,主要提供文本数据读取、创建迭代器的的功能与语料库、词向量的信息,分别对应了 torchtext.datatorchtext.datasetstorchtext.vocab 三个子模块。本文参考了三篇文章

1. 语料库 torchtext.datasets

TorchText 内建的语料库有:

  • Language Modeling
    • WikiText-2
    • WikiText103
    • PennTreebank
  • Sentiment Analysis
    • SST
    • IMDb
  • Text Classification
    • TextClassificationDataset
    • AG_NEWS
    • SogouNews
    • DBpedia
    • YelpReviewPolarity
    • YelpReviewFull
    • YahooAnswers
    • AmazonReviewPolarity
    • AmazonReviewFull
  • Question Classification
    • TREC
  • Entailment
    • SNLI
    • MultiNLI
  • Machine Translation
    • Multi30k
    • IWSLT
    • WMT14
  • Sequence Tagging
    • UDPOS
    • CoNLL2000Chunking
  • Question Answering
    • BABI20
  • Unsupervised Learning
    • EnWik9

2. 预训练的词向量 torchtext.vocab

TorchText 内建的预训练词向量有:

  • charngram.100d
  • fasttext.en.300d
  • fasttext.simple.300d
  • glove.42B.300d
  • glove.840B.300d
  • glove.twitter.27B.25d
  • glove.twitter.27B.50d
  • glove.twitter.27B.100d
  • glove.twitter.27B.200d
  • glove.6B.50d
  • glove.6B.100d
  • glove.6B.200d
  • glove.6B.300d

3. 数据读取、数据框的创建 torchtext.data

3.1 创建 Field

Field 可以理解为一个告诉 TorchText 如何处理字段的声明。

torchtext.data.Field(sequential=True, use_vocab=True, init_token=None, eos_token=None, fix_length=None, dtype=torch.int64, preprocessing=None, postprocessing=None, lower=False, tokenize=None, tokenizer_language='en', include_lengths=False, batch_first=False, pad_token='<pad>', unk_token='<unk>', pad_first=False, truncate_first=False, stop_words=None, is_target=False)

参数很多,这里仅仅介绍主要参数:

  • sequential:是否为已经被序列化的数据,默认为 True;
  • use_vocab:是否应用词汇表。若为 False 则数据应该已经是数字形式,默认为 True;
  • init_token:序列开头填充的 token,默认为 None 即不填充;
  • eos_token:序列结尾填充的 token,默认为 None 即不填充;
  • lower:是否将文本转换为小写,默认为 False;
  • tokenize:分词器,默认为 string.split
  • batch_first:batch 是否在第一维上;
  • pad_token:填充的 token,默认为 “”;
  • unk_token:词汇表以外的词汇的表示,默认为 “”;
  • pad_first:是否在序列的开头进行填充;默认为 False;
  • truncate_first:是否在序列的开头将序列超过规定长度的部分进行截断;默认为 False;
  • stop_words:是否过滤停用词,默认为 False;
  • is_target:这个 Field 是否为标签,默认为 False。

tokenize 可以使用 SpaCy 的分词功能,使用以前要先构建分词功能:

import spacy
spacy_en = spacy.load('en')
def tokenizer(text):
	return [token for toekn in spacy_en.tokenizer(text)]

spacy 分词的效果比原生的 split 函数好一点,但是速度也慢一些。然后可以创建对应文本的 Field 了:

TEXT = data.Field(sequential=True, tokenize=tokenizer, lower=True) # 假设文本为 raw data
LABEL = data.Field(sequential=False, use_vocab=False) # 假设标签为离散的数字变量

3.2 创建 Dataset

如果文本数据保存在 csvtsvjson 文件中,我们优先使用 torchtext.data.TabularDataset 进行读取。

torchtext.data.TabularDataset(path, format, fields, skip_header=False, csv_reader_params={}, **kwargs)

  • path:数据的路径;
  • format:文件的格式,为 csvtsvjson
  • fields:上面已经定义好的 Field;
  • skip_header:是否跳过第一行;
  • csv_reader_params:当文件为 csvtsv 时,可以自定义文件的格式。

例子:

train, val = data.TabularDataset.splits(
        path='.', train='train.csv',validation='val.csv', format='csv',skip_header=True,
        fields=[('PhraseId',None),('SentenceId',None),('Phrase', TEXT), ('Sentiment', LABEL)])

test = data.TabularDataset('test.tsv',
        format='tsv',skip_header=True,
        fields=[('PhraseId',None),('SentenceId',None),('Phrase', TEXT)])

上面的例子说,'PhraseId''SentenceId' 不读取(FieldNone),'Phrase'TEXT 的方式进行读取,'Sentiment'LABEL 的方式进行读取。

3.3 建立词汇表

现在我们需要将词转化为数字,并在模型中载入预训练好的词向量。词汇表存储在之前声明好的 Field 里面。

TEXT.build_vocab(train_data, # 建词表是用训练集建,不要用验证集和测试集
                  max_size=400000, # 单词表容量
                  vectors='glove.6B.300d', # 还有'glove.840B.300d'已经很多可以选
                  unk_init=torch.init.xavier_uniform # 初始化train_data中不存在预训练词向量词表中的单词
)

# 在神经网络里加载词向量
pretrained_embeddings = TEXT.vocab.vectors
model.embedding.weight.data.copy_(pretrained_embeddings)
UNK_IDX = REVIEW.vocab.stoi[REVIEW.unk_token]
PAD_IDX = REVIEW.vocab.stoi[REVIEW.pad_token]
# 因为预训练的权重的unk和pad的词向量不是在我们的数据集语料上训练得到的,所以最好置零
model.embedding.weight.data[UNK_IDX] = torch.zeros(EMBEDDING_DIM)
model.embedding.weight.data[PAD_IDX] = torch.zeros(EMBEDDING_DIM)

3.4 创建迭代器

迭代器推荐使用 BucketIterator,因为它会将文本中长度相似的序列尽量放在同一个 batch 里,减少 padding,从而减少计算量,加速计算。

torchtext.data.BucketIterator(dataset, batch_size, sort_key=None, device=None, batch_size_fn=None, train=True, repeat=False, shuffle=None, sort=None, sort_within_batch=None)
  • dataset:目标数据;
  • batch_size:batch 的大小;
  • sort_key:排序的方式默认为 None;
  • device:载入的设备,默认为 CPU;
  • batch_size_fn:取 batch 的函数,默认为 None;
  • train:是否为训练集,默认为 True;
  • repeat:在不同的 epoch 中是否重复相同的 iterater,默认为 False;
  • shuffle:在不同的 epoch 中是否打乱数据的顺序,默认为 None;
  • sort:是否根据 sort_key 对数据进行排序,默认为 None;
  • sort_within_batch:是否根据 sort_key 对每个 batch 内的数据进行降序排序。

举例:

train_iter, val_iter = data.BucketIterator.split((train, val), batch_size=128, sort_key=lambda x: len(x.Phrase), 
                                 shuffle=True,device=DEVICE)

# 在 test_iter , sort一定要设置成 False, 要不然会被 torchtext 搞乱样本顺序
test_iter = data.Iterator(dataset=test, batch_size=128, train=False,
                          sort=False, device=DEVICE)
版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/weixin_44614687/article/details/106582015

智能推荐

c# 调用c++ lib静态库_c#调用lib-程序员宅基地

文章浏览阅读2w次,点赞7次,收藏51次。四个步骤1.创建C++ Win32项目动态库dll 2.在Win32项目动态库中添加 外部依赖项 lib头文件和lib库3.导出C接口4.c#调用c++动态库开始你的表演...①创建一个空白的解决方案,在解决方案中添加 Visual C++ , Win32 项目空白解决方案的创建:添加Visual C++ , Win32 项目这......_c#调用lib

deepin/ubuntu安装苹方字体-程序员宅基地

文章浏览阅读4.6k次。苹方字体是苹果系统上的黑体,挺好看的。注重颜值的网站都会使用,例如知乎:font-family: -apple-system, BlinkMacSystemFont, Helvetica Neue, PingFang SC, Microsoft YaHei, Source Han Sans SC, Noto Sans CJK SC, W..._ubuntu pingfang

html表单常见操作汇总_html表单的处理程序有那些-程序员宅基地

文章浏览阅读159次。表单表单概述表单标签表单域按钮控件demo表单标签表单标签基本语法结构<form action="处理数据程序的url地址“ method=”get|post“ name="表单名称”></form><!--action,当提交表单时,向何处发送表单中的数据,地址可以是相对地址也可以是绝对地址--><!--method将表单中的数据传送给服务器处理,get方式直接显示在url地址中,数据可以被缓存,且长度有限制;而post方式数据隐藏传输,_html表单的处理程序有那些

PHP设置谷歌验证器(Google Authenticator)实现操作二步验证_php otp 验证器-程序员宅基地

文章浏览阅读1.2k次。使用说明:开启Google的登陆二步验证(即Google Authenticator服务)后用户登陆时需要输入额外由手机客户端生成的一次性密码。实现Google Authenticator功能需要服务器端和客户端的支持。服务器端负责密钥的生成、验证一次性密码是否正确。客户端记录密钥后生成一次性密码。下载谷歌验证类库文件放到项目合适位置(我这边放在项目Vender下面)https://github.com/PHPGangsta/GoogleAuthenticatorPHP代码示例://引入谷_php otp 验证器

【Python】matplotlib.plot画图横坐标混乱及间隔处理_matplotlib更改横轴间距-程序员宅基地

文章浏览阅读4.3k次,点赞5次,收藏11次。matplotlib.plot画图横坐标混乱及间隔处理_matplotlib更改横轴间距

docker — 容器存储_docker 保存容器-程序员宅基地

文章浏览阅读2.2k次。①Storage driver 处理各镜像层及容器层的处理细节,实现了多层数据的堆叠,为用户 提供了多层数据合并后的统一视图②所有 Storage driver 都使用可堆叠图像层和写时复制(CoW)策略③docker info 命令可查看当系统上的 storage driver主要用于测试目的,不建议用于生成环境。_docker 保存容器

随便推点

网络拓扑结构_网络拓扑csdn-程序员宅基地

文章浏览阅读834次,点赞27次,收藏13次。网络拓扑结构是指计算机网络中各组件(如计算机、服务器、打印机、路由器、交换机等设备)及其连接线路在物理布局或逻辑构型上的排列形式。这种布局不仅描述了设备间的实际物理连接方式,也决定了数据在网络中流动的路径和方式。不同的网络拓扑结构影响着网络的性能、可靠性、可扩展性及管理维护的难易程度。_网络拓扑csdn

JS重写Date函数,兼容IOS系统_date.prototype 将所有 ios-程序员宅基地

文章浏览阅读1.8k次,点赞5次,收藏8次。IOS系统Date的坑要创建一个指定时间的new Date对象时,通常的做法是:new Date("2020-09-21 11:11:00")这行代码在 PC 端和安卓端都是正常的,而在 iOS 端则会提示 Invalid Date 无效日期。在IOS年月日中间的横岗许换成斜杠,也就是new Date("2020/09/21 11:11:00")通常为了兼容IOS的这个坑,需要做一些额外的特殊处理,笔者在开发的时候经常会忘了兼容IOS系统。所以就想试着重写Date函数,一劳永逸,避免每次ne_date.prototype 将所有 ios

如何将EXCEL表导入plsql数据库中-程序员宅基地

文章浏览阅读5.3k次。方法一:用PLSQL Developer工具。 1 在PLSQL Developer的sql window里输入select * from test for update; 2 按F8执行 3 打开锁, 再按一下加号. 鼠标点到第一列的列头,使全列成选中状态,然后粘贴,最后commit提交即可。(前提..._excel导入pl/sql

Git常用命令速查手册-程序员宅基地

文章浏览阅读83次。Git常用命令速查手册1、初始化仓库git init2、将文件添加到仓库git add 文件名 # 将工作区的某个文件添加到暂存区 git add -u # 添加所有被tracked文件中被修改或删除的文件信息到暂存区,不处理untracked的文件git add -A # 添加所有被tracked文件中被修改或删除的文件信息到暂存区,包括untracked的文件...

分享119个ASP.NET源码总有一个是你想要的_千博二手车源码v2023 build 1120-程序员宅基地

文章浏览阅读202次。分享119个ASP.NET源码总有一个是你想要的_千博二手车源码v2023 build 1120

【C++缺省函数】 空类默认产生的6个类成员函数_空类默认产生哪些类成员函数-程序员宅基地

文章浏览阅读1.8k次。版权声明:转载请注明出处 http://blog.csdn.net/irean_lau。目录(?)[+]1、缺省构造函数。2、缺省拷贝构造函数。3、 缺省析构函数。4、缺省赋值运算符。5、缺省取址运算符。6、 缺省取址运算符 const。[cpp] view plain copy_空类默认产生哪些类成员函数

推荐文章

热门文章

相关标签