Skip to content

Commit

Permalink
Merge pull request #109 from nasaharvest/comprehensive-fillna
Browse files Browse the repository at this point in the history
Comprehensive fillna
  • Loading branch information
ivanzvonkov authored Oct 3, 2022
2 parents ebf188e + ee7ced5 commit 61dad07
Show file tree
Hide file tree
Showing 7 changed files with 117 additions and 162 deletions.
103 changes: 44 additions & 59 deletions openmapflow/engineer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,47 @@
)


def _fillna(data):
"""Fill in the missing values in the data array"""
bands_np = np.array(DYNAMIC_BANDS + STATIC_BANDS)
if len(data.shape) != 4:
raise ValueError(
f"Expected data to be 4D (time, band, x, y) - got {data.shape}"
)
if data.shape[1] != len(bands_np):
raise ValueError(
f"Expected data to have {len(bands_np)} bands - got {data.shape[1]}"
)

is_nan = np.isnan(data)
if not is_nan.any().item():
return data
mean_per_time_band = data.mean(axis=(2, 3), skipna=True)
is_nan_any = is_nan.any(axis=(0, 2, 3)).values
is_nan_all = is_nan.all(axis=(0, 2, 3)).values
bands_all_nans = bands_np[is_nan_all]
bands_some_nans = bands_np[is_nan_any & ~is_nan_all]
if bands_all_nans.size > 0:
print(f"WARNING: Bands: {bands_all_nans} have all nan values")
# If a band has all nan values, fill with default: 0
mean_per_time_band[:, is_nan_all] = 0
if bands_some_nans.size > 0:
print(f"WARNING: Bands: {bands_some_nans} have some nan values")
if np.isnan(mean_per_time_band).any():
mean_per_band = mean_per_time_band.mean(axis=0, skipna=True)
return data.fillna(mean_per_band)
return data.fillna(mean_per_time_band)


def load_tif(
filepath: Path,
start_date: Optional[datetime] = None,
num_timesteps: Optional[int] = None,
filepath: Path, start_date: Optional[datetime] = None, fillna: bool = True
):
r"""
The sentinel files exported from google earth have all the timesteps
concatenated together. This function loads a tif files and splits the
timesteps
Returns: The loaded xr.DataArray, and the average slope (used for filling nan slopes)
Returns: The loaded xr.DataArray
"""
xr = import_optional_dependency("xarray")

Expand All @@ -40,11 +70,9 @@ def load_tif(
num_dynamic_bands = num_bands - len(STATIC_BANDS)

assert num_dynamic_bands % bands_per_timestep == 0
if num_timesteps is None:
num_timesteps = num_dynamic_bands // bands_per_timestep
num_timesteps = num_dynamic_bands // bands_per_timestep

static_data = da.isel(band=slice(num_bands - len(STATIC_BANDS), num_bands))
average_slope = np.nanmean(static_data.values[STATIC_BANDS.index("slope"), :, :])

for timestep in range(num_timesteps):
time_specific_da = da.isel(
Expand All @@ -61,14 +89,15 @@ def load_tif(
start_date + timedelta(days=DAYS_PER_TIMESTEP) * i
for i in range(len(da_split_by_time))
]
dynamic_data = xr.concat(da_split_by_time, pd.Index(timesteps, name="time"))
data = xr.concat(da_split_by_time, pd.Index(timesteps, name="time"))
else:
dynamic_data = xr.concat(
data = xr.concat(
da_split_by_time, pd.Index(range(len(da_split_by_time)), name="time")
)
dynamic_data.attrs["band_descriptions"] = BANDS

return dynamic_data, average_slope
if fillna:
data = _fillna(data)
data.attrs["band_descriptions"] = BANDS
return data


def calculate_ndvi(input_array: np.ndarray) -> np.ndarray:
Expand Down Expand Up @@ -105,40 +134,6 @@ def calculate_ndvi(input_array: np.ndarray) -> np.ndarray:
return np.append(input_array, np.expand_dims(ndvi, -1), axis=-1)


def fillna(array: np.ndarray, average_slope: float) -> Optional[np.ndarray]:
r"""
Given an input array of shape [timesteps, BANDS]
fill NaN values with the mean of each band across the timestep
The slope values may all be nan so average_slope is manually passed
"""
num_dims = len(array.shape)
if num_dims == 2:
bands_index = 1
mean_per_band = np.nanmean(array, axis=0)
elif num_dims == 3:
bands_index = 2
mean_per_band = np.nanmean(np.nanmean(array, axis=0), axis=0)
else:
raise ValueError(f"Expected num_dims to be 2 or 3 - got {num_dims}")

assert array.shape[bands_index] == len(BANDS)

if np.isnan(mean_per_band).any():
if (sum(np.isnan(mean_per_band)) == bands_index) & (
np.isnan(mean_per_band[BANDS.index("slope")]).all()
):
mean_per_band[BANDS.index("slope")] = average_slope
assert not np.isnan(mean_per_band).any()
else:
return None
for i in range(array.shape[bands_index]):
if num_dims == 2:
array[:, i] = np.nan_to_num(array[:, i], nan=mean_per_band[i])
elif num_dims == 3:
array[:, :, i] = np.nan_to_num(array[:, :, i], nan=mean_per_band[i])
return array


def remove_bands(array: np.ndarray) -> np.ndarray:
"""
Expects the input to be of shape [timesteps, bands] or
Expand Down Expand Up @@ -168,30 +163,20 @@ def remove_bands(array: np.ndarray) -> np.ndarray:
raise ValueError(error_message)


def process_test_file(
path_to_file: Path, start_date: datetime, num_timesteps: Optional[int] = None
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
da, slope = load_tif(
path_to_file, start_date=start_date, num_timesteps=num_timesteps
)
def process_test_file(path_to_file: Path) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
da = load_tif(path_to_file)

# Process remote sensing data
x_np = da.values
x_np = x_np.reshape(x_np.shape[0], x_np.shape[1], x_np.shape[2] * x_np.shape[3])
x_np = np.moveaxis(x_np, -1, 0)
x_np = calculate_ndvi(x_np)
x_np = remove_bands(x_np)
final_x = fillna(x_np, slope)
if final_x is None:
raise RuntimeError(
"fillna on the test instance returned None; "
"does the test instance contain NaN only bands?"
)

# Get lat lons
lon, lat = np.meshgrid(da.x.values, da.y.values)
flat_lat, flat_lon = (
np.squeeze(lat.reshape(-1, 1), -1),
np.squeeze(lon.reshape(-1, 1), -1),
)
return final_x, flat_lat, flat_lon
return x_np, flat_lat, flat_lon
18 changes: 2 additions & 16 deletions openmapflow/inference.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import re
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -43,15 +41,6 @@ def __init__(
"Using PyTorch model but PyTorch is not installed. Please pip install torch"
)

@staticmethod
def start_date_from_str(path: Union[Path, str]) -> datetime:
dates = re.findall(r"\d{4}-\d{2}-\d{2}", str(path))
if len(dates) < 1:
raise ValueError(f"{path} should have at least one date")
start_date_str = dates[0]
start_date = datetime.strptime(start_date_str, "%Y-%m-%d")
return start_date

@staticmethod
def _combine_predictions(
flat_lat: np.ndarray, flat_lon: np.ndarray, batch_predictions: List[np.ndarray]
Expand Down Expand Up @@ -84,13 +73,10 @@ def _on_single_batch(self, batch_x_np: np.ndarray) -> np.ndarray:
def run(
self,
local_path: Path,
start_date: Optional[datetime] = None,
dest_path: Optional[Path] = None,
) -> pd.DataFrame:
"""Runs inference on a single tif file."""
if start_date is None:
start_date = self.start_date_from_str(local_path)
x_np, flat_lat, flat_lon = process_test_file(local_path, start_date)
x_np, flat_lat, flat_lon = process_test_file(local_path)

if self.normalizing_dict is not None:
x_np = (x_np - self.normalizing_dict["mean"]) / self.normalizing_dict["std"]
Expand Down
28 changes: 7 additions & 21 deletions openmapflow/labeled_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import random
import shutil
import tempfile
import warnings
from dataclasses import dataclass
from datetime import date
from pathlib import Path
Expand Down Expand Up @@ -42,7 +41,7 @@
EarthEngineExporter,
get_cloud_tif_list,
)
from openmapflow.engineer import calculate_ndvi, fillna, load_tif, remove_bands
from openmapflow.engineer import calculate_ndvi, load_tif, remove_bands
from openmapflow.utils import str_to_np, tqdm

SEED = 42
Expand Down Expand Up @@ -145,21 +144,20 @@ def _find_matching_point(
So the function finds the closest grid coordinate to the label coordinate.
Additional value is given to a grid coordinate that is close to the center of the tif.
"""
tif_slope_tuples = []
tifs = []
for p in eo_paths:
blob = tif_bucket.blob(str(p))
local_path = Path(f"{temp_dir}/{p.name}")
if not local_path.exists():
blob.download_to_filename(str(local_path))
tif_slope_tuples.append(load_tif(local_path))
tifs.append(load_tif(local_path))
if local_path.exists():
local_path.unlink()

if len(tif_slope_tuples) > 1:
if len(tifs) > 1:
min_distance_from_point = np.inf
min_distance_from_center = np.inf
for i, tif_slope_tuple in enumerate(tif_slope_tuples):
tif, slope = tif_slope_tuple
for i, tif in enumerate(tifs):
lon, lon_idx = _find_nearest(tif.x, label_lon)
lat, lat_idx = _find_nearest(tif.y, label_lat)
distance_from_point = _distance(label_lat, label_lon, lat, lon)
Expand All @@ -173,22 +171,16 @@ def _find_matching_point(
min_distance_from_center = distance_from_center
min_distance_from_point = distance_from_point
eo_data = tif.sel(x=lon).sel(y=lat).values
average_slope = slope
eo_file = eo_paths[i].name
else:
tif, slope = tif_slope_tuples[0]
tif = tifs[0]
closest_lon = _find_nearest(tif.x, label_lon)[0]
closest_lat = _find_nearest(tif.y, label_lat)[0]
eo_data = tif.sel(x=closest_lon).sel(y=closest_lat).values
average_slope = slope
eo_file = eo_paths[0].name

eo_data = calculate_ndvi(eo_data)
eo_data = remove_bands(eo_data)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
eo_data = fillna(eo_data, average_slope)

return eo_data, closest_lon, closest_lat, eo_file


Expand All @@ -206,22 +198,16 @@ def _find_matching_point_url(
with local_path.open("wb") as f:
shutil.copyfileobj(r.raw, f)

tif, slope = load_tif(local_path)
tif = load_tif(local_path)
if local_path.exists():
local_path.unlink()

closest_lon = _find_nearest(tif.x, label_lon)[0]
closest_lat = _find_nearest(tif.y, label_lat)[0]

eo_data = tif.sel(x=closest_lon).sel(y=closest_lat).values
average_slope = slope

eo_data = calculate_ndvi(eo_data)
eo_data = remove_bands(eo_data)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
eo_data = fillna(eo_data, average_slope)

return eo_data, closest_lon, closest_lat


Expand Down
7 changes: 1 addition & 6 deletions openmapflow/torchserve_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,8 @@ def inference(self, data: str, *args, **kwargs) -> Tuple[str, str]:
local_src_path = download_file(uri)
local_dest_path = Path(tempfile.gettempdir() + f"/pred_{Path(uri).stem}.nc")

start_date = Inference.start_date_from_str(uri)
print(f"HANDLER: Start date: {start_date}")

print("HANDLER: Starting inference")
self.inference_module.run(
local_path=local_src_path, start_date=start_date, dest_path=local_dest_path
)
self.inference_module.run(local_path=local_src_path, dest_path=local_dest_path)
print("HANDLER: Completed inference")
dest_uri = upload_file(
bucket_name=self.dest_bucket_name, local_path=local_dest_path, src_uri=uri
Expand Down
Loading

0 comments on commit 61dad07

Please sign in to comment.