Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multiple dicot pipeline #71

Merged

Conversation

eberrigan
Copy link
Collaborator

@eberrigan eberrigan commented Mar 27, 2024

  • MultipleDicotPipeline added
  • Series class altered to take in expected number of plants per series
  • Functions for computing series and batch traits using DicotPipeline for multiple plants in a series created
  • Functions for filtering and matching roots added
  • Relevant tests added

Summary by CodeRabbit

  • New Features
    • Introduced rules for tracking large files using Git LFS for specific file extensions.
    • Added MultipleDicotPipeline for enhanced pipeline functionality.
    • Expanded Series class functionality to include loading expected plant count from a CSV file.
    • New functions for calculating traits, handling NaN values, and associating lateral roots with primary roots.
    • Introduced a custom NumPy array encoder for JSON serialization.
  • Bug Fixes
    • Added input validation in get_max_length_pts to handle empty arrays and incorrect shapes.
  • Documentation
    • Updated package version.
  • Tests
    • Added fixtures and tests for new functionalities including the minimum distance calculation between two lines and root association methods.
  • Chores
    • Added version control information for Git LFS across multiple test data files.

@eberrigan eberrigan requested a review from talmo March 27, 2024 03:28
Copy link
Contributor

coderabbitai bot commented Mar 27, 2024

Walkthrough

This update encompasses significant enhancements across the project, focusing on improved data management with Git LFS, expanded root analysis functionalities in sleap_roots, and comprehensive testing for new features. The key developments include Git LFS integration for handling large files, introduction of new root analysis functions, and thorough testing to ensure reliability. The version upgrade in sleap_roots signifies the project's growth in capabilities and maturity.

Changes

Files Changes
.gitattributes
sleap_roots/.../6039_1.h5, 6039_1.lateral.predictions.slp, 7327_2.h5, tests/.../997_1.h5, 997_1.lateral.predictions.slp
Introduced rules for tracking large files using Git LFS for .h5 and .slp files.
Added version control info for Git LFS.
sleap_roots/__init__.py
sleap_roots/series.py
sleap_roots/trait_pipelines.py
tests/fixtures/data.py
tests/test_trait_pipelines.py
Added MultipleDicotPipeline to exports; updated version to "0.0.7".
Expanded Series class to load expected plant count from CSV; added error handling.
Enhanced trait computation methods; introduced custom NumPy array encoder for JSON serialization.
Added fixtures for Arabidopsis 11-day-old data handling.
Updated tests for MultipleDicotPipeline functionality.
sleap_roots/lengths.py
sleap_roots/points.py
tests/test_lengths.py
tests/test_points.py
Enhanced get_max_length_pts function; added input validation and get_min_distance_line_to_line function.
Added functions for root filtering, association, and plotting; imported visualization modules.
Introduced get_min_distance_line_to_line for line distance calculation.
Added functions for root association and filtering; updated tests.

Poem

Among bytes and branches, the rabbit did play,
Git LFS and roots, in a digital array.
🌿 Functions sprouted, tests danced in delight,
Data woven in code, a symphony so bright.
🐇 Through CSVs and traits, the project did thrive,
In the garden of updates, where roots come alive.

Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

Share

Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai generate interesting stats about this repository and render them as a table.
    • @coderabbitai show all the console.log statements in this repository.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (invoked as PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger a review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai help to get help.

Additionally, you can add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.

CodeRabbit Configration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • The JSON schema for the configuration file is available here.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/coderabbit-overrides.v2.json

CodeRabbit Discord Community

Join our Discord Community to get help, request features, and share feedback.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review Status

Actionable comments generated: 1

Configuration used: CodeRabbit UI

Commits Files that changed from the base of the PR and between e54d7f7 and de7853d.
Files ignored due to path filters (1)
  • tests/data/multiple_arabidopsis_11do/merged_proofread_samples_03122024.csv is excluded by: !**/*.csv
Files selected for processing (24)
  • .gitattributes (1 hunks)
  • sleap_roots/init.py (1 hunks)
  • sleap_roots/lengths.py (2 hunks)
  • sleap_roots/points.py (2 hunks)
  • sleap_roots/series.py (9 hunks)
  • sleap_roots/trait_pipelines.py (7 hunks)
  • tests/data/multiple_arabidopsis_11do/6039_1.h5 (1 hunks)
  • tests/data/multiple_arabidopsis_11do/6039_1.lateral.predictions.slp (1 hunks)
  • tests/data/multiple_arabidopsis_11do/6039_1.primary.predictions.slp (1 hunks)
  • tests/data/multiple_arabidopsis_11do/7327_2.h5 (1 hunks)
  • tests/data/multiple_arabidopsis_11do/7327_2.lateral.predictions.slp (1 hunks)
  • tests/data/multiple_arabidopsis_11do/7327_2.primary.predictions.slp (1 hunks)
  • tests/data/multiple_arabidopsis_11do/9535_1.h5 (1 hunks)
  • tests/data/multiple_arabidopsis_11do/9535_1.lateral.predictions.slp (1 hunks)
  • tests/data/multiple_arabidopsis_11do/9535_1.not_proofread_lateral.predictions.slp (1 hunks)
  • tests/data/multiple_arabidopsis_11do/9535_1.not_proofread_primary.predictions.slp (1 hunks)
  • tests/data/multiple_arabidopsis_11do/9535_1.primary.predictions.slp (1 hunks)
  • tests/data/multiple_arabidopsis_11do/997_1.h5 (1 hunks)
  • tests/data/multiple_arabidopsis_11do/997_1.lateral.predictions.slp (1 hunks)
  • tests/data/multiple_arabidopsis_11do/997_1.primary.predictions.slp (1 hunks)
  • tests/fixtures/data.py (1 hunks)
  • tests/test_lengths.py (2 hunks)
  • tests/test_points.py (2 hunks)
  • tests/test_trait_pipelines.py (2 hunks)
Files skipped from review due to trivial changes (15)
  • .gitattributes
  • tests/data/multiple_arabidopsis_11do/6039_1.h5
  • tests/data/multiple_arabidopsis_11do/6039_1.lateral.predictions.slp
  • tests/data/multiple_arabidopsis_11do/6039_1.primary.predictions.slp
  • tests/data/multiple_arabidopsis_11do/7327_2.h5
  • tests/data/multiple_arabidopsis_11do/7327_2.lateral.predictions.slp
  • tests/data/multiple_arabidopsis_11do/7327_2.primary.predictions.slp
  • tests/data/multiple_arabidopsis_11do/9535_1.h5
  • tests/data/multiple_arabidopsis_11do/9535_1.lateral.predictions.slp
  • tests/data/multiple_arabidopsis_11do/9535_1.not_proofread_lateral.predictions.slp
  • tests/data/multiple_arabidopsis_11do/9535_1.not_proofread_primary.predictions.slp
  • tests/data/multiple_arabidopsis_11do/9535_1.primary.predictions.slp
  • tests/data/multiple_arabidopsis_11do/997_1.h5
  • tests/data/multiple_arabidopsis_11do/997_1.lateral.predictions.slp
  • tests/data/multiple_arabidopsis_11do/997_1.primary.predictions.slp
Additional comments: 29
sleap_roots/__init__.py (2)
  • 20-20: The addition of MultipleDicotPipeline to the exported entities aligns with the PR's objectives to enhance the package's capabilities. This change makes the new functionality accessible to users.
  • 26-26: Updating the package version from "0.0.6" to "0.0.7" correctly reflects the addition of new features. This version increment follows semantic versioning principles, indicating a minor update.
tests/fixtures/data.py (1)
  • 94-115: The addition of fixtures for multiple Arabidopsis 11-day-old data, including paths for image stacks, primary root predictions, and lateral root predictions, is essential for testing the new functionality related to handling multiple plants. These fixtures provide the necessary paths to test data, supporting comprehensive testing of the new features.
sleap_roots/lengths.py (2)
  • 13-40: The modifications to get_max_length_pts enhance the function's robustness by adding input validation and handling different input shapes. These changes ensure that the function can process both single instances and multiple instances of root landmarks effectively, improving the utility of the function for various use cases.
  • 137-158: The addition of get_min_distance_line_to_line expands the package's geometric analysis capabilities by enabling the calculation of the minimum distance between two LineString objects. This function is valuable for analyzing spatial relationships between different root structures, and the implementation includes necessary input validation for reliability.
tests/test_trait_pipelines.py (3)
  • 1-1: The addition of import numpy as np is necessary for numerical computations in the new test function test_multiple_dicot_pipeline. This change supports the testing of the new functionality.
  • 6-6: Adding MultipleDicotPipeline to the list of imports is essential for testing its functionality within the new test function test_multiple_dicot_pipeline. This change enables comprehensive testing of the new feature.
  • 140-174: The new test function test_multiple_dicot_pipeline is crucial for validating the functionality of the MultipleDicotPipeline. It tests the pipeline's ability to compute traits for multiple dicot plants, ensuring the feature works as expected. This addition aligns with the PR's objectives to enhance the software's capabilities in analyzing root traits of multiple plants.
tests/test_lengths.py (1)
  • 150-171: The new test function test_min_distance_line_to_line provides comprehensive testing for the get_min_distance_line_to_line function, covering various scenarios including non-intersecting lines, intersecting lines, parallel lines, and invalid input types. This thorough testing ensures the function's reliability and correctness.
sleap_roots/series.py (3)
  • 48-48: The addition of the csv_path attribute to the Series class supports the new functionality of loading expected plant counts from a CSV file. This attribute is essential for specifying the path to the CSV file containing the expected plant counts.
  • 133-147: The implementation of the expected_count property enhances the Series class by enabling the retrieval of expected plant counts from a CSV file. This addition is crucial for analyses that depend on the expected number of plants. The error handling for missing CSV files and unmatched series names ensures robustness.
  • 273-281: > 📝 NOTE

This review was outside the diff hunks and was mapped to the diff hunk with the greatest overlap. Original lines [251-304]

The error handling in the property getters for primary, lateral, and crown root points improves the usability of the Series class by providing clear feedback when labels are not available. This ensures that users are informed of missing data, enhancing the class's reliability.

sleap_roots/points.py (4)
  • 4-7: Import statements for matplotlib, Line2D, LineString, and nearest_points are correctly added to support the new functionalities introduced in this file.
  • 294-317: The filter_roots_with_nans function correctly filters out roots containing NaN values. It includes input validation and handles the case where all roots contain NaN values by returning an empty array. This approach ensures robustness in data processing.
  • 320-361: The filter_plants_with_unexpected_ct function correctly filters primary and lateral roots based on an expected count. It includes comprehensive input validation and correctly handles NaN expected counts. The approach of adjusting primary and lateral roots to empty arrays when the count does not match is clear and effective.
  • 536-593: The plot_root_associations function provides a visual representation of the associations between primary and lateral roots, including the minimum distance lines. It correctly uses matplotlib for plotting and customizes the legend and color map. However, the use of a red dashed line ("r--") for minimum distance might conflict with the comment about ensuring the color map does not include red. This is not a critical issue but something to be aware of in terms of visual clarity.
tests/test_points.py (5)
  • 3-16: The import statements are correctly updated to include LineString and additional functions from sleap_roots.points. This ensures that the new functionalities introduced in sleap_roots/points.py are properly tested.
  • 364-382: The test_associate_basic function correctly tests the basic association between one primary and one lateral root. It includes comprehensive assertions to verify the structure and content of the association result. This test effectively validates the expected behavior of the associate_lateral_to_primary function.
  • 385-393: The test_associate_no_primary function correctly tests the scenario where there are no primary roots. It validates that an empty dictionary is returned, which is the expected behavior. This test ensures that the function handles edge cases gracefully.
  • 621-626: The test_filter_roots_with_nans_no_nans function effectively tests the filter_roots_with_nans function with an input array that contains no NaN values. It correctly asserts that the original array should be returned, validating the function's behavior in a scenario without NaN values.
  • 687-696: The test_filter_plants_with_unexpected_ct_valid_input_matching_count function correctly tests the scenario where the number of primary roots matches the expected count. It validates that the original primary and lateral points arrays are returned, ensuring the function behaves as expected when the count matches.
sleap_roots/trait_pipelines.py (8)
  • 3-3: The addition of import json is appropriate for the new functionality related to JSON serialization of NumPy arrays.
  • 118-125: The NumpyArrayEncoder class is a well-implemented custom JSON encoder for handling NumPy arrays. It correctly checks if the object is an instance of np.ndarray and converts it to a list, which is JSON serializable.
  • 127-127: The TraitDef class is a significant addition that provides a structured way to define traits for analysis. It includes comprehensive documentation and a clear structure, enhancing maintainability and readability.
  • 261-268: The csv_traits_multiple_plants property method is a thoughtful addition for handling CSV traits specific to scenarios involving multiple plants. It ensures that only traits marked for inclusion in CSVs are processed, which is crucial for performance and correctness.
  • 376-481: The compute_multiple_dicots_traits function introduces complex logic for computing traits across multiple dicots. It's well-structured and includes error handling for file operations. However, consider adding more detailed error messages to improve debuggability.

Consider enhancing the error messages in the exception handling blocks to include more context about the failure, which can aid in debugging.

  • 525-568: The compute_batch_multiple_dicots_traits function efficiently processes a batch of series with multiple dicots. It demonstrates good use of existing functionality and maintains consistency in handling CSV output. The logging of the processing status for each series is a helpful addition for users.
  • 1977-2046: The MultipleDicotPipeline class is a crucial addition for handling multiple dicot plants. It defines traits specific to this scenario and leverages existing functions for filtering and associating roots. The structure and documentation of this class contribute to its readability and maintainability.
  • 2051-2071: The get_initial_frame_traits method in the MultipleDicotPipeline class correctly gathers initial traits for a plant frame, including handling the expected plant count. This method is essential for initializing the trait computation process and is implemented correctly.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review Status

Actionable comments generated: 0

Configuration used: CodeRabbit UI

Commits Files that changed from the base of the PR and between de7853d and da510f6.
Files selected for processing (1)
  • sleap_roots/lengths.py (2 hunks)
Files skipped from review as they are similar to previous changes (1)
  • sleap_roots/lengths.py

Copy link

codecov bot commented Mar 27, 2024

Codecov Report

Attention: Patch coverage is 54.80769% with 141 lines in your changes are missing coverage. Please review.

Project coverage is 74.90%. Comparing base (e54d7f7) to head (40fc672).

Files Patch % Lines
sleap_roots/trait_pipelines.py 39.50% 98 Missing ⚠️
sleap_roots/points.py 72.38% 29 Missing ⚠️
sleap_roots/series.py 56.66% 13 Missing ⚠️
sleap_roots/lengths.py 90.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main      #71      +/-   ##
==========================================
- Coverage   80.65%   74.90%   -5.76%     
==========================================
  Files          13       13              
  Lines        1003     1307     +304     
==========================================
+ Hits          809      979     +170     
- Misses        194      328     +134     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Review Status

Configuration used: CodeRabbit UI

Commits Files that changed from the base of the PR and between da510f6 and 57071f6.
Files selected for processing (8)
  • .github/workflows/ci.yml (1 hunks)
  • sleap_roots/bases.py (1 hunks)
  • sleap_roots/points.py (2 hunks)
  • sleap_roots/series.py (9 hunks)
  • sleap_roots/trait_pipelines.py (7 hunks)
  • tests/fixtures/data.py (1 hunks)
  • tests/test_series.py (3 hunks)
  • tests/test_trait_pipelines.py (2 hunks)
Files skipped from review as they are similar to previous changes (3)
  • sleap_roots/series.py
  • tests/fixtures/data.py
  • tests/test_trait_pipelines.py
Additional comments not posted (17)
.github/workflows/ci.yml (1)

69-70: Enabling Git LFS during the repository checkout step is a crucial addition for handling large files in CI workflows. This ensures that all necessary files are properly fetched, facilitating more comprehensive testing and integration checks.

tests/test_series.py (3)

9-13: Adding a series_instance fixture simplifies the creation of a Series instance with dummy data for testing. This enhances test modularity and reusability.


50-57: The csv_path fixture for creating a dummy CSV file is a valuable addition for testing series properties. Ensure that the dummy CSV content aligns with the expected format and fields for Series instances.


79-86: Utilizing the new fixtures in tests, as seen in test_series_name and test_expected_count, improves test clarity and maintainability. These modifications ensure that the tests are more focused and easier to understand.

sleap_roots/bases.py (1)

279-282: Initializing default_dists, default_left_bases, and default_right_bases with NaN values using np.full ensures consistent handling of missing data. This modification improves the clarity and robustness of the get_root_widths function by clearly indicating cases with no valid data.

sleap_roots/points.py (4)

4-7: The addition of imports for matplotlib and shapely supports the new functionality for plotting root associations and performing spatial analysis. These imports are essential for the added capabilities.


294-317: The filter_roots_with_nans function is a useful addition for preprocessing root data by removing instances with NaN values. This function enhances data cleanliness and reliability for subsequent analyses.


320-361: The filter_plants_with_unexpected_ct function provides a mechanism to filter primary and lateral roots based on an expected count, which is crucial for ensuring data consistency. This function adds robustness to the preprocessing steps.


537-596: The plot_root_associations function introduces visualization capabilities for root associations, enhancing the interpretability of the analysis. The use of matplotlib for plotting and the careful consideration of plot aesthetics, such as the color map and axis inversion, are commendable.

sleap_roots/trait_pipelines.py (8)

3-3: The addition of import json is appropriate for JSON serialization tasks introduced in this update.


118-135: The NumpyArrayEncoder class is well-implemented for custom JSON serialization of NumPy arrays and np.int64 types. It correctly falls back to the base class method for other types.


115-140: > 📝 NOTE

This review was outside the diff hunks and was mapped to the diff hunk with the greatest overlap. Original lines [137-211]

The TraitDef class is well-defined, providing a clear structure for trait computation. It includes comprehensive documentation and a flexible design to accommodate various trait computations. Ensure that all functions referenced in fn are implemented and tested.


383-625: > 📝 NOTE

This review was outside the diff hunks and was mapped to the diff hunk with the greatest overlap. Original lines [213-699]

The Pipeline class and its subclasses (DicotPipeline, MultipleDicotPipeline) are central to the trait computation logic. They are well-structured and include methods for defining traits, computing traits for frames, and handling multiple plants. However, there are a few areas that could be improved for clarity and efficiency:

  1. In the compute_multiple_dicots_traits method, consider handling exceptions more gracefully when writing JSON or CSV files. Instead of printing errors directly, it might be beneficial to log these errors or raise exceptions to be handled by the caller.
  2. The method get_initial_frame_traits in MultipleDicotPipeline class uses expected_plant_ct which is not defined within the provided context. Ensure that this attribute is correctly initialized and used within the Series class.
  3. The use of print statements for logging (e.g., print(f"Processing series '{series.series_name}'")) is not ideal for production code. Consider using a logging framework that allows for different logging levels and better control over the output.

386-496: The compute_multiple_dicots_traits method is comprehensive and covers various aspects of trait computation for multiple dicots. However, consider adding error handling for file operations and replacing print statements with logging for better control and flexibility in output management.


498-605: The compute_multiple_dicots_traits_for_groups method aggregates traits over groups of samples efficiently. Similar to previous comments, consider enhancing error handling and replacing print statements with a more robust logging approach.


655-698: The compute_batch_multiple_dicots_traits method effectively summarizes traits for a batch of series with multiple dicots. Ensure that the CSV writing process is robust against potential errors, and consider using logging instead of print statements.


2106-2198: The MultipleDicotPipeline class introduces a specialized pipeline for computing traits for multiple dicot plants. The design and implementation are consistent with the rest of the module. Ensure that all referenced methods (e.g., filter_roots_with_nans, filter_plants_with_unexpected_ct) are implemented and tested for correctness.

Comment on lines +414 to +496
def associate_lateral_to_primary(
primary_pts: np.ndarray, lateral_pts: np.ndarray
) -> dict:
"""Associates each lateral root with the closest primary root.

Args:
primary_pts: A numpy array of primary root points with shape
(instances, nodes, 2), where 'instances' is the number of primary roots,
'nodes' is the number of points in each root, and '2' corresponds to the x and y
coordinates. Points cannot have NaN values.
lateral_pts: A numpy array of lateral root points with a shape similar
to primary_pts, representing the lateral roots. Points cannot have NaN values.

Returns:
dict: A dictionary where each key is an index of a primary root (from the primary_pts
array) and each value is a dictionary containing 'primary_points' as the points of
the primary root (1, nodes, 2) and 'lateral_points' as an array of
lateral root points that are closest to that primary root. The shape of
'lateral_points' is (instances, nodes, 2), where instances is the number of
lateral roots associated with the primary root.
"""
# Basic input validation
if not isinstance(primary_pts, np.ndarray) or not isinstance(
lateral_pts, np.ndarray
):
raise ValueError("Both primary_pts and lateral_pts must be numpy arrays.")
if len(primary_pts.shape) != 3 or len(lateral_pts.shape) != 3:
raise ValueError("Input arrays must have a shape of (instances, nodes, 2).")
if primary_pts.shape[2] != 2 or lateral_pts.shape[2] != 2:
raise ValueError(
"The last dimension of input arrays must be 2, representing x and y coordinates."
)

plant_associations = {}

# Initialize plant associations dictionary
for i, primary_root in enumerate(primary_pts):
if not is_line_valid(primary_root):
continue # Skip primary roots containing NaN values
plant_associations[i] = {
"primary_points": primary_root,
"lateral_points": [],
}

# Associate each lateral root with the closest primary root
for lateral_root in lateral_pts:
if not is_line_valid(lateral_root):
continue # Skip lateral roots containing NaN values

lateral_line = LineString(lateral_root)
min_distance = float("inf")
closest_primary_index = None

for primary_index, primary_data in plant_associations.items():
primary_root = primary_data["primary_points"]
try:
primary_line = LineString(primary_root)
distance = primary_line.distance(lateral_line)
except Exception as e:
print(f"Error computing distance: {e}")
continue

if distance < min_distance:
min_distance = distance
closest_primary_index = primary_index

if closest_primary_index is not None:
plant_associations[closest_primary_index]["lateral_points"].append(
lateral_root
)

# Convert lateral points lists into arrays
for primary_index, data in plant_associations.items():
lateral_points_list = data["lateral_points"]
if lateral_points_list: # Check if there are any lateral points to convert
lateral_points_array = np.array(lateral_points_list)
plant_associations[primary_index]["lateral_points"] = lateral_points_array
else:
# Create an array of NaNs if there are no lateral points
shape = (1, lateral_pts.shape[1], 2) # Shape of lateral points array
plant_associations[primary_index]["lateral_points"] = np.full(shape, np.nan)

return plant_associations
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The associate_lateral_to_primary function effectively associates lateral roots with the closest primary root. It uses LineString for representing roots and nearest_points for finding the closest points, which is a robust approach for spatial analysis. However, consider catching specific exceptions instead of a general Exception to improve error handling specificity.

-            except Exception as e:
+            except ValueError as e:  # Adjust to the specific exception type if different

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
def associate_lateral_to_primary(
primary_pts: np.ndarray, lateral_pts: np.ndarray
) -> dict:
"""Associates each lateral root with the closest primary root.
Args:
primary_pts: A numpy array of primary root points with shape
(instances, nodes, 2), where 'instances' is the number of primary roots,
'nodes' is the number of points in each root, and '2' corresponds to the x and y
coordinates. Points cannot have NaN values.
lateral_pts: A numpy array of lateral root points with a shape similar
to primary_pts, representing the lateral roots. Points cannot have NaN values.
Returns:
dict: A dictionary where each key is an index of a primary root (from the primary_pts
array) and each value is a dictionary containing 'primary_points' as the points of
the primary root (1, nodes, 2) and 'lateral_points' as an array of
lateral root points that are closest to that primary root. The shape of
'lateral_points' is (instances, nodes, 2), where instances is the number of
lateral roots associated with the primary root.
"""
# Basic input validation
if not isinstance(primary_pts, np.ndarray) or not isinstance(
lateral_pts, np.ndarray
):
raise ValueError("Both primary_pts and lateral_pts must be numpy arrays.")
if len(primary_pts.shape) != 3 or len(lateral_pts.shape) != 3:
raise ValueError("Input arrays must have a shape of (instances, nodes, 2).")
if primary_pts.shape[2] != 2 or lateral_pts.shape[2] != 2:
raise ValueError(
"The last dimension of input arrays must be 2, representing x and y coordinates."
)
plant_associations = {}
# Initialize plant associations dictionary
for i, primary_root in enumerate(primary_pts):
if not is_line_valid(primary_root):
continue # Skip primary roots containing NaN values
plant_associations[i] = {
"primary_points": primary_root,
"lateral_points": [],
}
# Associate each lateral root with the closest primary root
for lateral_root in lateral_pts:
if not is_line_valid(lateral_root):
continue # Skip lateral roots containing NaN values
lateral_line = LineString(lateral_root)
min_distance = float("inf")
closest_primary_index = None
for primary_index, primary_data in plant_associations.items():
primary_root = primary_data["primary_points"]
try:
primary_line = LineString(primary_root)
distance = primary_line.distance(lateral_line)
except Exception as e:
print(f"Error computing distance: {e}")
continue
if distance < min_distance:
min_distance = distance
closest_primary_index = primary_index
if closest_primary_index is not None:
plant_associations[closest_primary_index]["lateral_points"].append(
lateral_root
)
# Convert lateral points lists into arrays
for primary_index, data in plant_associations.items():
lateral_points_list = data["lateral_points"]
if lateral_points_list: # Check if there are any lateral points to convert
lateral_points_array = np.array(lateral_points_list)
plant_associations[primary_index]["lateral_points"] = lateral_points_array
else:
# Create an array of NaNs if there are no lateral points
shape = (1, lateral_pts.shape[1], 2) # Shape of lateral points array
plant_associations[primary_index]["lateral_points"] = np.full(shape, np.nan)
return plant_associations
except ValueError as e: # Adjust to the specific exception type if different

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Review Status

Configuration used: CodeRabbit UI

Commits Files that changed from the base of the PR and between 57071f6 and bc12818.
Files selected for processing (5)
  • MultiDicotPipeline.ipynb (1 hunks)
  • sleap_roots/bases.py (2 hunks)
  • sleap_roots/trait_pipelines.py (7 hunks)
  • tests/test_bases.py (2 hunks)
  • tests/test_trait_pipelines.py (2 hunks)
Files skipped from review as they are similar to previous changes (2)
  • sleap_roots/bases.py
  • tests/test_trait_pipelines.py
Additional Context Used
GitHub Check Runs (1)
codecov/patch success (30)

sleap_roots/trait_pipelines.py: [warning] 130-133: sleap_roots/trait_pipelines.py#L130-L133
Added lines #L130 - L133 were not covered by tests


sleap_roots/trait_pipelines.py: [warning] 135-135: sleap_roots/trait_pipelines.py#L135
Added line #L135 was not covered by tests


sleap_roots/trait_pipelines.py: [warning] 416-416: sleap_roots/trait_pipelines.py#L416
Added line #L416 was not covered by tests


sleap_roots/trait_pipelines.py: [warning] 418-418: sleap_roots/trait_pipelines.py#L418
Added line #L418 was not covered by tests


sleap_roots/trait_pipelines.py: [warning] 467-470: sleap_roots/trait_pipelines.py#L467-L470
Added lines #L467 - L470 were not covered by tests

Additional comments not posted (11)
tests/test_bases.py (4)

379-384: The use of np.nan in the expected results for test_get_root_widths with return_inds=True is correct for handling cases where the primary or lateral roots are not present or cannot be calculated due to missing data. However, it's important to ensure that the handling of np.nan values is consistent across all related functions and that any downstream processing of these results properly accounts for np.nan values to avoid unexpected behavior.


429-439: The test case for get_root_widths with minimum length inputs correctly asserts that the result should be np.nan. This is a good practice for handling edge cases where the input data does not meet the minimum requirements for meaningful computation. It's crucial to document these edge cases and the expected behavior in both the test and the function being tested to ensure clarity for future maintenance.


390-395: In the parameterized test for get_root_widths, handling of np.nan values is consistent with the expected behavior when primary or lateral roots are missing or invalid. This approach ensures robustness in the face of incomplete or corrupted data. It would be beneficial to include a brief comment explaining the rationale behind each test case scenario for future reference and clarity.


429-446: The handling of invalid cases in test_get_root_widths_invalid_cases through explicit ValueError exceptions is a good practice for early detection of incorrect usage patterns. It's important to ensure that these error messages are descriptive enough to aid in debugging and that similar validation is performed in the main function to prevent misuse.

MultiDicotPipeline.ipynb (4)

26-29: The paths for CSV and folder are hardcoded, which might limit the notebook's usability in different environments or datasets. Consider using a configuration file or environment variables to make these paths configurable. This approach enhances the notebook's flexibility and reusability across different datasets or project structures.


117-118: Initializing the MultipleDicotPipeline without any configuration parameters is straightforward for this example. However, if the pipeline supports customization (e.g., different processing strategies, parameters), it would be beneficial to demonstrate how to configure these options or link to documentation for further details. This information aids users in understanding how to adapt the pipeline to their specific needs.


136-137: When computing traits for the first sample and writing the results to JSON and CSV, it's good practice to provide feedback to the user about the operation's success and the location of the output files. This feedback can be particularly helpful in a notebook environment where users might not immediately notice file generation. Consider adding print statements or logging to inform the user.


668-669: The function compute_batch_multiple_dicots_traits_for_groups is called with write_json=True and write_csv=True, which is consistent with the intent to save the computed traits. However, it's unclear where these files are saved. Providing explicit paths or confirming the output directory would improve usability and clarity for users trying to locate the generated files.

sleap_roots/trait_pipelines.py (3)

3-6: The added imports json and from typing import Union are correctly used in the context of the new functionality introduced in this file.


137-137: The TraitDef class is well-defined and documented, providing a clear structure for trait computation. Good use of attrs for concise attribute definitions.


137-137: The Pipeline class is correctly implemented with smart use of networkx for determining the computation order of traits. Consider adding more detailed documentation for methods involving complex logic, such as get_computation_order, to enhance maintainability.

except Exception as e:
raise IOError(f"Failed to write computed traits to CSV: {e}")

return all_group_summaries_df
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The DicotPipeline class is well-implemented, providing specific trait computations for dicot plants. Given the complexity of these computations, ensure that unit tests cover this new functionality comprehensively.

+ # TODO: Add comprehensive unit tests for DicotPipeline trait computations

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
return all_group_summaries_df
# TODO: Add comprehensive unit tests for DicotPipeline trait computations
return all_group_summaries_df

@eberrigan
Copy link
Collaborator Author

  • The added functions and pipelines have been tested with real data in jupyter notebooks
  • More tests with dummy data as fixtures could be added in the future
  • Functions for aggregating traits per genotype before getting the summary statistics have been added for the MultipleDicotPipeline

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Review Status

Configuration used: CodeRabbit UI

Commits Files that changed from the base of the PR and between bc12818 and 40fc672.
Files selected for processing (1)
  • sleap_roots/trait_pipelines.py (7 hunks)
Additional Context Used
GitHub Check Runs (1)
codecov/patch success (30)

sleap_roots/trait_pipelines.py: [warning] 130-133: sleap_roots/trait_pipelines.py#L130-L133
Added lines #L130 - L133 were not covered by tests


sleap_roots/trait_pipelines.py: [warning] 135-135: sleap_roots/trait_pipelines.py#L135
Added line #L135 was not covered by tests


sleap_roots/trait_pipelines.py: [warning] 416-416: sleap_roots/trait_pipelines.py#L416
Added line #L416 was not covered by tests


sleap_roots/trait_pipelines.py: [warning] 418-418: sleap_roots/trait_pipelines.py#L418
Added line #L418 was not covered by tests


sleap_roots/trait_pipelines.py: [warning] 467-470: sleap_roots/trait_pipelines.py#L467-L470
Added lines #L467 - L470 were not covered by tests

Comment on lines +118 to +135
class NumpyArrayEncoder(json.JSONEncoder):
"""Custom encoder for NumPy array types."""

def default(self, obj):
"""Serialize NumPy arrays to lists.

Args:
obj: The object to serialize.

Returns:
A list representation of the NumPy array.
"""
if isinstance(obj, np.ndarray):
return obj.tolist()
elif isinstance(obj, np.int64):
return int(obj)
# Let the base class default method raise the TypeError
return json.JSONEncoder.default(self, obj)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The NumpyArrayEncoder class is correctly implemented to handle NumPy types during JSON serialization. However, some lines in this class were not covered by tests. It's important to ensure comprehensive test coverage for this functionality.

+ # TODO: Add tests to cover the NumpyArrayEncoder functionality

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
class NumpyArrayEncoder(json.JSONEncoder):
"""Custom encoder for NumPy array types."""
def default(self, obj):
"""Serialize NumPy arrays to lists.
Args:
obj: The object to serialize.
Returns:
A list representation of the NumPy array.
"""
if isinstance(obj, np.ndarray):
return obj.tolist()
elif isinstance(obj, np.int64):
return int(obj)
# Let the base class default method raise the TypeError
return json.JSONEncoder.default(self, obj)
class NumpyArrayEncoder(json.JSONEncoder):
"""Custom encoder for NumPy array types."""
# TODO: Add tests to cover the NumpyArrayEncoder functionality
def default(self, obj):
"""Serialize NumPy arrays to lists.
Args:
obj: The object to serialize.
Returns:
A list representation of the NumPy array.
"""
if isinstance(obj, np.ndarray):
return obj.tolist()
elif isinstance(obj, np.int64):
return int(obj)
# Let the base class default method raise the TypeError
return json.JSONEncoder.default(self, obj)

Comment on lines +386 to +496
series: The Series object containing the primary and lateral root points.
write_json: Whether to write the aggregated traits to a JSON file. Default is False.
json_suffix: The suffix to append to the JSON file name. Default is ".all_frames_traits.json".
write_csv: Whether to write the summary statistics to a CSV file. Default is False.
csv_suffix: The suffix to append to the CSV file name. Default is ".all_frames_summary.csv".

Returns:
A dictionary containing the series name, group, aggregated traits, and summary statistics.
"""
# Initialize the return structure with the series name and group
result = {
"series": str(series.series_name),
"group": str(series.group),
"traits": {},
"summary_stats": {},
}

# Check if the series has frames to process
if len(series) == 0:
print(f"Series '{series.series_name}' contains no frames to process.")
# Return early with the initialized structure
return result

# Initialize a separate dictionary to hold the aggregated traits across all frames
aggregated_traits = {}

# Iterate over frames in series
for frame in range(len(series)):
# Get initial points and number of plants per frame
initial_frame_traits = self.get_initial_frame_traits(series, frame)
# Compute initial associations and perform filter operations
frame_traits = self.compute_frame_traits(initial_frame_traits)

# Instantiate DicotPipeline
dicot_pipeline = DicotPipeline()

# Extract the plant associations for this frame
associations = frame_traits["plant_associations_dict"]

for primary_idx, assoc in associations.items():
primary_pts = assoc["primary_points"]
lateral_pts = assoc["lateral_points"]
# Get the initial frame traits for this plant using the primary and lateral points
initial_frame_traits = {
"primary_pts": primary_pts,
"lateral_pts": lateral_pts,
}
# Use the dicot pipeline to compute the plant traits on this frame
plant_traits = dicot_pipeline.compute_frame_traits(initial_frame_traits)

# For each plant's traits in the frame
for trait_name, trait_value in plant_traits.items():
# Not all traits are added to the aggregated traits dictionary
if trait_name in dicot_pipeline.csv_traits_multiple_plants:
if trait_name not in aggregated_traits:
# Initialize the trait array if it's the first frame
aggregated_traits[trait_name] = [np.atleast_1d(trait_value)]
else:
# Append new trait values for subsequent frames
aggregated_traits[trait_name].append(
np.atleast_1d(trait_value)
)

# After processing, update the result dictionary with computed traits
for trait, arrays in aggregated_traits.items():
aggregated_traits[trait] = np.concatenate(arrays, axis=0)
result["traits"] = aggregated_traits

# Write to JSON if requested
if write_json:
json_name = f"{series.series_name}{json_suffix}"
try:
with open(json_name, "w") as f:
json.dump(
result, f, cls=NumpyArrayEncoder, ensure_ascii=False, indent=4
)
print(f"Aggregated traits saved to {json_name}")
except IOError as e:
print(f"Error writing JSON file '{json_name}': {e}")

# Compute summary statistics and update result
summary_stats = {}
for trait_name, trait_values in aggregated_traits.items():
trait_stats = get_summary(trait_values, prefix=f"{trait_name}_")
summary_stats.update(trait_stats)
result["summary_stats"] = summary_stats

# Optionally write summary stats to CSV
if write_csv:
csv_name = f"{series.series_name}{csv_suffix}"
try:
summary_df = pd.DataFrame([summary_stats])
summary_df.insert(0, "series", series.series_name)
summary_df.to_csv(csv_name, index=False)
print(f"Summary statistics saved to {csv_name}")
except IOError as e:
print(f"Failed to write CSV file '{csv_name}': {e}")

# Return the final result structure
return result
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The compute_multiple_dicots_traits method is well-implemented, but some lines were not covered by tests. Ensuring comprehensive test coverage is crucial for verifying the correctness of this functionality.

+ # TODO: Add tests to cover the compute_multiple_dicots_traits functionality

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
def compute_multiple_dicots_traits(
self,
series: Series,
write_json: bool = False,
json_suffix: str = ".all_frames_traits.json",
write_csv: bool = False,
csv_suffix: str = ".all_frames_summary.csv",
):
"""Computes plant traits for pipelines with multiple plants over all frames in a series.
Args:
series: The Series object containing the primary and lateral root points.
write_json: Whether to write the aggregated traits to a JSON file. Default is False.
json_suffix: The suffix to append to the JSON file name. Default is ".all_frames_traits.json".
write_csv: Whether to write the summary statistics to a CSV file. Default is False.
csv_suffix: The suffix to append to the CSV file name. Default is ".all_frames_summary.csv".
Returns:
A dictionary containing the series name, group, aggregated traits, and summary statistics.
"""
# Initialize the return structure with the series name and group
result = {
"series": str(series.series_name),
"group": str(series.group),
"traits": {},
"summary_stats": {},
}
# Check if the series has frames to process
if len(series) == 0:
print(f"Series '{series.series_name}' contains no frames to process.")
# Return early with the initialized structure
return result
# Initialize a separate dictionary to hold the aggregated traits across all frames
aggregated_traits = {}
# Iterate over frames in series
for frame in range(len(series)):
# Get initial points and number of plants per frame
initial_frame_traits = self.get_initial_frame_traits(series, frame)
# Compute initial associations and perform filter operations
frame_traits = self.compute_frame_traits(initial_frame_traits)
# Instantiate DicotPipeline
dicot_pipeline = DicotPipeline()
# Extract the plant associations for this frame
associations = frame_traits["plant_associations_dict"]
for primary_idx, assoc in associations.items():
primary_pts = assoc["primary_points"]
lateral_pts = assoc["lateral_points"]
# Get the initial frame traits for this plant using the primary and lateral points
initial_frame_traits = {
"primary_pts": primary_pts,
"lateral_pts": lateral_pts,
}
# Use the dicot pipeline to compute the plant traits on this frame
plant_traits = dicot_pipeline.compute_frame_traits(initial_frame_traits)
# For each plant's traits in the frame
for trait_name, trait_value in plant_traits.items():
# Not all traits are added to the aggregated traits dictionary
if trait_name in dicot_pipeline.csv_traits_multiple_plants:
if trait_name not in aggregated_traits:
# Initialize the trait array if it's the first frame
aggregated_traits[trait_name] = [np.atleast_1d(trait_value)]
else:
# Append new trait values for subsequent frames
aggregated_traits[trait_name].append(
np.atleast_1d(trait_value)
)
# After processing, update the result dictionary with computed traits
for trait, arrays in aggregated_traits.items():
aggregated_traits[trait] = np.concatenate(arrays, axis=0)
result["traits"] = aggregated_traits
# Write to JSON if requested
if write_json:
json_name = f"{series.series_name}{json_suffix}"
try:
with open(json_name, "w") as f:
json.dump(
result, f, cls=NumpyArrayEncoder, ensure_ascii=False, indent=4
)
print(f"Aggregated traits saved to {json_name}")
except IOError as e:
print(f"Error writing JSON file '{json_name}': {e}")
# Compute summary statistics and update result
summary_stats = {}
for trait_name, trait_values in aggregated_traits.items():
trait_stats = get_summary(trait_values, prefix=f"{trait_name}_")
summary_stats.update(trait_stats)
result["summary_stats"] = summary_stats
# Optionally write summary stats to CSV
if write_csv:
csv_name = f"{series.series_name}{csv_suffix}"
try:
summary_df = pd.DataFrame([summary_stats])
summary_df.insert(0, "series", series.series_name)
summary_df.to_csv(csv_name, index=False)
print(f"Summary statistics saved to {csv_name}")
except IOError as e:
print(f"Failed to write CSV file '{csv_name}': {e}")
# Return the final result structure
return result
def compute_multiple_dicots_traits(
self,
series: Series,
write_json: bool = False,
json_suffix: str = ".all_frames_traits.json",
write_csv: bool = False,
csv_suffix: str = ".all_frames_summary.csv",
):
"""Computes plant traits for pipelines with multiple plants over all frames in a series.
Args:
series: The Series object containing the primary and lateral root points.
write_json: Whether to write the aggregated traits to a JSON file. Default is False.
json_suffix: The suffix to append to the JSON file name. Default is ".all_frames_traits.json".
write_csv: Whether to write the summary statistics to a CSV file. Default is False.
csv_suffix: The suffix to append to the CSV file name. Default is ".all_frames_summary.csv".
Returns:
A dictionary containing the series name, group, aggregated traits, and summary statistics.
"""
# Initialize the return structure with the series name and group
result = {
"series": str(series.series_name),
"group": str(series.group),
"traits": {},
"summary_stats": {},
}
# Check if the series has frames to process
if len(series) == 0:
print(f"Series '{series.series_name}' contains no frames to process.")
# Return early with the initialized structure
return result
# Initialize a separate dictionary to hold the aggregated traits across all frames
aggregated_traits = {}
# Iterate over frames in series
for frame in range(len(series)):
# Get initial points and number of plants per frame
initial_frame_traits = self.get_initial_frame_traits(series, frame)
# Compute initial associations and perform filter operations
frame_traits = self.compute_frame_traits(initial_frame_traits)
# Instantiate DicotPipeline
dicot_pipeline = DicotPipeline()
# Extract the plant associations for this frame
associations = frame_traits["plant_associations_dict"]
for primary_idx, assoc in associations.items():
primary_pts = assoc["primary_points"]
lateral_pts = assoc["lateral_points"]
# Get the initial frame traits for this plant using the primary and lateral points
initial_frame_traits = {
"primary_pts": primary_pts,
"lateral_pts": lateral_pts,
}
# Use the dicot pipeline to compute the plant traits on this frame
plant_traits = dicot_pipeline.compute_frame_traits(initial_frame_traits)
# For each plant's traits in the frame
for trait_name, trait_value in plant_traits.items():
# Not all traits are added to the aggregated traits dictionary
if trait_name in dicot_pipeline.csv_traits_multiple_plants:
if trait_name not in aggregated_traits:
# Initialize the trait array if it's the first frame
aggregated_traits[trait_name] = [np.atleast_1d(trait_value)]
else:
# Append new trait values for subsequent frames
aggregated_traits[trait_name].append(
np.atleast_1d(trait_value)
)
# After processing, update the result dictionary with computed traits
for trait, arrays in aggregated_traits.items():
aggregated_traits[trait] = np.concatenate(arrays, axis=0)
result["traits"] = aggregated_traits
# Write to JSON if requested
if write_json:
json_name = f"{series.series_name}{json_suffix}"
try:
with open(json_name, "w") as f:
json.dump(
result, f, cls=NumpyArrayEncoder, ensure_ascii=False, indent=4
)
print(f"Aggregated traits saved to {json_name}")
except IOError as e:
print(f"Error writing JSON file '{json_name}': {e}")
# Compute summary statistics and update result
summary_stats = {}
for trait_name, trait_values in aggregated_traits.items():
trait_stats = get_summary(trait_values, prefix=f"{trait_name}_")
summary_stats.update(trait_stats)
result["summary_stats"] = summary_stats
# Optionally write summary stats to CSV
if write_csv:
csv_name = f"{series.series_name}{csv_suffix}"
try:
summary_df = pd.DataFrame([summary_stats])
summary_df.insert(0, "series", series.series_name)
summary_df.to_csv(csv_name, index=False)
print(f"Summary statistics saved to {csv_name}")
except IOError as e:
print(f"Failed to write CSV file '{csv_name}': {e}")
# Return the final result structure
return result
# TODO: Add tests to cover the compute_multiple_dicots_traits functionality

Comment on lines +2197 to +2289
@attrs.define
class MultipleDicotPipeline(Pipeline):
"""Pipeline for computing traits for multiple dicot plants."""

def define_traits(self) -> List[TraitDef]:
"""Define the trait computation pipeline for primary roots."""
trait_definitions = [
TraitDef(
name="primary_pts_no_nans",
fn=filter_roots_with_nans,
input_traits=["primary_pts"],
scalar=False,
include_in_csv=False,
kwargs={},
description="Primary roots without any NaNs.",
),
TraitDef(
name="lateral_pts_no_nans",
fn=filter_roots_with_nans,
input_traits=["lateral_pts"],
scalar=False,
include_in_csv=False,
kwargs={},
description="Lateral roots without any NaNs.",
),
TraitDef(
name="filtered_pts_expected_plant_ct",
fn=filter_plants_with_unexpected_ct,
input_traits=[
"primary_pts_no_nans",
"lateral_pts_no_nans",
"expected_plant_ct",
],
scalar=False,
include_in_csv=False,
kwargs={},
description="Tuple of filtered points with expected plant count.",
),
TraitDef(
name="primary_pts_expected_plant_ct",
fn=get_filtered_primary_pts,
input_traits=["filtered_pts_expected_plant_ct"],
scalar=False,
include_in_csv=False,
kwargs={},
description="Filtered primary root points with expected plant count.",
),
TraitDef(
name="lateral_pts_expected_plant_ct",
fn=get_filtered_lateral_pts,
input_traits=["filtered_pts_expected_plant_ct"],
scalar=False,
include_in_csv=False,
kwargs={},
description="Filtered lateral root points with expected plant count.",
),
TraitDef(
name="plant_associations_dict",
fn=associate_lateral_to_primary,
input_traits=[
"primary_pts_expected_plant_ct",
"lateral_pts_expected_plant_ct",
],
scalar=False,
include_in_csv=False,
kwargs={},
description="Dictionary of plant associations.",
),
]

return trait_definitions

def get_initial_frame_traits(self, plant: Series, frame_idx: int) -> Dict[str, Any]:
"""Return initial traits for a plant frame.

Args:
plant: The plant `Series` object.
frame_idx: The index of the current frame.

Returns:
A dictionary of initial traits with keys:
- "primary_pts": Array of primary root points.
- "lateral_pts": Array of lateral root points.
- "expected_ct": Expected number of plants as a float.
"""
primary_pts = plant.get_primary_points(frame_idx)
lateral_pts = plant.get_lateral_points(frame_idx)
expected_plant_ct = plant.expected_count
return {
"primary_pts": primary_pts,
"lateral_pts": lateral_pts,
"expected_plant_ct": expected_plant_ct,
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The MultipleDicotPipeline class introduces important functionality for handling multiple dicot plants, including filtering and association logic. Ensure that this functionality is thoroughly tested, given its critical role in the pipeline.

+ # TODO: Ensure comprehensive testing for filtering and association logic in MultipleDicotPipeline

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
@attrs.define
class MultipleDicotPipeline(Pipeline):
"""Pipeline for computing traits for multiple dicot plants."""
def define_traits(self) -> List[TraitDef]:
"""Define the trait computation pipeline for primary roots."""
trait_definitions = [
TraitDef(
name="primary_pts_no_nans",
fn=filter_roots_with_nans,
input_traits=["primary_pts"],
scalar=False,
include_in_csv=False,
kwargs={},
description="Primary roots without any NaNs.",
),
TraitDef(
name="lateral_pts_no_nans",
fn=filter_roots_with_nans,
input_traits=["lateral_pts"],
scalar=False,
include_in_csv=False,
kwargs={},
description="Lateral roots without any NaNs.",
),
TraitDef(
name="filtered_pts_expected_plant_ct",
fn=filter_plants_with_unexpected_ct,
input_traits=[
"primary_pts_no_nans",
"lateral_pts_no_nans",
"expected_plant_ct",
],
scalar=False,
include_in_csv=False,
kwargs={},
description="Tuple of filtered points with expected plant count.",
),
TraitDef(
name="primary_pts_expected_plant_ct",
fn=get_filtered_primary_pts,
input_traits=["filtered_pts_expected_plant_ct"],
scalar=False,
include_in_csv=False,
kwargs={},
description="Filtered primary root points with expected plant count.",
),
TraitDef(
name="lateral_pts_expected_plant_ct",
fn=get_filtered_lateral_pts,
input_traits=["filtered_pts_expected_plant_ct"],
scalar=False,
include_in_csv=False,
kwargs={},
description="Filtered lateral root points with expected plant count.",
),
TraitDef(
name="plant_associations_dict",
fn=associate_lateral_to_primary,
input_traits=[
"primary_pts_expected_plant_ct",
"lateral_pts_expected_plant_ct",
],
scalar=False,
include_in_csv=False,
kwargs={},
description="Dictionary of plant associations.",
),
]
return trait_definitions
def get_initial_frame_traits(self, plant: Series, frame_idx: int) -> Dict[str, Any]:
"""Return initial traits for a plant frame.
Args:
plant: The plant `Series` object.
frame_idx: The index of the current frame.
Returns:
A dictionary of initial traits with keys:
- "primary_pts": Array of primary root points.
- "lateral_pts": Array of lateral root points.
- "expected_ct": Expected number of plants as a float.
"""
primary_pts = plant.get_primary_points(frame_idx)
lateral_pts = plant.get_lateral_points(frame_idx)
expected_plant_ct = plant.expected_count
return {
"primary_pts": primary_pts,
"lateral_pts": lateral_pts,
"expected_plant_ct": expected_plant_ct,
}
@attrs.define
class MultipleDicotPipeline(Pipeline):
"""Pipeline for computing traits for multiple dicot plants."""
# TODO: Ensure comprehensive testing for filtering and association logic in MultipleDicotPipeline
def define_traits(self) -> List[TraitDef]:
"""Define the trait computation pipeline for primary roots."""
trait_definitions = [
TraitDef(
name="primary_pts_no_nans",
fn=filter_roots_with_nans,
input_traits=["primary_pts"],
scalar=False,
include_in_csv=False,
kwargs={},
description="Primary roots without any NaNs.",
),
TraitDef(
name="lateral_pts_no_nans",
fn=filter_roots_with_nans,
input_traits=["lateral_pts"],
scalar=False,
include_in_csv=False,
kwargs={},
description="Lateral roots without any NaNs.",
),
TraitDef(
name="filtered_pts_expected_plant_ct",
fn=filter_plants_with_unexpected_ct,
input_traits=[
"primary_pts_no_nans",
"lateral_pts_no_nans",
"expected_plant_ct",
],
scalar=False,
include_in_csv=False,
kwargs={},
description="Tuple of filtered points with expected plant count.",
),
TraitDef(
name="primary_pts_expected_plant_ct",
fn=get_filtered_primary_pts,
input_traits=["filtered_pts_expected_plant_ct"],
scalar=False,
include_in_csv=False,
kwargs={},
description="Filtered primary root points with expected plant count.",
),
TraitDef(
name="lateral_pts_expected_plant_ct",
fn=get_filtered_lateral_pts,
input_traits=["filtered_pts_expected_plant_ct"],
scalar=False,
include_in_csv=False,
kwargs={},
description="Filtered lateral root points with expected plant count.",
),
TraitDef(
name="plant_associations_dict",
fn=associate_lateral_to_primary,
input_traits=[
"primary_pts_expected_plant_ct",
"lateral_pts_expected_plant_ct",
],
scalar=False,
include_in_csv=False,
kwargs={},
description="Dictionary of plant associations.",
),
]
return trait_definitions
def get_initial_frame_traits(self, plant: Series, frame_idx: int) -> Dict[str, Any]:
"""Return initial traits for a plant frame.
Args:
plant: The plant `Series` object.
frame_idx: The index of the current frame.
Returns:
A dictionary of initial traits with keys:
- "primary_pts": Array of primary root points.
- "lateral_pts": Array of lateral root points.
- "expected_ct": Expected number of plants as a float.
"""
primary_pts = plant.get_primary_points(frame_idx)
lateral_pts = plant.get_lateral_points(frame_idx)
expected_plant_ct = plant.expected_count
return {
"primary_pts": primary_pts,
"lateral_pts": lateral_pts,
"expected_plant_ct": expected_plant_ct,
}

@talmo talmo merged commit cae212d into main Mar 31, 2024
5 checks passed
@talmo talmo deleted the elizabeth/add-multiple-dicot-pipeline-with-commit-history-from-main branch March 31, 2024 19:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants