Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Comprehensive fillna #109

Merged
merged 5 commits into from
Oct 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 44 additions & 59 deletions openmapflow/engineer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,47 @@
)


def _fillna(data):
"""Fill in the missing values in the data array"""
bands_np = np.array(DYNAMIC_BANDS + STATIC_BANDS)
if len(data.shape) != 4:
raise ValueError(
f"Expected data to be 4D (time, band, x, y) - got {data.shape}"
)
if data.shape[1] != len(bands_np):
raise ValueError(
f"Expected data to have {len(bands_np)} bands - got {data.shape[1]}"
)

is_nan = np.isnan(data)
if not is_nan.any().item():
return data
mean_per_time_band = data.mean(axis=(2, 3), skipna=True)
is_nan_any = is_nan.any(axis=(0, 2, 3)).values
is_nan_all = is_nan.all(axis=(0, 2, 3)).values
bands_all_nans = bands_np[is_nan_all]
bands_some_nans = bands_np[is_nan_any & ~is_nan_all]
if bands_all_nans.size > 0:
print(f"WARNING: Bands: {bands_all_nans} have all nan values")
# If a band has all nan values, fill with default: 0
mean_per_time_band[:, is_nan_all] = 0
if bands_some_nans.size > 0:
print(f"WARNING: Bands: {bands_some_nans} have some nan values")
if np.isnan(mean_per_time_band).any():
mean_per_band = mean_per_time_band.mean(axis=0, skipna=True)
return data.fillna(mean_per_band)
return data.fillna(mean_per_time_band)


def load_tif(
filepath: Path,
start_date: Optional[datetime] = None,
num_timesteps: Optional[int] = None,
filepath: Path, start_date: Optional[datetime] = None, fillna: bool = True
):
r"""
The sentinel files exported from google earth have all the timesteps
concatenated together. This function loads a tif files and splits the
timesteps

Returns: The loaded xr.DataArray, and the average slope (used for filling nan slopes)
Returns: The loaded xr.DataArray
"""
xr = import_optional_dependency("xarray")

Expand All @@ -40,11 +70,9 @@ def load_tif(
num_dynamic_bands = num_bands - len(STATIC_BANDS)

assert num_dynamic_bands % bands_per_timestep == 0
if num_timesteps is None:
num_timesteps = num_dynamic_bands // bands_per_timestep
num_timesteps = num_dynamic_bands // bands_per_timestep

static_data = da.isel(band=slice(num_bands - len(STATIC_BANDS), num_bands))
average_slope = np.nanmean(static_data.values[STATIC_BANDS.index("slope"), :, :])

for timestep in range(num_timesteps):
time_specific_da = da.isel(
Expand All @@ -61,14 +89,15 @@ def load_tif(
start_date + timedelta(days=DAYS_PER_TIMESTEP) * i
for i in range(len(da_split_by_time))
]
dynamic_data = xr.concat(da_split_by_time, pd.Index(timesteps, name="time"))
data = xr.concat(da_split_by_time, pd.Index(timesteps, name="time"))
else:
dynamic_data = xr.concat(
data = xr.concat(
da_split_by_time, pd.Index(range(len(da_split_by_time)), name="time")
)
dynamic_data.attrs["band_descriptions"] = BANDS

return dynamic_data, average_slope
if fillna:
data = _fillna(data)
data.attrs["band_descriptions"] = BANDS
return data


def calculate_ndvi(input_array: np.ndarray) -> np.ndarray:
Expand Down Expand Up @@ -105,40 +134,6 @@ def calculate_ndvi(input_array: np.ndarray) -> np.ndarray:
return np.append(input_array, np.expand_dims(ndvi, -1), axis=-1)


def fillna(array: np.ndarray, average_slope: float) -> Optional[np.ndarray]:
r"""
Given an input array of shape [timesteps, BANDS]
fill NaN values with the mean of each band across the timestep
The slope values may all be nan so average_slope is manually passed
"""
num_dims = len(array.shape)
if num_dims == 2:
bands_index = 1
mean_per_band = np.nanmean(array, axis=0)
elif num_dims == 3:
bands_index = 2
mean_per_band = np.nanmean(np.nanmean(array, axis=0), axis=0)
else:
raise ValueError(f"Expected num_dims to be 2 or 3 - got {num_dims}")

assert array.shape[bands_index] == len(BANDS)

if np.isnan(mean_per_band).any():
if (sum(np.isnan(mean_per_band)) == bands_index) & (
np.isnan(mean_per_band[BANDS.index("slope")]).all()
):
mean_per_band[BANDS.index("slope")] = average_slope
assert not np.isnan(mean_per_band).any()
else:
return None
for i in range(array.shape[bands_index]):
if num_dims == 2:
array[:, i] = np.nan_to_num(array[:, i], nan=mean_per_band[i])
elif num_dims == 3:
array[:, :, i] = np.nan_to_num(array[:, :, i], nan=mean_per_band[i])
return array


def remove_bands(array: np.ndarray) -> np.ndarray:
"""
Expects the input to be of shape [timesteps, bands] or
Expand Down Expand Up @@ -168,30 +163,20 @@ def remove_bands(array: np.ndarray) -> np.ndarray:
raise ValueError(error_message)


def process_test_file(
path_to_file: Path, start_date: datetime, num_timesteps: Optional[int] = None
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
da, slope = load_tif(
path_to_file, start_date=start_date, num_timesteps=num_timesteps
)
def process_test_file(path_to_file: Path) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
da = load_tif(path_to_file)

# Process remote sensing data
x_np = da.values
x_np = x_np.reshape(x_np.shape[0], x_np.shape[1], x_np.shape[2] * x_np.shape[3])
x_np = np.moveaxis(x_np, -1, 0)
x_np = calculate_ndvi(x_np)
x_np = remove_bands(x_np)
final_x = fillna(x_np, slope)
if final_x is None:
raise RuntimeError(
"fillna on the test instance returned None; "
"does the test instance contain NaN only bands?"
)

# Get lat lons
lon, lat = np.meshgrid(da.x.values, da.y.values)
flat_lat, flat_lon = (
np.squeeze(lat.reshape(-1, 1), -1),
np.squeeze(lon.reshape(-1, 1), -1),
)
return final_x, flat_lat, flat_lon
return x_np, flat_lat, flat_lon
18 changes: 2 additions & 16 deletions openmapflow/inference.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import re
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -43,15 +41,6 @@ def __init__(
"Using PyTorch model but PyTorch is not installed. Please pip install torch"
)

@staticmethod
def start_date_from_str(path: Union[Path, str]) -> datetime:
dates = re.findall(r"\d{4}-\d{2}-\d{2}", str(path))
if len(dates) < 1:
raise ValueError(f"{path} should have at least one date")
start_date_str = dates[0]
start_date = datetime.strptime(start_date_str, "%Y-%m-%d")
return start_date

@staticmethod
def _combine_predictions(
flat_lat: np.ndarray, flat_lon: np.ndarray, batch_predictions: List[np.ndarray]
Expand Down Expand Up @@ -84,13 +73,10 @@ def _on_single_batch(self, batch_x_np: np.ndarray) -> np.ndarray:
def run(
self,
local_path: Path,
start_date: Optional[datetime] = None,
dest_path: Optional[Path] = None,
) -> pd.DataFrame:
"""Runs inference on a single tif file."""
if start_date is None:
start_date = self.start_date_from_str(local_path)
x_np, flat_lat, flat_lon = process_test_file(local_path, start_date)
x_np, flat_lat, flat_lon = process_test_file(local_path)

if self.normalizing_dict is not None:
x_np = (x_np - self.normalizing_dict["mean"]) / self.normalizing_dict["std"]
Expand Down
28 changes: 7 additions & 21 deletions openmapflow/labeled_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import random
import shutil
import tempfile
import warnings
from dataclasses import dataclass
from datetime import date
from pathlib import Path
Expand Down Expand Up @@ -42,7 +41,7 @@
EarthEngineExporter,
get_cloud_tif_list,
)
from openmapflow.engineer import calculate_ndvi, fillna, load_tif, remove_bands
from openmapflow.engineer import calculate_ndvi, load_tif, remove_bands
from openmapflow.utils import str_to_np, tqdm

SEED = 42
Expand Down Expand Up @@ -145,21 +144,20 @@ def _find_matching_point(
So the function finds the closest grid coordinate to the label coordinate.
Additional value is given to a grid coordinate that is close to the center of the tif.
"""
tif_slope_tuples = []
tifs = []
for p in eo_paths:
blob = tif_bucket.blob(str(p))
local_path = Path(f"{temp_dir}/{p.name}")
if not local_path.exists():
blob.download_to_filename(str(local_path))
tif_slope_tuples.append(load_tif(local_path))
tifs.append(load_tif(local_path))
if local_path.exists():
local_path.unlink()

if len(tif_slope_tuples) > 1:
if len(tifs) > 1:
min_distance_from_point = np.inf
min_distance_from_center = np.inf
for i, tif_slope_tuple in enumerate(tif_slope_tuples):
tif, slope = tif_slope_tuple
for i, tif in enumerate(tifs):
lon, lon_idx = _find_nearest(tif.x, label_lon)
lat, lat_idx = _find_nearest(tif.y, label_lat)
distance_from_point = _distance(label_lat, label_lon, lat, lon)
Expand All @@ -173,22 +171,16 @@ def _find_matching_point(
min_distance_from_center = distance_from_center
min_distance_from_point = distance_from_point
eo_data = tif.sel(x=lon).sel(y=lat).values
average_slope = slope
eo_file = eo_paths[i].name
else:
tif, slope = tif_slope_tuples[0]
tif = tifs[0]
closest_lon = _find_nearest(tif.x, label_lon)[0]
closest_lat = _find_nearest(tif.y, label_lat)[0]
eo_data = tif.sel(x=closest_lon).sel(y=closest_lat).values
average_slope = slope
eo_file = eo_paths[0].name

eo_data = calculate_ndvi(eo_data)
eo_data = remove_bands(eo_data)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
eo_data = fillna(eo_data, average_slope)

return eo_data, closest_lon, closest_lat, eo_file


Expand All @@ -206,22 +198,16 @@ def _find_matching_point_url(
with local_path.open("wb") as f:
shutil.copyfileobj(r.raw, f)

tif, slope = load_tif(local_path)
tif = load_tif(local_path)
if local_path.exists():
local_path.unlink()

closest_lon = _find_nearest(tif.x, label_lon)[0]
closest_lat = _find_nearest(tif.y, label_lat)[0]

eo_data = tif.sel(x=closest_lon).sel(y=closest_lat).values
average_slope = slope

eo_data = calculate_ndvi(eo_data)
eo_data = remove_bands(eo_data)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
eo_data = fillna(eo_data, average_slope)

return eo_data, closest_lon, closest_lat


Expand Down
7 changes: 1 addition & 6 deletions openmapflow/torchserve_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,8 @@ def inference(self, data: str, *args, **kwargs) -> Tuple[str, str]:
local_src_path = download_file(uri)
local_dest_path = Path(tempfile.gettempdir() + f"/pred_{Path(uri).stem}.nc")

start_date = Inference.start_date_from_str(uri)
print(f"HANDLER: Start date: {start_date}")

print("HANDLER: Starting inference")
self.inference_module.run(
local_path=local_src_path, start_date=start_date, dest_path=local_dest_path
)
self.inference_module.run(local_path=local_src_path, dest_path=local_dest_path)
print("HANDLER: Completed inference")
dest_uri = upload_file(
bucket_name=self.dest_bucket_name, local_path=local_dest_path, src_uri=uri
Expand Down
Loading