From 882d17804bd44569d496d8340671fe920359a7db Mon Sep 17 00:00:00 2001 From: Nicolas Abril Date: Mon, 7 Oct 2024 15:19:26 +0200 Subject: [PATCH] Fix Ctr patterns not being renamed in imports --- src/imports/book.rs | 25 +++++++++++++------ .../import_system/import_type.bend | 6 ++++- 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/src/imports/book.rs b/src/imports/book.rs index 8ed6c0f5..67a171be 100644 --- a/src/imports/book.rs +++ b/src/imports/book.rs @@ -2,7 +2,7 @@ use super::{BindMap, ImportsMap, PackageLoader}; use crate::{ diagnostics::{Diagnostics, DiagnosticsConfig}, fun::{ - parser::ParseBook, Adt, AdtCtr, Book, Definition, HvmDefinition, Name, Rule, Source, SourceKind, Term, + parser::ParseBook, Adt, AdtCtr, Book, Definition, HvmDefinition, Name, Pattern, Source, SourceKind, Term, }, imp::{self, Expr, MatchArm, Stmt}, imports::packages::Packages, @@ -346,19 +346,30 @@ trait Def { impl Def for Definition { fn apply_binds(&mut self, maybe_constructor: bool, binds: &BindMap) { - fn rename_ctr_patterns(rule: &mut Rule, binds: &BindMap) { - for pat in &mut rule.pats { - for bind in pat.binds_mut().flatten() { - if let Some(alias) = binds.get(bind) { - *bind = alias.clone(); + fn rename_ctr_pattern(pat: &mut Pattern, binds: &BindMap) { + for pat in pat.children_mut() { + rename_ctr_pattern(pat, binds); + } + match pat { + Pattern::Ctr(nam, _) => { + if let Some(alias) = binds.get(nam) { + *nam = alias.clone(); + } + } + Pattern::Var(Some(nam)) => { + if let Some(alias) = binds.get(nam) { + *nam = alias.clone(); } } + _ => {} } } for rule in &mut self.rules { if maybe_constructor { - rename_ctr_patterns(rule, binds); + for pat in &mut rule.pats { + rename_ctr_pattern(pat, binds); + } } let bod = std::mem::take(&mut rule.body); rule.body = bod.fold_uses(binds.iter().rev()); diff --git a/tests/golden_tests/import_system/import_type.bend b/tests/golden_tests/import_system/import_type.bend index b79c33fa..fae91d64 100644 --- a/tests/golden_tests/import_system/import_type.bend +++ b/tests/golden_tests/import_system/import_type.bend @@ -1,7 +1,11 @@ from lib/MyOption import (MyOption, MyOption/bind, MyOption/wrap) +unwrap (val : (MyOption u24)) : u24 +unwrap (MyOption/Some x) = x +unwrap (MyOption/None) = 0 + def main() -> MyOption((u24, u24)): with MyOption: a <- MyOption/Some(1) - b <- MyOption/Some(2) + b = unwrap(MyOption/Some(2)) return wrap((a, b))