From 8aaf5f2a5b12f2eb16959a0bbabafb11d0a22f0c Mon Sep 17 00:00:00 2001 From: almostmeenal Date: Mon, 7 Jun 2021 20:05:27 +0530 Subject: [PATCH 1/6] extend bug fix --- arviz/data/inference_data.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/arviz/data/inference_data.py b/arviz/data/inference_data.py index 98e5531d94..573255c80d 100644 --- a/arviz/data/inference_data.py +++ b/arviz/data/inference_data.py @@ -1394,10 +1394,11 @@ def extend(self, other, join="left"): ) dataset = getattr(other, group) setattr(self, group, dataset) - if group.startswith(WARMUP_TAG): - self._groups_warmup.append(group) - else: - self._groups.append(group) + if not hasattr(self, group): + if group.startswith(WARMUP_TAG): + self._groups_warmup.append(group) + else: + self._groups.append(group) set_index = _extend_xr_method(xr.Dataset.set_index, see_also="reset_index") get_index = _extend_xr_method(xr.Dataset.get_index) From 26b9c7547e241b491c71f67790e00ff2e0f95cd1 Mon Sep 17 00:00:00 2001 From: Meenal Jhajharia Date: Mon, 13 Sep 2021 18:21:24 +0530 Subject: [PATCH 2/6] Update inference_data.py --- arviz/data/inference_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arviz/data/inference_data.py b/arviz/data/inference_data.py index 573255c80d..20ea33ebe5 100644 --- a/arviz/data/inference_data.py +++ b/arviz/data/inference_data.py @@ -1394,7 +1394,7 @@ def extend(self, other, join="left"): ) dataset = getattr(other, group) setattr(self, group, dataset) - if not hasattr(self, group): + if group not in self._groups and self._groups_warmup: if group.startswith(WARMUP_TAG): self._groups_warmup.append(group) else: From 66527af684eee86fbf762b6caed5967a93f6b24c Mon Sep 17 00:00:00 2001 From: Ari Hartikainen Date: Sun, 16 Jan 2022 11:11:13 +0200 Subject: [PATCH 3/6] Update logic --- arviz/data/inference_data.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/arviz/data/inference_data.py b/arviz/data/inference_data.py index 20ea33ebe5..625e31a8a3 100644 --- a/arviz/data/inference_data.py +++ b/arviz/data/inference_data.py @@ -1394,10 +1394,11 @@ def extend(self, other, join="left"): ) dataset = getattr(other, group) setattr(self, group, dataset) - if group not in self._groups and self._groups_warmup: - if group.startswith(WARMUP_TAG): + if group.startswith(WARMUP_TAG): + if group not in self._groups_warmup: self._groups_warmup.append(group) - else: + else: + if group not in self._groups: self._groups.append(group) set_index = _extend_xr_method(xr.Dataset.set_index, see_also="reset_index") From 3defae7d5a524041963f0248f0fe2f3e0bbab0ac Mon Sep 17 00:00:00 2001 From: Ari Hartikainen Date: Sun, 16 Jan 2022 11:18:28 +0200 Subject: [PATCH 4/6] Add logic to keep the correct order --- arviz/data/inference_data.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/arviz/data/inference_data.py b/arviz/data/inference_data.py index 625e31a8a3..650c8960d2 100644 --- a/arviz/data/inference_data.py +++ b/arviz/data/inference_data.py @@ -1396,10 +1396,22 @@ def extend(self, other, join="left"): setattr(self, group, dataset) if group.startswith(WARMUP_TAG): if group not in self._groups_warmup: - self._groups_warmup.append(group) + supported_order = [key for key in SUPPORTED_GROUPS_ALL if key in self._groups_warmup] + if (supported_order == self._groups_warmup) and (group in SUPPORTED_GROUPS_ALL): + group_order = [key for key in SUPPORTED_GROUPS_ALL if key in self._groups_warmup + [group]] + group_idx = group_order.index(group) + self._groups_warmup.insert(group_idx, group) + else: + self._groups_warmup.append(group) else: if group not in self._groups: - self._groups.append(group) + supported_order = [key for key in SUPPORTED_GROUPS_ALL if key in self._groups] + if (supported_order == self._groups) and (group in SUPPORTED_GROUPS_ALL): + group_order = [key for key in SUPPORTED_GROUPS_ALL if key in self._groups + [group]] + group_idx = group_order.index(group) + self._groups.insert(group_idx, group) + else: + self._groups.append(group) set_index = _extend_xr_method(xr.Dataset.set_index, see_also="reset_index") get_index = _extend_xr_method(xr.Dataset.get_index) From ec699c3aab81d3f02e6d384547d43371ab567e25 Mon Sep 17 00:00:00 2001 From: Ari Hartikainen Date: Sun, 16 Jan 2022 11:36:09 +0200 Subject: [PATCH 5/6] Update logic for add_groups --- arviz/data/inference_data.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/arviz/data/inference_data.py b/arviz/data/inference_data.py index 650c8960d2..348d73f318 100644 --- a/arviz/data/inference_data.py +++ b/arviz/data/inference_data.py @@ -1358,9 +1358,21 @@ def add_groups(self, group_dict=None, coords=None, dims=None, **kwargs): if dataset: setattr(self, group, dataset) if group.startswith(WARMUP_TAG): - self._groups_warmup.append(group) + supported_order = [key for key in SUPPORTED_GROUPS_ALL if key in self._groups_warmup] + if (supported_order == self._groups_warmup) and (group in SUPPORTED_GROUPS_ALL): + group_order = [key for key in SUPPORTED_GROUPS_ALL if key in self._groups_warmup + [group]] + group_idx = group_order.index(group) + self._groups_warmup.insert(group_idx, group) + else: + self._groups_warmup.append(group) else: - self._groups.append(group) + supported_order = [key for key in SUPPORTED_GROUPS_ALL if key in self._groups] + if (supported_order == self._groups) and (group in SUPPORTED_GROUPS_ALL): + group_order = [key for key in SUPPORTED_GROUPS_ALL if key in self._groups + [group]] + group_idx = group_order.index(group) + self._groups.insert(group_idx, group) + else: + self._groups.append(group) def extend(self, other, join="left"): """Extend InferenceData with groups from another InferenceData. From 06cfa92211db69d211eb449514977419daab4d9e Mon Sep 17 00:00:00 2001 From: Ari Hartikainen Date: Sun, 16 Jan 2022 11:44:06 +0200 Subject: [PATCH 6/6] Fix long lines --- arviz/data/inference_data.py | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/arviz/data/inference_data.py b/arviz/data/inference_data.py index 348d73f318..3c609d06a2 100644 --- a/arviz/data/inference_data.py +++ b/arviz/data/inference_data.py @@ -1358,9 +1358,15 @@ def add_groups(self, group_dict=None, coords=None, dims=None, **kwargs): if dataset: setattr(self, group, dataset) if group.startswith(WARMUP_TAG): - supported_order = [key for key in SUPPORTED_GROUPS_ALL if key in self._groups_warmup] + supported_order = [ + key for key in SUPPORTED_GROUPS_ALL if key in self._groups_warmup + ] if (supported_order == self._groups_warmup) and (group in SUPPORTED_GROUPS_ALL): - group_order = [key for key in SUPPORTED_GROUPS_ALL if key in self._groups_warmup + [group]] + group_order = [ + key + for key in SUPPORTED_GROUPS_ALL + if key in self._groups_warmup + [group] + ] group_idx = group_order.index(group) self._groups_warmup.insert(group_idx, group) else: @@ -1368,7 +1374,9 @@ def add_groups(self, group_dict=None, coords=None, dims=None, **kwargs): else: supported_order = [key for key in SUPPORTED_GROUPS_ALL if key in self._groups] if (supported_order == self._groups) and (group in SUPPORTED_GROUPS_ALL): - group_order = [key for key in SUPPORTED_GROUPS_ALL if key in self._groups + [group]] + group_order = [ + key for key in SUPPORTED_GROUPS_ALL if key in self._groups + [group] + ] group_idx = group_order.index(group) self._groups.insert(group_idx, group) else: @@ -1408,9 +1416,15 @@ def extend(self, other, join="left"): setattr(self, group, dataset) if group.startswith(WARMUP_TAG): if group not in self._groups_warmup: - supported_order = [key for key in SUPPORTED_GROUPS_ALL if key in self._groups_warmup] + supported_order = [ + key for key in SUPPORTED_GROUPS_ALL if key in self._groups_warmup + ] if (supported_order == self._groups_warmup) and (group in SUPPORTED_GROUPS_ALL): - group_order = [key for key in SUPPORTED_GROUPS_ALL if key in self._groups_warmup + [group]] + group_order = [ + key + for key in SUPPORTED_GROUPS_ALL + if key in self._groups_warmup + [group] + ] group_idx = group_order.index(group) self._groups_warmup.insert(group_idx, group) else: @@ -1419,7 +1433,9 @@ def extend(self, other, join="left"): if group not in self._groups: supported_order = [key for key in SUPPORTED_GROUPS_ALL if key in self._groups] if (supported_order == self._groups) and (group in SUPPORTED_GROUPS_ALL): - group_order = [key for key in SUPPORTED_GROUPS_ALL if key in self._groups + [group]] + group_order = [ + key for key in SUPPORTED_GROUPS_ALL if key in self._groups + [group] + ] group_idx = group_order.index(group) self._groups.insert(group_idx, group) else: