Skip to content

Commit

Permalink
ENH: reimplement bijection checks from 10bf4cd
Browse files Browse the repository at this point in the history
  • Loading branch information
phmbressan committed Apr 4, 2023
1 parent d5c4eab commit 1d39356
Showing 1 changed file with 88 additions and 4 deletions.
92 changes: 88 additions & 4 deletions rocketpy/Function.py
Original file line number Diff line number Diff line change
Expand Up @@ -2411,20 +2411,99 @@ def integralFunction(self, lower=None, upper=None, datapoints=100):
outputs=[o + " Integral" for o in self.__outputs__],
)

def isBijective(self):
"""Checks whether the Function is bijective. Only applicable to
Functions whose source is a list of points, raises an error otherwise.
Returns
-------
result : bool
True if the Function is bijective, False otherwise.
"""
if isinstance(self.source, np.ndarray):
xDataDistinct = set(self.xArray)
yDataDistinct = set(self.yArray)
distinctMap = set(zip(xDataDistinct, yDataDistinct))
return len(distinctMap) == len(xDataDistinct) == len(yDataDistinct)
else:
raise TypeError(
"Only Functions whose source is a list of points can be "
"checked for bijectivity."
)

def isStrictlyBijective(self):
"""Checks whether the Function is "strictly" bijective.
Only applicable to Functions whose source is a list of points,
raises an error otherwise.
Notes
-----
By "strictly" bijective, this implementation considers the
list-of-points-defined Function bijective between each consecutive pair
of points. Therefore, the Function may be flagged as not bijective even
if the mapping between the set of points which define the Function is
bijective.
Returns
-------
result : bool
True if the Function is "strictly" bijective, False otherwise.
Examples
--------
>>> f = Function([[0, 0], [1, 1], [2, 4]])
>>> f.isBijective()
True
>>> f.isStrictlyBijective()
True
>>> f = Function([[-1, 1], [0, 0], [1, 1], [2, 4]])
>>> f.isBijective()
False
>>> f.isStrictlyBijective()
False
A Function which is not "strictly" bijective, but is bijective, can be
constructed as x^2 defined at -1, 0 and 2.
>>> f = Function([[-1, 1], [0, 0], [2, 4]])
>>> f.isBijective()
True
>>> f.isStrictlyBijective()
False
"""
if isinstance(self.source, np.ndarray):
# Assuming domain is sorted, range must also be
yData = self.yArray
# Both ascending and descending order means Function is bijective
yDataDiff = np.diff(yData)
return np.all(yDataDiff >= 0) or np.all(yDataDiff <= 0)
else:
raise TypeError(
"Only Functions whose source is a list of points can be "
"checked for bijectivity."
)

def inverseFunction(self, approxFunc=None, tol=1e-4):
"""
Returns the inverse of the Function. The inverse function of F is a
function that undoes the operation of F. The inverse of F exists if
and only if F is bijective. Makes the domain the range and the range
the domain.
If the Function is given by a list of points, its bijectivity is
checked and an error is raised if it is not bijective.
If the Function is given by a function, its bijection is not
checked and may lead to innacuracies outside of its bijective region.
Parameters
----------
approxFunc : callable, optional
A function that approximates the inverse of the Function. This
function is used to find the starting guesses for the inverse
root finding algorithm. This is better used when the inverse
in complex but has a simple approximation.
in complex but has a simple approximation or when the root
finding algorithm performs poorly due to default start point.
The default is None in which case the starting point is zero.
tol : float, optional
Expand All @@ -2437,10 +2516,15 @@ def inverseFunction(self, approxFunc=None, tol=1e-4):
A Function whose domain and range have been inverted.
"""
if isinstance(self.source, np.ndarray):
# Swap the columns
source = np.flip(self.source, axis=1)
if self.isStrictlyBijective():
# Swap the columns
source = np.flip(self.source, axis=1)
else:
raise ValueError(
"Function is not bijective, so it does not have an inverse."
)
else:
if approxFunc:
if approxFunc is not None:
source = lambda x: self.findInput(x, start=approxFunc(x), tol=tol)
else:
source = lambda x: self.findInput(x, start=0, tol=tol)
Expand Down

0 comments on commit 1d39356

Please sign in to comment.