Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

get rid of conf_threshold parameter as it may introduce "undecided" p… #188

Merged
merged 1 commit into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions notebooks/worldcereal_v1_demo_custom_cropland.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,6 @@
"postprocess_method = \"majority_vote\"\n",
"# Additiona parameters for the majority vote method:\n",
"kernel_size = 3 # default = 5\n",
"conf_threshold = 60 # default = 30\n",
"# Do you want to save the intermediate results (before applying the postprocessing)\n",
"save_intermediate = True #default is False\n",
"# Do you want to save all class probabilities in the final product? (default is False)\n",
Expand All @@ -499,7 +498,6 @@
"postprocess_parameters = PostprocessParameters(enable=postprocess_result,\n",
" method=postprocess_method,\n",
" kernel_size=kernel_size,\n",
" conf_threshold=conf_threshold,\n",
" save_intermediate=save_intermediate,\n",
" keep_class_probs=keep_class_probs)\n",
"\n",
Expand Down
2 changes: 0 additions & 2 deletions notebooks/worldcereal_v1_demo_custom_croptype.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,6 @@
"postprocess_method = \"majority_vote\"\n",
"# Additiona parameters for the majority vote method:\n",
"kernel_size = 5 # default = 5\n",
"conf_threshold = 30 # default = 30\n",
"# Do you want to save the intermediate results (before applying the postprocessing)\n",
"save_intermediate = True #default is False\n",
"# Do you want to save all class probabilities in the final product? (default is False)\n",
Expand All @@ -517,7 +516,6 @@
"postprocess_parameters = PostprocessParameters(enable=postprocess_result,\n",
" method=postprocess_method,\n",
" kernel_size=kernel_size,\n",
" conf_threshold=conf_threshold,\n",
" save_intermediate=save_intermediate,\n",
" keep_class_probs=keep_class_probs)\n",
"\n",
Expand Down
2 changes: 0 additions & 2 deletions notebooks/worldcereal_v1_demo_default_cropland_EXTENDED.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,6 @@
"postprocess_method = \"majority_vote\"\n",
"# Additiona parameters for the majority vote method:\n",
"kernel_size = 3 # default = 5\n",
"conf_threshold = 60 # default = 30\n",
"# Do you want to save the intermediate results (before applying the postprocessing)\n",
"save_intermediate = True #default is False\n",
"# Do you want to save all class probabilities in the final product? (default is False)\n",
Expand All @@ -231,7 +230,6 @@
"postprocess_parameters = PostprocessParameters(enable=postprocess_result,\n",
" method=postprocess_method,\n",
" kernel_size=kernel_size,\n",
" conf_threshold=conf_threshold,\n",
" save_intermediate=save_intermediate,\n",
" keep_class_probs=keep_class_probs)\n",
"\n",
Expand Down
14 changes: 3 additions & 11 deletions src/worldcereal/openeo/postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,10 @@ def majority_vote(
base_labels: xr.DataArray,
max_probabilities: xr.DataArray,
kernel_size: int,
conf_threshold: int,
) -> xr.DataArray:
"""Majority vote is performed using a sliding local kernel.
For each pixel, the voting of a final class is done from
neighbours values weighted with the confidence threshold.
For each pixel, the voting of a final class is done by counting
neighbours values.
Pixels that have one of the specified excluded values are
excluded in the voting process and are unchanged.

Expand All @@ -55,8 +54,6 @@ def majority_vote(
The original probabilities of the winning class (ranging between 0 and 100).
kernel_size : int
The size of the kernel used for the neighbour around the pixel.
conf_threshold : int
Pixels under this confidence threshold do not count into the voting process.

Returns
-------
Expand Down Expand Up @@ -93,9 +90,6 @@ def majority_vote(
# Take the binary mask of the interest class, and multiply by the probabilities
class_mask = ((prediction == cls_value) * probability).astype(np.uint16)

# Sets to 0 the class scores where the threshold is lower
class_mask[probability <= conf_threshold] = 0

# Set to 0 the class scores where the label is excluded
for excluded_value in cls.EXCLUDED_VALUES:
class_mask[prediction == excluded_value] = 0
Expand Down Expand Up @@ -156,7 +150,7 @@ def majority_vote(
# Setting excluded values back to their original values
for excluded_value in cls.EXCLUDED_VALUES:
aggregated_predictions[prediction == excluded_value] = excluded_value
aggregated_probabilities[prediction == excluded_value] = cls.NODATA
aggregated_probabilities[prediction == excluded_value] = excluded_value
kvantricht marked this conversation as resolved.
Show resolved Hide resolved

return xr.DataArray(
np.stack((aggregated_predictions, aggregated_probabilities)),
Expand Down Expand Up @@ -286,13 +280,11 @@ def execute(self, inarr: xr.DataArray) -> xr.DataArray:
elif self._parameters.get("method") == "majority_vote":

kernel_size = self._parameters.get("kernel_size")
conf_threshold = self._parameters.get("conf_threshold")

new_labels = PostProcessor.majority_vote(
inarr.sel(bands="classification"),
inarr.sel(bands="probability"),
kernel_size=kernel_size,
conf_threshold=conf_threshold,
)

# Append the per-class probabalities if required
Expand Down
7 changes: 0 additions & 7 deletions src/worldcereal/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,6 @@ class PostprocessParameters(BaseModel):
The method to use for postprocessing. Must be one of ["smooth_probabilities", "majority_vote"]
kernel_size: int (default=5)
Used for majority vote postprocessing. Must be smaller than 25.
conf_threshold: int (default=30)
Used for majority vote postprocessing. Must be between 0 and 100.
save_intermediate: bool (default=False)
Whether to save intermediate results (before applying the postprocessing).
The intermediate results will be saved in the GeoTiff format.
Expand All @@ -180,7 +178,6 @@ class PostprocessParameters(BaseModel):
enable: bool = Field(default=True)
method: str = Field(default="smooth_probabilities")
kernel_size: int = Field(default=5)
conf_threshold: int = Field(default=30)
save_intermediate: bool = Field(default=False)
keep_class_probs: bool = Field(default=False)

Expand Down Expand Up @@ -213,9 +210,5 @@ def check_parameters(self):
raise ValueError(
f"Kernel size must be smaller than 25, got {self.kernel_size}"
)
if self.conf_threshold < 0 or self.conf_threshold > 100:
raise ValueError(
f"Confidence threshold must be between 0 and 100, got {self.conf_threshold}"
)

return self
9 changes: 0 additions & 9 deletions tests/worldcerealtests/test_postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def test_cropland_postprocessing_majority_vote(WorldCerealCroplandClassification
"lookup_table": lookup_table,
"method": "majority_vote",
"kernel_size": 7,
"conf_threshold": 30,
},
)

Expand Down Expand Up @@ -90,7 +89,6 @@ def test_croptype_postprocessing_majority_vote(WorldCerealCroptypeClassification
"lookup_table": lookup_table,
"method": "majority_vote",
"kernel_size": 7,
"conf_threshold": 30,
},
)

Expand All @@ -103,7 +101,6 @@ def test_postprocessing_parameters():
"enable": True,
"method": "smooth_probabilities",
"kernel_size": 5,
"conf_threshold": 30,
"save_intermediate": False,
"keep_class_probs": False,
}
Expand All @@ -118,12 +115,6 @@ def test_postprocessing_parameters():
with pytest.raises(ValueError):
PostprocessParameters(**params)

# This one should fail with invalid conf_threshold
params["kernel_size"] = 5
params["conf_threshold"] = 101
with pytest.raises(ValueError):
PostprocessParameters(**params)

# This one should fail with invalid method
params["method"] = "test"
with pytest.raises(ValueError):
Expand Down
Loading