南强小屋 Design By 杰米
本文实例讲述了pytorch制作自己的LMDB数据操作。分享给大家供大家参考,具体如下:
前言
记录下pytorch里如何使用lmdb的code,自用
制作部分的Code
code就是ASTER里数据制作部分的代码改了点,aster_train.txt里面就算图片的完整路径每行一个,图片同目录下有同名的txt,里面记着jpg的标签
import os import lmdb # install lmdb by "pip install lmdb" import cv2 import numpy as np from tqdm import tqdm import six from PIL import Image import scipy.io as sio from tqdm import tqdm import re def checkImageIsValid(imageBin): if imageBin is None: return False imageBuf = np.fromstring(imageBin, dtype=np.uint8) img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE) imgH, imgW = img.shape[0], img.shape[1] if imgH * imgW == 0: return False return True def writeCache(env, cache): with env.begin(write=True) as txn: for k, v in cache.items(): txn.put(k.encode(), v) def _is_difficult(word): assert isinstance(word, str) return not re.match('^[\w]+$', word) def createDataset(outputPath, imagePathList, labelList, lexiconList=None, checkValid=True): """ Create LMDB dataset for CRNN training. ARGS: outputPath : LMDB output path imagePathList : list of image path labelList : list of corresponding groundtruth texts lexiconList : (optional) list of lexicon lists checkValid : if true, check the validity of every image """ assert(len(imagePathList) == len(labelList)) nSamples = len(imagePathList) env = lmdb.open(outputPath, map_size=1099511627776)#最大空间1048576GB cache = {} cnt = 1 for i in range(nSamples): imagePath = imagePathList[i] label = labelList[i] if len(label) == 0: continue if not os.path.exists(imagePath): print('%s does not exist' % imagePath) continue with open(imagePath, 'rb') as f: imageBin = f.read() if checkValid: if not checkImageIsValid(imageBin): print('%s is not a valid image' % imagePath) continue #数据库中都是二进制数据 imageKey = 'image-%09d' % cnt#9位数不足填零 labelKey = 'label-%09d' % cnt cache[imageKey] = imageBin cache[labelKey] = label.encode() if lexiconList: lexiconKey = 'lexicon-%09d' % cnt cache[lexiconKey] = ' '.join(lexiconList[i]) if cnt % 1000 == 0: writeCache(env, cache) cache = {} print('Written %d / %d' % (cnt, nSamples)) cnt += 1 nSamples = cnt-1 cache['num-samples'] = str(nSamples).encode() writeCache(env, cache) print('Created dataset with %d samples' % nSamples) def get_sample_list(txt_path:str): with open(txt_path,'r') as fr: jpg_list=[x.strip() for x in fr.readlines() if os.path.exists(x.replace('.jpg','.txt').strip())] txt_content_list=[] for jpg in jpg_list: label_path=jpg.replace('.jpg','.txt') with open(label_path,'r') as fr: try: str_tmp=fr.readline() except UnicodeDecodeError as e: print(label_path) raise(e) txt_content_list.append(str_tmp.strip()) return jpg_list,txt_content_list if __name__ == "__main__": txt_path='/home/gpu-server/disk/disk1/NumberData/8NumberSample/aster_train.txt' lmdb_output_path = '/home/gpu-server/project/aster/dataset/train' imagePathList,labelList=get_sample_list(txt_path) createDataset(lmdb_output_path, imagePathList, labelList)
读取部分
这里用的pytorch的dataloader,简单记录一下,人比较懒,代码就直接抄过来,不整理拆分了,重点看__getitem__
from __future__ import absolute_import # import sys # sys.path.append('./') import os # import moxing as mox import pickle from tqdm import tqdm from PIL import Image, ImageFile import numpy as np import random import cv2 import lmdb import sys import six import torch from torch.utils import data from torch.utils.data import sampler from torchvision import transforms from lib.utils.labelmaps import get_vocabulary, labels2strs from lib.utils import to_numpy ImageFile.LOAD_TRUNCATED_IMAGES = True from config import get_args global_args = get_args(sys.argv[1:]) if global_args.run_on_remote: import moxing as mox #moxing是一个分布式的框架 跳过 class LmdbDataset(data.Dataset): def __init__(self, root, voc_type, max_len, num_samples, transform=None): super(LmdbDataset, self).__init__() if global_args.run_on_remote: dataset_name = os.path.basename(root) data_cache_url = "/cache/%s" % dataset_name if not os.path.exists(data_cache_url): os.makedirs(data_cache_url) if mox.file.exists(root): mox.file.copy_parallel(root, data_cache_url) else: raise ValueError("%s not exists!" % root) self.env = lmdb.open(data_cache_url, max_readers=32, readonly=True) else: self.env = lmdb.open(root, max_readers=32, readonly=True) assert self.env is not None, "cannot create lmdb from %s" % root self.txn = self.env.begin() self.voc_type = voc_type self.transform = transform self.max_len = max_len self.nSamples = int(self.txn.get(b"num-samples")) self.nSamples = min(self.nSamples, num_samples) assert voc_type in ['LOWERCASE', 'ALLCASES', 'ALLCASES_SYMBOLS','DIGITS'] self.EOS = 'EOS' self.PADDING = 'PADDING' self.UNKNOWN = 'UNKNOWN' self.voc = get_vocabulary(voc_type, EOS=self.EOS, PADDING=self.PADDING, UNKNOWN=self.UNKNOWN) self.char2id = dict(zip(self.voc, range(len(self.voc)))) self.id2char = dict(zip(range(len(self.voc)), self.voc)) self.rec_num_classes = len(self.voc) self.lowercase = (voc_type == 'LOWERCASE') def __len__(self): return self.nSamples def __getitem__(self, index): assert index <= len(self), 'index range error' index += 1 img_key = b'image-%09d' % index imgbuf = self.txn.get(img_key) #由于Image.open需要一个类文件对象 所以这里需要把二进制转为一个类文件对象 buf = six.BytesIO() buf.write(imgbuf) buf.seek(0) try: img = Image.open(buf).convert('RGB') # img = Image.open(buf).convert('L') # img = img.convert('RGB') except IOError: print('Corrupted image for %d' % index) return self[index + 1] # reconition labels label_key = b'label-%09d' % index word = self.txn.get(label_key).decode() if self.lowercase: word = word.lower() ## fill with the padding token label = np.full((self.max_len,), self.char2id[self.PADDING], dtype=np.int) label_list = [] for char in word: if char in self.char2id: label_list.append(self.char2id[char]) else: ## add the unknown token print('{0} is out of vocabulary.'.format(char)) label_list.append(self.char2id[self.UNKNOWN]) ## add a stop token label_list = label_list + [self.char2id[self.EOS]] assert len(label_list) <= self.max_len label[:len(label_list)] = np.array(label_list) if len(label) <= 0: return self[index + 1] # label length label_len = len(label_list) if self.transform is not None: img = self.transform(img) return img, label, label_len
更多关于Python相关内容可查看本站专题:《Python数学运算技巧总结》、《Python图片操作技巧总结》、《Python数据结构与算法教程》、《Python函数使用技巧总结》、《Python字符串操作技巧汇总》及《Python入门与进阶经典教程》
希望本文所述对大家Python程序设计有所帮助。
南强小屋 Design By 杰米
广告合作:本站广告合作请联系QQ:858582 申请时备注:广告合作(否则不回)
免责声明:本站文章均来自网站采集或用户投稿,网站不提供任何软件下载或自行开发的软件! 如有用户或公司发现本站内容信息存在侵权行为,请邮件告知! 858582#qq.com
免责声明:本站文章均来自网站采集或用户投稿,网站不提供任何软件下载或自行开发的软件! 如有用户或公司发现本站内容信息存在侵权行为,请邮件告知! 858582#qq.com
南强小屋 Design By 杰米
暂无pytorch制作自己的LMDB数据操作示例的评论...
RTX 5090要首发 性能要翻倍!三星展示GDDR7显存
三星在GTC上展示了专为下一代游戏GPU设计的GDDR7内存。
首次推出的GDDR7内存模块密度为16GB,每个模块容量为2GB。其速度预设为32 Gbps(PAM3),但也可以降至28 Gbps,以提高产量和初始阶段的整体性能和成本效益。
据三星表示,GDDR7内存的能效将提高20%,同时工作电压仅为1.1V,低于标准的1.2V。通过采用更新的封装材料和优化的电路设计,使得在高速运行时的发热量降低,GDDR7的热阻比GDDR6降低了70%。