You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
As specified in base.py, the UpdateFn should only take in two arguments: an RNG key and the previous state. However many samplers take in additional arguments. For example, sgld also takes a minibatch of data and a step size after the rng and state. Maybe let's add *args to the signature of UpdateFn.__call__?
The text was updated successfully, but these errors were encountered:
not sure if adding *args to UpdateFn.__call__ works or will just force us to add # type: ignore[arg-type] to every use of SamplingAlgorithm (instead of just the ones that deviate from the two arguments you mentioned).
but you can try opening a PR with the change to see if it works @JGameCreation, and we can see from there.
As specified in base.py, the UpdateFn should only take in two arguments: an RNG key and the previous state. However many samplers take in additional arguments. For example, sgld also takes a minibatch of data and a step size after the rng and state. Maybe let's add
*args
to the signature ofUpdateFn.__call__
?The text was updated successfully, but these errors were encountered: