fromrandomimportshuffleimportglobimportcv2importtensorflowastfimportnumpyasnpimportsysimporttqdmtrain_filename='test1.tfrecords'# address to save the TFRecords fileshuffle_data=True# shuffle the addresses before savingcat_dog_train_path="./all/test1/test1/*jpg"# read addresses and labels from the 'train' folderaddrs=glob.glob(cat_dog_train_path)labels=[0if'cat'inaddrelse1foraddrinaddrs]# 0 = Cat, 1 = Dog# to shuffle dataifshuffle_data:c=list(zip(addrs,labels))shuffle(c)addrs,labels=zip(*c)# Divide the data into 60% train, 20% validation, and 20% testtrain_addrs=addrs[0:int(0.6*len(addrs))]train_labels=labels[0:int(0.6*len(labels))]val_addrs=addrs[int(0.6*len(addrs)):int(0.8*len(addrs))]val_labels=labels[int(0.6*len(addrs)):int(0.8*len(addrs))]test_addrs=addrs[int(0.8*len(addrs)):]test_labels=labels[int(0.8*len(labels)):]defload_image(addr):# read an image and resize to (224, 224)# cv2 load images as BGR, convert it to RGBimg=cv2.imread(addr,0)img=cv2.resize(img,(224,224),interpolation=cv2.INTER_CUBIC)# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)img=img.astype(np.uint8)returnimgdef_int64_feature(value):returntf.train.Feature(int64_list=tf.train.Int64List(value=[value]))def_bytes_feature(value):returntf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))# open the TFRecords filewriter=tf.python_io.TFRecordWriter(train_filename)foriintqdm(range(len(train_addrs))):# print how many images are saved every 1000 imagesifnoti%1000:print('Train data: {}/{}'.format(i,len(train_addrs)))sys.stdout.flush()# Load the imageimg=load_image(train_addrs[i])label=train_labels[i]# Create a featurefeature={'train/label':_int64_feature(label),'train/image':_bytes_feature(tf.compat.as_bytes(img.tostring()))}# Create an example protocol bufferexample=tf.train.Example(features=tf.train.Features(feature=feature))# Serialize to string and write on the filewriter.write(example.SerializeToString())writer.close()sys.stdout.flush()
Loading and testing tfrecord images
importtensorflowastfimportnumpyasnpimportcv2importmatplotlib.pyplotaspltdata_path='test1.tfrecords'# address to save the hdf5 filewithtf.Session()assess:feature={'train/image':tf.FixedLenFeature([],tf.string),'train/label':tf.FixedLenFeature([],tf.int64)}# Create a list of filenames and pass it to a queuefilename_queue=tf.train.string_input_producer([data_path],num_epochs=1)# Define a reader and read the next recordreader=tf.TFRecordReader()_,serialized_example=reader.read(filename_queue)# Decode the record read by the readerfeatures=tf.parse_single_example(serialized_example,features=feature)# Convert the image data from string back to the numbersimage=tf.decode_raw(features['train/image'],tf.uint8)# Cast label data into int32label=tf.cast(features['train/label'],tf.int32)# Reshape image data into the original shapeimage=tf.reshape(image,[224,224])# Any preprocessing here ...# Creates batches by randomly shuffling tensorsimages,labels=tf.train.shuffle_batch([image,label],batch_size=10,capacity=30,num_threads=1,min_after_dequeue=10)# Initialize all global and local variablesinit_op=tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())sess.run(init_op)# Create a coordinator and run all QueueRunner objectscoord=tf.train.Coordinator()threads=tf.train.start_queue_runners(coord=coord)forbatch_indexinrange(5):img,lbl=sess.run([images,labels])# img = img.astype(np.uint8)forjinrange(6):plt.subplot(2,3,j+1)plt.imshow(img[j,...])plt.title('cat'iflbl[j]==0else'dog')plt.show()forjinrange(6):print(img[j].shape)cv2.imshow('cat'iflbl[j]==0else'dog',img[j])cv2.waitKey(0)# Stop the threadscoord.request_stop()# Wait for threads to stopcoord.join(threads)sess.close()
Manuel Cuevas
Hello, I'm Manuel Cuevas a Software Engineer with background in machine learning and artificial intelligence.