Skip to content

Commit

Permalink
Merge pull request #102 from traja-team/numpy-complex-type
Browse files Browse the repository at this point in the history
Bump Numpy version and fix complex type
  • Loading branch information
WolfByttner authored Jun 2, 2024
2 parents 08f9e00 + 6e1d0ce commit f165aff
Show file tree
Hide file tree
Showing 32 changed files with 506 additions and 491 deletions.
9 changes: 6 additions & 3 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,21 @@ on:

jobs:
miniconda:
name: Miniconda ${{ matrix.os }}
name: Miniconda ${{ matrix.os }} - Python ${{ matrix.python-version }}
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: ["ubuntu-latest", "windows-latest"]
python-version: [ "3.8", "3.12" ]
steps:
- uses: actions/checkout@v2
- uses: conda-incubator/setup-miniconda@v2
- name: Set up Miniconda (Python ${{ matrix.python-version }})
uses: conda-incubator/setup-miniconda@v2
with:
activate-environment: test
channels: conda-forge,defaults
environment-file: environment.yml
python-version: 3.8
python-version: ${{ matrix.python-version }}
auto-activate-base: false
- shell: bash -l {0}
run: |
Expand All @@ -44,6 +46,7 @@ jobs:
conda install pytest
py.test . --cov-report=xml --cov=traja -vvv
- name: Upload coverage to Codecov
if: ${{ matrix.python-version }} == '3.8'
uses: codecov/codecov-action@v1
with:
token: ${{ secrets.CODECOV_TOKEN }}
Expand Down
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,5 @@ docs/source/reference

# Model parameter files
*.pt
.python-version
datasets/
3 changes: 1 addition & 2 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ dependencies:
- networkx
- seaborn
- pytorch
- pytest==6.2.2
- pytest>=8.0.0
- numba>=0.50.0
- pyDOE2>=1.3.0
- statsmodels
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
pandas>=1.2.0
numpy==1.18.5
numpy>=1.22.0
matplotlib
shapely
scipy>=1.4.1
Expand Down
2 changes: 1 addition & 1 deletion requirements/docs.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
pandas>=1.2.0
numpy==1.18.5
numpy>=1.22.0
matplotlib
shapely
scipy
Expand Down
12 changes: 3 additions & 9 deletions requirements/extra.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,7 @@ pytest
h5py
ipython
pre-commit
shapely
scipy>=1.4.1
scikit-learn
fastdtw
networkx
seaborn
torch
h5py
numba>=0.50.0
pyDOE2>=1.3.0
numba>=0.50.1
black
isort
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.2.0
current_version = 23.0.0

[yapf]
column_limit = 120
Expand Down
10 changes: 5 additions & 5 deletions traja/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import logging

from traja import dataset
from traja import models
from traja import dataset, models

from .accessor import TrajaAccessor
from .frame import TrajaDataFrame, TrajaCollection
from .parsers import read_file, from_df
from .frame import TrajaCollection, TrajaDataFrame
from .parsers import from_df, read_file
from .plotting import *
from .trajectory import *

__author__ = "justinshenk"
__version__ = "22.0.0"
__version__ = "23.0.0"

logging.basicConfig(level=logging.INFO)

Expand Down
3 changes: 2 additions & 1 deletion traja/contrib/rdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@
The above copyright notice and this permission notice shall be included
in all copies or substantial portions of the Software.
"""

from functools import partial
from typing import Union, Callable
from typing import Callable, Union

import numpy as np

Expand Down
8 changes: 8 additions & 0 deletions traja/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import pandas as pd


# Check whether pandas series is datetime or timedelta
def is_datetime_or_timedelta_dtype(series: pd.Series) -> bool:
return pd.api.types.is_datetime64_dtype(
series
) or pd.api.types.is_timedelta64_dtype(series)
2 changes: 1 addition & 1 deletion traja/dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from . import example
from .dataset import TimeSeriesDataset, MultiModalDataLoader
from .dataset import MultiModalDataLoader, TimeSeriesDataset
from .pedestrian import load_ped_data, ped_datasets
1 change: 1 addition & 0 deletions traja/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
1. Class distribution in the dataset
"""

import logging
import math
from collections import defaultdict
Expand Down
2 changes: 1 addition & 1 deletion traja/dataset/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@
def jaguar(cache_url=default_cache_url):
# Sample data
data_url = "https://raw.githubusercontent.com/traja-team/traja-research/dataset_und_notebooks/dataset_analysis/jaguar5.csv"
df = pd.read_csv(data_url, error_bad_lines=False)
df = pd.read_csv(data_url, on_bad_lines="skip")
return df
7 changes: 4 additions & 3 deletions traja/dataset/pedestrian.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import subprocess
import glob
import os
import subprocess
from typing import List

import pandas as pd
from traja.dataset import dataset
import traja

import traja
from traja.dataset import dataset

"""Convenience module for downloading pedestrian-related datasets."""

Expand Down
151 changes: 0 additions & 151 deletions traja/dataset/pituitary_gland.py

This file was deleted.

12 changes: 6 additions & 6 deletions traja/frame.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from typing import Optional, Union, Tuple
import warnings
from typing import Optional, Tuple, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -54,13 +54,13 @@ def __init__(self, *args, **kwargs):
args[0]._copy_attrs(self)
for name, value in traja_kwargs.items():
self.__dict__[name] = value
# Initialize

# Initialize
self._convex_hull = None

# Initialize metadata like 'fps','spatial_units', etc.
self._init_metadata()

@property
def _constructor(self):
return TrajaDataFrame
Expand Down Expand Up @@ -171,15 +171,15 @@ def center(self):
x = self.x
y = self.y
return float(x.mean()), float(y.mean())

@property
def convex_hull(self):
"""Property of TrajaDataFrame class representing
bounds for convex area enclosing trajectory points.
"""
# Calculate if it doesn't exist
if self._convex_hull is None:
if self._convex_hull is None:
xy_arr = self.traja.xy
point_arr = traja.trajectory.calc_convex_hull(xy_arr)
self._convex_hull = point_arr
Expand Down
3 changes: 2 additions & 1 deletion traja/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from traja.models.generative_models.vaegan import MultiModelVAEGAN
from traja.models.predictive_models.ae import MultiModelAE
from traja.models.predictive_models.lstm import LSTM

from .inference import *
from .train import HybridTrainer
from .utils import TimeDistributed, read_hyperparameters, save, load
from .utils import TimeDistributed, load, read_hyperparameters, save
1 change: 1 addition & 0 deletions traja/models/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def generate(self, num_steps, classify=True, scaler=None, plot_data=True):
elif self.model_type == "vaegan" or "custom":
return NotImplementedError


class Predictor:
def __init__(
self,
Expand Down
3 changes: 2 additions & 1 deletion traja/models/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

device = "cuda" if torch.cuda.is_available() else "cpu"


class Criterion:
"""Implements the loss functions of Autoencoders, Variational Autoencoders and LSTM models
Huber loss is set as default for reconstruction loss, alternative is to use rmse,
Expand Down Expand Up @@ -30,7 +31,7 @@ def forecasting_criterion(
"""

if mu is not None and logvar is not None:
kld = -0.5 * torch.sum(1 + logvar - mu ** 2 - logvar.exp())
kld = -0.5 * torch.sum(1 + logvar - mu**2 - logvar.exp())
else:
kld = 0

Expand Down
Loading

0 comments on commit f165aff

Please sign in to comment.