1 # -*- coding: utf-8 -*- 2 from pathlib import Path #从pathlib中导入Path 3 import os 4 import fileinput 5 import random 6 root_path='/home/tay/Videos/trash/垃圾分类项目/total/' 7 train = open('./trash_train.txt','a') 8 test = open('./trash_test.txt','a') 9 pwd = os.getcwd() +'/'# the val data path 训练集的路径10 11 12 def gen_txt():13 i =014 for file in os.listdir(root_path):15 print('file is{}'.format(str(file)))16 for init in os.listdir(os.path.join(root_path, file)): #子文件夹17 print('init is{}'.format(str(init)))18 i += 119 pathDir = os.listdir(os.path.join(root_path, file, init)) #20 print('pathDir is', pathDir)21 file_num = len(pathDir)22 rate = 0.223 pick_num = int(file_num * rate)24 sample = random.sample(pathDir, pick_num) #随机选取20%的pathDir字符串25 print('sample is', sample)26 for pick_name in sample:27 test.write(root_path.split('total/')[-1] +file + '/' + init +'/' + pick_name + ' ' + str(i) + '\n')28 # for name in pathDir: #文件夹中的图片名29 # print('name is{}'.format(str(name)))30 # if test31 # total.write(root_path.split('total/')[-1] +file + '/' + init +'/' + name + ' ' + str(i) + '\n' )32 same = [x for x in pathDir if x in sample] #列表中相同的内容33 diff = [y for y in (sample + pathDir) if y not in same] #列表中不同的内容34 print('different', diff)35 print('same', same)36 for train_name in diff:37 train.write(root_path.split('total/')[-1] +file + '/' + init +'/' + train_name + ' ' + str(i) + '\n')38 gen_txt()
采用了random.sample函数来随机选取特定数量的文件名作为测试集,通过比较两个列表中不同的元素来获取训练集的文件名。
总体上就是在进行字符串操作。