1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
| import numpy as np import pandas as pd from skimage import io import matplotlib.pyplot as plt import os import lmdb import caffe
def make_datum(image, label, channels, height, width): datum = caffe.proto.caffe_pb2.Datum() datum.channels = channels datum.label = int(label) datum.height = height datum.width = width datum.data = image.tobytes()
return datum
# data path and lmdb path dataset_path = './data' label_file = 'labels.csv' lmdb_path = 'cifar10_lmdb'
# labels mapping labels_mapping = {'airplane': 0, 'automobile': 1, 'bird': 2, 'cat': 3, 'deer': 4, 'dog': 5, 'frog': 6, 'horse': 7, 'ship': 8, 'truck': 9} classes = {} for key in labels_mapping: classes[labels_mapping[key]] = key
# load data df = pd.read_csv(label_file) df['label'] = df['label'].map(labels_mapping) images = list(df.id) labels = list(df.label)
# write data to LMDB map_size = 1e6 batch_size = 4
count = 0 lmdb_env = lmdb.open(lmdb_path, map_size=map_size) lmdb_txn = lmdb_env.begin(write=True)
for image_id, label in zip(images, labels): count = count + 1 image_file = os.path.join(dataset_path, str(image_id) + '.png') image = io.imread(image_file) height, width, channels = image.shape datum = make_datum(image, label, channels, height, width) str_id = '{:08}'.format(count) lmdb_txn.put(str_id, datum.SerializeToString())
if count % batch_size == 0: lmdb_txn.commit() lmdb_txn = lmdb_env.begin(write=True)
lmdb_txn.commit() lmdb_env.close()
|