Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 125 additions & 24 deletions crates/ide-assists/src/handlers/promote_local_to_const.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ use hir::HirDisplay;
use ide_db::{assists::AssistId, defs::Definition};
use stdx::to_upper_snake_case;
use syntax::{
AstNode,
AstNode, T,
ast::{self, HasName},
syntax_editor::Position,
};

use crate::{
Expand Down Expand Up @@ -40,7 +41,7 @@ use crate::{
// }
// ```
pub(crate) fn promote_local_to_const(acc: &mut Assists, ctx: &AssistContext<'_, '_>) -> Option<()> {
let pat = ctx.find_node_at_offset::<ast::IdentPat>()?;
let pat = ctx.find_node_at_offset_with_descend::<ast::IdentPat>()?;
let name = pat.name()?;
if !pat.is_simple_ident() {
cov_mark::hit!(promote_local_non_simple_ident);
Expand All @@ -51,52 +52,82 @@ pub(crate) fn promote_local_to_const(acc: &mut Assists, ctx: &AssistContext<'_,
let module = ctx.sema.scope(pat.syntax())?.module();
let local = ctx.sema.to_def(&pat)?;
let ty = ctx.sema.type_of_pat(&pat.into())?.original;

let ty = match ty.display_source_code(ctx.db(), module.into(), false) {
Ok(ty) => ty,
Err(_) => return None,
};
let ty = ty.display_source_code(ctx.db(), module.into(), false).ok()?;

let initializer = let_stmt.initializer()?;
if !utils::is_body_const(&ctx.sema, &initializer) {
cov_mark::hit!(promote_local_non_const);
return None;
}

let let_stmt_range = ctx.sema.original_range_opt(let_stmt.syntax())?.range;
let node_in_source = let_stmt.syntax().text_range() == let_stmt_range;

let usages = Definition::Local(local).usages(&ctx.sema).all();
if let Some(usages) = usages.references.get(&ctx.file_id()) {
for usage in usages {
let Some(usage_name) = usage.name.as_name_ref() else {
continue;
};
let Some(record_field) = ast::RecordExprField::for_field_name(usage_name) else {
continue;
};
let Some(record_expr) = record_field.syntax().parent().and_then(|list| list.parent())
else {
continue;
};
if utils::original_range_in(ctx.file_id(), &ctx.sema, &record_expr).is_none() {
cov_mark::hit!(promote_local_shorthand_in_macro);
return None;
}
}
}

let const_name = to_upper_snake_case(&name.to_string());

acc.add(
AssistId::refactor("promote_local_to_const"),
"Promote local to constant",
let_stmt.syntax().text_range(),
let_stmt_range,
|edit| {
let editor = edit.make_editor(let_stmt.syntax());
let source = ctx.source_file().syntax();
let editor = edit.make_editor(source);
let make = editor.make();
let name = to_upper_snake_case(&name.to_string());
let usages = Definition::Local(local).usages(&ctx.sema).all();
if let Some(usages) = usages.references.get(&ctx.file_id()) {
let name_ref = make.name_ref(&name);

if let Some(usages) = usages.references.get(&ctx.file_id()) {
for usage in usages {
let Some(usage_name) = usage.name.as_name_ref().cloned() else { continue };
if let Some(record_field) = ast::RecordExprField::for_name_ref(&usage_name) {
let path = make.ident_path(&name);
let name_expr = make.expr_path(path);
utils::replace_record_field_expr(ctx, edit, record_field, name_expr);
let Some(usage_name) = usage.name.as_name_ref() else {
continue;
};
let place = utils::cover_edit_range(source, usage.range);
if ast::RecordExprField::for_field_name(usage_name).is_some() {
editor.insert_all(
Position::after(place.end()),
vec![
make.token(T![:]).into(),
make.whitespace(" ").into(),
make.name_ref(&const_name).syntax().clone().into(),
],
);
} else {
let usage_range = usage.range;
edit.replace(usage_range, name_ref.syntax().text());
editor.replace_all(
place,
vec![make.name_ref(&const_name).syntax().clone().into()],
);
}
}
}

let item = make.item_const(None, None, make.name(&name), make.ty(&ty), initializer);
let item =
make.item_const(None, None, make.name(&const_name), make.ty(&ty), initializer);

if let Some((cap, name)) = ctx.config.snippet_cap.zip(item.name()) {
if node_in_source && let Some((cap, name)) = ctx.config.snippet_cap.zip(item.name()) {
let tabstop = edit.make_tabstop_before(cap);
editor.add_annotation(name.syntax().clone(), tabstop);
}

editor.replace(let_stmt.syntax(), item.syntax());

let place = utils::cover_edit_range(source, let_stmt_range);
editor.replace_all(place, vec![item.syntax().clone().into()]);
edit.add_file_edits(ctx.vfs_file_id(), editor);
},
)
Expand Down Expand Up @@ -292,6 +323,76 @@ fn foo() {
);
}

#[test]
fn let_in_macro() {
check_assist(
promote_local_to_const,
r#"
//- proc_macros: identity
#[proc_macros::identity]
fn f() {
let x$0 = 0;
let _ = x;
}
"#,
r"
#[proc_macros::identity]
fn f() {
const X: i32 = 0;
let _ = X;
}
",
);

check_assist(
promote_local_to_const,
r"
macro_rules! id { ($($tt:tt)*) => { $($tt)* }; }

fn f() {
id! {
let x$0 = 0;
let _ = x;
}
}
",
r"
macro_rules! id { ($($tt:tt)*) => { $($tt)* }; }

fn f() {
id! {
const X: i32 = 0;
let _ = X;
}
}
",
);
}

#[test]
fn not_applicable_shorthand_in_macro() {
cov_mark::check!(promote_local_shorthand_in_macro);
check_assist_not_applicable(
promote_local_to_const,
r"
struct Foo {
foo: usize,
}

macro_rules! make_foo {
($v:ident) => {
Foo { $v }
};
}

fn baz() -> Foo {
let $0foo = 2;
make_foo!(foo)
}
",
);
}

#[test]
fn not_applicable_non_simple_ident() {
cov_mark::check!(promote_local_non_simple_ident);
Expand Down