Skip to content

Commit

Permalink
doseg model compile for parallel
Browse files Browse the repository at this point in the history
  • Loading branch information
dbuscombe-usgs committed Mar 2, 2023
1 parent 87924cb commit 0999c40
Showing 1 changed file with 14 additions and 8 deletions.
22 changes: 14 additions & 8 deletions doodleverse_utils/prediction_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from .imports import standardize, label_to_colors, fromhex

import os
import os,gc
import numpy as np
import matplotlib.pyplot as plt
from scipy import io
Expand All @@ -50,10 +50,10 @@

tf.random.set_seed(SEED)

print("Version: ", tf.__version__)
print("Eager mode: ", tf.executing_eagerly())
print("GPU name: ", tf.config.experimental.list_physical_devices("GPU"))
print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices("GPU")))
# print("Version: ", tf.__version__)
# print("Eager mode: ", tf.executing_eagerly())
# print("GPU name: ", tf.config.experimental.list_physical_devices("GPU"))
# print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices("GPU")))

##========================================================
def rescale(dat,
Expand Down Expand Up @@ -130,6 +130,12 @@ def do_seg(
out_dir_name='out'
):

Mc = []
for m in M:
m.compile(optimizer='adam')
Mc.append(m)
# del M

if f.endswith("jpg"):
segfile = f.replace(".jpg", "_predseg.png")
elif f.endswith("png"):
Expand Down Expand Up @@ -172,7 +178,7 @@ def do_seg(
E0 = []
E1 = []

for counter, model in enumerate(M):
for counter, model in enumerate(Mc):#M):
# heatmap = make_gradcam_heatmap(tf.expand_dims(image, 0) , model)

try:
Expand Down Expand Up @@ -301,8 +307,6 @@ def do_seg(
else:
est_label = model.predict(tf.expand_dims(image[:,:,0], 0), batch_size=1).squeeze()

# est_label = model.predict(tf.expand_dims(image, 0), batch_size=1).squeeze()

if TESTTIMEAUG == True:
# return the flipped prediction
if MODEL=='segformer':
Expand Down Expand Up @@ -477,6 +481,8 @@ def do_seg(
plt.savefig(tmpfile, dpi=200, bbox_inches="tight")
plt.close("all")

gc.collect()



# --------------------------------------------------------
Expand Down

0 comments on commit 0999c40

Please sign in to comment.