-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
12d73aa
commit 18afc71
Showing
1 changed file
with
89 additions
and
32 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,7 +11,7 @@ gh-badge: [star, fork] | |
* @Author: Conghao Wong | ||
* @Date: 2023-08-21 15:58:54 | ||
* @LastEditors: Conghao Wong | ||
* @LastEditTime: 2023-10-26 17:30:41 | ||
* @LastEditTime: 2024-04-11 16:04:35 | ||
* @Description: file content | ||
* @Github: https://cocoon2wong.github.io | ||
* Copyright 2023 Conghao Wong, All Rights Reserved. | ||
|
@@ -54,7 +54,7 @@ These codes are developed with different python versions: | |
|
||
- PyTorch Version: | ||
|
||
Codes ([https://github.com/cocoon2wong/SocialCircle/tree/TorchVersion(beta)](https://github.com/cocoon2wong/SocialCircle/tree/TorchVersion(beta))) are developed with **Python 3.11**. | ||
Codes ([https://github.com/cocoon2wong/SocialCircle/tree/TorchVersion(beta)](https://github.com/cocoon2wong/SocialCircle/tree/TorchVersion(beta))) are developed with **Python 3.10**. | ||
|
||
Additional packages used are included in the `requirements.txt` file. | ||
|
||
|
@@ -70,70 +70,126 @@ pip install -r requirements.txt | |
|
||
## Dataset Prepare and Process | ||
|
||
Before training `SocialCircle` models on your own dataset, you should add your dataset information. | ||
See [this document](https://cocoon2wong.github.io/Project-Luna/) for details. | ||
If you just want to validate these models on ETH-UCY and SDD, please head to the [Pre-Trained Model Weights](https://cocoon2wong.github.io/SocialCircle/guidelines/#pre-trained-model-weights) section. | ||
|
||
## Training | ||
|
||
*Available Soon* | ||
|
||
## Evaluation | ||
### ETH-UCY, SDD, NBA, nuScenes | ||
|
||
*Available Soon* | ||
|
||
### Pre-Trained Model Weights | ||
|
||
We have provided our pre-trained model weights to help you quickly evaluate the `SocialCircle` models' performance. | ||
Our pre-trained models contain: | ||
|
||
- `MSN` ([🔗homepage](https://northocean.github.io/MSN/)) and its SocialCircle variation `MSN-SC` (8-to-12 on SDD, forecasts 20 random sampled trajectories for each agent); | ||
- `V^2-Net` ([🔗homepage](https://cocoon2wong.github.io/Vertical/)) and its SocialCircle variation `V^2-Net-SC` (8-to-12 on SDD, 20 trajectories); | ||
- `E-V^2-Net` ([🔗homepage](https://cocoon2wong.github.io/E-Vertical/)) and its SocialCircle variation `E-V^2-Net-SC` (8-to-12 on SDD, 20 trajectories). | ||
{: .box-warning} | ||
**Warning:** If you want to validate `SocialCircle` models on these datasets, make sure you are getting this repository via `git clone` and that all `gitsubmodules` have been properly initialized via `git submodule update --init --recursive`. | ||
|
||
You can run the following commands to prepare ETH-UCY and SDD dataset files: | ||
You can run the following commands to prepare dataset files that have been validated in our paper: | ||
|
||
1. Run Python the script inner the `dataset_original` folder: | ||
|
||
```bash | ||
cd dataset_original && python main_ethucysdd.py | ||
cd dataset_original | ||
``` | ||
|
||
2. For the `PyTorch` version, you can run the following command inner the `dataset_original` folder to transfer dataset files of NBA dataset. ***(Optional)*** | ||
- For `ETH-UCY` and `SDD`, run | ||
|
||
```bash | ||
python main_nba.py | ||
``` | ||
```bash | ||
python main_ethucysdd.py | ||
``` | ||
|
||
- For `NBA` or `nuScenes`, you can download their original dataset files, put them into the given path listed within `dataset_original/main_nba.py` or `dataset_original/main_nuscenes.py`, then run | ||
|
||
3. Back to the repo folder and create soft links: | ||
```bash | ||
python main_nba.py | ||
python main_nuscenes.py | ||
``` | ||
|
||
(You can also download the processed dataset files manually from [here](https://github.com/cocoon2wong/Project-Monandaeg/tree/main/Dataset), and put them into `dataset_processed` and `dataset_configs` folders.) | ||
|
||
2. Back to the repo folder and create soft links: | ||
|
||
```bash | ||
cd .. | ||
ln -s dataset_original/dataset_processed ./ | ||
ln -s dataset_original/dataset_configs ./ | ||
``` | ||
|
||
Click the following buttons to download our weights and learn more about how to install these datasets. | ||
Click the following button to learn more about how to process these dataset files. | ||
|
||
<div style="text-align: center;"> | ||
<a class="btn btn-colorful btn-lg" href="https://cocoon2wong.github.io/Project-Luna/howToUse/">💡 Dataset Guidelines</a> | ||
</div> | ||
|
||
### Prepare Your New Datasets | ||
|
||
Before training `SocialCircle` models on your own dataset, you should add your dataset information. | ||
See [this document](https://cocoon2wong.github.io/Project-Luna/) for details. | ||
|
||
## Pre-Trained Model Weights and Evaluation | ||
|
||
We have provided our pre-trained model weights to help you quickly evaluate the `SocialCircle` models' performance. | ||
Click the following buttons to download our model weights. | ||
We recommend that you download the weights and place them in the `weights/SocialCircle` folder. | ||
<div style="text-align: center;"> | ||
<a class="btn btn-colorful btn-lg" href="https://github.com/cocoon2wong/SocialCircle/releases">⬇️ Download Weights (TensorFlow 2)</a> | ||
<a class="btn btn-colorful btn-lg" href="https://github.com/cocoon2wong/Project-Monandaeg/tree/main/Silverbullet-Torch">⬇️ Download Weights (PyTorch)</a> | ||
<a class="btn btn-colorful btn-lg" href="https://cocoon2wong.github.io/Project-Luna/howToUse/">💡 Dataset Guidelines</a> | ||
</div> | ||
{: .box-warning} | ||
**Warning:** The TensorFlow 2 version of codes only support weights that trained with TensorFlow 2, and the PyTorch version of codes only support weights that trained with PyTorch. | ||
Please download the correct weights file or the program will not run correctly. | ||
You can start evaluating our pre-trained weights by | ||
You can start evaluating models by | ||
```bash | ||
python main.py --sc SOME_MODEL_WEIGHTS | ||
``` | ||
Here, `SOME_MODEL_WEIGHTS` is the path of the weights folder, for example, `weights/SocialCircle/evsc_P8_sdd`. | ||
## Training | ||
You can start training a `SocialCircle` model via the following command: | ||
```bash | ||
python main.py --model MODEL_IDENTIFIER --split DATASET_SPLIT | ||
``` | ||
Here, `MODEL_IDENTIFIER` is the identifier of the model. | ||
These identifiers are supported in current codes: | ||
- The basic transformer model for trajectory prediction: | ||
- `trans` (named the `Transformer` in the paper); | ||
- `transsc` (SocialCircle variation `Transformer-SC`). | ||
- MSN ([🔗homepage](https://northocean.github.io/MSN/)): | ||
- `msna` (original model); | ||
- `msnsc` (SocialCircle variation). | ||
- *NOTE: `MSN` models are not supported in the PyTorch branch. Please validate them in the TensorFlow branch.* | ||
- V^2-Net ([🔗homepage](https://cocoon2wong.github.io/Vertical/)): | ||
- `va` (original model); | ||
- `vsc` (SocialCircle variation). | ||
- E-V^2-Net ([🔗homepage](https://cocoon2wong.github.io/E-Vertical/)): | ||
- `eva` (original model); | ||
- `evsc` (SocialCircle variation). | ||
`DATASET_SPLIT` is the identifier (i.e., the name of dataset's split files in `dataset_configs`, for example `eth` is the identifier of the split list in `dataset_configs/ETH-UCY/eth.plist`) of the dataset or splits used for training. | ||
It accepts: | ||
|
||
- ETH-UCY: {`eth`, `hotel`, `univ`, `zara1`, `zara2`}; | ||
- SDD: `sdd`; | ||
- NBA: `nba50k`; | ||
- nuScenes: {`nuScenes_v1.0`, `nuScenes_ov_v1.0`}; | ||
|
||
For example, you can start training the `E-V^2-Net-SC` model by | ||
|
||
```bash | ||
python main.py --model evsc --split zara1 | ||
``` | ||
|
||
You can also specify other needed args, like the learning rate `--lr`, batch size `--batch_size`, etc. | ||
See detailed args in the `Args Used` Section. | ||
|
||
In addition, the simplest way to reproduce our results is to copy all training args we used in the provided weights. | ||
For example, you can start a training of `E-V^2-Net-SC` on `zara1` by: | ||
|
||
```bash | ||
python main.py --restore_args weights/SocialCircle/evsczara1 | ||
``` | ||
|
||
### Toy Example | ||
|
||
You can run the following script to learn how the proposed `SocialCircle` works in an interactive way: | ||
|
@@ -393,4 +449,5 @@ About the `argtype`: | |
## Contact us | ||
|
||
Conghao Wong ([@cocoon2wong](https://github.com/cocoon2wong)): [email protected] | ||
Beihao Xia ([@NorthOcean](https://github.com/NorthOcean)): [email protected] | ||
Beihao Xia ([@NorthOcean](https://github.com/NorthOcean)): [email protected] | ||
Ziqian Zou ([@LivepoolQ](https://github.com/LivepoolQ)): [email protected] |