南强小屋 Design By 杰米
我就废话不多说了,直接上代码吧!
from os import listdir import os from time import time import torch.utils.data as data import torchvision.transforms as transforms from torch.utils.data import DataLoader def printProgressBar(iteration, total, prefix='', suffix='', decimals=1, length=100, fill='=', empty=' ', tip='>', begin='[', end=']', done="[DONE]", clear=True): percent = ("{0:." + str(decimals) + "f}").format(100 * (iteration / float(total))) filledLength = int(length * iteration // total) bar = fill * filledLength if iteration != total: bar = bar + tip bar = bar + empty * (length - filledLength - len(tip)) display = '\r{prefix}{begin}{bar}{end} {percent}%{suffix}' .format(prefix=prefix, begin=begin, bar=bar, end=end, percent=percent, suffix=suffix) print(display, end=''), # comma after print() required for python 2 if iteration == total: # print with newline on complete if clear: # display given complete message with spaces to 'erase' previous progress bar finish = '\r{prefix}{done}'.format(prefix=prefix, done=done) if hasattr(str, 'decode'): # handle python 2 non-unicode strings for proper length measure finish = finish.decode('utf-8') display = display.decode('utf-8') clear = ' ' * max(len(display) - len(finish), 0) print(finish + clear) else: print('') class DatasetFromFolder(data.Dataset): def __init__(self, image_dir): super(DatasetFromFolder, self).__init__() self.photo_path = os.path.join(image_dir, "a") self.sketch_path = os.path.join(image_dir, "b") self.image_filenames = [x for x in listdir(self.photo_path) if is_image_file(x)] transform_list = [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] self.transform = transforms.Compose(transform_list) def __getitem__(self, index): # Load Image input = load_img(os.path.join(self.photo_path, self.image_filenames[index])) input = self.transform(input) target = load_img(os.path.join(self.sketch_path, self.image_filenames[index])) target = self.transform(target) return input, target def __len__(self): return len(self.image_filenames) if __name__ == '__main__': dataset = DatasetFromFolder("./dataset/facades/train") dataloader = DataLoader(dataset=dataset, num_workers=8, batch_size=1, shuffle=True) total = len(dataloader) for epoch in range(20): t0 = time() for i, batch in enumerate(dataloader): real_a, real_b = batch[0], batch[1] printProgressBar(i + 1, total + 1, length=20, prefix='Epoch %s ' % str(1), suffix=', d_loss: %d' % 1) printProgressBar(total, total, done='Epoch [%s] ' % str(epoch) + ', time: %.2f s' % (time() - t0) )
以上这篇pytorch 批次遍历数据集打印数据的例子就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持。
南强小屋 Design By 杰米
广告合作:本站广告合作请联系QQ:858582 申请时备注:广告合作(否则不回)
免责声明:本站文章均来自网站采集或用户投稿,网站不提供任何软件下载或自行开发的软件! 如有用户或公司发现本站内容信息存在侵权行为,请邮件告知! 858582#qq.com
免责声明:本站文章均来自网站采集或用户投稿,网站不提供任何软件下载或自行开发的软件! 如有用户或公司发现本站内容信息存在侵权行为,请邮件告知! 858582#qq.com
南强小屋 Design By 杰米
暂无pytorch 批次遍历数据集打印数据的例子的评论...