Skip to content

Commit

Permalink
class name constructed from task_type; cleaned unused pieces
Browse files Browse the repository at this point in the history
  • Loading branch information
Butsko Christina committed Aug 5, 2024
1 parent 532ec67 commit a350280
Showing 1 changed file with 23 additions and 58 deletions.
81 changes: 23 additions & 58 deletions presto/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,57 +9,46 @@
import pandas as pd
import torch
from catboost import CatBoostClassifier, Pool

from .hierarchical_classification import (
CatBoostClassifierWrapper,
LocalClassifierPerNodeWrapper,
LocalClassifierPerParentNodeWrapper,
)
from hiclass import LocalClassifierPerParentNode, LocalClassifierPerNode

from hiclass import LocalClassifierPerNode, LocalClassifierPerParentNode
from sklearn.base import BaseEstimator, clone
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import f1_score, precision_score, recall_score, classification_report
from sklearn.metrics import classification_report, f1_score, precision_score, recall_score
from torch import nn
from torch.optim import AdamW, lr_scheduler
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm

from .dataset import (
NORMED_BANDS,
CLASS_MAPPINGS,
NORMED_BANDS,
WorldCerealInferenceDataset,
WorldCerealLabelled10DDataset,
WorldCerealLabelledDataset,
)
from .presto import (
Presto,
PrestoFineTuningModel,
get_sinusoid_encoding_table,
param_groups_lrd,
from .hierarchical_classification import (
CatBoostClassifierWrapper,
LocalClassifierPerNodeWrapper,
LocalClassifierPerParentNodeWrapper,
)
from .presto import Presto, PrestoFineTuningModel, get_sinusoid_encoding_table, param_groups_lrd
from .utils import DEFAULT_SEED, device


logger = logging.getLogger("__main__")

SklearnStyleModel = Union[BaseEstimator, CatBoostClassifier]


@dataclass
class Hyperparams:
# lr: float = 2e-5
lr: float = 0.0001
lr: float = 2e-5
max_epochs: int = 1000
batch_size: int = 256
# batch_size: int = 2048
patience: int = 20
num_workers: int = 8


class WorldCerealEval:
name = "WorldCerealCroptype"
threshold = 0.5
regression = False

Expand All @@ -82,11 +71,9 @@ def __init__(
finetune_classes: str = "CROPTYPE0",
):
self.seed = seed

if name is not None:
self.name = name
self.target_function = target_function
self.task_type = task_type
self.name = f"WorldCereal{task_type.title()}"

train_data, val_data = WorldCerealLabelledDataset.split_df(train_data, val_size=val_size)
self.train_df = self.prep_dataframe(train_data, filter_function, dekadal=dekadal)
Expand Down Expand Up @@ -200,8 +187,6 @@ def _construct_finetuning_model(self, pretrained_model: Presto) -> PrestoFineTun
@torch.no_grad()
def finetune_sklearn_model(
self,
# dl: DataLoader,
# val_dl: DataLoader,
pretrained_model: PrestoFineTuningModel,
models: List[str] = ["Regression", "Random Forest"],
) -> Union[Sequence[BaseEstimator], Dict]:
Expand Down Expand Up @@ -406,7 +391,7 @@ def _inference_for_dl(
probs = preds.copy()
if task_type == "croptype":
preds = finetuned_model.predict(encodings)

# for hierarchical classification, get predictions on the most granular level
if preds.ndim > 1:
preds = preds[:, -1]
Expand Down Expand Up @@ -454,14 +439,14 @@ def spatial_inference(
if self.task_type == "croptype":
if pretrained_model is None:
temp_croptype_map = pd.DataFrame(
CLASS_MAPPINGS["CROPTYPE0"].items(), columns=["ewoc_code", "name"]
CLASS_MAPPINGS["CROPTYPE19"].items(), columns=["ewoc_code", "name"]
)
test_preds_np = np.argmax(test_preds_np, axis=-1)
test_preds_str = np.array([self.croptype_list[xx] for xx in test_preds_np])
else:
test_preds_np = np.argmax(test_probs_np, axis=-1)
temp_croptype_map = pd.DataFrame(
CLASS_MAPPINGS["CROPTYPE9"].items(), columns=["ewoc_code", "name"]
CLASS_MAPPINGS["CROPTYPE19"].items(), columns=["ewoc_code", "name"]
)

temp_croptype_map.sort_values(
Expand All @@ -486,9 +471,17 @@ def spatial_inference(
df = ds.combine_predictions(
latlons, test_preds_np, test_preds_np, test_preds_np, y, ndvi, b2, b3, b4
)
if self.task_type == "croptype":
if self.task_type == "croptype":
df = ds.combine_predictions(
latlons, test_preds_np, test_preds_ewoc_code, test_probs_np, y, ndvi, b2, b3, b4
latlons,
test_preds_np,
test_preds_ewoc_code,
test_probs_np,
y,
ndvi,
b2,
b3,
b4,
)
prefix = f"{self.name}_{ds.all_files[i].stem}"
if pretrained_model is None:
Expand Down Expand Up @@ -778,35 +771,7 @@ def finetuning_results_sklearn(
):
results_df = pd.DataFrame()
if len(sklearn_model_modes) > 0:
# dl = DataLoader(
# WorldCerealLabelledDataset(
# self.train_df,
# countries_to_remove=self.countries_to_remove,
# years_to_remove=self.years_to_remove,
# target_function=self.target_function,
# task_type=self.task_type,
# croptype_list=[],
# ),
# batch_size=2048,
# shuffle=False,
# num_workers=8,
# )
# val_dl = DataLoader(
# WorldCerealLabelledDataset(
# self.val_df,
# countries_to_remove=self.countries_to_remove,
# years_to_remove=self.years_to_remove,
# target_function=self.target_function,
# task_type=self.task_type,
# croptype_list=[],
# ),
# batch_size=2048,
# shuffle=False,
# num_workers=8,
# )
sklearn_models = self.finetune_sklearn_model(
# dl,
# val_dl,
finetuned_model,
models=sklearn_model_modes,
)
Expand Down

0 comments on commit a350280

Please sign in to comment.