Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TFRecord file format #13

Open
mingxoxo opened this issue May 22, 2021 · 0 comments
Open

TFRecord file format #13

mingxoxo opened this issue May 22, 2021 · 0 comments

Comments

@mingxoxo
Copy link
Member

출처 : https://digitalbourgeois.tistory.com/50

TFRecord file format

TFRecord 파일은 텐서플로우로 딥러닝 학습을 하는데 필요한 데이터들을 보관하기 위한 데이터 포맷

장점

  • 코드 구현시 더욱 효율적으로 구현 가능
    TFRecord 파일은 학습 데이터 정보와 label 정보를 하나의 파일에서 관리한다.
    별도의 작업 없이 TFRecord 파일을 읽어 원하는 정보를 받을 수 있다.
  • 학습 속도 개선
    TFRecord 파일은 바이너리 데이터 포맷으로 부수 작업이 필요 없다.
    (이미지 파일과 같은 경우 jpg, png 파일로 되어 있을 경우 매번 코드에서 인코딩 / 디코딩 작업을 해줘야 한다. 이와 같은 작업은 학습을 하는데 비효율적이다.)
  • 파일 사이즈
    이미지 파일을 TFRecord 파일로 생성하게 되면 파일 사이즈가 작아진다.

공부하는 코드에서 TFRecord 변환 부분

def _serialize_image(path, transform=None):
    image = tf.io.read_file(path)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize(image, [CFG.img_size, CFG.img_size])
    image = tf.cast(image, tf.uint8)
    
    if transform is not None:
        image = transform(image=image.numpy())['image']
        
    return tf.image.encode_jpeg(image).numpy()


def _serialize_sample(image, image_name, label):
    feature = {
        'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
        'image_name': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_name])),
        'complex': tf.train.Feature(int64_list=tf.train.Int64List(value=[label[0]])),
        'frog_eye_leaf_spot': tf.train.Feature(int64_list=tf.train.Int64List(value=[label[1]])),
        'powdery_mildew': tf.train.Feature(int64_list=tf.train.Int64List(value=[label[2]])),
        'rust': tf.train.Feature(int64_list=tf.train.Int64List(value=[label[3]])),
        'scab': tf.train.Feature(int64_list=tf.train.Int64List(value=[label[4]])),
        'healthy': tf.train.Feature(int64_list=tf.train.Int64List(value=[label[5]]))}
    sample = tf.train.Example(features=tf.train.Features(feature=feature))
    return sample.SerializeToString()


def serialize_fold(fold, name, transform=None, bar=None):
    samples = []
    
    for image_name, labels in fold.iterrows():
        path = os.path.join(CFG.root, image_name)
        image = _serialize_image(path, transform=transform)
        samples.append(_serialize_sample(image, image_name.encode(), labels))
    
    with tf.io.TFRecordWriter(name + '.tfrec') as writer:
        [writer.write(x) for x in samples]
        
    if bar is not None:
        bar.update(1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant