diff --git a/hordelib/model_manager/lora.py b/hordelib/model_manager/lora.py index 4ea032e9..f9f0623e 100644 --- a/hordelib/model_manager/lora.py +++ b/hordelib/model_manager/lora.py @@ -44,14 +44,19 @@ class LoraModelManager(BaseModelManager): ) LORA_API = "https://civitai.com/api/v1/models?types=LORA&sort=Highest%20Rated&primaryFileOnly=true" MAX_RETRIES = 10 if not TESTS_ONGOING else 3 - MAX_DOWNLOAD_THREADS = 3 + MAX_DOWNLOAD_THREADS = 5 if not TESTS_ONGOING else 15 RETRY_DELAY = 3 if not TESTS_ONGOING else 0.2 """The time to wait between retries in seconds""" - REQUEST_METADATA_TIMEOUT = 20 - """The time to wait for a response from the server in seconds""" - REQUEST_DOWNLOAD_TIMEOUT = 300 - """The time to wait for a response from the server in seconds""" - THREAD_WAIT_TIME = 2 + REQUEST_METADATA_TIMEOUT = 20 # Longer because civitai performs poorly on metadata requests for more than 5 models + """The maximum time for no data to be received before we give up on a metadata fetch, in seconds""" + REQUEST_DOWNLOAD_TIMEOUT = 10 if not TESTS_ONGOING else 1 + """The maximum time for no data to be received before we give up on a download, in seconds + + This is not the time to download the file, but the time to wait in between data packets. \ + If we're actively downloading and the connection to the server is alive, this doesn't come into play + """ + + THREAD_WAIT_TIME = 0.1 """The time to wait between checking the download queue in seconds""" _file_lock: multiprocessing_lock | nullcontext @@ -274,21 +279,45 @@ def _add_lora_ids_to_download_queue(self, lora_ids, adhoc=False, version_compare def _get_json(self, url): retries = 0 while retries <= self.MAX_RETRIES: + response = None try: - response = requests.get(url, timeout=self.REQUEST_METADATA_TIMEOUT) + response = requests.get( + url, + timeout=self.REQUEST_METADATA_TIMEOUT if len(url) < 200 else self.REQUEST_METADATA_TIMEOUT * 1.5, + ) response.raise_for_status() # Attempt to decode the response to JSON return response.json() except (requests.HTTPError, requests.ConnectionError, requests.Timeout, json.JSONDecodeError) as e: - # CivitAI Errors when the model ID is too long - if response.status_code in [404, 500]: - logger.debug(f"url '{url}' download failed with status code {response.status_code}") - return None - logger.debug(f"url '{url}' download failed {type(e)} {e}") + + # If this is a 401, 404, or 500, we're not going to get anywhere, just give up + # The following are the CivitAI errors encountered so far + # [401: requires a token, 404: model ID too long, 500: internal server error] + if response is not None: + if response.status_code in [401, 404]: + logger.debug(f"url '{url}' download failed with status code {response.status_code}") + return None + if response.status_code == 500: + logger.debug(f"url '{url}' download failed with status code {response.status_code}") + retries += 3 + + # The json being invalid is a CivitAI issue, possibly it showing an HTML page and + # this isn't likely to change in the next 30 seconds, so we'll try twice more + # and give up if it doesn't work + if isinstance(e, json.JSONDecodeError): + logger.debug(f"url '{url}' download failed with {type(e)} {e}") + retries += 3 + + # If the network connection timed out, then self.REQUEST_METADATA_TIMEOUT seconds passed + # and the clock is ticking, so we'll try once more + if response is None: + retries += 5 + retries += 1 self.total_retries_attempted += 1 + if retries <= self.MAX_RETRIES: time.sleep(self.RETRY_DELAY) else: @@ -674,8 +703,8 @@ def clear_all_references(self): def wait_for_downloads(self, timeout=None): rtr = 0 while not self.are_downloads_complete(): - time.sleep(0.5) - rtr += 0.5 + time.sleep(self.THREAD_WAIT_TIME) + rtr += self.THREAD_WAIT_TIME if timeout and rtr > timeout: raise Exception(f"Lora downloads exceeded specified timeout ({timeout})") logger.debug("Downloads complete") @@ -973,7 +1002,7 @@ def reset_adhoc_loras(self): if self._stop_all_threads: logger.debug("Stopped processing thread") return - time.sleep(0.2) + time.sleep(self.THREAD_WAIT_TIME) now = datetime.now().strftime("%Y-%m-%d %H:%M:%S") self._adhoc_loras = set() unsorted_items = [] @@ -1073,8 +1102,8 @@ def is_adhoc_reset_complete(self): def wait_for_adhoc_reset(self, timeout=15): rtr = 0 while not self.is_adhoc_reset_complete(): - time.sleep(0.2) - rtr += 0.2 + time.sleep(self.THREAD_WAIT_TIME) + rtr += self.THREAD_WAIT_TIME if timeout and rtr > timeout: raise Exception(f"Lora adhoc reset exceeded specified timeout ({timeout})") diff --git a/hordelib/model_manager/ti.py b/hordelib/model_manager/ti.py index 30b14fb6..e8c7ffc0 100644 --- a/hordelib/model_manager/ti.py +++ b/hordelib/model_manager/ti.py @@ -151,6 +151,7 @@ def _add_ti_ids_to_download_queue(self, ti_ids, adhoc=False, version_compare=Non def _get_json(self, url): retries = 0 while retries <= self.MAX_RETRIES: + response = None try: response = requests.get(url, timeout=self.REQUEST_METADATA_TIMEOUT) response.raise_for_status() @@ -159,8 +160,19 @@ def _get_json(self, url): except (requests.HTTPError, requests.ConnectionError, requests.Timeout, json.JSONDecodeError): # CivitAI Errors when the model ID is too long - if response.status_code in [404, 500]: - return None + if response is not None: + if response.status_code in [401, 404]: + return None + if response.status_code == 500: + retries += 3 + logger.debug( + "CivitAI reported an internal error when downloading metadata. " + "Fewer retries will be attempted.", + ) + + if response is None: + retries += 5 + retries += 1 self.total_retries_attempted += 1 if retries <= self.MAX_RETRIES: diff --git a/tests/model_managers/test_mm_lora.py b/tests/model_managers/test_mm_lora.py index 40c8b549..138137f5 100644 --- a/tests/model_managers/test_mm_lora.py +++ b/tests/model_managers/test_mm_lora.py @@ -220,7 +220,7 @@ def test_adhoc_non_existing_intstring_large(self): lora_model_manager.wait_for_adhoc_reset(15) lora_name = "99999999999999" lora_key = lora_model_manager.fetch_adhoc_lora(lora_name) - assert lora_model_manager.total_retries_attempted == 0 + assert lora_model_manager.total_retries_attempted == 1 assert lora_key is None assert not lora_model_manager.is_model_available(lora_name) lora_model_manager.stop_all()