forked from NVlabs/stylegan2-ada-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
FirestoreDataset.py
42 lines (32 loc) · 1.29 KB
/
FirestoreDataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
# from io import BytesIO
# import firebase_admin
# import requests
# import torch
# import torchvision.transforms as transforms
# from firebase_admin import credentials, firestore
# from PIL import Image
# from torch.utils.data import Dataset
# # Initialize Firebase
# cred = credentials.Certificate('../fyp-project-83298-firebase-adminsdk-omga1-3c741ce672.json')
# firebase_admin.initialize_app(cred, {
# 'storageBucket': 'fyp-project-83298.appspot.com'
# }
# db = firestore.client()
# bucket = storage.bucket()
# class FirestoreDataset(Dataset):
# def __init__(self):
# self.metadata = list(db.collection('products').stream())
# self.transform = transforms.ToTensor() # Convert images to PyTorch tensors
# self.root_dir = 'data/training_images'
# def __len__(self):
# return len(self.metadata)
# def __getitem__(self, idx):
# meta = self.metadata[idx].to_dict()
# product_id = meta['product_id']
# img_path = os.path.join(self.root_dir, f'{product_id}.png')
# image = Image.open(img_path).convert('RGB')
# if self.transform:
# image = self.transform(image)
# labels = torch.tensor([meta['color_tag'], meta['complexity_tag']])
# weight = meta['weight']
# return image, labels, weight