diff --git a/blackjax/util.py b/blackjax/util.py index ee226a5ac..df527ed01 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -182,7 +182,7 @@ def run_inference_algorithm( init_key, sample_key = split(rng_key, 2) try: initial_state = inference_algorithm.init(initial_state_or_position, init_key) - except TypeError: + except (TypeError, ValueError, AttributeError): # We assume initial_state is already in the right format. initial_state = initial_state_or_position