Skip to content

Commit

Permalink
work in progress
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed Aug 28, 2024
1 parent 57b63bd commit 86888a6
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 11 deletions.
2 changes: 0 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,3 @@
See https://microsoft.github.io/aurora/intro.html

See https://github.com/microsoft/aurora


112 changes: 103 additions & 9 deletions src/ai_models_aurora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,119 @@

import logging

from anemoi.inference.plugin import AIModelPlugin
from aurora import Aurora, AuroraSmall, rollout
import numpy as np
import torch
from ai_models.model import Model
from aurora import Aurora
from aurora import Batch
from aurora import Metadata
from aurora import rollout

LOG = logging.getLogger(__name__)


class AuroraModel(AIModelPlugin):
expver = "auro"
class AuroraModel(Model):

# Input
area = [90, 0, -90, 360 - 0.25]
grid = [0.25, 0.25]

surf_vars = ("2t", "10u", "10v", "msl")
static_vars = ("lsm", "z", "slt")
atmos_vars = ("z", "u", "v", "t", "q")
levels = (1000, 925, 850, 700, 600, 500, 400, 300, 250, 200, 150, 100, 50)

lagged = (-6, 0)

# Download
download_files = ["checkpoint.ckpt"]
# For the MARS requets
param_sfc = surf_vars + static_vars
param_level_pl = (atmos_vars, levels)

# Output

expver = "auro"

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.model =Aurora()
self.model = Aurora()

def run(self):
model = AuroraSmall()
model.load_checkpoint("microsoft/aurora", "aurora-0.25-small-pretrained.ckpt")
LOG.info("Running Aurora model")
model = Aurora(use_lora=False) # Model is not fine-tuned.
model = model.to(self.device)
LOG.info("Downloading Aurora model")
# TODO: control location of cache
model.load_checkpoint("microsoft/aurora", "aurora-0.25-pretrained.ckpt")
LOG.info("Loading Aurora model to device %s", self.device)

model = model.to(self.device)
model.eval()

fields_pl = self.fields_pl
fields_sfc = self.fields_sfc

N, W, S, E = self.area
WE, NS = self.grid
Nj = round((N - S) / NS) + 1
Ni = round((E - W) / WE) + 1

to_numpy_kwargs = dict(dtype=np.float32)




# Shape (Batch, Time, Lat, Lon)
surf_vars = {}

for k in self.surf_vars:
f = fields_sfc.sel(param=k).order_by(valid_datetime="ascending")
f = f.to_numpy(**to_numpy_kwargs)
f = torch.from_numpy(f)
f = f.unsqueeze(0) # Add batch dimension
print(f.shape)
surf_vars[k] = f

# Shape (Lat, Lon)
static_vars = {}
for k in self.static_vars:
f = fields_sfc.sel(param=k).order_by(valid_datetime="ascending")
f =f.to_numpy(**to_numpy_kwargs)[-1]
f = torch.from_numpy(f)
print(f.shape)
static_vars[k] = f

# Shape (Batch, Time, Level, Lat, Lon)
atmos_vars = {}
for k in self.atmos_vars:
f = fields_pl.sel(param=k).order_by(valid_datetime="ascending", level=self.levels)
f = f.to_numpy(**to_numpy_kwargs).reshape(len(self.lagged), len(self.levels), Nj, Ni)
f = torch.from_numpy(f)
f = f.unsqueeze(0) # Add batch dimension
print(f.shape)
atmos_vars[k] = f


# https://microsoft.github.io/aurora/batch.html

batch = Batch(
surf_vars=surf_vars,
static_vars=static_vars,
atmos_vars=atmos_vars,
metadata=Metadata(
lat=torch.linspace(N, S, Nj),
lon=torch.linspace(W, E, Ni),
time=self.start_datetime,
atmos_levels=self.levels,
),
)

print(batch.metadata.lat.shape)
print(batch.metadata.lon.shape)

with torch.inference_mode():

for pred in rollout(model, batch, steps=10):
print(pred.metadata.time)


model = AuroraModel

0 comments on commit 86888a6

Please sign in to comment.