Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed Aug 28, 2024
1 parent 86888a6 commit f1cd7eb
Showing 1 changed file with 27 additions and 15 deletions.
42 changes: 27 additions & 15 deletions src/ai_models_aurora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,43 +61,45 @@ def run(self):
N, W, S, E = self.area
WE, NS = self.grid
Nj = round((N - S) / NS) + 1
Ni = round((E - W) / WE) + 1
Ni = round((E - W) / WE) + 1

to_numpy_kwargs = dict(dtype=np.float32)



templates = {}

# Shape (Batch, Time, Lat, Lon)
surf_vars = {}

for k in self.surf_vars:
f = fields_sfc.sel(param=k).order_by(valid_datetime="ascending")
templates[k] = f[-1]
f = f.to_numpy(**to_numpy_kwargs)
f = torch.from_numpy(f)
f = f.unsqueeze(0) # Add batch dimension
print(f.shape)
f = f.unsqueeze(0) # Add batch dimension
surf_vars[k] = f

# Shape (Lat, Lon)
static_vars = {}
for k in self.static_vars:
f = fields_sfc.sel(param=k).order_by(valid_datetime="ascending")
f =f.to_numpy(**to_numpy_kwargs)[-1]
f = f.to_numpy(**to_numpy_kwargs)[-1]
f = torch.from_numpy(f)
print(f.shape)
static_vars[k] = f

# Shape (Batch, Time, Level, Lat, Lon)
atmos_vars = {}
for k in self.atmos_vars:
f = fields_pl.sel(param=k).order_by(valid_datetime="ascending", level=self.levels)

for level in self.levels:
templates[(k, level)] = f.sel(level=level)[-1]

f = f.to_numpy(**to_numpy_kwargs).reshape(len(self.lagged), len(self.levels), Nj, Ni)
f = torch.from_numpy(f)
f = f.unsqueeze(0) # Add batch dimension
print(f.shape)
f = f.unsqueeze(0) # Add batch dimension
atmos_vars[k] = f

self.write_input_fields(fields_pl + fields_sfc)

# https://microsoft.github.io/aurora/batch.html

Expand All @@ -108,18 +110,28 @@ def run(self):
metadata=Metadata(
lat=torch.linspace(N, S, Nj),
lon=torch.linspace(W, E, Ni),
time=self.start_datetime,
time=(self.start_datetime,),
atmos_levels=self.levels,
),
)

print(batch.metadata.lat.shape)
print(batch.metadata.lon.shape)

LOG.info("Starting inference")
with torch.inference_mode():

for pred in rollout(model, batch, steps=10):
print(pred.metadata.time)
with self.stepper(6) as stepper:
for i, pred in enumerate(rollout(model, batch, steps=self.lead_time // 6)):
step = (i + 1) * 6

for k, v in pred.surf_vars.items():
v = v.cpu().numpy()
self.write(v, template=templates[k], step=step)

for k, v in pred.atmos_vars.items():
v = v.cpu().numpy()
for i, level in enumerate(self.levels):
self.write(v[:, :, i], template=templates[(k, level)], step=step)

stepper(i, step)


model = AuroraModel

0 comments on commit f1cd7eb

Please sign in to comment.