Skip to content

Commit

Permalink
Per #2550, update to version 0.7.0 of tc_diag_driver to make the land…
Browse files Browse the repository at this point in the history
… file optional.
  • Loading branch information
JohnHalleyGotway committed Oct 6, 2023
1 parent a45d5d7 commit 559fd12
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 44 deletions.
2 changes: 1 addition & 1 deletion scripts/python/tc_diag/tc_diag_driver/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.6.1"
__version__ = "0.7.0"
137 changes: 94 additions & 43 deletions scripts/python/tc_diag/tc_diag_driver/post_resample_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import inspect
import io
import pathlib
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Optional, Tuple
import sys

import numpy as np
import pandas as pd
Expand All @@ -21,12 +22,12 @@
ATCF_DELIM_CHAR = ","
ATCF_TECH_ID_COL = 4


@dataclasses.dataclass
class DriverConfig:
pressure_independent_computation_specs: List[Dict[str, Any]]
sounding_computation_specs: List[Dict[str, Any]]

land_lut_file: pathlib.Path
in_forecast_time_name: str
in_levels_name: str
in_radii_name: str
Expand All @@ -36,9 +37,15 @@ class DriverConfig:
init_time_name: str
init_time_format: str
full_track_line_name: str
radii_to_validate: List[float]
"""List of radii to check against the min/max radii found in the file."""

land_lut_file: Optional[pathlib.Path] = None
"""Land LUT file can be provided in config or as a calling arg to diag_calcs"""

def __post_init__(self):
self.land_lut_file = pathlib.Path(self.land_lut_file)
if self.land_lut_file is not None:
self.land_lut_file = pathlib.Path(self.land_lut_file)


def main():
Expand All @@ -47,7 +54,8 @@ def main():
config = config_from_file(args.config_file, DriverConfig)

results = diag_calcs(
config, args.data_file, suppress_exceptions=args.suppress_exceptions)
config, args.data_file, suppress_exceptions=args.suppress_exceptions
)
if args.out_dir is not None:
_dump_results(results, args.out_dir)

Expand All @@ -66,62 +74,71 @@ def config_from_file(filename: pathlib.Path, config_class: Any) -> Any:

def _get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description=
"Driver to perform diag computations from model data resampled to a cylindrical grid."
description="Driver to perform diag computations from model data resampled to a cylindrical grid."
)
parser.add_argument(
"config_file",
type=pathlib.Path,
help="YAML config file specifying how to process diag vars.")
help="YAML config file specifying how to process diag vars.",
)
parser.add_argument(
"data_file",
type=pathlib.Path,
help="NetCDF file containing model data resampled to cylindrical grid."
help="NetCDF file containing model data resampled to cylindrical grid.",
)
parser.add_argument(
"-o",
"--out_dir",
type=pathlib.Path,
default=None,
help="Optional directory to write results to for debugging purposes.")
parser.add_argument("-s", "--suppress_exceptions", action="store_true",
default=False, help="If this flag is set, then "
"exceptions encountered during diagnostic computations "
"will be logged and then ignored.")
help="Optional directory to write results to for debugging purposes.",
)
parser.add_argument(
"-s",
"--suppress_exceptions",
action="store_true",
default=False,
help="If this flag is set, then "
"exceptions encountered during diagnostic computations "
"will be logged and then ignored.",
)

return parser.parse_args()


def populate_missing_results(
config: DriverConfig, forecast_hour: int,
levels_hPa: List[int]) -> fcresults.ForecastHourResults:
_batches, results = _prep_diag_calculations(config, forecast_hour,
levels_hPa)
config: DriverConfig, forecast_hour: int, levels_hPa: List[int]
) -> fcresults.ForecastHourResults:
_batches, results = _prep_diag_calculations(config, forecast_hour, levels_hPa)
return results


def diag_calcs(
config: DriverConfig,
data_path: pathlib.Path,
suppress_exceptions: bool = False) -> fcresults.ForecastHourResults:
config: DriverConfig,
data_path: pathlib.Path,
suppress_exceptions: bool = False,
land_lut_override: Optional[pathlib.Path] = None,
) -> fcresults.ForecastHourResults:
# Gather various data necessary to perform diagnostic calculations
input_data = xr.load_dataset(data_path, engine="netcdf4")
forecast_hour = _get_forecast_hour(config, input_data)
levels_hPa = _get_pressure_levels(config, input_data)
init_time = _get_init_time(config, input_data)
land_lut = diag_vars.get_land_lut(config.land_lut_file)
land_lut_file = _get_land_lut_filename(config, land_lut_override)
land_lut = diag_vars.get_land_lut(land_lut_file)
radii_1d = input_data[config.in_radii_name]
_validate_radii(config.radii_to_validate, radii_1d, data_path)
azimuth_1d = input_data[config.in_azimuth_name]
theta_2d, radii_2d = np.meshgrid(azimuth_1d, radii_1d)
atcf_tech_id = _parse_atcf_id(input_data[config.full_track_line_name])
track = _dataset_track_lines_to_track(
input_data[config.full_track_line_name], atcf_tech_id)
input_data[config.full_track_line_name], atcf_tech_id
)

lon = input_data[config.lon_input_name][0]
lat = input_data[config.lat_input_name][0]

batches, results = _prep_diag_calculations(config, forecast_hour,
levels_hPa)
batches, results = _prep_diag_calculations(config, forecast_hour, levels_hPa)

call_args = {
"input_data": input_data,
Expand All @@ -135,7 +152,7 @@ def diag_calcs(
"forecast_hour": forecast_hour,
"init_time": init_time,
"track": track,
"results": results
"results": results,
}

for batch in batches:
Expand All @@ -144,30 +161,41 @@ def diag_calcs(
call_args,
forecast_hour,
levels_hPa,
suppress_computation_exceptions=suppress_exceptions)
suppress_computation_exceptions=suppress_exceptions,
)

return results


def _get_land_lut_filename(
config: DriverConfig, land_lut_override: Optional[pathlib.Path]
) -> pathlib.Path:
if land_lut_override is None:
return config.land_lut_file

return land_lut_override


def _prep_diag_calculations(
config: DriverConfig, forecast_hour: int, levels_hPa: List[int]
) -> Tuple[List[ce.ComputationBatch], fcresults.ForecastHourResults]:
pi_comps = ce.diag_computations_from_entry(
config.pressure_independent_computation_specs)
snd_comps = ce.diag_computations_from_entry(
config.sounding_computation_specs)

pi_result_names, snd_result_names = ce.get_all_result_names(
pi_comps, snd_comps)
results = fcresults.ForecastHourResults([forecast_hour], levels_hPa,
pi_result_names, snd_result_names)
config.pressure_independent_computation_specs
)
snd_comps = ce.diag_computations_from_entry(config.sounding_computation_specs)

pi_result_names, snd_result_names = ce.get_all_result_names(pi_comps, snd_comps)
results = fcresults.ForecastHourResults(
[forecast_hour], levels_hPa, pi_result_names, snd_result_names
)
batches = ce.get_computation_batches(pi_comps, snd_comps)

return batches, results


def _dump_results(results: fcresults.ForecastHourResults,
out_dir: pathlib.Path) -> None:
def _dump_results(
results: fcresults.ForecastHourResults, out_dir: pathlib.Path
) -> None:
sounding_filename = out_dir / "sounding.nc"
pressue_independent_filename = out_dir / "pressure_independent.nc"

Expand All @@ -179,22 +207,21 @@ def _get_forecast_hour(config: DriverConfig, input_data: xr.Dataset) -> int:
return int(input_data[config.in_forecast_time_name][0]) // LEAD_TIME_TO_HRS


def _get_pressure_levels(config: DriverConfig,
input_data: xr.Dataset) -> List[int]:
def _get_pressure_levels(config: DriverConfig, input_data: xr.Dataset) -> List[int]:
levels = input_data[config.in_levels_name]
return [round(float(level + LEVEL_EPSILON)) for level in levels]


def _get_init_time(config: DriverConfig,
input_data: xr.Dataset) -> dt.datetime:
def _get_init_time(config: DriverConfig, input_data: xr.Dataset) -> dt.datetime:
init_time_var = input_data[config.init_time_name]
init_time_str = str(init_time_var.values)

return dt.datetime.strptime(init_time_str, config.init_time_format)


def _dataset_track_lines_to_track(track_lines: xr.DataArray,
atcf_tech_id: str) -> pd.DataFrame:
def _dataset_track_lines_to_track(
track_lines: xr.DataArray, atcf_tech_id: str
) -> pd.DataFrame:
lines = []
for line in track_lines:
lines.append(str(line.values))
Expand All @@ -203,11 +230,35 @@ def _dataset_track_lines_to_track(track_lines: xr.DataArray,
track = track_tools.get_adeck_track(line_buffer, atcf_tech_id)
return track


def _parse_atcf_id(track_lines: xr.DataArray) -> str:
first_line = str(track_lines[0].values)
split_line = first_line.split(ATCF_DELIM_CHAR)
atcf_id = split_line[ATCF_TECH_ID_COL].strip()
return atcf_id


def _validate_radii(
radii_to_validate: List[float], radii_1d: List[float], data_path: pathlib.Path
) -> None:
min_radii = float(min(radii_1d))
max_radii = float(max(radii_1d))

for i, radius in enumerate(radii_to_validate):
if radius < min_radii - sys.float_info.epsilon:
msg = (
f"Radius: {radius} at index: {i} of config param radii_to_validate is < "
f"min radius: {min_radii} in file: {data_path}"
)
raise ValueError(msg)

if radius > max_radii + sys.float_info.epsilon:
msg = (
f"Radius: {radius} at index: {i} of config param radii_to_validate is > "
f"max radius: {max_radii} in file: {data_path}"
)
raise ValueError(msg)


if __name__ == "__main__":
main()
main()

0 comments on commit 559fd12

Please sign in to comment.