Skip to content

Commit

Permalink
阶段性完成SERDataset数据集加载
Browse files Browse the repository at this point in the history
  • Loading branch information
KevinNuNu committed Mar 28, 2023
1 parent d3e16ad commit 1d0c5e3
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 25 deletions.
55 changes: 31 additions & 24 deletions mmocr/datasets/ser_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,39 +46,46 @@ def load_data_list(self) -> List[dict]:
data_list = super().load_data_list()

# split text to several slices because of over-length
input_ids, bboxes, labels = [], [], []
segment_ids, position_ids = [], []
image_path = []
split_text_data_list = []
for i in range(len(data_list)):
start = 0
cur_iter = 0
while start < len(data_list[i]['input_ids']):
end = min(start + 510, len(data_list[i]['input_ids']))

input_ids.append([self.tokenizer.cls_token_id] +
data_list[i]['input_ids'][start:end] +
[self.tokenizer.sep_token_id])
bboxes.append([[0, 0, 0, 0]] +
data_list[i]['bboxes'][start:end] +
[[1000, 1000, 1000, 1000]])
labels.append([-100] + data_list[i]['labels'][start:end] +
[-100])

cur_segment_ids = self.get_segment_ids(bboxes[-1])
cur_position_ids = self.get_position_ids(cur_segment_ids)
segment_ids.append(cur_segment_ids)
position_ids.append(cur_position_ids)
image_path.append(
os.path.join(self.data_root, data_list[i]['img_path']))
# get input_ids
input_ids = [self.tokenizer.cls_token_id] + \
data_list[i]['input_ids'][start:end] + \
[self.tokenizer.sep_token_id]
# get bboxes
bboxes = [[0, 0, 0, 0]] + \
data_list[i]['bboxes'][start:end] + \
[[1000, 1000, 1000, 1000]]
# get labels
labels = [-100] + data_list[i]['labels'][start:end] + [-100]
# get segment_ids
segment_ids = self.get_segment_ids(bboxes)
# get position_ids
position_ids = self.get_position_ids(segment_ids)
# get img_path
img_path = os.path.join(self.data_root,
data_list[i]['img_path'])
# get attention_mask
attention_mask = [1] * len(input_ids)

data_info = {}
data_info['input_ids'] = input_ids
data_info['bboxes'] = bboxes
data_info['labels'] = labels
data_info['segment_ids'] = segment_ids
data_info['position_ids'] = position_ids
data_info['img_path'] = img_path
data_info['attention_mask '] = attention_mask
split_text_data_list.append(data_info)

start = end
cur_iter += 1

assert len(input_ids) == len(bboxes) == len(labels) == len(
segment_ids) == len(position_ids)
assert len(segment_ids) == len(image_path)

return data_list
return split_text_data_list

def parse_data_info(self, raw_data_info: dict) -> Union[dict, List[dict]]:
instances = raw_data_info['instances']
Expand Down
11 changes: 10 additions & 1 deletion projects/LayoutLMv3/test.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,23 @@
from mmengine.config import Config
from mmengine.registry import init_default_scope

from mmocr.registry import DATASETS

if __name__ == '__main__':
cfg_path = '/Users/wangnu/Documents/GitHub/mmocr/projects/' \
'LayoutLMv3/configs/layoutlmv3_xfund_zh.py'
cfg = Config.fromfile(cfg_path)
init_default_scope(cfg.get('default_scope', 'mmocr'))

dataset_cfg = cfg.train_dataset
dataset_cfg['tokenizer'] = \
'/Users/wangnu/Documents/GitHub/mmocr/data/layoutlmv3-base-chinese'

train_pipeline = [
dict(type='LoadImageFromFile', color_type='color'),
dict(type='Resize', scale=(224, 224))
]
dataset_cfg['pipeline'] = train_pipeline
ds = DATASETS.build(dataset_cfg)
print(ds[0])
data = ds[0]
print('hi')

0 comments on commit 1d0c5e3

Please sign in to comment.