From 7c8a57b45e6c38c2a4f227241f284424ccec7bfa Mon Sep 17 00:00:00 2001 From: Joey Ballentine Date: Fri, 17 Nov 2023 00:18:12 -0500 Subject: [PATCH] README improvements --- README.md | 68 ++++++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 55 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index a9cf1181..e618b38a 100644 --- a/README.md +++ b/README.md @@ -10,31 +10,73 @@ Slightly selfishly, I'm also hoping this will encourage the community to help ad This package does not yet have easy inference code for these model types, but porting that code is planned as well. +## Installation + +Spandrel is available through pip and can be installed via a simple pip install command: + +```shell +pip install spandrel +``` + ## Usage **This package is still in early stages of development, and is subject to change at any time.** -To use this package, simply use the ArchSupport class like so: +To use this package for automatic architecture loading, simply use the ModelLoader class like so: ```python -from spandrel import ArchSupport +from spandrel import ModelLoader import torch -arch_loader = ArchSupport(torch.device("cuda:0")) -model = arch_loader.load_from_path(r"/path/to/your/model.pth") +# Initialize the ModelLoader class with an optional preferred torch.device. Defaults to cpu. +model_loader = ModelLoader(torch.device("cuda:0")) -print(model.metadata) -print(model.model) -print(model.state_dict) +# Load the model from the given path +loaded_model = model_loader.load_from_file(r"/path/to/your/model.pth") ``` -And that's it. The model gets loaded into a wrapper class that has some `metadata` on it that tells you a bit about the model and its size. You can also access the actual torch `model` and `state_dict` from it. +And that's it. The model gets loaded into a helper class with various helpful bits of information, as well as the actual model information. + +```py +# The model itself (a torch.nn.Module loaded with the weights) +loaded_model.model + +# The state dict of the model (the weights) +loaded_model.state_dict + +# The architecture of the model (e.g. "ESRGAN") +loaded_model.architecture + +# A list of tags for the model, usually describing the size (e.g. ["64nf", "large"]) +loaded_model.tags -You can also just use it for inference the same way you would with the `model` directly, so for example you could do `result = model(img)` and it will automatically call the forward method of the model. It also supports moving it to other devices, so you can call `.to` on it just like you would the direct model. +# A boolean indicating whether the model supports half precision (fp16) +loaded_model.supports_half + +# A boolean indicating whether the model supports bfloat16 precision +loaded_model.supports_bfloat16 + +# The scale of the model (e.g. 4) +loaded_model.scale + +# The number of input channels of the model (e.g. 3) +loaded_model.input_channels + +# The number of output channels of the model (e.g. 3) +loaded_model.output_channels + +# A SizeRequirements object describing the image size requirements of the model +# i.e the minimum size, the multiple of size, and whether the model requires a square input +loaded_model.size +``` + +You can also just use this helper class for inference the same way you would with the `model` directly, so for example you could do `result = loaded_model(img)` and it will automatically call the forward method of the model. It also supports moving it to other devices, so you can call `.to` on it just like you would the direct model. ## Model Architecture Support -spandrel currently supports a limited amount of neural network architectures. It can auto-detect these architectures just from their .pth files. This has only been tested with the models that are linked here, and any unofficial variants (especially if changes are made to their architectures) are not guaranteed to work. +Spandrel currently supports a limited amount of neural network architectures. It can auto-detect these architectures just from their files alone. + +This has only been tested with the models that are linked here, and any unofficial variants (especially if changes are made to their architectures) are not guaranteed to work. ### Pytorch @@ -70,11 +112,11 @@ spandrel currently supports a limited amount of neural network architectures. It ## Contributing -Feel free to contribute more model architecture support. When I add model support, I usually dig through the .pth file (state dict) keys and weights to find a way to get all the parameters of a model. At some point, I will document that entire process here. For now, there are plenty of references (most in the super_resolution folder) to reference. +Feel free to contribute more model architecture support. When I add model support, I usually dig through the .pth file (state dict) keys and weights to find a way to get all the parameters of a model. At some point, I will document that entire process here. For now, there are plenty of example to reference. -If the model arch you're adding does not have any parameter variants (for example, different scales or layer counts) then it should be fine adding it without any of the param detection. At the very least, you will need to find something uniquely identifiable in your model (usually a unique, really long key) that you can then add to `/spandrel/__helpers/model_loading.py` in order to load your model (preferably to the bottom of the if block before the else). You will also need to set up the `__init__.py` file for your arch to include a `load` method, returning the model and some metadata about the model and its parameters. +If the model arch you're adding does not have any parameter variants (for example, different scales or layer counts) then it should be fine adding it without any of the param detection. At the very least, you will need to find something uniquely identifiable in your model (usually a unique, really long key) that you can then add to `/spandrel/__helpers/main_registry.py` in order to load your model (preferably at the bottom). You will also need to set up the `__init__.py` file for your arch to include a `load` method, returning as ModelDescriptor with the model and some metadata about the model and its parameters. -Like with the parameter detection, there's plenty of examples there. This might seem like a lot of hardcoding (and it very well is), but it's the only way to identify models based on just the .pth file, since .pth files are just the weights of a model. If anybody can figure out a better way to do this, be my guest, but for now this is the best way and it works well. +Like with the parameter detection, there's plenty of examples there. This might seem like a lot of hardcoding (and it very well is), but it's the only way to identify models based on just the .pth file (or any other weight storage format), since these files are just the weights of a model. If anybody can figure out a better way to do this, be my guest, but for now this is the best way and it works well. ## License Notice