Skip to content

Commit

Permalink
Now working with dependency as zip file and presto code packed as whe…
Browse files Browse the repository at this point in the history
…el file
  • Loading branch information
GriffinBabe committed Jun 11, 2024
1 parent 34e4621 commit df87509
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 40 deletions.
7 changes: 3 additions & 4 deletions scripts/inference/cropland_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,7 @@
backend_context = BackendContext(Backend.FED)

connection = openeo.connect(
# "https://openeo.creo.vito.be/openeo/"
"https://openeo-staging.creo.vito.be/openeo/"
"https://openeo.creo.vito.be/openeo/"
).authenticate_oidc()

# Preparing the input cube for the inference
Expand Down Expand Up @@ -111,7 +110,7 @@
out_format="NetCDF",
job_options={
"driver-memory": "4g",
"executor-memoryOverhead": "8g",
"logging-threshold": "debug",
"executor-memoryOverhead": "12g",
"udf-dependency-archives": [f"{ONNX_DEPS_URL}#onnx_deps"],
},
)
44 changes: 30 additions & 14 deletions src/worldcereal/openeo/feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@ class PrestoFeatureExtractor(PatchFeatureExtractor):
specified, should be set as `False`.
"""

PRESTO_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal-minimal-inference/presto.pt" # NOQA
PRESTO_MODEL_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal-minimal-inference/presto.pt" # NOQA
PRESO_WHL_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/dependencies/presto_worldcereal-0.1.0-temp-py3-none-any.whl"
BASE_URL = "https://s3.waw3-1.cloudferro.com/swift/v1/project_dependencies" # NOQA
DEPENDENCY_NAME = "wc_presto_onnx_dependencies.zip"
DEPENDENCY_NAME = "worldcereal_deps.zip"

GFMAP_BAND_MAPPING = {
"S2-L2A-B02": "B02",
Expand All @@ -40,12 +41,18 @@ class PrestoFeatureExtractor(PatchFeatureExtractor):
"AGERA5-PRECIP": "precipitation-flux",
}

def dependencies(self) -> list:
"""Gives the presto dependencies from a wheel with all it's subdependencies."""
return [
"torch @ https://download.pytorch.org/whl/cpu/torch-2.0.0%2Bcpu-cp38-cp38-linux_x86_64.whl#sha256=354f281351cddb590990089eced60f866726415f7b287db5105514aa3c5f71ca",
"presto @ https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/dependencies/presto_worldcereal-0.0.1-py3-none-any.whl",
]
def unpack_presto_wheel(self, wheel_url: str, destination_dir: str) -> list:
import urllib.request
import zipfile
from pathlib import Path

# Downloads the wheel file
modelfile, _ = urllib.request.urlretrieve(
wheel_url, filename=Path.cwd() / Path(wheel_url).name
)
with zipfile.ZipFile(modelfile, "r") as zip_ref:
zip_ref.extractall(destination_dir)
return destination_dir

def output_labels(self) -> list:
"""Returns the output labels from this UDF, which is the output labels
Expand All @@ -54,13 +61,17 @@ def output_labels(self) -> list:

def execute(self, inarr: xr.DataArray) -> xr.DataArray:
import sys
from pathlib import Path

if self.epsg is None:
raise ValueError(
"EPSG code is required for Presto feature extraction, but was "
"not correctly initialized."
)
presto_url = self._parameters.get("presto_url", self.PRESTO_URL)
presto_model_url = self._parameters.get(
"presto_model_url", self.PRESTO_MODEL_URL
)
presto_wheel_url = self._parameters.get("presot_wheel_url", self.PRESO_WHL_URL)

# The below is required to avoid flipping of the result
# when running on OpenEO backend!
Expand All @@ -76,16 +87,21 @@ def execute(self, inarr: xr.DataArray) -> xr.DataArray:
inarr = inarr.fillna(65535)

# Unzip de dependencies on the backend
# self.logger.info("Unzipping dependencies")
# deps_dir = self.extract_dependencies(self.BASE_URL, self.DEPENDENCY_NAME)
self.logger.info("Unzipping dependencies")
deps_dir = self.extract_dependencies(self.BASE_URL, self.DEPENDENCY_NAME)
self.logger.info("Unpacking presto wheel")
deps_dir = self.unpack_presto_wheel(presto_wheel_url, deps_dir)

self.logger.info("Appending dependencies")
sys.path.append(str(deps_dir))

# self.logger.info("Appending dependencies")
# sys.path.append(str(deps_dir))
# Debug, print the dependency directory
self.logger.info(f"Dependency directory: {list(Path(deps_dir).iterdir())}")

from presto.inference import get_presto_features

self.logger.info("Extracting presto features")
features = get_presto_features(inarr, presto_url, self.epsg)
features = get_presto_features(inarr, presto_model_url, self.epsg)
return features

def _execute(self, cube: XarrayDataCube, parameters: dict) -> XarrayDataCube:
Expand Down
64 changes: 42 additions & 22 deletions src/worldcereal/openeo/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,41 +15,61 @@ class CroplandClassifier(ModelInference):
the public Presto model.
"""

import numpy as np

CATBOOST_PATH = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal-minimal-inference/wc_catboost.onnx" # NOQA
BASE_URL = "https://s3.waw3-1.cloudferro.com/swift/v1/project_dependencies" # NOQA
DEPENDENCY_NAME = "wc_presto_onnx_dependencies.zip"

def __init__(self):
super().__init__()

self.onnx_session = None

def dependencies(self) -> list:
"""Gives the presto dependencies from a wheel with all it's subdependencies."""
return [
"onnxruntime",
"torch @ https://download.pytorch.org/whl/cpu/torch-2.0.0%2Bcpu-cp38-cp38-linux_x86_64.whl#sha256=354f281351cddb590990089eced60f866726415f7b287db5105514aa3c5f71ca",
"presto @ https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/dependencies/presto_worldcereal-0.0.1-py3-none-any.whl",
]
return [] # Disable the dependencies from PIP install

def output_labels(self) -> list:
return ["classification"]

def execute(self, inarr: xr.DataArray) -> xr.DataArray:
import sys
def predict(self, features: np.ndarray) -> np.ndarray:
"""
Predicts labels using the provided features array.
"""
import numpy as np

classifier_url = self._parameters.get("classifier_url", self.CATBOOST_PATH)
if self.onnx_session is None:
raise ValueError("Model has not been loaded. Please load a model first.")

# shape and indiches for output
inarr = inarr.transpose("bands", "x", "y")
# Prepare input data for ONNX model
outputs = self.onnx_session.run(None, {"features": features})

# Unzip de dependencies on the backend
# self.logger.info("Unzipping dependencies")
# dep_dir = self.extract_dependencies(self.BASE_URL, self.DEPENDENCY_NAME)
# Threshold for binary conversion
threshold = 0.5

# Extract all prediction values and convert them to binary labels
prediction_values = [sublist["True"] for sublist in outputs[1]]
binary_labels = np.array(prediction_values) >= threshold
binary_labels = binary_labels.astype(int)

return binary_labels

def execute(self, inarr: xr.DataArray) -> xr.DataArray:
classifier_url = self._parameters.get("classifier_url", self.CATBOOST_PATH)

# self.logger.info("Adding dependencies")
# sys.path.append(str(dep_dir))
# shape and indices for output ("xy", "bands")
x_coords, y_coords = inarr.x.values, inarr.y.values
inarr = inarr.transpose("bands", "x", "y").stack(xy=["x", "y"]).transpose()

from presto.inference import classify_with_catboost
self.onnx_session = self.load_ort_session(classifier_url)

# Run catboost classification
self.logger.info("Catboost classification")
classification = classify_with_catboost(inarr, classifier_url)
self.logger.info("Done")
self.logger.info(f"Catboost classification with input shape: {inarr.shape}")
classification = self.predict(inarr.values)
self.logger.info(f"Classification done with shape: {classification.shape}")

classification = xr.DataArray(
classification.reshape((1, len(x_coords), len(y_coords))),
dims=["bands", "x", "y"],
coords={"bands": ["classification"], "x": x_coords, "y": y_coords},
)

return classification

0 comments on commit df87509

Please sign in to comment.