Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Segmentation model implementation speedup ideas (model.predict is slow) #31

Open
dbuscombe-usgs opened this issue May 11, 2023 · 17 comments
Assignees
Labels
bug Something isn't working

Comments

@dbuscombe-usgs
Copy link
Member

At the moment, Unet models are called in a loop, one image at a time. Model inference time reported by keras is approximately constant, but the overall time per image increases steadily. Over the course of ~1000 images, the slow-down is about 4x - 5x, on Windows and Linux.

This is not a priority right now, but should be fixed eventually. There are at least two aspects

  1. model inference time. what is the quickest way to get the label? Should we be batching all images together, or continue to use a batch size of 1?
  2. file I/O and other operations - is there a memory leak causing the slowdown?

Would this eventually be solved by switching to dask?

Leaving this issue here to be looked at later

@2320sharon
Copy link
Collaborator

Shoreline Extraction Process - Performance Issues

Overview

We are experiencing slowdowns in our shoreline extraction process. Below are the key differences in our approach compared to the standard CoastSat method, which we believe might be contributing to these performance issues.

Key Differences

  1. Use of Dask: Our implementation heavily relies on Dask, which might not be ideal due to the intensive I/O operations involved in reading and creating files.

  2. Handling of NPZ Files:

    • Reading NPZ Files: Unlike CoastSat, which generates predictions on the fly, our process involves reading predictions from NPZ files for each image.
    • Merging Labels from NPZ Files: We are also merging labels from these NPZ files, though this step's impact on performance is unclear.

The remaining steps are largely similar to the CoastSat approach and are not detailed here.

Planned Improvements

Please try the following to see if they cause an performance increases in the shoreline extraction process for the zoo workflow.

Tasks

  • Remove Dask from the shoreline extraction process

    • I believe we could definitely improve this process by removing dask because its adding overhead that we can't take advantage of because of the heavy I/O
  • Run the model on the images and extract shorelines at the same time like the coastsat process

    • We might also want to consider generating predictions on the fly like coastsat does to extract shorelines. That way we no longer need to read from the NPZ file and it reduces the number of steps the user has to go through to extract shorelines.

@dbuscombe-usgs
Copy link
Member Author

Okay, there are two problems being articulated here. I opened the issue with the doodleverse/do_seg or model.predict() performance, i.e., the process that generates the npz files.

I think you are referring to the subsequent process of using the npz files (image segmentation outputs) from shoreline extraction.

My focus right now is how very slow in general I made the doodleverse-utils do_seg style of model inference. The reason initially for the complex code with the need to resize imagery, and apply image standardization, neither of which at the time had a good keras implementation.

However now I see that keras.utils has image_dataset_from_directory so imagery can be called using

dataset = tf.keras.utils.image_dataset_from_directory(
    folder,
    labels=None,
    color_mode="rgb",
    batch_size=32,
    image_size=TARGET_SIZE,
    shuffle=False,
    seed=None,
    validation_split=None,
    subset=None,
    interpolation="bilinear"
)

and standardization cam be applied using

normalization_layer = tf.keras.layers.Normalization()

https://www.tensorflow.org/api_docs/python/tf/keras/layers/Normalization

This layer will shift and scale inputs into a distribution centered around 0
with standard deviation 1. It accomplishes this by precomputing the mean and variance of the data,
and calling (input - mean) / sqrt(var) at runtime.

which finally I figured out could be applied using .map

normalized_ds = dataset.map(lambda x: (normalization_layer(x)))

Now we can call the mode.predict step on all the jpeg files using all cores like so

predictions = model.predict(normalized_ds, workers=-1)

:)

Now I need to figure out how to implement it in ensemble model, etc, and how to get it in the doodleverse, but I just tested it on 1175 images and it worked great

@dbuscombe-usgs dbuscombe-usgs changed the title Unet model implementation speedup ideas Segmentation model implementation speedup ideas (model.predict is slow) Feb 14, 2024
@dbuscombe-usgs
Copy link
Member Author

This approach would be adapted to NDWI and MNDWI jpegs using color_mode="gray"

@dbuscombe-usgs
Copy link
Member Author

I think maybe this issue should be in doodleverse_utils. Ultimately, I think you just need to call do_seg from doodleverse-utils

@2320sharon
Copy link
Collaborator

@dbuscombe-usgs
This is a major win! I'm so glad there is a native way to do image preprocessing now in keras.
So it sounds like we can keep most of the zoo workflow in CoastSeg intact it just means when the the new do_seg function is ready we will update the version of doodleverse_utils required and it should run faster.

@dbuscombe-usgs
Copy link
Member Author

This is a complicated upgrade because it needs to be able to deal with a lot of different scenarios (large and small images, large and small numbers of total imagery), and because of the custom inputs/outputs we need for model inference.

I'm working with the largest dataset (5500x7500 pixel imagery, up to tens of thousands of samples) in order to explore options, because the greatest limitation is always GPU memory. I have found that in order to use tf.keras.utils.image_dataset_from_directory effectively, I need to split tasks between GPUs and CPUs

This is typically defined within a scope of GPUs or multiple GPUs:

dataset = tf.keras.utils.image_dataset_from_directory(
    folder,
    labels=None,
    color_mode="rgb",
    batch_size=16,# 32,
    image_size=TARGET_SIZE,
    shuffle=False,
    seed=None,
    validation_split=None,
    subset=None,
    interpolation="bilinear"
)
## https://www.tensorflow.org/api_docs/python/tf/keras/layers/Normalization
## This layer will shift and scale inputs into a distribution centered around 0 
## with standard deviation 1. It accomplishes this by precomputing the mean and variance of the data, 
## and calling (input - mean) / sqrt(var) at runtime.
normalization_layer = tf.keras.layers.Normalization(axis=None)
normalized_ds = dataset.map(lambda x: (normalization_layer(x))) 

Then my current implementation uses GPUs for model.predict and the cpu for the argmax step

start_time = time.time()
for counter, model in enumerate(M):
    if USE_MULTI_GPU:
        with strategy.scope():
            predictions = model.predict(normalized_ds, callbacks=callbacks).astype('float32') #, workers=-1, 
            with tf.device('/cpu:0'):
                acc_result = K.argmax(predictions,axis=-1)
    else:
        predictions = model.predict(normalized_ds, workers=-1, callbacks=callbacks).astype('float32') 
        with tf.device('/cpu:0'):
            acc_result = K.argmax(predictions,axis=-1)

    del predictions
    if counter>0:
        acc_result += acc_result

    print(f"--- Model {counter}: %s seconds ---" % (time.time() - start_time)) 
    gc.collect()
    K.clear_session()

My custom garbage collector class is:

class ClearMemory(Callback):
    def on_predict_end(self, logs=None):
        gc.collect()
        K.clear_session()
callbacks = [ClearMemory()]

This is a profoundly different approach - it uses Keras' normalization layer (that I haven't yet been able to verify produces the same results as our custom image standardization routine). It also does the argmax and resizing steps afterwards, on the cpu. This is for memory management and speed. This is the binarization step, which can happen on GPUs

def binarize_func(x):
    return x>.5

tmp2 = tf.keras.layers.Lambda(lambda x:binarize_func(x))(tf.cast(acc_result,'float32'))

This is the resizing step, which must happen on the CPU for large datasets

resize_func = tf.keras.Sequential([tf.keras.layers.Resizing(image_size[0], image_size[1], dtype='uint8')])

tmp = np.moveaxis(tmp2, 0, -1)

start_time = time.time()
with tf.device('/cpu:0'): ##All cores are wrapped in cpu:0, i.e., TensorFlow does indeed use multiple CPU cores by default.
    out = tf.keras.layers.Lambda(lambda x: resize_func(x))(tmp)
print("--- Resize: %s seconds ---" % (time.time() - start_time)) 

It works, but I'm still exploring consistency with previous approaches

@dbuscombe-usgs
Copy link
Member Author

Th above workflow is specific to the ResUNet, because of the resizing step. SegFormer needs a slightly different implementation

@dbuscombe-usgs
Copy link
Member Author

Notes on SegFormer implementation. First, the model is called using a simpler API provided by TF

from transformers import TFSegformerForSemanticSegmentation
def segformer(
    id2label,
    num_classes=2,
):
    """
    https://keras.io/examples/vision/segformer/
    https://huggingface.co/nvidia/mit-b0
    """

    label2id = {label: id for id, label in id2label.items()}
    model_checkpoint = "nvidia/mit-b0"

    model = TFSegformerForSemanticSegmentation.from_pretrained(
        model_checkpoint,
        num_labels=num_classes,
        id2label=id2label,
        label2id=label2id,
        ignore_mismatched_sizes=True,
    )
    return model

then model is constructed using

        elif MODEL=='segformer':
            id2label = {}
            for k in range(NCLASSES):
                id2label[k]=str(k)
            model = segformer(id2label,num_classes=NCLASSES)

SegFormer models take reshaped inputs, rearranging the channels from 0,1,2 to 2,0,1. This adds a layer of complexity because only inputs for segformer need to be reshaped. This is dealt with using a custom Lambda layer to reshape inputs

    dataset = tf.keras.utils.image_dataset_from_directory(
        folder,
        labels=None,
        color_mode="rgb",
        batch_size=16, #32,
        image_size=TARGET_SIZE,
        shuffle=False,
        seed=None,
        validation_split=None,
        subset=None,
        interpolation="bilinear"
    )
    if MODEL=='segformer':
        transpose_layer = tf.keras.layers.Reshape( (-1, TARGET_SIZE[0], TARGET_SIZE[1]))
        transpose_ds = dataset.map(lambda x: (transpose_layer(x)))

        normalization_layer = tf.keras.layers.Normalization(axis=None)
        normalized_ds = transpose_ds.map(lambda x: (normalization_layer(x)))
    else:
        normalization_layer = tf.keras.layers.Normalization(axis=None)
        normalized_ds = dataset.map(lambda x: (normalization_layer(x))) 

Finally, we use the logits from the SegFormer model, so the inference code is adapted like so


start_time = time.time()
for counter, model in enumerate(M):
    if USE_MULTI_GPU:
        with strategy.scope():
            if model == 'segformer':
                predictions = model.predict(normalized_ds, callbacks=callbacks).logits.astype('float32')
            else:
                predictions = model.predict(normalized_ds, callbacks=callbacks).astype('float32') 
                
            with tf.device('/cpu:0'):
                acc_result = K.argmax(predictions,axis=-1)
    else:
        # predictions = model.predict(normalized_ds, workers=-1, callbacks=callbacks).astype('float32') 
        if model == 'segformer':
            predictions = model.predict(normalized_ds, callbacks=callbacks).logits.astype('float32')
        else:
            predictions = model.predict(normalized_ds, callbacks=callbacks).astype('float32') 

        with tf.device('/cpu:0'):
            acc_result = K.argmax(predictions,axis=-1)

    del predictions
    if counter>0:
        acc_result += acc_result

    print(f"--- Model {counter}: %s seconds ---" % (time.time() - start_time)) 
    gc.collect()
    K.clear_session()

Groovy!! we now have a single streamlined workflow for both ResUnets and SegFormers based on tf.keras.utils.image_dataset_from_directory 🎉

@dbuscombe-usgs
Copy link
Member Author

This is the specific model checkpoint we modify by fine-tuning to new data https://huggingface.co/nvidia/mit-b0. In the future we could upgrade to https://huggingface.co/nvidia/mit-b1 or similar in the future

@2320sharon
Copy link
Collaborator

Thank you for laying out the steps it takes to construct the model, add the new layers, how its called in the code and finally how it can run in inference mode. It all being laid out like that made it much easier to understand.

I do have a minor question, in the code below what is K?

start_time = time.time()
for counter, model in enumerate(M):
    if USE_MULTI_GPU:
        with strategy.scope():
            if model == 'segformer':
                predictions = model.predict(normalized_ds, callbacks=callbacks).logits.astype('float32')
            else:
                predictions = model.predict(normalized_ds, callbacks=callbacks).astype('float32') 
                
            with tf.device('/cpu:0'):
                acc_result = K.argmax(predictions,axis=-1)
    else:
        # predictions = model.predict(normalized_ds, workers=-1, callbacks=callbacks).astype('float32') 
        if model == 'segformer':
            predictions = model.predict(normalized_ds, callbacks=callbacks).logits.astype('float32')
        else:
            predictions = model.predict(normalized_ds, callbacks=callbacks).astype('float32') 

        with tf.device('/cpu:0'):
            acc_result = K.argmax(predictions,axis=-1)

    del predictions
    if counter>0:
        acc_result += acc_result

    print(f"--- Model {counter}: %s seconds ---" % (time.time() - start_time)) 
    gc.collect()
    K.clear_session()

@dbuscombe-usgs
Copy link
Member Author

That's my shorthand for the keras backend which I use to access https://www.tensorflow.org/api_docs/python/tf/keras/backend/clear_session

import tensorflow.keras.backend as K

@dbuscombe-usgs
Copy link
Member Author

@ebgoldstein
Copy link
Member

ebgoldstein commented Apr 24, 2024

@dbuscombe-usgs - i still can't find the exact scrpiut i used to make the lambda layer output the prediicted segmentation, but here is a related codeblock developing a lambda layer to output a confidence (the difference between highest and lowest logit for each pixel, then summed over all pixels).. (from this nb)... i think it can easily be adapted with a a squeeze and an argmax

#define the function for the lambda layer.. in this case its a confidence 

def AvgConf(x):
    #sort the inputs
    sorted_x = tf.sort(x, axis=-1, direction='DESCENDING', name=None)
    #take first
    TopPred = sorted_x[:,:,:,0]
    #calulate confidence
    LConfidence = 1 - TopPred
    #mean over image
    #conf_x = LConfidence
    conf_x = tf.reduce_mean(LConfidence, axis = 1)
    conf_x = tf.reduce_mean(conf_x, axis = 1)
    return(conf_x)

#margin output too

then attach it to the Gym model at the end

AverageConf = tf.keras.layers.Lambda(AvgConf)(base_model.output)


model = tf.keras.Model(base_model.input, AverageConf) 
#############
model.summary()

and then use the model as normal...

so i bet something like this would work (untested):

def SegOutput(x) 
    sq_x = tf.squeeze(x)
    pred = tf.math.argmax(sq_x,-1)
    return(pred)

PredSeg = tf.keras.layers.Lambda(SegOutput)(base_model.output)
model = tf.keras.Model(base_model.input, PredSeg) 

@dbuscombe-usgs
Copy link
Member Author

Thanks for this. Just looking at it for the first time today. For a segformer model, there is no model.output (AttributeError: Layer tf_segformer_for_semantic_segmentation has no inbound nodes.)

There is no model.input either. "AttributeError: Layer tf_segformer_for_semantic_segmentation is not connected, no input to return."

The model is defined thus


    model = TFSegformerForSemanticSegmentation.from_pretrained(
        model_checkpoint,
        num_labels=num_classes,
        id2label=id2label,
        label2id=label2id,
        ignore_mismatched_sizes=True,
    )

The **kwargs are (very typically) badly documented so I don't know if there is an option to specify argmax (I doubt it)

After some searching, I don;t know how to proceed. SO, I am moving onto the 'normalization/standardization issue'

@dbuscombe-usgs
Copy link
Member Author

I can't make any progress on standardization either. I can't seem to get https://keras.io/api/layers/preprocessing_layers/numerical/normalization/ to function properly. It never produces the intended output of a batched tensor consisting of zero mean and unit variance. No matter if I rescale the imagery first, reorder channels, specify channels, etc. Always wrong. I can't find any other examples

@dbuscombe-usgs
Copy link
Member Author

@ebgoldstein did you say you had an example of a custom standardization layer

@dbuscombe-usgs
Copy link
Member Author

For example, this code runs but always produces garbage

Example input
20181231T185801_20181231T190045_T10SEG

Example output:
tmp

What it should look like
correct


batch_size = 12

dataset = tf.keras.utils.image_dataset_from_directory(
    folder,
    labels=None,
    color_mode="rgb",
    batch_size=batch_size,
    image_size=image_size,
    shuffle=False,
    seed=None,
    validation_split=None,
    subset=None,
    interpolation="bilinear"
)

transpose_layer = tf.keras.layers.Reshape( (-1, image_size[0], image_size[1]))
transpose_ds = dataset.map(lambda x: (transpose_layer(x)))

## https://www.tensorflow.org/api_docs/python/tf/keras/layers/Normalization
## This layer will shift and scale inputs into a distribution centered around 0 
## with standard deviation 1. It accomplishes this by precomputing the mean and variance of the data, 
## and calling (input - mean) / sqrt(var) at runtime.
normalization_layer = tf.keras.layers.Normalization(axis=None)
normalized_ds = transpose_ds.map(lambda x: (normalization_layer(x)))

for image_batch in normalized_ds:
  print(image_batch.shape)
  break

tmp = np.array(image_batch[0], dtype=np.uint8)
tmp = np.einsum('ijk->jki',tmp)

plt.imshow(tmp)
plt.savefig('tmp.png',dpi=300); plt.close()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
Status: Todo
Development

No branches or pull requests

4 participants