Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow other dtype than complex128. #22

Open
wants to merge 5 commits into
base: master
Choose a base branch
from

Conversation

Ericgig
Copy link
Member

@Ericgig Ericgig commented Jun 22, 2023

Our data layer always use double precision complex, but this is not always ideals, GPUs often don's support double precision well.

With this, the default is still to convert any array to complex128, but it can be overwritten.
However, this is not usable trough the Qutip interface with this PR. I am not too sure how to make it available to the user....

@coveralls
Copy link

coveralls commented Jun 26, 2023

Pull Request Test Coverage Report for Build 5455627965

  • 58 of 62 (93.55%) changed or added relevant lines in 9 files are covered.
  • No unchanged relevant lines lost coverage.
  • Overall coverage decreased (-0.1%) to 90.268%

Changes Missing Coverage Covered Lines Changed/Added Lines %
src/qutip_jax/qobjevo.py 19 23 82.61%
Totals Coverage Status
Change from base Build 5378410853: -0.1%
Covered Lines: 742
Relevant Lines: 822

💛 - Coveralls

Copy link
Member

@AGaliciaMartinez AGaliciaMartinez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Eric, I did not finish reviewing this PR but I wanted to ask the following:

  • Should we consider registering a "jax64" dtype to addres the complex64 case? In qutip-tensorflow I used a _BaseTfTensor class to represent both types of tensors. I then wrapped this class to create TfTensor64 and TfTensor128. So in qutip-tensorflow these two classes are different but virtually the same. It does the work but there may be cleaner approaches? It does require to register the specialisation twice (or perhaps three times if you want to allow operations between 64 and 128 types).
  • I see the use of fast_constructor instead of the init method. Is it to provide a significant speedup in the construction time?

src/qutip_jax/binops.py Show resolved Hide resolved
Comment on lines 15 to +26
@jax.jit
def _cplx2float(arr):
return jnp.stack([arr.real, arr.imag])
if jnp.iscomplexobj(arr):
return jnp.stack([arr.real, arr.imag])
return arr


@jax.jit
def _float2cplx(arr):
return arr[0] + 1j * arr[1]
if arr.ndim == 3:
return arr[0] + 1j * arr[1]
return arr
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test

Copy link
Member Author

@Ericgig Ericgig Jul 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is covered in test_non_cplx128_Diffrax.

@Ericgig
Copy link
Member Author

Ericgig commented Jul 4, 2023

* Should we consider registering a `"jax64"` dtype to addres the complex64 case? In qutip-tensorflow I used a _BaseTfTensor class to represent both types of tensors. I then wrapped this class to create `TfTensor64` and `TfTensor128`. So in qutip-tensorflow these two classes are different but virtually the same. It does the work but there may be cleaner approaches? It does require to register the specialisation twice (or perhaps three times if you want to allow operations between 64 and 128 types).

I would prefer not to.
I am interested in using float64 arrays in specific integrators. But these should not be registered since they would not work anywhere else. As it is right now, they would just be converted to complex arrays when needed.

Also having a jax64 type would encourage users to use them, but some have already trouble differentiating floating point error with physics / code error with double precision.
It would run into issues with default options (integrator's atol is 1e-8, core options' atol is 1e-12). Many users would not think of updating these settings or understand the error these settings would cause.

* I see the use of fast_constructor instead of the **init** method. Is it to provide a significant speedup in the construction time?
  • __init__ convert to complex128 if not explicitly told not to do so.
  • _fast_constructor` can be jitted (I jitted it in the JaxDia PR).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants