Skip to content

Commit

Permalink
Merge pull request #59 from WorldCereal/flexible-batchsize
Browse files Browse the repository at this point in the history
Make `batch_size` for Presto flexible
  • Loading branch information
GriffinBabe authored Jun 12, 2024
2 parents 77d5efc + 50caab3 commit 679f33c
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions src/worldcereal/openeo/feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class PrestoFeatureExtractor(PatchFeatureExtractor):
"""

PRESTO_MODEL_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal-minimal-inference/presto.pt" # NOQA
PRESO_WHL_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/dependencies/presto_worldcereal-0.1.0-py3-none-any.whl"
PRESO_WHL_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/dependencies/presto_worldcereal-0.1.1-py3-none-any.whl"
BASE_URL = "https://s3.waw3-1.cloudferro.com/swift/v1/project_dependencies" # NOQA
DEPENDENCY_NAME = "worldcereal_deps.zip"

Expand Down Expand Up @@ -103,7 +103,9 @@ def execute(self, inarr: xr.DataArray) -> xr.DataArray:
)

self.logger.info("Extracting presto features")
features = get_presto_features(inarr, presto_model_url, self.epsg)
features = get_presto_features(
inarr, presto_model_url, self.epsg, batch_size=4096
)
return features

def _execute(self, cube: XarrayDataCube, parameters: dict) -> XarrayDataCube:
Expand Down

0 comments on commit 679f33c

Please sign in to comment.