forked from dynamicslab/pysindy
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request dynamicslab#85 from dynamicslab/expanded-derivatives
Expanded derivatives
- Loading branch information
Showing
15 changed files
with
1,220 additions
and
179 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,12 @@ | ||
from .base import BaseDifferentiation | ||
from .finite_difference import FiniteDifference | ||
from .sindy_derivative import SINDyDerivative | ||
from .smoothed_finite_difference import SmoothedFiniteDifference | ||
|
||
__all__ = ["BaseDifferentiation", "FiniteDifference", "SmoothedFiniteDifference"] | ||
|
||
__all__ = [ | ||
"BaseDifferentiation", | ||
"FiniteDifference", | ||
"SINDyDerivative", | ||
"SmoothedFiniteDifference", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
""" | ||
Wrapper classes for differentiation methods from the :doc:`derivative:index` package. | ||
Some default values used here may differ from those used in :doc:`derivative:index`. | ||
""" | ||
from derivative import dxdt | ||
from numpy import arange | ||
from sklearn.base import BaseEstimator | ||
|
||
from pysindy.utils.base import validate_input | ||
|
||
|
||
class SINDyDerivative(BaseEstimator): | ||
""" | ||
Wrapper class for differentiation classes from the :doc:`derivative:index` package. | ||
This class is meant to provide all the same functionality as the | ||
`dxdt <https://derivative.readthedocs.io/en/latest/api.html\ | ||
#derivative.differentiation.dxdt>`_ method. | ||
This class also has ``_differentiate`` and ``__call__`` methods which are | ||
used by PySINDy. | ||
Parameters | ||
---------- | ||
derivative_kws: dictionary, optional | ||
Keyword arguments to be passed to the | ||
`dxdt <https://derivative.readthedocs.io/en/latest/api.html\ | ||
#derivative.differentiation.dxdt>`_ | ||
method. | ||
Notes | ||
----- | ||
See the `derivative documentation <https://derivative.readthedocs.io/en/latest/>`_ | ||
for acceptable keywords. | ||
""" | ||
|
||
def __init__(self, **kwargs): | ||
self.kwargs = kwargs | ||
|
||
def set_params(self, **params): | ||
""" | ||
Set the parameters of this estimator. | ||
Modification of the pysindy method to allow unknown kwargs. This allows using | ||
the full range of derivative parameters that are not defined as member variables | ||
in sklearn grid search. | ||
Returns | ||
------- | ||
self | ||
""" | ||
if not params: | ||
# Simple optimization to gain speed (inspect is slow) | ||
return self | ||
else: | ||
self.kwargs.update(params) | ||
|
||
return self | ||
|
||
def get_params(self, deep=True): | ||
"""Get parameters.""" | ||
params = super().get_params(deep) | ||
|
||
if isinstance(self.kwargs, dict): | ||
params.update(self.kwargs) | ||
|
||
return params | ||
|
||
def _differentiate(self, x, t=1): | ||
if isinstance(t, (int, float)): | ||
if t < 0: | ||
raise ValueError("t must be a positive constant or an array") | ||
t = arange(x.shape[0]) * t | ||
|
||
return dxdt(x, t, axis=0, **self.kwargs) | ||
|
||
def __call__(self, x, t=1): | ||
x = validate_input(x, t=t) | ||
return self._differentiate(x, t) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.