Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[💡SUG] 请问在 SASRec 下,train,valid 和 test 的维度可以不同吗?如何实现? #2087

Open
gosgjkaj opened this issue Sep 23, 2024 · 0 comments
Assignees
Labels
enhancement New feature or request

Comments

@gosgjkaj
Copy link

gosgjkaj commented Sep 23, 2024

RT
例如数据集包含 1000 个用户,将他们分成两组各 500 人,只对 a 组的用户通过 data.data_preperation 进行 leave-one-out。
此时存在 a 组数据,train,valid 和 test 四个数据集

请问能否做到使用 a+train 作为训练集?也就是 1000 个用户作为训练,但只在组b 的用户上进行 valid 和 test?

我目前的执行方式:

dataset = create_dataset(config)
 ... 
A = dataset.copy(user_id_group_A)
B = dataset.copy(user_id_group_B)

train_data, valid_data, test_data = data_preparation(config, B)

A._change_feat_format()
to_cat = A.inter_feat

be_cat = train_data.dataset.inter_feat
cat = cat_interactions([be_cat,to_cat])
train_data.dataset.inter_feat = cat

如此操作的话会给出如下报错:

Traceback (most recent call last):
  File "/data/xxx/RecBole/test2.py", line 59, in <module>
    run_recbole_dm(model=model,
  File "/data/xxx/RecBole/recbole/quick_start/quick_start.py", line 155, in run_recbole_dm
    flops = get_flops(model, dataset, config["device"], logger, transform)
  File "/data/xxx/RecBole/recbole/utils/utils.py", line 295, in get_flops
    inter = dataset[torch.tensor([1])].to(device)
  File "/data/xxx/RecBole/recbole/data/dataset/dataset.py", line 1525, in __getitem__
    df = self.inter_feat[index]
  File "/home/xxx/anaconda3/envs/recbole/lib/python3.10/site-packages/pandas/core/frame.py", line 4108, in __getitem__
    indexer = self.columns._get_indexer_strict(key, "columns")[1]
  File "/home/xxx/anaconda3/envs/recbole/lib/python3.10/site-packages/pandas/core/indexes/base.py", line 6200, in _get_indexer_strict
    self._raise_if_missing(keyarr, indexer, axis_name)
  File "/home/xxx/anaconda3/envs/recbole/lib/python3.10/site-packages/pandas/core/indexes/base.py", line 6249, in _raise_if_missing
    raise KeyError(f"None of [{key}] are in the [{axis_name}]")
KeyError: "None of [Index([1], dtype='int64')] are in the [columns]"

本人能力有限,恳请各位大佬解答,谢谢

@gosgjkaj gosgjkaj added the enhancement New feature or request label Sep 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants