diff --git a/tests/test_routing/test_fallbacks.py b/tests/test_routing/test_fallbacks.py index 7fbde4b..0d6bb08 100644 --- a/tests/test_routing/test_fallbacks.py +++ b/tests/test_routing/test_fallbacks.py @@ -13,3 +13,7 @@ def test_endpoint_fallback(): unify.Unify( "llama-3.1-405b-chat@together-ai->llama-3.1-70b-chat@groq", ).generate("Hello.") + + +if __name__ == "__main__": + pass diff --git a/unify/universal_api/clients/uni_llm.py b/unify/universal_api/clients/uni_llm.py index d187ce6..c196f8d 100644 --- a/unify/universal_api/clients/uni_llm.py +++ b/unify/universal_api/clients/uni_llm.py @@ -353,7 +353,12 @@ def set_endpoint(self, value: str) -> Self: """ _assert_is_valid_endpoint(value, api_key=self._api_key) self._endpoint = value - self._model, self._provider = value.split("->")[0].split("@") # noqa: WPS414 + lhs = value.split("->")[0] + if "@" in lhs: + self._model, self._provider = lhs.split("@") + else: + self._model = lhs + self._provider = value.split("->")[1] return self def set_model(self, value: str) -> Self: