From 7823a106936d032bb835d4b82576d31fda2d5b9e Mon Sep 17 00:00:00 2001 From: Jesse Herrick Date: Mon, 4 May 2026 09:48:48 -0600 Subject: [PATCH 1/3] Add support for struct auto-completion --- internal/lsp/beam_server.exs | 61 +++++ internal/lsp/elixir.go | 416 +++++++++++++++++++++++++++++++ internal/lsp/elixir_test.go | 233 +++++++++++++++++ internal/lsp/formatter.go | 40 +++ internal/lsp/server.go | 469 +++++++++++++++++++++++++++++++++-- internal/lsp/server_test.go | 242 ++++++++++++++++++ 6 files changed, 1440 insertions(+), 21 deletions(-) diff --git a/internal/lsp/beam_server.exs b/internal/lsp/beam_server.exs index 6fbc489..d0cb50e 100644 --- a/internal/lsp/beam_server.exs +++ b/internal/lsp/beam_server.exs @@ -44,6 +44,9 @@ # CodeIntel op 4 (runtime_info) payload: # empty # +# CodeIntel op 5 (struct_fields) payload: +# 2-byte Elixir module length (big-endian) + module +# # Notification 0 (otp_modules_ready) payload: # 2-byte module_count (big-endian) + [name_len(u16) name] # @@ -450,6 +453,7 @@ defmodule Dexter.CodeIntel do @op_warm_otp_modules 2 @op_erlang_exports 3 @op_runtime_info 4 + @op_struct_fields 5 def handle_request(op, payload) do case op do @@ -458,6 +462,7 @@ defmodule Dexter.CodeIntel do @op_warm_otp_modules -> handle_warm_otp_modules(payload) @op_erlang_exports -> handle_erlang_exports(payload) @op_runtime_info -> handle_runtime_info(payload) + @op_struct_fields -> handle_struct_fields(payload) _ -> {1, "unknown code intel op: #{inspect(op)}"} end end @@ -821,6 +826,62 @@ defmodule Dexter.CodeIntel do byte_size(code_root_dir)::unsigned-big-16, code_root_dir::binary>>} end + defp handle_struct_fields(payload) do + case parse_module(payload) do + {:ok, module_name} -> + case fetch_struct_fields(module_name) do + {:ok, fields} -> {0, encode_string_list(fields)} + :error -> {1, "struct fields not found"} + end + + :error -> + {1, "invalid struct_fields payload"} + end + end + + defp fetch_struct_fields(module_name) do + with {:ok, module_atom} <- elixir_module_atom(module_name), + {:module, ^module_atom} <- Code.ensure_loaded(module_atom), + true <- function_exported?(module_atom, :__struct__, 0), + struct when is_map(struct) <- apply(module_atom, :__struct__, []) do + fields = + struct + |> Map.delete(:__struct__) + |> Map.keys() + |> Enum.map(&Atom.to_string/1) + |> Enum.sort() + + {:ok, fields} + else + _ -> :error + end + rescue + _ -> :error + end + + defp elixir_module_atom(module_name) do + if Regex.match?(~r/\A(?:Elixir\.)?[A-Z][A-Za-z0-9_]*(?:\.[A-Z][A-Za-z0-9_]*)*\z/, module_name) do + module_atom = + module_name + |> String.trim_leading("Elixir.") + |> String.split(".") + |> Module.concat() + + {:ok, module_atom} + else + :error + end + end + + defp encode_string_list(values) do + payload = + for value <- values, into: <<>> do + <> + end + + <> + end + defp parse_module_function_arity(payload) do with <> <- payload, {:ok, module_name, rest} <- take_string(rest, module_len), diff --git a/internal/lsp/elixir.go b/internal/lsp/elixir.go index 783585e..9b1d954 100644 --- a/internal/lsp/elixir.go +++ b/internal/lsp/elixir.go @@ -171,6 +171,422 @@ func (tf *TokenizedFile) CompletionContextAtCursor(line, col int) CompletionCont return CompletionContextAtCursor(tf.tokens, tf.source, tf.lineStarts, line, col) } +// StructCompletionContext describes completion inside `%Module{...}` field keys. +type StructCompletionContext struct { + ModuleRef string + FieldPrefix string + StartCol int +} + +type StructModuleRef struct { + ModuleRef string + Line int +} + +// StructModuleRefs returns module references used in struct literals, including +// incomplete `%Module` expressions before the opening brace has been typed. +func (tf *TokenizedFile) StructModuleRefs() []StructModuleRef { + return StructModuleRefs(tf.tokens, tf.source) +} + +func StructModuleRefs(tokens []parser.Token, source []byte) []StructModuleRef { + var refs []StructModuleRef + for i := 0; i < len(tokens); i++ { + if tokens[i].Kind != parser.TokPercent { + continue + } + j := tokNextSig(tokens, len(tokens), i+1) + if j >= len(tokens) || tokens[j].Kind != parser.TokModule { + continue + } + moduleRef, k := tokCollectModuleName(source, tokens, len(tokens), j) + if moduleRef == "" { + continue + } + k = tokNextSig(tokens, len(tokens), k) + refs = append(refs, StructModuleRef{ + ModuleRef: moduleRef, + Line: tokens[i].Line - 1, + }) + if k < len(tokens) && tokens[k].Kind == parser.TokOpenBrace { + i = k + } else { + i = k - 1 + } + } + return refs +} + +// StructCompletionContextAtCursor returns the struct module and current field +// prefix when the cursor is inside a struct literal/update key position. +func (tf *TokenizedFile) StructCompletionContextAtCursor(line, col int) (StructCompletionContext, bool) { + return StructCompletionContextAtCursor(tf.tokens, tf.source, tf.lineStarts, line, col) +} + +// StructValueContextAtCursor reports whether the cursor is in a top-level value +// position inside a struct literal, e.g. `%User{name: |}`. +func (tf *TokenizedFile) StructValueContextAtCursor(line, col int) bool { + return StructValueContextAtCursor(tf.tokens, tf.source, tf.lineStarts, line, col) +} + +func (tf *TokenizedFile) VariableNamesBeforeCursor(line, col int) []string { + return VariableNamesBeforeCursor(tf.tokens, tf.source, tf.lineStarts, line, col) +} + +func VariableNamesBeforeCursor(tokens []parser.Token, source []byte, lineStarts []int, line, col int) []string { + offset := parser.LineColToOffset(lineStarts, line, col) + if offset < 0 { + return nil + } + + defIdx := -1 + for i := 0; i < len(tokens); i++ { + tok := tokens[i] + if tok.Kind == parser.TokEOF || tok.Start >= offset { + break + } + if isFunctionDefinitionToken(tok.Kind) { + defIdx = i + } + } + if defIdx < 0 { + return nil + } + + seen := make(map[string]bool) + var names []string + for i := defIdx + 1; i < len(tokens); i++ { + tok := tokens[i] + if tok.Kind == parser.TokEOF || tok.Start >= offset { + break + } + if tok.Kind != parser.TokIdent { + continue + } + name := parser.TokenText(source, tok) + if strings.HasPrefix(name, "_") || parser.IsElixirKeyword(name) { + continue + } + prev := prevSignificantToken(tokens, i) + if prev >= 0 { + if tokens[prev].Kind == parser.TokDot || isFunctionDefinitionToken(tokens[prev].Kind) { + continue + } + } + next := tokNextSig(tokens, len(tokens), i+1) + if next < len(tokens) { + if tokens[next].Kind == parser.TokColon { + continue + } + if tokens[next].Kind == parser.TokOpenParen { + continue + } + } + if !seen[name] { + seen[name] = true + names = append(names, name) + } + } + return names +} + +func isFunctionDefinitionToken(kind parser.TokenKind) bool { + switch kind { + case parser.TokDef, parser.TokDefp, parser.TokDefmacro, parser.TokDefmacrop, + parser.TokDefguard, parser.TokDefguardp, parser.TokDefdelegate: + return true + default: + return false + } +} + +func StructValueContextAtCursor(tokens []parser.Token, source []byte, lineStarts []int, line, col int) bool { + if line < 0 || line >= len(lineStarts) || col < 0 { + return false + } + offset := parser.LineColToOffset(lineStarts, line, col) + if offset < 0 { + return false + } + + openIdx := enclosingOpenBraceBeforeOffset(tokens, offset) + if openIdx < 0 { + return false + } + if _, ok := structModuleBeforeOpenBrace(tokens, source, openIdx); !ok { + return false + } + + return structValuePositionAtOffset(tokens, openIdx, offset) +} + +// StructCompletionContextAtCursor returns struct-key completion context at the +// given 0-based line/column. It intentionally rejects value positions, so +// `%User{name: |}` does not ask for field completions. +func StructCompletionContextAtCursor(tokens []parser.Token, source []byte, lineStarts []int, line, col int) (StructCompletionContext, bool) { + if line < 0 || line >= len(lineStarts) || col < 0 { + return StructCompletionContext{}, false + } + offset := parser.LineColToOffset(lineStarts, line, col) + if offset < 0 { + return StructCompletionContext{}, false + } + + openIdx := enclosingOpenBraceBeforeOffset(tokens, offset) + if openIdx < 0 { + return StructCompletionContext{}, false + } + + moduleRef, ok := structModuleBeforeOpenBrace(tokens, source, openIdx) + if !ok { + return StructCompletionContext{}, false + } + + fieldPrefix, startOffset, ok := structFieldPrefixAtOffset(tokens, source, openIdx, offset) + if !ok { + return StructCompletionContext{}, false + } + + return StructCompletionContext{ + ModuleRef: moduleRef, + FieldPrefix: fieldPrefix, + StartCol: startOffset - lineStarts[line], + }, true +} + +func enclosingOpenBraceBeforeOffset(tokens []parser.Token, offset int) int { + depth := 0 + for i := len(tokens) - 1; i >= 0; i-- { + tok := tokens[i] + if tok.Kind == parser.TokEOF || tok.Start >= offset { + continue + } + switch tok.Kind { + case parser.TokCloseBrace: + depth++ + case parser.TokOpenBrace: + if depth == 0 { + return i + } + depth-- + } + } + return -1 +} + +func prevSignificantToken(tokens []parser.Token, before int) int { + for i := before - 1; i >= 0; i-- { + switch tokens[i].Kind { + case parser.TokEOL, parser.TokComment: + continue + default: + return i + } + } + return -1 +} + +func structModuleBeforeOpenBrace(tokens []parser.Token, source []byte, openIdx int) (string, bool) { + endIdx := prevSignificantToken(tokens, openIdx) + if endIdx < 0 || tokens[endIdx].Kind != parser.TokModule { + return "", false + } + + startIdx := endIdx + for startIdx >= 2 && tokens[startIdx-1].Kind == parser.TokDot && tokens[startIdx-2].Kind == parser.TokModule { + startIdx -= 2 + } + + percentIdx := prevSignificantToken(tokens, startIdx) + if percentIdx < 0 || tokens[percentIdx].Kind != parser.TokPercent { + return "", false + } + + moduleRef, nextIdx := tokCollectModuleName(source, tokens, len(tokens), startIdx) + if moduleRef == "" || nextIdx != openIdx { + return "", false + } + return moduleRef, true +} + +func structFieldPrefixAtOffset(tokens []parser.Token, source []byte, openIdx, offset int) (string, int, bool) { + segmentStartOffset := tokens[openIdx].End + parenDepth, bracketDepth, braceDepth := 0, 0, 0 + inValue := false + + for i := openIdx + 1; i < len(tokens); i++ { + tok := tokens[i] + if tok.Kind == parser.TokEOF || tok.Start >= offset { + break + } + + if parenDepth == 0 && bracketDepth == 0 && braceDepth == 0 { + switch tok.Kind { + case parser.TokComma: + segmentStartOffset = tok.End + inValue = false + continue + case parser.TokColon, parser.TokAssoc: + inValue = true + continue + case parser.TokOther: + if parser.TokenText(source, tok) == "|" && !inValue { + segmentStartOffset = tok.End + continue + } + } + } + + switch tok.Kind { + case parser.TokOpenParen: + parenDepth++ + case parser.TokCloseParen: + if parenDepth > 0 { + parenDepth-- + } + case parser.TokOpenBracket: + bracketDepth++ + case parser.TokCloseBracket: + if bracketDepth > 0 { + bracketDepth-- + } + case parser.TokOpenBrace: + braceDepth++ + case parser.TokCloseBrace: + if braceDepth == 0 { + return "", 0, false + } + braceDepth-- + } + } + + if parenDepth != 0 || bracketDepth != 0 || braceDepth != 0 { + return "", 0, false + } + if inValue { + return "", 0, false + } + if segmentStartOffset == tokens[openIdx].End && hasTopLevelStructUpdatePipeAhead(tokens, source, openIdx, offset) { + return "", 0, false + } + + if offset > 0 { + if idx := parser.TokenAtOffset(tokens, offset-1); idx >= 0 { + tok := tokens[idx] + if tok.Start >= segmentStartOffset && tok.Kind == parser.TokIdent { + end := offset + if end > tok.End { + end = tok.End + } + if end > tok.Start { + return string(source[tok.Start:end]), tok.Start, true + } + } + switch tok.Kind { + case parser.TokOpenBrace, parser.TokComma, parser.TokPipe, parser.TokEOL, parser.TokComment: + return "", offset, true + } + } + } + + return "", offset, true +} + +func structValuePositionAtOffset(tokens []parser.Token, openIdx, offset int) bool { + parenDepth, bracketDepth, braceDepth := 0, 0, 0 + inValue := false + + for i := openIdx + 1; i < len(tokens); i++ { + tok := tokens[i] + if tok.Kind == parser.TokEOF || tok.Start >= offset { + break + } + + if parenDepth == 0 && bracketDepth == 0 && braceDepth == 0 { + switch tok.Kind { + case parser.TokComma: + inValue = false + continue + case parser.TokColon, parser.TokAssoc: + inValue = true + continue + case parser.TokCloseBrace: + return false + } + } + + switch tok.Kind { + case parser.TokOpenParen: + parenDepth++ + case parser.TokCloseParen: + if parenDepth > 0 { + parenDepth-- + } + case parser.TokOpenBracket: + bracketDepth++ + case parser.TokCloseBracket: + if bracketDepth > 0 { + bracketDepth-- + } + case parser.TokOpenBrace: + braceDepth++ + case parser.TokCloseBrace: + if braceDepth == 0 { + return false + } + braceDepth-- + } + } + + return inValue && parenDepth == 0 && bracketDepth == 0 && braceDepth == 0 +} + +func hasTopLevelStructUpdatePipeAhead(tokens []parser.Token, source []byte, openIdx, offset int) bool { + parenDepth, bracketDepth, braceDepth := 0, 0, 0 + for i := openIdx + 1; i < len(tokens); i++ { + tok := tokens[i] + if tok.Kind == parser.TokEOF { + return false + } + if tok.Start < offset { + continue + } + + if parenDepth == 0 && bracketDepth == 0 && braceDepth == 0 { + switch tok.Kind { + case parser.TokColon, parser.TokAssoc, parser.TokComma, parser.TokCloseBrace: + return false + case parser.TokOther: + if parser.TokenText(source, tok) == "|" { + return true + } + } + } + + switch tok.Kind { + case parser.TokOpenParen: + parenDepth++ + case parser.TokCloseParen: + if parenDepth > 0 { + parenDepth-- + } + case parser.TokOpenBracket: + bracketDepth++ + case parser.TokCloseBracket: + if bracketDepth > 0 { + bracketDepth-- + } + case parser.TokOpenBrace: + braceDepth++ + case parser.TokCloseBrace: + if braceDepth > 0 { + braceDepth-- + } + } + } + return false +} + // CompletionContextAtCursor extracts the token-aware completion context at the // given 0-based line/column. func CompletionContextAtCursor(tokens []parser.Token, source []byte, lineStarts []int, line, col int) CompletionContext { diff --git a/internal/lsp/elixir_test.go b/internal/lsp/elixir_test.go index 1f33c3d..8d573d0 100644 --- a/internal/lsp/elixir_test.go +++ b/internal/lsp/elixir_test.go @@ -390,6 +390,239 @@ func TestCompletionContextAtCursor(t *testing.T) { } } +func TestStructCompletionContextAtCursor(t *testing.T) { + tests := []struct { + name string + code string + line int + col int + wantOK bool + wantModule string + wantPrefix string + wantStart int + }{ + { + name: "field prefix", + code: " %User{em", + line: 0, + col: len(" %User{em"), + wantOK: true, + wantModule: "User", + wantPrefix: "em", + wantStart: len(" %User{"), + }, + { + name: "empty field prefix", + code: " %User{", + line: 0, + col: len(" %User{"), + wantOK: true, + wantModule: "User", + wantPrefix: "", + wantStart: len(" %User{"), + }, + { + name: "qualified module", + code: " %MyApp.User{na", + line: 0, + col: len(" %MyApp.User{na"), + wantOK: true, + wantModule: "MyApp.User", + wantPrefix: "na", + wantStart: len(" %MyApp.User{"), + }, + { + name: "module special", + code: " %__MODULE__{na", + line: 0, + col: len(" %__MODULE__{na"), + wantOK: true, + wantModule: "__MODULE__", + wantPrefix: "na", + wantStart: len(" %__MODULE__{"), + }, + { + name: "after previous field", + code: ` %User{name: "x", em`, + line: 0, + col: len(` %User{name: "x", em`), + wantOK: true, + wantModule: "User", + wantPrefix: "em", + wantStart: len(` %User{name: "x", `), + }, + { + name: "struct update", + code: " %User{user | em", + line: 0, + col: len(" %User{user | em"), + wantOK: true, + wantModule: "User", + wantPrefix: "em", + wantStart: len(" %User{user | "), + }, + { + name: "struct update variable position rejected", + code: " %User{user | em", + line: 0, + col: len(" %User{user"), + wantOK: false, + }, + { + name: "struct update empty key position after pipe", + code: " %User{user | ", + line: 0, + col: len(" %User{user | "), + wantOK: true, + wantModule: "User", + wantPrefix: "", + wantStart: len(" %User{user | "), + }, + { + name: "multiline field prefix", + code: " %User{\n em\n }", + line: 1, + col: len(" em"), + wantOK: true, + wantModule: "User", + wantPrefix: "em", + wantStart: len(" "), + }, + { + name: "value position rejected", + code: " %User{name: ", + line: 0, + col: len(" %User{name: "), + wantOK: false, + }, + { + name: "nested map value rejected", + code: " %User{name: %{em", + line: 0, + col: len(" %User{name: %{em"), + wantOK: false, + }, + { + name: "plain map rejected", + code: " %{em", + line: 0, + col: len(" %{em"), + wantOK: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tokens, source, lineStarts := tokenize(tt.code) + ctx, ok := StructCompletionContextAtCursor(tokens, source, lineStarts, tt.line, tt.col) + if ok != tt.wantOK { + t.Fatalf("ok = %v, want %v", ok, tt.wantOK) + } + if !ok { + return + } + if ctx.ModuleRef != tt.wantModule { + t.Errorf("ModuleRef = %q, want %q", ctx.ModuleRef, tt.wantModule) + } + if ctx.FieldPrefix != tt.wantPrefix { + t.Errorf("FieldPrefix = %q, want %q", ctx.FieldPrefix, tt.wantPrefix) + } + if ctx.StartCol != tt.wantStart { + t.Errorf("StartCol = %d, want %d", ctx.StartCol, tt.wantStart) + } + }) + } +} + +func TestStructValueContextAtCursor(t *testing.T) { + tests := []struct { + name string + code string + line int + col int + want bool + }{ + { + name: "empty value position", + code: " %User{name: ", + line: 0, + col: len(" %User{name: "), + want: true, + }, + { + name: "typed value position", + code: " %User{name: na", + line: 0, + col: len(" %User{name: na"), + want: true, + }, + { + name: "key position", + code: " %User{na", + line: 0, + col: len(" %User{na"), + want: false, + }, + { + name: "struct update variable position", + code: " %User{user | name: ", + line: 0, + col: len(" %User{user"), + want: false, + }, + { + name: "nested map value rejected", + code: " %User{meta: %{", + line: 0, + col: len(" %User{meta: %{"), + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tokens, source, lineStarts := tokenize(tt.code) + got := StructValueContextAtCursor(tokens, source, lineStarts, tt.line, tt.col) + if got != tt.want { + t.Fatalf("StructValueContextAtCursor = %v, want %v", got, tt.want) + } + }) + } +} + +func TestStructModuleRefs(t *testing.T) { + code := `defmodule MyApp.Controller do + alias MyApp.Accounts.User + + def run do + %User{name: "A"} + %MyApp.Accounts.Org{ + id: 1 + } + %{} + %__MODULE__{} + %MyApp.Accounts.Pending + end +end` + tokens, source, _ := tokenize(code) + refs := StructModuleRefs(tokens, source) + + want := []StructModuleRef{ + {ModuleRef: "User", Line: 4}, + {ModuleRef: "MyApp.Accounts.Org", Line: 5}, + {ModuleRef: "__MODULE__", Line: 9}, + {ModuleRef: "MyApp.Accounts.Pending", Line: 10}, + } + if len(refs) != len(want) { + t.Fatalf("got %d refs, want %d: %#v", len(refs), len(want), refs) + } + for i := range want { + if refs[i] != want[i] { + t.Errorf("refs[%d] = %#v, want %#v", i, refs[i], want[i]) + } + } +} + func TestFullExpressionAtCursor(t *testing.T) { code := " Foo.Bar.baz(123)" tokens, source, lineStarts := tokenize(code) diff --git a/internal/lsp/formatter.go b/internal/lsp/formatter.go index 550e250..a7b1b17 100644 --- a/internal/lsp/formatter.go +++ b/internal/lsp/formatter.go @@ -50,6 +50,7 @@ const ( codeIntelOpWarmOTPModules byte = 0x02 codeIntelOpErlangExports byte = 0x03 codeIntelOpRuntimeInfo byte = 0x04 + codeIntelOpStructFields byte = 0x05 beamNotificationOTPModulesReady byte = 0x00 beamNotificationOTPModulesFailed byte = 0x01 @@ -657,6 +658,44 @@ func (bp *beamProcess) ErlangExports(ctx context.Context, module string) ([]Erla return exports, err } +// StructFields asks the BEAM's CodeIntel service for the fields of a compiled +// Elixir struct module. +func (bp *beamProcess) StructFields(ctx context.Context, module string) ([]string, error) { + var fields []string + var payload bytes.Buffer + _ = binary.Write(&payload, binary.BigEndian, uint16(len(module))) + payload.WriteString(module) + + err := bp.doRequest(ctx, serviceCodeIntel, codeIntelOpStructFields, payload.Bytes(), func(status byte, payload []byte) error { + if status != 0 { + if len(payload) > 0 { + return fmt.Errorf("struct fields failed: %s", strings.TrimSpace(string(payload))) + } + return fmt.Errorf("struct fields failed") + } + + reader := bytes.NewReader(payload) + var fieldCount uint16 + if err := binary.Read(reader, binary.BigEndian, &fieldCount); err != nil { + return fmt.Errorf("read field count: %w", err) + } + fields = make([]string, 0, fieldCount) + for i := 0; i < int(fieldCount); i++ { + var fieldLen uint16 + if err := binary.Read(reader, binary.BigEndian, &fieldLen); err != nil { + return fmt.Errorf("read field length: %w", err) + } + fieldBuf := make([]byte, fieldLen) + if _, err := io.ReadFull(reader, fieldBuf); err != nil { + return fmt.Errorf("read field: %w", err) + } + fields = append(fields, string(fieldBuf)) + } + return nil + }) + return fields, err +} + // FormatError represents a formatting failure (e.g. syntax error in the source). // The persistent process is still alive — this is not a protocol/crash error. type FormatError struct { @@ -938,6 +977,7 @@ func (s *Server) evictBeam(bp *beamProcess, reason string) { if buildRoot != "" { log.Printf("BEAM: evicting process for %s (pid %d): %s", buildRoot, bp.cmd.process.Pid, reason) + s.clearStructFieldCacheForBuildRoot(buildRoot) } else { log.Printf("BEAM: evicting untracked process (pid %d): %s", bp.cmd.process.Pid, reason) } diff --git a/internal/lsp/server.go b/internal/lsp/server.go index 4f1d83e..a5aabee 100644 --- a/internal/lsp/server.go +++ b/internal/lsp/server.go @@ -67,6 +67,17 @@ type erlangRuntimeCache struct { readyCh chan struct{} } +type structFieldCacheKey struct { + buildRoot string + module string +} + +type structFieldCacheEntry struct { + fields []string + loaded bool + loading bool +} + type Server struct { store *store.Store docs *DocumentStore @@ -86,6 +97,10 @@ type Server struct { erlangRuntimeCache map[string]*erlangRuntimeCache // runtime key → cached OTP modules/exports erlangRuntimeMu sync.Mutex + structFieldCache map[structFieldCacheKey]*structFieldCacheEntry + structFieldMu sync.Mutex + structFieldGen uint64 + usingCache map[string]*usingCacheEntry // module name → parsed __using__ result usingCacheMu sync.RWMutex @@ -124,6 +139,7 @@ func NewServer(s *store.Store, projectRoot string) *Server { followDelegates: true, erlangBuildRoots: make(map[string]*erlangBuildRootState), erlangRuntimeCache: make(map[string]*erlangRuntimeCache), + structFieldCache: make(map[structFieldCacheKey]*structFieldCacheEntry), usingCache: make(map[string]*usingCacheEntry), depsCache: make(map[string]bool), } @@ -205,6 +221,7 @@ func (s *Server) backgroundReindex() { if !indexRefs { refs = nil } + s.invalidateStructFieldCacheForFile(path, defs) if err := s.store.IndexFileWithRefs(path, defs, refs); err != nil { log.Printf("Warning: reindex %s: %v", path, err) } @@ -229,6 +246,9 @@ func (s *Server) backgroundReindex() { } } if len(toRemove) > 0 { + for _, path := range toRemove { + s.invalidateStructFieldCacheForFile(path, nil) + } _ = s.store.RemoveFiles(toRemove) } } @@ -403,7 +423,7 @@ func (s *Server) Initialize(ctx context.Context, params *protocol.InitializePara RenameProvider: &protocol.RenameOptions{PrepareProvider: true}, CallHierarchyProvider: true, CompletionProvider: &protocol.CompletionOptions{ - TriggerCharacters: []string{"."}, + TriggerCharacters: []string{".", "{", ",", "|", ":", " "}, ResolveProvider: true, }, SignatureHelpProvider: &protocol.SignatureHelpOptions{ @@ -458,8 +478,10 @@ func (s *Server) Exit(ctx context.Context) error { func (s *Server) DidOpen(ctx context.Context, params *protocol.DidOpenTextDocumentParams) error { docURI := string(params.TextDocument.URI) - s.docs.Set(docURI, params.TextDocument.Text) + text := params.TextDocument.Text + s.docs.Set(docURI, text) path := uriToPath(params.TextDocument.URI) + s.maybePrewarmStructFields(docURI, path, text) // Eagerly start the persistent BEAM process so the first format is instant. // Skip deps and stdlib files — we don't format those. @@ -478,7 +500,9 @@ func (s *Server) DidChange(ctx context.Context, params *protocol.DidChangeTextDo if len(params.ContentChanges) > 0 { // Full sync mode — last change contains the full text text := params.ContentChanges[len(params.ContentChanges)-1].Text - s.docs.Set(string(params.TextDocument.URI), text) + docURI := string(params.TextDocument.URI) + s.docs.Set(docURI, text) + s.maybePrewarmStructFields(docURI, uriToPath(params.TextDocument.URI), text) } return nil } @@ -527,6 +551,7 @@ func (s *Server) DidSave(ctx context.Context, params *protocol.DidSaveTextDocume return } + s.invalidateStructFieldCacheForFile(path, defs) if err := s.store.IndexFileWithRefs(path, defs, refs); err != nil { log.Printf("Error indexing %s: %v", path, err) } @@ -1177,6 +1202,270 @@ func (s *Server) getErlangExports(ctx context.Context, filePath, module string) return exports } +func (s *Server) structFieldBuildRoot(filePath string) string { + if filePath != "" { + return s.findBuildRoot(filepath.Dir(filePath)) + } + return s.findBuildRoot(s.projectRoot) +} + +func (s *Server) cachedStructFieldsOrWarm(filePath, module string) ([]string, bool) { + return s.cachedStructFieldsOrWarmWithLogging(filePath, module, true) +} + +func (s *Server) prewarmStructFields(filePath, module string) { + _, _ = s.cachedStructFieldsOrWarmWithLogging(filePath, module, false) +} + +func (s *Server) cachedStructFieldsOrWarmWithLogging(filePath, module string, logCacheState bool) ([]string, bool) { + if module == "" { + return nil, false + } + + key := structFieldCacheKey{ + buildRoot: s.structFieldBuildRoot(filePath), + module: module, + } + + s.structFieldMu.Lock() + generation := s.structFieldGen + if entry := s.structFieldCache[key]; entry != nil { + if entry.loaded { + fields := append([]string(nil), entry.fields...) + s.structFieldMu.Unlock() + if s.debug && logCacheState { + s.debugf("StructFields cache hit") + s.debugf(" module=%s", key.module) + s.debugf(" fields=%d", len(fields)) + } + return fields, true + } + if entry.loading { + s.structFieldMu.Unlock() + if s.debug && logCacheState { + s.debugf("StructFields cache warming") + s.debugf(" module=%s", key.module) + } + return nil, false + } + entry.loading = true + generation = s.structFieldGen + } else { + s.structFieldCache[key] = &structFieldCacheEntry{loading: true} + } + s.structFieldMu.Unlock() + + if s.debug && logCacheState { + s.debugf("StructFields cache miss") + s.debugf(" module=%s", key.module) + s.debugf(" buildRoot=%s", key.buildRoot) + } + s.backgroundWork.Add(1) + go s.warmStructFields(key, generation) + return nil, false +} + +func (s *Server) warmStructFields(key structFieldCacheKey, generation uint64) { + defer s.backgroundWork.Done() + var tWarm time.Time + if s.debug { + tWarm = time.Now() + s.debugf("StructFields lookup start") + s.debugf(" module=%s", key.module) + s.debugf(" buildRoot=%s", key.buildRoot) + } + + ctx, cancel := context.WithTimeout(context.Background(), beamWaitTimeout) + defer cancel() + + var fields []string + status := "ok" + var readyElapsed, lookupElapsed, cacheElapsed time.Duration + + bp := s.getBeamProcess(ctx, key.buildRoot) + if bp != nil { + tReady := s.debugNow() + if err := bp.Ready(ctx); err == nil { + if s.debug { + readyElapsed = time.Since(tReady) + } + tRequest := s.debugNow() + if result, err := bp.StructFields(ctx, key.module); err == nil { + fields = result + if s.debug { + lookupElapsed = time.Since(tRequest) + } + } else { + status = "lookup_error" + s.debugf("StructFields lookup failed for %s: %v", key.module, err) + } + } else { + status = "beam_not_ready" + if s.debug { + readyElapsed = time.Since(tReady) + s.debugf("StructFields BEAM not ready for %s: %v", key.module, err) + } + } + } else { + status = "no_beam" + s.debugf("StructFields no BEAM process for module=%s buildRoot=%s", key.module, key.buildRoot) + } + + tCache := s.debugNow() + s.structFieldMu.Lock() + if generation != s.structFieldGen { + s.structFieldMu.Unlock() + if s.debug { + s.debugf("StructFields lookup discarded") + s.debugf(" module=%s", key.module) + s.debugf(" reason=cache invalidated during lookup") + s.debugf(" total=%s", time.Since(tWarm).Round(time.Microsecond)) + } + return + } + entry := s.structFieldCache[key] + if entry == nil { + entry = &structFieldCacheEntry{} + s.structFieldCache[key] = entry + } + entry.fields = fields + entry.loaded = true + entry.loading = false + s.structFieldMu.Unlock() + if s.debug { + cacheElapsed = time.Since(tCache) + s.debugf("StructFields lookup finished") + s.debugf(" module=%s", key.module) + s.debugf(" status=%s", status) + s.debugf(" fields=%d", len(fields)) + s.debugf(" ready=%s", readyElapsed.Round(time.Microsecond)) + s.debugf(" lookup=%s", lookupElapsed.Round(time.Microsecond)) + s.debugf(" cache_write=%s", cacheElapsed.Round(time.Microsecond)) + s.debugf(" total=%s", time.Since(tWarm).Round(time.Microsecond)) + } +} + +func (s *Server) clearStructFieldCacheForBuildRoot(buildRoot string) { + s.structFieldMu.Lock() + removed := 0 + for key := range s.structFieldCache { + if key.buildRoot == buildRoot { + delete(s.structFieldCache, key) + removed++ + } + } + if removed > 0 { + s.structFieldGen++ + } + s.structFieldMu.Unlock() +} + +func moduleNamesFromDefs(defs []parser.Definition) []string { + seen := make(map[string]bool) + var modules []string + for _, def := range defs { + if def.Module == "" { + continue + } + if !seen[def.Module] { + seen[def.Module] = true + modules = append(modules, def.Module) + } + } + return modules +} + +func (s *Server) invalidateStructFieldCacheForFile(filePath string, newDefs []parser.Definition) { + modules := make(map[string]bool) + if oldModules, err := s.store.LookupModulesInFile(filePath); err == nil { + for _, module := range oldModules { + modules[module] = true + } + } + for _, module := range moduleNamesFromDefs(newDefs) { + modules[module] = true + } + if len(modules) == 0 { + return + } + + removed := 0 + s.structFieldMu.Lock() + for key := range s.structFieldCache { + if modules[key.module] { + delete(s.structFieldCache, key) + removed++ + } + } + if removed > 0 { + s.structFieldGen++ + } + s.structFieldMu.Unlock() + + if s.debug && removed > 0 { + s.debugf("StructFields cache invalidated") + s.debugf(" file=%s", filePath) + s.debugf(" modules=%d", len(modules)) + s.debugf(" entries=%d", removed) + } +} + +func (s *Server) moduleKnown(module string) bool { + if module == "" { + return false + } + results, err := s.store.LookupModule(module) + return err == nil && len(results) > 0 +} + +func (s *Server) maybePrewarmStructFields(docURI, filePath, text string) { + if filePath == "" || !parser.IsElixirFile(filePath) || !s.isProjectFile(filePath) || s.isDepsFile(filePath) { + return + } + if !strings.Contains(text, "%") { + return + } + + s.backgroundWork.Add(1) + go func() { + defer s.backgroundWork.Done() + s.prewarmStructFieldsFromText(docURI, filePath, text) + }() +} + +func (s *Server) prewarmStructFieldsFromText(_ string, filePath, text string) { + tPrewarm := s.debugNow() + tf := NewTokenizedFile(text) + refs := tf.StructModuleRefs() + if len(refs) == 0 { + return + } + + seen := make(map[string]bool, len(refs)) + warmed := 0 + skippedUnknown := 0 + for _, ref := range refs { + aliases := tf.ExtractAliasesInScope(ref.Line) + s.mergeAliasesFromUseTokenized(tf, aliases) + moduleRef := tf.ResolveModuleExpr(ref.ModuleRef, ref.Line) + fullModule := s.resolveModuleWithNesting(moduleRef, aliases, filePath, ref.Line) + if fullModule == "" || seen[fullModule] { + continue + } + seen[fullModule] = true + if !s.moduleKnown(fullModule) { + skippedUnknown++ + continue + } + s.prewarmStructFields(filePath, fullModule) + warmed++ + } + + if s.debug && warmed > 0 { + s.debugf("StructFields prewarm refs=%d queued=%d skipped_unknown=%d (%s)", len(refs), warmed, skippedUnknown, time.Since(tPrewarm).Round(time.Microsecond)) + } +} + func lineRange(line int) protocol.Range { return protocol.Range{ Start: protocol.Position{Line: uint32(line), Character: 0}, @@ -1415,13 +1704,85 @@ func (s *Server) Completion(ctx context.Context, params *protocol.CompletionPara return nil, nil } + if isSpaceCompletionTrigger(params) && !lineCouldBeStructValueSpaceTrigger(lines[lineNum], col) { + return nil, nil + } + tf := s.docs.GetTokenizedFile(docURI) if tf == nil { tf = NewTokenizedFile(text) } + if structCtx, ok := tf.StructCompletionContextAtCursor(lineNum, col); ok { + tStruct := s.debugNow() + aliases := tf.ExtractAliasesInScope(lineNum) + s.mergeAliasesFromUseTokenized(tf, aliases) + moduleRef := tf.ResolveModuleExpr(structCtx.ModuleRef, lineNum) + fullModule := s.resolveModuleWithNesting(moduleRef, aliases, filePath, lineNum) + if s.debug { + s.debugf("Completion struct context") + s.debugf(" moduleRef=%s", structCtx.ModuleRef) + s.debugf(" resolved=%s", fullModule) + s.debugf(" prefix=%q", structCtx.FieldPrefix) + } + fields, ready := s.cachedStructFieldsOrWarm(filePath, fullModule) + if !ready || len(fields) == 0 { + if s.debug { + s.debugf("Completion struct fields unavailable") + s.debugf(" module=%s", fullModule) + s.debugf(" ready=%v", ready) + s.debugf(" fields=%d", len(fields)) + s.debugf(" elapsed=%s", time.Since(tStruct).Round(time.Microsecond)) + } + return nil, nil + } + + prefixRange := protocol.Range{ + Start: protocol.Position{Line: uint32(lineNum), Character: uint32(structCtx.StartCol)}, + End: protocol.Position{Line: uint32(lineNum), Character: uint32(col)}, + } + var items []protocol.CompletionItem + for _, field := range fields { + if !strings.HasPrefix(field, structCtx.FieldPrefix) { + continue + } + items = append(items, protocol.CompletionItem{ + Label: field, + Kind: protocol.CompletionItemKindField, + Detail: fullModule + " struct field", + TextEdit: &protocol.TextEdit{ + Range: prefixRange, + NewText: field + ": ", + }, + }) + } + if len(items) == 0 { + if s.debug { + s.debugf("Completion struct fields no matches") + s.debugf(" module=%s", fullModule) + s.debugf(" prefix=%q", structCtx.FieldPrefix) + s.debugf(" fields=%d", len(fields)) + s.debugf(" elapsed=%s", time.Since(tStruct).Round(time.Microsecond)) + } + return nil, nil + } + if s.debug { + s.debugf("Completion struct fields returned") + s.debugf(" module=%s", fullModule) + s.debugf(" prefix=%q", structCtx.FieldPrefix) + s.debugf(" items=%d", len(items)) + s.debugf(" cached_fields=%d", len(fields)) + s.debugf(" elapsed=%s", time.Since(tStruct).Round(time.Microsecond)) + } + return &protocol.CompletionList{ + IsIncomplete: false, + Items: items, + }, nil + } + completionCtx := tf.CompletionContextAtCursor(lineNum, col) prefix, afterDot, prefixStartCol := completionCtx.Prefix, completionCtx.AfterDot, completionCtx.StartCol + structValueContext := tf.StructValueContextAtCursor(lineNum, col) // Inside a multi-line alias block: complete child module segments under the parent. if aliasParent, inBlock := tf.ExtractAliasBlockParent(lineNum); inBlock { @@ -1462,7 +1823,7 @@ func (s *Server) Completion(ctx context.Context, params *protocol.CompletionPara }, nil } - if prefix == "" && !afterDot { + if prefix == "" && !afterDot && !structValueContext { return nil, nil } @@ -1642,9 +2003,32 @@ func (s *Server) Completion(ctx context.Context, params *protocol.CompletionPara for _, r := range results { addModuleItem(r.Module, "module") } - } else if funcPrefix != "" { + } else if funcPrefix != "" || structValueContext { seen := make(map[string]bool) + addVariableCompletion := func(varName string) { + if !isCompletableVariableName(varName) || !strings.HasPrefix(varName, funcPrefix) || seen[varName] { + return + } + seen[varName] = true + items = append(items, protocol.CompletionItem{ + Label: varName, + Kind: protocol.CompletionItemKindVariable, + Detail: "variable", + SortText: "000_" + varName, + }) + } + + // Variables are usually the intended target in bare value positions. + if tree, src, ok := s.docs.GetTree(docURI); ok { + for _, varName := range treesitter.FindVariablesInScopeWithTree(tree.RootNode(), src, uint(lineNum), uint(col)) { + addVariableCompletion(varName) + } + } + for _, varName := range tf.VariableNamesBeforeCursor(lineNum, col) { + addVariableCompletion(varName) + } + for _, bf := range tf.FindBufferFunctions() { key := funcKey(bf.Name, bf.Arity) if strings.HasPrefix(bf.Name, funcPrefix) && !seen[key] { @@ -1696,22 +2080,6 @@ func (s *Server) Completion(ctx context.Context, params *protocol.CompletionPara s.addCompletionsFromUsing(resolveModule(usedModule, aliases), funcPrefix, seen, &items, visitedCompletion, inPipe, s.snippetSupport) } - // Variables in scope via tree-sitter - var varsInScope []string - if tree, src, ok := s.docs.GetTree(docURI); ok { - varsInScope = treesitter.FindVariablesInScopeWithTree(tree.RootNode(), src, uint(lineNum), uint(col)) - } - for _, varName := range varsInScope { - if strings.HasPrefix(varName, funcPrefix) && !seen[varName] { - seen[varName] = true - items = append(items, protocol.CompletionItem{ - Label: varName, - Kind: protocol.CompletionItemKindVariable, - Detail: "variable", - }) - } - } - if s.snippetSupport { for name, snippet := range elixirFormSnippets { if strings.HasPrefix(name, funcPrefix) && !seen[name] { @@ -1738,6 +2106,63 @@ func (s *Server) Completion(ctx context.Context, params *protocol.CompletionPara }, nil } +func isSpaceCompletionTrigger(params *protocol.CompletionParams) bool { + return params.Context != nil && + params.Context.TriggerKind == protocol.CompletionTriggerKindTriggerCharacter && + params.Context.TriggerCharacter == " " +} + +func lineCouldBeStructValueSpaceTrigger(line string, col int) bool { + if col <= 0 { + return false + } + if col > len(line) { + col = len(line) + } + before := line[:col] + if len(before) == 0 || before[len(before)-1] != ' ' { + return false + } + + colonIdx := strings.LastIndexByte(before, ':') + if colonIdx < 0 || (colonIdx > 0 && before[colonIdx-1] == ':') { + return false + } + for i := colonIdx + 1; i < len(before); i++ { + if before[i] != ' ' && before[i] != '\t' { + return false + } + } + + braceIdx := strings.LastIndexByte(before[:colonIdx], '{') + if braceIdx < 0 { + return false + } + i := braceIdx - 1 + for i >= 0 && (before[i] == ' ' || before[i] == '\t') { + i-- + } + if i < 0 || !isStructModuleRefByte(before[i]) { + return false + } + for i >= 0 && isStructModuleRefByte(before[i]) { + i-- + } + return i >= 0 && before[i] == '%' +} + +func isStructModuleRefByte(b byte) bool { + return (b >= 'A' && b <= 'Z') || + (b >= 'a' && b <= 'z') || + (b >= '0' && b <= '9') || + b == '_' || + b == '.' +} + +func isCompletableVariableName(name string) bool { + return name != "" && !strings.HasPrefix(name, "_") +} + // cachedUsing returns the parsed __using__ body for the given module name. // The result is cached by module name; filePath is stored in the entry so // LookupModule is only called on the first access. The cache is invalidated @@ -2620,12 +3045,14 @@ func (s *Server) DidChangeWatchedFiles(ctx context.Context, params *protocol.Did return } + s.invalidateStructFieldCacheForFile(filePath, defs) if err := s.store.IndexFileWithRefs(filePath, defs, refs); err != nil { log.Printf("Error indexing %s: %v", filePath, err) } }(path) case protocol.FileChangeTypeDeleted: go func(filePath string) { + s.invalidateStructFieldCacheForFile(filePath, nil) if err := s.store.RemoveFile(filePath); err != nil { log.Printf("Error removing %s from index: %v", filePath, err) } diff --git a/internal/lsp/server_test.go b/internal/lsp/server_test.go index 48eced8..492e090 100644 --- a/internal/lsp/server_test.go +++ b/internal/lsp/server_test.go @@ -283,12 +283,18 @@ func referencesAt(t *testing.T, server *Server, uri string, line, col uint32) [] } func completionAt(t *testing.T, server *Server, uri string, line, col uint32) []protocol.CompletionItem { + t.Helper() + return completionAtWithContext(t, server, uri, line, col, nil) +} + +func completionAtWithContext(t *testing.T, server *Server, uri string, line, col uint32, completionCtx *protocol.CompletionContext) []protocol.CompletionItem { t.Helper() result, err := server.Completion(context.Background(), &protocol.CompletionParams{ TextDocumentPositionParams: protocol.TextDocumentPositionParams{ TextDocument: protocol.TextDocumentIdentifier{URI: protocol.DocumentURI(uri)}, Position: protocol.Position{Line: line, Character: col}, }, + Context: completionCtx, }) if err != nil { t.Fatal(err) @@ -308,6 +314,15 @@ func hasCompletionItem(items []protocol.CompletionItem, label string) bool { return false } +func hasString(values []string, value string) bool { + for _, existing := range values { + if existing == value { + return true + } + } + return false +} + func TestCompletion_FunctionAfterDot(t *testing.T) { server, cleanup := setupTestServer(t) defer cleanup() @@ -747,6 +762,214 @@ end`) } } +func TestCompletion_StructFieldsFromWarmCache(t *testing.T) { + server, cleanup := setupTestServer(t) + defer cleanup() + + path := filepath.Join(server.projectRoot, "lib", "controller.ex") + uriStr := string(uri.File(path)) + server.docs.Set(uriStr, `defmodule MyApp.Controller do + alias MyApp.Accounts.User + + def run do + %User{em + end +end`) + + key := structFieldCacheKey{ + buildRoot: server.structFieldBuildRoot(path), + module: "MyApp.Accounts.User", + } + server.structFieldMu.Lock() + server.structFieldCache[key] = &structFieldCacheEntry{ + fields: []string{"active?", "email", "name"}, + loaded: true, + } + server.structFieldMu.Unlock() + + items := completionAt(t, server, uriStr, 4, uint32(len(" %User{em"))) + if !hasCompletionItem(items, "email") { + t.Fatal("expected 'email' struct field completion") + } + if hasCompletionItem(items, "name") { + t.Fatal("did not expect non-matching 'name' completion") + } + + for _, item := range items { + if item.Label != "email" { + continue + } + if item.Kind != protocol.CompletionItemKindField { + t.Errorf("Kind = %v, want Field", item.Kind) + } + if item.TextEdit == nil || item.TextEdit.NewText != "email: " { + t.Fatalf("TextEdit = %#v, want email insertion", item.TextEdit) + } + } + + server.docs.Set(uriStr, `defmodule MyApp.Controller do + alias MyApp.Accounts.User + + def run do + %User{ + end +end`) + items = completionAt(t, server, uriStr, 4, uint32(len(" %User{"))) + for _, label := range []string{"active?", "email", "name"} { + if !hasCompletionItem(items, label) { + t.Fatalf("expected %q for empty struct field prefix", label) + } + } +} + +func TestStructFieldPrewarmFromDocument(t *testing.T) { + server, cleanup := setupTestServer(t) + defer cleanup() + server.mixBin = "" // keep the cache warm test from starting a real sidecar + + path := filepath.Join(server.projectRoot, "lib", "controller.ex") + docURI := string(uri.File(path)) + text := `defmodule MyApp.Controller do + alias MyApp.Accounts.User + + def run do + %User{} + %__MODULE__{} + %User{name: "duplicate"} + end +end` + indexFile(t, server.store, server.projectRoot, "lib/user.ex", `defmodule MyApp.Accounts.User do + defstruct [:name] +end`) + indexFile(t, server.store, server.projectRoot, "lib/controller.ex", text) + + server.prewarmStructFieldsFromText(docURI, path, text) + + userKey := structFieldCacheKey{ + buildRoot: server.structFieldBuildRoot(path), + module: "MyApp.Accounts.User", + } + moduleKey := structFieldCacheKey{ + buildRoot: server.structFieldBuildRoot(path), + module: "MyApp.Controller", + } + + server.structFieldMu.Lock() + _, hasUser := server.structFieldCache[userKey] + _, hasModule := server.structFieldCache[moduleKey] + cacheLen := len(server.structFieldCache) + server.structFieldMu.Unlock() + + if !hasUser { + t.Fatal("expected prewarm to enqueue aliased User struct fields") + } + if !hasModule { + t.Fatal("expected prewarm to enqueue __MODULE__ struct fields") + } + if cacheLen != 2 { + t.Fatalf("expected duplicate User structs to share one cache entry, got %d entries", cacheLen) + } +} + +func TestStructFieldCacheInvalidatedForReindexedModule(t *testing.T) { + server, cleanup := setupTestServer(t) + defer cleanup() + + indexFile(t, server.store, server.projectRoot, "lib/user.ex", `defmodule MyApp.User do + defstruct [:name] +end`) + + key := structFieldCacheKey{ + buildRoot: server.structFieldBuildRoot(filepath.Join(server.projectRoot, "lib", "user.ex")), + module: "MyApp.User", + } + server.structFieldMu.Lock() + server.structFieldCache[key] = &structFieldCacheEntry{ + fields: []string{"name"}, + loaded: true, + } + beforeGen := server.structFieldGen + server.structFieldMu.Unlock() + + path := filepath.Join(server.projectRoot, "lib", "user.ex") + defs, _, err := parser.ParseText(path, `defmodule MyApp.User do + defstruct [:name, :email] +end`) + if err != nil { + t.Fatal(err) + } + server.invalidateStructFieldCacheForFile(path, defs) + + server.structFieldMu.Lock() + _, stillCached := server.structFieldCache[key] + afterGen := server.structFieldGen + server.structFieldMu.Unlock() + + if stillCached { + t.Fatal("expected struct field cache entry to be invalidated") + } + if afterGen <= beforeGen { + t.Fatal("expected struct field generation to advance") + } +} + +func TestCompletion_StructValueShowsLocalVariables(t *testing.T) { + server, cleanup := setupTestServer(t) + defer cleanup() + + path := filepath.Join(server.projectRoot, "lib", "controller.ex") + uriStr := string(uri.File(path)) + server.docs.Set(uriStr, `defmodule MyApp.Controller do + def helper(value), do: value + + def run(status, actor) do + local_status = status + _ignored = "label only" + + %ContractDocument{status: + end +end`) + + key := structFieldCacheKey{ + buildRoot: server.structFieldBuildRoot(path), + module: "ContractDocument", + } + server.structFieldMu.Lock() + server.structFieldCache[key] = &structFieldCacheEntry{ + fields: []string{"contract_id"}, + loaded: true, + } + server.structFieldMu.Unlock() + + items := completionAtWithContext(t, server, uriStr, 7, uint32(len(" %ContractDocument{status: ")), &protocol.CompletionContext{ + TriggerKind: protocol.CompletionTriggerKindTriggerCharacter, + TriggerCharacter: " ", + }) + for _, label := range []string{"status", "actor", "local_status"} { + if !hasCompletionItem(items, label) { + t.Fatalf("expected local variable %q in struct value completions", label) + } + } + if len(items) == 0 || items[0].Kind != protocol.CompletionItemKindVariable { + t.Fatalf("expected variables to be prioritized first, got %#v", items) + } + if items[0].SortText == "" || !strings.HasPrefix(items[0].SortText, "000_") { + t.Fatalf("expected variable sort text to prioritize variables, got %#v", items[0]) + } + if hasCompletionItem(items, "_ignored") { + t.Fatal("did not expect underscore-prefixed variables in completions") + } + if !hasCompletionItem(items, "helper") { + t.Fatal("expected regular bare function completion in struct value completions") + } + if !hasCompletionItem(items, "if") { + t.Fatal("expected regular special-form completion in struct value completions") + } + if hasCompletionItem(items, "contract_id") { + t.Fatal("did not expect struct field key completion in struct value position") + } +} + func TestCompletion_AliasedModulePrefix(t *testing.T) { server, cleanup := setupTestServer(t) defer cleanup() @@ -1196,6 +1419,19 @@ func TestCompletion_NoResults(t *testing.T) { if len(items) != 0 { t.Errorf("expected no completions on whitespace, got %d", len(items)) } + + server.docs.Set(uri, `defmodule MyModule do + def run do + value = + end +end`) + items = completionAtWithContext(t, server, uri, 2, uint32(len(" value = ")), &protocol.CompletionContext{ + TriggerKind: protocol.CompletionTriggerKindTriggerCharacter, + TriggerCharacter: " ", + }) + if len(items) != 0 { + t.Errorf("expected no completions for ordinary space trigger, got %d", len(items)) + } } func TestCompletion_IgnoresStringsAndComments(t *testing.T) { @@ -4527,6 +4763,12 @@ func TestServer_Capabilities_DocumentSymbolAndWorkspaceSymbol(t *testing.T) { if caps.SignatureHelpProvider == nil { t.Error("SignatureHelpProvider should not be nil") } + if caps.CompletionProvider == nil { + t.Fatal("CompletionProvider should not be nil") + } + if !hasString(caps.CompletionProvider.TriggerCharacters, " ") { + t.Error("CompletionProvider should advertise space trigger for struct value completions") + } } // === DocumentHighlight === From fe0a9ba862764716418928677776df545d15a06e Mon Sep 17 00:00:00 2001 From: Jesse Herrick Date: Mon, 4 May 2026 10:26:52 -0600 Subject: [PATCH 2/3] Add struct dot completion --- internal/lsp/elixir.go | 266 +++++++++++++++++++++ internal/lsp/elixir_test.go | 446 ++++++++++++++++++++++++++++++++++++ internal/lsp/server.go | 55 +++++ 3 files changed, 767 insertions(+) diff --git a/internal/lsp/elixir.go b/internal/lsp/elixir.go index 9b1d954..40914f0 100644 --- a/internal/lsp/elixir.go +++ b/internal/lsp/elixir.go @@ -159,6 +159,88 @@ type CompletionContext struct { StartCol int } +// VariableFieldAccess describes a `variable.field_prefix` context at the cursor. +type VariableFieldAccess struct { + VariableName string + FieldPrefix string + StartCol int // column where the field prefix starts (for textEdit) +} + +// VariableFieldAccessAtCursor detects whether the cursor is in a `variable.` +// or `variable.field_prefix` position and returns the variable name and partial +// field name. Returns ok=false if the cursor is not in such a position. +func (tf *TokenizedFile) VariableFieldAccessAtCursor(line, col int) (VariableFieldAccess, bool) { + return VariableFieldAccessAtCursor(tf.tokens, tf.source, tf.lineStarts, line, col) +} + +func VariableFieldAccessAtCursor(tokens []parser.Token, source []byte, lineStarts []int, line, col int) (VariableFieldAccess, bool) { + if line < 0 || line >= len(lineStarts) || col <= 0 { + return VariableFieldAccess{}, false + } + + lineStart := lineStarts[line] + offset := parser.LineColToOffset(lineStarts, line, col) + if offset <= lineStart { + return VariableFieldAccess{}, false + } + + idx := parser.TokenAtOffset(tokens, offset-1) + if idx < 0 { + return VariableFieldAccess{}, false + } + + tok := tokens[idx] + + // Case 1: cursor right after dot — "variable.|" + if tok.Kind == parser.TokDot { + if idx < 1 { + return VariableFieldAccess{}, false + } + prev := tokens[idx-1] + if prev.Kind != parser.TokIdent { + return VariableFieldAccess{}, false + } + varName := parser.TokenText(source, prev) + if strings.HasPrefix(varName, "_") || parser.IsElixirKeyword(varName) { + return VariableFieldAccess{}, false + } + return VariableFieldAccess{ + VariableName: varName, + FieldPrefix: "", + StartCol: tok.End - lineStart, // right after the dot + }, true + } + + // Case 2: cursor on field prefix — "variable.fie|" + if tok.Kind == parser.TokIdent && idx >= 2 { + dotTok := tokens[idx-1] + if dotTok.Kind != parser.TokDot { + return VariableFieldAccess{}, false + } + varTok := tokens[idx-2] + if varTok.Kind != parser.TokIdent { + return VariableFieldAccess{}, false + } + varName := parser.TokenText(source, varTok) + if strings.HasPrefix(varName, "_") || parser.IsElixirKeyword(varName) { + return VariableFieldAccess{}, false + } + // The field prefix is the portion of the current token up to the cursor + fieldEnd := offset + if fieldEnd > tok.End { + fieldEnd = tok.End + } + fieldPrefix := string(source[tok.Start:fieldEnd]) + return VariableFieldAccess{ + VariableName: varName, + FieldPrefix: fieldPrefix, + StartCol: tok.Start - lineStart, + }, true + } + + return VariableFieldAccess{}, false +} + // Empty returns true if no completion should be offered at the cursor. func (c CompletionContext) Empty() bool { return c.Prefix == "" && !c.AfterDot @@ -233,6 +315,190 @@ func (tf *TokenizedFile) VariableNamesBeforeCursor(line, col int) []string { return VariableNamesBeforeCursor(tf.tokens, tf.source, tf.lineStarts, line, col) } +// VariableStructTypes returns a map of variable names to their struct module +// references for variables that are bound to struct literals via pattern matching +// or assignment before the given cursor position within the current function scope. +// The module references are unresolved (e.g. "User", "MyApp.User", "__MODULE__"). +func (tf *TokenizedFile) VariableStructTypes(line, col int) map[string]string { + return VariableStructTypes(tf.tokens, tf.source, tf.lineStarts, line, col) +} + +// VariableStructTypes scans from the enclosing function definition to the cursor +// position and identifies variables bound to struct types via patterns like: +// +// %User{} = user (match on left, var on right) +// user = %User{...} (var on left, struct on right) +// def foo(%User{} = user) (function head pattern) +// +// Returns a map of variable name -> module reference string. +func VariableStructTypes(tokens []parser.Token, source []byte, lineStarts []int, line, col int) map[string]string { + offset := parser.LineColToOffset(lineStarts, line, col) + if offset < 0 { + return nil + } + + // Find the enclosing function definition. + defIdx := -1 + for i := 0; i < len(tokens); i++ { + tok := tokens[i] + if tok.Kind == parser.TokEOF || tok.Start >= offset { + break + } + if isFunctionDefinitionToken(tok.Kind) { + defIdx = i + } + } + if defIdx < 0 { + return nil + } + + result := make(map[string]string) + + // Scan tokens from the function definition to the cursor. + // We look for two patterns: + // Pattern A: %Module{...} = var (struct on left, variable on right of =) + // Pattern B: var = %Module{...} (variable on left, struct on right of =) + for i := defIdx + 1; i < len(tokens); i++ { + tok := tokens[i] + if tok.Kind == parser.TokEOF || tok.Start >= offset { + break + } + + // Look for % which starts a struct literal + if tok.Kind == parser.TokPercent { + // Collect the module name after % + j := tokNextSig(tokens, len(tokens), i+1) + if j >= len(tokens) || tokens[j].Kind != parser.TokModule { + continue + } + moduleRef, k := tokCollectModuleName(source, tokens, len(tokens), j) + if moduleRef == "" { + continue + } + + // Find the matching close brace (or accept no brace for patterns like %User{} = var) + braceIdx := tokNextSig(tokens, len(tokens), k) + if braceIdx >= len(tokens) || tokens[braceIdx].Start >= offset { + continue + } + + // Must have an open brace to be a struct literal + if tokens[braceIdx].Kind != parser.TokOpenBrace { + continue + } + + // Skip past the struct body to find the closing brace + closeIdx := findMatchingCloseBrace(tokens, braceIdx) + if closeIdx < 0 { + continue + } + + // Pattern A: %Module{...} = var (or %Module{...} = var = ...) + // Look for = after the struct, then a variable + afterClose := tokNextSig(tokens, len(tokens), closeIdx+1) + if afterClose < len(tokens) && tokens[afterClose].Start < offset && + tokens[afterClose].Kind == parser.TokOther && tokenText(source, tokens[afterClose]) == "=" { + // Look for variable after = + varIdx := tokNextSig(tokens, len(tokens), afterClose+1) + if varIdx < len(tokens) && tokens[varIdx].Start < offset && tokens[varIdx].Kind == parser.TokIdent { + varName := parser.TokenText(source, tokens[varIdx]) + if !strings.HasPrefix(varName, "_") && !parser.IsElixirKeyword(varName) { + // Exclude function calls (ident followed by open paren) + nextAfterVar := tokNextSig(tokens, len(tokens), varIdx+1) + if nextAfterVar < len(tokens) && tokens[nextAfterVar].Kind == parser.TokOpenParen { + // This is a function call like get_user(), not a variable + } else { + result[varName] = moduleRef + } + } + } + } + + i = closeIdx + continue + } + + // Pattern B: var = %Module{...} or var \\ %Module{...} (default arg) + if tok.Kind == parser.TokIdent { + varName := parser.TokenText(source, tok) + if strings.HasPrefix(varName, "_") || parser.IsElixirKeyword(varName) { + continue + } + + // Check if next significant token is = or \\ + eqIdx := tokNextSig(tokens, len(tokens), i+1) + if eqIdx >= len(tokens) || tokens[eqIdx].Start >= offset { + continue + } + isEquals := tokens[eqIdx].Kind == parser.TokOther && tokenText(source, tokens[eqIdx]) == "=" + isDefault := tokens[eqIdx].Kind == parser.TokBackslash + if !isEquals && !isDefault { + continue + } + + // Check if next significant token after = is % + pctIdx := tokNextSig(tokens, len(tokens), eqIdx+1) + if pctIdx >= len(tokens) || tokens[pctIdx].Start >= offset { + continue + } + if tokens[pctIdx].Kind != parser.TokPercent { + continue + } + + // Collect module name + modIdx := tokNextSig(tokens, len(tokens), pctIdx+1) + if modIdx >= len(tokens) || tokens[modIdx].Kind != parser.TokModule { + continue + } + moduleRef, k := tokCollectModuleName(source, tokens, len(tokens), modIdx) + if moduleRef == "" { + continue + } + + // Verify there's an open brace (confirms it's a struct literal, not just %Module) + braceIdx := tokNextSig(tokens, len(tokens), k) + if braceIdx >= len(tokens) || tokens[braceIdx].Kind != parser.TokOpenBrace { + continue + } + + result[varName] = moduleRef + + // Skip past the struct body + closeIdx := findMatchingCloseBrace(tokens, braceIdx) + if closeIdx >= 0 { + i = closeIdx + } + } + } + + return result +} + +// findMatchingCloseBrace finds the matching } for the { at tokens[openIdx]. +// Returns -1 if not found. +func findMatchingCloseBrace(tokens []parser.Token, openIdx int) int { + depth := 1 + for i := openIdx + 1; i < len(tokens); i++ { + switch tokens[i].Kind { + case parser.TokOpenBrace: + depth++ + case parser.TokCloseBrace: + depth-- + if depth == 0 { + return i + } + case parser.TokEOF: + return -1 + } + } + return -1 +} + +// tokenText returns the source text for a token as a string. +func tokenText(source []byte, tok parser.Token) string { + return string(source[tok.Start:tok.End]) +} + func VariableNamesBeforeCursor(tokens []parser.Token, source []byte, lineStarts []int, line, col int) []string { offset := parser.LineColToOffset(lineStarts, line, col) if offset < 0 { diff --git a/internal/lsp/elixir_test.go b/internal/lsp/elixir_test.go index 8d573d0..c569632 100644 --- a/internal/lsp/elixir_test.go +++ b/internal/lsp/elixir_test.go @@ -669,6 +669,452 @@ func TestExpressionAtCursor_ExprBounds(t *testing.T) { } } +func TestVariableStructTypes(t *testing.T) { + tests := []struct { + name string + code string + line int + col int + want map[string]string + }{ + { + name: "direct struct assignment", + code: `def run do + user = %User{name: "A"} + user +end`, + line: 2, + col: len(" user"), + want: map[string]string{"user": "User"}, + }, + { + name: "pattern match left side", + code: `def run do + %User{} = get_user() +end`, + line: 1, + col: len(" %User{} = get_user()"), + want: map[string]string{}, + }, + { + name: "pattern match with named variable", + code: `def run do + %User{} = user = get_user() + user +end`, + line: 2, + col: len(" user"), + want: map[string]string{"user": "User"}, + }, + { + name: "match operator struct on left with var", + code: `def run do + %User{name: name} = user = get_user() + user +end`, + line: 2, + col: len(" user"), + want: map[string]string{"user": "User"}, + }, + { + name: "function head pattern match", + code: `def run(%User{} = user) do + user +end`, + line: 1, + col: len(" user"), + want: map[string]string{"user": "User"}, + }, + { + name: "function head with qualified module", + code: `def run(%MyApp.Accounts.User{} = user) do + user +end`, + line: 1, + col: len(" user"), + want: map[string]string{"user": "MyApp.Accounts.User"}, + }, + { + name: "multiple variables", + code: `def run do + user = %User{name: "A"} + org = %Organization{name: "B"} + user +end`, + line: 3, + col: len(" user"), + want: map[string]string{ + "user": "User", + "org": "Organization", + }, + }, + { + name: "variable after cursor excluded", + code: `def run do + user = %User{name: "A"} + cursor_here + org = %Organization{name: "B"} +end`, + line: 2, + col: len(" cursor_here"), + want: map[string]string{"user": "User"}, + }, + { + name: "struct literal on right side of assignment", + code: `def run do + user = %User{} + user +end`, + line: 2, + col: len(" user"), + want: map[string]string{"user": "User"}, + }, + { + name: "no struct patterns", + code: `def run do + x = 1 + y = "hello" + y +end`, + line: 3, + col: len(" y"), + want: map[string]string{}, + }, + { + name: "case clause pattern", + code: `def run(thing) do + case thing do + %User{} = user -> user + end +end`, + line: 2, + col: len(" %User{} = user -> user"), + want: map[string]string{"user": "User"}, + }, + { + name: "with clause pattern", + code: `def run do + with %User{} = user <- fetch_user() do + user + end +end`, + line: 2, + col: len(" user"), + want: map[string]string{"user": "User"}, + }, + { + name: "var = struct on right side", + code: `def run do + result = %Result{ok: true} + result +end`, + line: 2, + col: len(" result"), + want: map[string]string{"result": "Result"}, + }, + { + name: "reassignment overrides type", + code: `def run do + thing = %User{name: "A"} + thing = %Organization{name: "B"} + thing +end`, + line: 3, + col: len(" thing"), + want: map[string]string{"thing": "Organization"}, + }, + { + name: "__MODULE__ struct", + code: `def run do + self = %__MODULE__{} + self +end`, + line: 2, + col: len(" self"), + want: map[string]string{"self": "__MODULE__"}, + }, + { + name: "scoped to current function", + code: `def first do + user = %User{} +end + +def second do + user +end`, + line: 5, + col: len(" user"), + want: map[string]string{}, + }, + { + name: "defp function head", + code: `defp handle_user(%User{} = user) do + user +end`, + line: 1, + col: len(" user"), + want: map[string]string{"user": "User"}, + }, + { + name: "reverse pattern: var = %Struct{}", + code: `def run do + user = %User{name: "test"} + user +end`, + line: 2, + col: len(" user"), + want: map[string]string{"user": "User"}, + }, + { + name: "for comprehension pattern", + code: `def run(users) do + for %User{} = user <- users do + user + end +end`, + line: 2, + col: len(" user"), + want: map[string]string{"user": "User"}, + }, + { + name: "nested struct in value position does not leak", + code: `def run do + user = %User{address: %Address{city: "NY"}} + user +end`, + line: 2, + col: len(" user"), + want: map[string]string{"user": "User"}, + }, + { + name: "pipeline with struct does not bind", + code: `def run do + result = thing |> Map.merge(%User{}) + result +end`, + line: 2, + col: len(" result"), + want: map[string]string{}, + }, + { + name: "struct inside function call args does not bind", + code: `def run do + Repo.insert(%User{name: "test"}) + x +end`, + line: 2, + col: len(" x"), + want: map[string]string{}, + }, + { + name: "struct in list pattern", + code: `def run do + [%User{} = user] = list + user +end`, + line: 2, + col: len(" user"), + want: map[string]string{"user": "User"}, + }, + { + name: "struct in tuple pattern", + code: `def run do + {:ok, %User{} = user} = fetch() + user +end`, + line: 2, + col: len(" user"), + want: map[string]string{"user": "User"}, + }, + { + name: "plain map not treated as struct", + code: `def run do + map = %{name: "test"} + map +end`, + line: 2, + col: len(" map"), + want: map[string]string{}, + }, + { + name: "pinned variable not bound", + code: `def run(existing) do + %User{} = ^existing + existing +end`, + line: 2, + col: len(" existing"), + want: map[string]string{}, + }, + { + name: "struct as keyword arg value does not bind outer var", + code: `def run do + result = func(key: %User{}) + result +end`, + line: 2, + col: len(" result"), + want: map[string]string{}, + }, + { + name: "deeply nested struct in value does not bind extra vars", + code: `def run do + org = %Org{owner: %User{address: %Address{city: "NY"}}} + org +end`, + line: 2, + col: len(" org"), + want: map[string]string{"org": "Org"}, + }, + { + name: "default argument with struct", + code: `def changeset(leave_type \\ %__MODULE__{}, attrs) do + leave_type +end`, + line: 1, + col: len(" leave_type"), + want: map[string]string{"leave_type": "__MODULE__"}, + }, + { + name: "default argument with named struct", + code: `def new(user \\ %User{}) do + user +end`, + line: 1, + col: len(" user"), + want: map[string]string{"user": "User"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tokens, source, lineStarts := tokenize(tt.code) + got := VariableStructTypes(tokens, source, lineStarts, tt.line, tt.col) + if len(got) != len(tt.want) { + t.Fatalf("got %d entries %v, want %d entries %v", len(got), got, len(tt.want), tt.want) + } + for varName, wantModule := range tt.want { + if gotModule, ok := got[varName]; !ok { + t.Errorf("missing variable %q (want module %q)", varName, wantModule) + } else if gotModule != wantModule { + t.Errorf("variable %q: got module %q, want %q", varName, gotModule, wantModule) + } + } + }) + } +} + +func TestVariableFieldAccessAtCursor(t *testing.T) { + tests := []struct { + name string + code string + line int + col int + wantOK bool + wantVar string + wantPfx string + wantStart int + }{ + { + name: "variable dot", + code: " user.", + line: 0, + col: len(" user."), + wantOK: true, + wantVar: "user", + wantPfx: "", + wantStart: len(" user."), + }, + { + name: "variable dot with field prefix", + code: " user.na", + line: 0, + col: len(" user.na"), + wantOK: true, + wantVar: "user", + wantPfx: "na", + wantStart: len(" user."), + }, + { + name: "variable dot partial field", + code: " user.email_addr", + line: 0, + col: len(" user.email"), + wantOK: true, + wantVar: "user", + wantPfx: "email", + wantStart: len(" user."), + }, + { + name: "module dot not variable", + code: " Enum.", + line: 0, + col: len(" Enum."), + wantOK: false, + }, + { + name: "module dot function not variable", + code: " Enum.ma", + line: 0, + col: len(" Enum.ma"), + wantOK: false, + }, + { + name: "underscore variable rejected", + code: " _user.", + line: 0, + col: len(" _user."), + wantOK: false, + }, + { + name: "keyword rejected", + code: " do.", + line: 0, + col: len(" do."), + wantOK: false, + }, + { + name: "no dot", + code: " user", + line: 0, + col: len(" user"), + wantOK: false, + }, + { + name: "multiline", + code: "def run do\n user.na", + line: 1, + col: len(" user.na"), + wantOK: true, + wantVar: "user", + wantPfx: "na", + wantStart: len(" user."), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tokens, source, lineStarts := tokenize(tt.code) + access, ok := VariableFieldAccessAtCursor(tokens, source, lineStarts, tt.line, tt.col) + if ok != tt.wantOK { + t.Fatalf("ok = %v, want %v", ok, tt.wantOK) + } + if !ok { + return + } + if access.VariableName != tt.wantVar { + t.Errorf("VariableName = %q, want %q", access.VariableName, tt.wantVar) + } + if access.FieldPrefix != tt.wantPfx { + t.Errorf("FieldPrefix = %q, want %q", access.FieldPrefix, tt.wantPfx) + } + if access.StartCol != tt.wantStart { + t.Errorf("StartCol = %d, want %d", access.StartCol, tt.wantStart) + } + }) + } +} + func TestCursorContext_Expr(t *testing.T) { tests := []struct { mod, fn, want string diff --git a/internal/lsp/server.go b/internal/lsp/server.go index a5aabee..edf8389 100644 --- a/internal/lsp/server.go +++ b/internal/lsp/server.go @@ -1780,6 +1780,60 @@ func (s *Server) Completion(ctx context.Context, params *protocol.CompletionPara }, nil } + // "variable." or "variable.field_prefix" — struct field access on a typed variable. + if fieldAccess, ok := tf.VariableFieldAccessAtCursor(lineNum, col); ok { + varStructTypes := tf.VariableStructTypes(lineNum, col) + if structModule, ok := varStructTypes[fieldAccess.VariableName]; ok { + tStruct := s.debugNow() + aliases := tf.ExtractAliasesInScope(lineNum) + s.mergeAliasesFromUseTokenized(tf, aliases) + resolvedModule := tf.ResolveModuleExpr(structModule, lineNum) + fullModule := s.resolveModuleWithNesting(resolvedModule, aliases, filePath, lineNum) + if s.debug { + s.debugf("Completion variable struct type") + s.debugf(" variable=%s", fieldAccess.VariableName) + s.debugf(" structModule=%s", structModule) + s.debugf(" resolved=%s", fullModule) + s.debugf(" fieldPrefix=%q", fieldAccess.FieldPrefix) + } + fields, ready := s.cachedStructFieldsOrWarm(filePath, fullModule) + if ready && len(fields) > 0 { + fieldPrefixRange := protocol.Range{ + Start: protocol.Position{Line: uint32(lineNum), Character: uint32(fieldAccess.StartCol)}, + End: protocol.Position{Line: uint32(lineNum), Character: uint32(col)}, + } + var items []protocol.CompletionItem + for _, field := range fields { + if !strings.HasPrefix(field, fieldAccess.FieldPrefix) { + continue + } + items = append(items, protocol.CompletionItem{ + Label: field, + Kind: protocol.CompletionItemKindField, + Detail: fullModule + " struct field", + TextEdit: &protocol.TextEdit{ + Range: fieldPrefixRange, + NewText: field, + }, + }) + } + if len(items) > 0 { + if s.debug { + s.debugf("Completion variable struct fields returned") + s.debugf(" variable=%s", fieldAccess.VariableName) + s.debugf(" module=%s", fullModule) + s.debugf(" items=%d", len(items)) + s.debugf(" elapsed=%s", time.Since(tStruct).Round(time.Microsecond)) + } + return &protocol.CompletionList{ + IsIncomplete: false, + Items: items, + }, nil + } + } + } + } + completionCtx := tf.CompletionContextAtCursor(lineNum, col) prefix, afterDot, prefixStartCol := completionCtx.Prefix, completionCtx.AfterDot, completionCtx.StartCol structValueContext := tf.StructValueContextAtCursor(lineNum, col) @@ -1900,6 +1954,7 @@ func (s *Server) Completion(ctx context.Context, params *protocol.CompletionPara // "Module.func." or "variable." — dot after a function call result or // map/struct field access. We have no type info to complete the result. + // (Variable struct types were already handled by VariableFieldAccessAtCursor above.) if afterDot && (funcPrefix != "" || moduleRef == "") { return nil, nil } From f18cc20c19bf957628b1076f203c5299ea5c7235 Mon Sep 17 00:00:00 2001 From: Jesse Herrick Date: Mon, 11 May 2026 17:26:12 -0400 Subject: [PATCH 3/3] Infer struct types from @spec and ExCk return types Adds two more sources of struct-type inference for variable dot-completion: - @spec parameter typespecs: parameters annotated `t()` or `Module.t()` resolve to that struct (pattern matches still take precedence). - ExCk return-type lookup: `var = Mod.func(...)` resolves `var.` against the struct returned by the compiled function, queried over a new `return_type_struct` op on the persistent BEAM process. --- internal/lsp/beam_server.exs | 113 +++++++ internal/lsp/elixir.go | 533 +++++++++++++++++++++++++++++++++ internal/lsp/elixir_test.go | 310 +++++++++++++++++++ internal/lsp/formatter.go | 51 +++- internal/lsp/formatter_test.go | 143 +++++++++ internal/lsp/server.go | 133 ++++++-- internal/lsp/server_test.go | 188 ++++++++++++ 7 files changed, 1444 insertions(+), 27 deletions(-) diff --git a/internal/lsp/beam_server.exs b/internal/lsp/beam_server.exs index d0cb50e..6db88f2 100644 --- a/internal/lsp/beam_server.exs +++ b/internal/lsp/beam_server.exs @@ -47,6 +47,12 @@ # CodeIntel op 5 (struct_fields) payload: # 2-byte Elixir module length (big-endian) + module # +# CodeIntel op 6 (return_type_struct) payload: +# 2-byte module length (big-endian) + module + +# 2-byte function length (big-endian) + function + +# 1-byte arity +# Response: 2-byte struct module name length + name (empty string = not a struct) +# # Notification 0 (otp_modules_ready) payload: # 2-byte module_count (big-endian) + [name_len(u16) name] # @@ -454,6 +460,7 @@ defmodule Dexter.CodeIntel do @op_erlang_exports 3 @op_runtime_info 4 @op_struct_fields 5 + @op_return_type_struct 6 def handle_request(op, payload) do case op do @@ -463,6 +470,7 @@ defmodule Dexter.CodeIntel do @op_erlang_exports -> handle_erlang_exports(payload) @op_runtime_info -> handle_runtime_info(payload) @op_struct_fields -> handle_struct_fields(payload) + @op_return_type_struct -> handle_return_type_struct(payload) _ -> {1, "unknown code intel op: #{inspect(op)}"} end end @@ -859,6 +867,111 @@ defmodule Dexter.CodeIntel do _ -> :error end + defp handle_return_type_struct(payload) do + case parse_module_function_arity(payload) do + {:ok, module_name, function_name, arity} -> + case fetch_return_type_struct(module_name, function_name, arity) do + {:ok, struct_module} -> + {0, <>} + + :none -> + {0, <<0::unsigned-big-16>>} + + :error -> + {1, "return type struct lookup failed"} + end + + :error -> + {1, "invalid return_type_struct payload"} + end + end + + defp fetch_return_type_struct(module_name, function_name, arity) do + with {:ok, module_atom} <- elixir_module_atom(module_name), + {:module, ^module_atom} <- Code.ensure_loaded(module_atom), + {:ok, exports} <- fetch_exck_exports(module_atom), + {:ok, sig} <- find_export_sig(exports, function_name, arity), + {:ok, struct_module} <- extract_struct_from_sig(sig) do + {:ok, inspect(struct_module)} + else + :none -> :none + _ -> :none + end + rescue + _ -> :error + end + + defp fetch_exck_exports(module_atom) do + case :code.get_object_code(module_atom) do + {^module_atom, binary, _filename} -> + case :beam_lib.chunks(binary, [~c"ExCk"]) do + {:ok, {^module_atom, [{~c"ExCk", chunk}]}} -> + {_version, contents} = :erlang.binary_to_term(chunk) + + case contents do + %{exports: exports} when is_list(exports) -> {:ok, exports} + _ -> :error + end + + _ -> + :error + end + + :error -> + :error + end + end + + defp find_export_sig(exports, function_name, arity) do + function_atom = String.to_atom(function_name) + + case List.keyfind(exports, {function_atom, arity}, 0) do + {_, %{sig: {:infer, _domain, clauses}}} when is_list(clauses) -> + {:ok, clauses} + + _ -> + :none + end + end + + defp extract_struct_from_sig(clauses) do + struct_modules = + clauses + |> Enum.map(fn {_args, return_type} -> extract_struct_from_type(return_type) end) + |> Enum.reject(&is_nil/1) + |> Enum.uniq() + + case struct_modules do + [single_module] -> {:ok, single_module} + _ -> :none + end + end + + defp extract_struct_from_type(type) when is_map(type) do + # Check the direct map descriptor for a struct + extract_struct_from_map_desc(type[:map]) || + # Check inside dynamic wrapper + extract_struct_from_dynamic(type[:dynamic]) + end + + defp extract_struct_from_type(_), do: nil + + defp extract_struct_from_dynamic(%{map: map_desc}), do: extract_struct_from_map_desc(map_desc) + defp extract_struct_from_dynamic(:term), do: nil + defp extract_struct_from_dynamic(_), do: nil + + defp extract_struct_from_map_desc({:closed, fields}) when is_map(fields) do + case fields do + %{__struct__: %{atom: {:union, union_map}}} when map_size(union_map) == 1 -> + union_map |> Map.keys() |> hd() + + _ -> + nil + end + end + + defp extract_struct_from_map_desc(_), do: nil + defp elixir_module_atom(module_name) do if Regex.match?(~r/\A(?:Elixir\.)?[A-Z][A-Za-z0-9_]*(?:\.[A-Z][A-Za-z0-9_]*)*\z/, module_name) do module_atom = diff --git a/internal/lsp/elixir.go b/internal/lsp/elixir.go index 40914f0..93ab29f 100644 --- a/internal/lsp/elixir.go +++ b/internal/lsp/elixir.go @@ -354,6 +354,19 @@ func VariableStructTypes(tokens []parser.Token, source []byte, lineStarts []int, result := make(map[string]string) + // Typespec inference: look backward from the def for a preceding @spec. + // Parse parameter types and match positionally to function param names. + specTypes := parseSpecParamTypes(tokens, source, defIdx) + paramNames := parseFunctionParamNames(tokens, source, defIdx) + if len(specTypes) == len(paramNames) && len(specTypes) > 0 { + for i, specType := range specTypes { + if specType != "" { + // Only set if not already overridden by pattern match (added later) + result[paramNames[i]] = specType + } + } + } + // Scan tokens from the function definition to the cursor. // We look for two patterns: // Pattern A: %Module{...} = var (struct on left, variable on right of =) @@ -474,6 +487,526 @@ func VariableStructTypes(tokens []parser.Token, source []byte, lineStarts []int, return result } +// knownNonStructTypes lists modules whose .t() type does not represent a struct. +var knownNonStructTypes = map[string]bool{ + "String": true, + "Integer": true, + "Float": true, + "Atom": true, + "BitString": true, + "Reference": true, + "Port": true, + "PID": true, + "Exception": true, + "Macro": true, + "Macro.Env": true, +} + +// parseSpecParamTypes looks backward from defIdx for a preceding @spec and +// extracts the struct-like types from it. Returns a slice where each element +// is the inferred module for that parameter position, or "" if not a struct type. +// +// Recognized patterns: +// - t() → "__MODULE__" +// - Module.t() → "Module" (unless it's a known non-struct type) +// - anything else → "" +func parseSpecParamTypes(tokens []parser.Token, source []byte, defIdx int) []string { + // Walk backward from defIdx to find the nearest preceding @spec. + // Skip all tokens until we find @spec, another def, or a module-level boundary. + specIdx := -1 + for i := defIdx - 1; i >= 0; i-- { + tok := tokens[i] + if tok.Kind == parser.TokAttrSpec { + specIdx = i + break + } + if isFunctionDefinitionToken(tok.Kind) || tok.Kind == parser.TokDefmodule || tok.Kind == parser.TokEnd { + break + } + } + if specIdx < 0 { + return nil + } + + // After @spec, we expect: func_name ( param_types ) :: return_type + // Find the open paren of the spec + j := tokNextSig(tokens, len(tokens), specIdx+1) + if j >= len(tokens) || tokens[j].Kind != parser.TokIdent { + return nil + } + + parenIdx := tokNextSig(tokens, len(tokens), j+1) + if parenIdx >= len(tokens) || tokens[parenIdx].Kind != parser.TokOpenParen { + return nil + } + + // Find the matching close paren + closeIdx := findMatchingCloseParen(tokens, parenIdx) + if closeIdx < 0 { + return nil + } + + // Check if there are no params (empty parens) + first := tokNextSig(tokens, len(tokens), parenIdx+1) + if first == closeIdx { + return nil + } + + // Parse each parameter type (comma-separated at depth 0) + var paramTypes []string + depth := 0 + typeStart := parenIdx + 1 + + for i := parenIdx + 1; i <= closeIdx; i++ { + tok := tokens[i] + switch tok.Kind { + case parser.TokOpenParen, parser.TokOpenBracket, parser.TokOpenBrace: + depth++ + case parser.TokCloseParen: + if depth == 0 { + // End of params — process last type + paramTypes = append(paramTypes, classifySpecType(tokens, source, typeStart, i)) + } else { + depth-- + } + case parser.TokCloseBracket, parser.TokCloseBrace: + depth-- + case parser.TokComma: + if depth == 0 { + paramTypes = append(paramTypes, classifySpecType(tokens, source, typeStart, i)) + typeStart = i + 1 + } + } + } + + return paramTypes +} + +// classifySpecType examines the tokens from start (inclusive) to end (exclusive) +// and determines if it represents a struct type. +// +// Returns "__MODULE__" for bare t(), the module name for Module.t(), +// or "" for anything else. +func classifySpecType(tokens []parser.Token, source []byte, start, end int) string { + // Collect significant tokens in this range + var sigTokens []int + for i := start; i < end; i++ { + if tokens[i].Kind != parser.TokEOL && tokens[i].Kind != parser.TokComment { + sigTokens = append(sigTokens, i) + } + } + + if len(sigTokens) == 0 { + return "" + } + + // Pattern: t() — just TokIdent("t"), TokOpenParen, TokCloseParen + if len(sigTokens) == 3 { + if tokens[sigTokens[0]].Kind == parser.TokIdent && + parser.TokenText(source, tokens[sigTokens[0]]) == "t" && + tokens[sigTokens[1]].Kind == parser.TokOpenParen && + tokens[sigTokens[2]].Kind == parser.TokCloseParen { + return "__MODULE__" + } + } + + // Pattern: Module.t() or Module.Sub.t() + // Tokens: Module, Dot, ... , Dot, t, (, ) + // The last 4 tokens should be: Dot, Ident("t"), OpenParen, CloseParen + if len(sigTokens) >= 5 { + last := len(sigTokens) - 1 + if tokens[sigTokens[last]].Kind == parser.TokCloseParen && + tokens[sigTokens[last-1]].Kind == parser.TokOpenParen && + tokens[sigTokens[last-2]].Kind == parser.TokIdent && + parser.TokenText(source, tokens[sigTokens[last-2]]) == "t" && + tokens[sigTokens[last-3]].Kind == parser.TokDot && + tokens[sigTokens[0]].Kind == parser.TokModule { + + // Collect the module name from the leading tokens (everything before the last .t()) + moduleRef, _ := tokCollectModuleName(source, tokens, len(tokens), sigTokens[0]) + if moduleRef != "" && !knownNonStructTypes[moduleRef] { + return moduleRef + } + } + } + + return "" +} + +// parseFunctionParamNames extracts parameter names from a function definition +// head starting at defIdx. Returns a slice of parameter names in order. +func parseFunctionParamNames(tokens []parser.Token, source []byte, defIdx int) []string { + // After def/defp, expect: func_name ( params ) + funcIdx := tokNextSig(tokens, len(tokens), defIdx+1) + if funcIdx >= len(tokens) || tokens[funcIdx].Kind != parser.TokIdent { + return nil + } + + parenIdx := tokNextSig(tokens, len(tokens), funcIdx+1) + if parenIdx >= len(tokens) || tokens[parenIdx].Kind != parser.TokOpenParen { + return nil + } + + closeIdx := findMatchingCloseParen(tokens, parenIdx) + if closeIdx < 0 { + return nil + } + + // Check for empty parens + first := tokNextSig(tokens, len(tokens), parenIdx+1) + if first == closeIdx { + return nil + } + + // Parse comma-separated params at depth 0. + // For each param, find the "root" identifier — the actual param name. + // This handles patterns like: %User{} = user, user \\ default, user + var names []string + depth := 0 + paramStart := parenIdx + 1 + + for i := parenIdx + 1; i <= closeIdx; i++ { + tok := tokens[i] + switch tok.Kind { + case parser.TokOpenParen, parser.TokOpenBracket, parser.TokOpenBrace: + depth++ + case parser.TokCloseParen: + if depth == 0 { + names = append(names, extractParamName(tokens, source, paramStart, i)) + } else { + depth-- + } + case parser.TokCloseBracket, parser.TokCloseBrace: + depth-- + case parser.TokComma: + if depth == 0 { + names = append(names, extractParamName(tokens, source, paramStart, i)) + paramStart = i + 1 + } + } + } + + return names +} + +// extractParamName finds the parameter name from a function head parameter +// expression. Handles patterns like: +// - user → "user" +// - %User{} = user → "user" +// - user \\ %User{} → "user" +// - %User{name: name} = user → "user" +func extractParamName(tokens []parser.Token, source []byte, start, end int) string { + // Strategy: find identifiers at depth 0 that aren't keywords and aren't + // preceded by a dot. Prefer the one after = if present, otherwise the first one. + var firstIdent string + var afterEquals string + sawEquals := false + depth := 0 + + for i := start; i < end; i++ { + tok := tokens[i] + switch tok.Kind { + case parser.TokOpenParen, parser.TokOpenBracket, parser.TokOpenBrace: + depth++ + case parser.TokCloseParen, parser.TokCloseBracket, parser.TokCloseBrace: + depth-- + case parser.TokOther: + if depth == 0 && tokenText(source, tok) == "=" { + sawEquals = true + } + case parser.TokBackslash: + // default arg — the param name is whatever we already found + if firstIdent != "" { + return firstIdent + } + case parser.TokIdent: + if depth != 0 { + continue + } + name := parser.TokenText(source, tok) + if strings.HasPrefix(name, "_") || parser.IsElixirKeyword(name) { + continue + } + // Skip if preceded by dot (struct field access / module function) + prev := prevSignificantToken(tokens, i) + if prev >= 0 && tokens[prev].Kind == parser.TokDot { + continue + } + // Skip if followed by open paren (function call) + next := tokNextSig(tokens, len(tokens), i+1) + if next < end && tokens[next].Kind == parser.TokOpenParen { + continue + } + if firstIdent == "" { + firstIdent = name + } + if sawEquals { + afterEquals = name + } + } + } + + if afterEquals != "" { + return afterEquals + } + return firstIdent +} + +// VariableFunctionCall describes a variable assigned from a module function call, +// e.g. `user = Accounts.get_user(id)`. +type VariableFunctionCall struct { + VarName string + Module string // unresolved, e.g. "Accounts" + Function string + Arity int + Line int // 0-based line of the assignment +} + +// VariableFunctionCalls scans from the enclosing function definition to the cursor +// and finds variables assigned from module function calls: `var = Module.func(...)`. +// Only detects simple top-level assignments, not nested or piped expressions. +func (tf *TokenizedFile) VariableFunctionCalls(line, col int) []VariableFunctionCall { + return VariableFunctionCalls(tf.tokens, tf.source, tf.lineStarts, line, col) +} + +func VariableFunctionCalls(tokens []parser.Token, source []byte, lineStarts []int, line, col int) []VariableFunctionCall { + offset := parser.LineColToOffset(lineStarts, line, col) + if offset < 0 { + return nil + } + + defIdx := -1 + for i := 0; i < len(tokens); i++ { + tok := tokens[i] + if tok.Kind == parser.TokEOF || tok.Start >= offset { + break + } + if isFunctionDefinitionToken(tok.Kind) { + defIdx = i + } + } + if defIdx < 0 { + return nil + } + + var results []VariableFunctionCall + seen := make(map[string]int) // varName -> index in results (last wins) + + for i := defIdx + 1; i < len(tokens); i++ { + tok := tokens[i] + if tok.Kind == parser.TokEOF || tok.Start >= offset { + break + } + + // Look for: ident = Module.func(...) + if tok.Kind != parser.TokIdent { + continue + } + + varName := parser.TokenText(source, tok) + if strings.HasPrefix(varName, "_") || parser.IsElixirKeyword(varName) { + continue + } + + // Next must be = + eqIdx := tokNextSig(tokens, len(tokens), i+1) + if eqIdx >= len(tokens) || tokens[eqIdx].Start >= offset { + continue + } + if tokens[eqIdx].Kind != parser.TokOther || tokenText(source, tokens[eqIdx]) != "=" { + continue + } + + // Next must be Module (uppercase) + modIdx := tokNextSig(tokens, len(tokens), eqIdx+1) + if modIdx >= len(tokens) || tokens[modIdx].Start >= offset { + continue + } + if tokens[modIdx].Kind != parser.TokModule { + continue + } + + // Collect the full module name (e.g. "MyApp.Accounts") + moduleRef, afterMod := tokCollectModuleName(source, tokens, len(tokens), modIdx) + if moduleRef == "" { + continue + } + + // Next must be . then function name + dotIdx := tokNextSig(tokens, len(tokens), afterMod) + if dotIdx >= len(tokens) || tokens[dotIdx].Kind != parser.TokDot { + continue + } + + funcIdx := tokNextSig(tokens, len(tokens), dotIdx+1) + if funcIdx >= len(tokens) || tokens[funcIdx].Start >= offset { + continue + } + if tokens[funcIdx].Kind != parser.TokIdent { + continue + } + funcName := parser.TokenText(source, tokens[funcIdx]) + + // Check if next token is ( for parenthesized call, or an argument for no-paren call + nextIdx := tokNextSig(tokens, len(tokens), funcIdx+1) + var arity int + var skipTo int + + if nextIdx < len(tokens) && tokens[nextIdx].Kind == parser.TokOpenParen { + // Parenthesized call: count args inside parens + arity = countCallArity(tokens, source, nextIdx) + closeIdx := findMatchingCloseParen(tokens, nextIdx) + if closeIdx >= 0 { + skipTo = closeIdx + } else { + skipTo = nextIdx + } + } else if nextIdx < len(tokens) && isCallArgStartToken(tokens[nextIdx].Kind) { + // No-paren call: count args until newline or closing delimiter + arity, skipTo = countNoParenCallArity(tokens, nextIdx) + } else { + continue + } + + call := VariableFunctionCall{ + VarName: varName, + Module: moduleRef, + Function: funcName, + Arity: arity, + Line: tok.Line - 1, + } + + if idx, ok := seen[varName]; ok { + results[idx] = call + } else { + seen[varName] = len(results) + results = append(results, call) + } + + i = skipTo + } + + return results +} + +// isCallArgStartToken returns true if the token kind can start a function argument +// in a no-paren call. This excludes operators, closing delimiters, and newlines. +func isCallArgStartToken(k parser.TokenKind) bool { + switch k { + case parser.TokIdent, parser.TokModule, parser.TokAtom, parser.TokNumber, + parser.TokString, parser.TokHeredoc, parser.TokSigil, parser.TokCharLiteral, + parser.TokOpenParen, parser.TokOpenBracket, parser.TokOpenBrace, parser.TokOpenAngle, + parser.TokPercent, parser.TokAttr, + parser.TokFn: + return true + default: + return false + } +} + +// countNoParenCallArity counts arguments in a no-paren function call starting +// at firstArg. Arguments end at a newline, closing delimiter, or certain keywords. +// Returns the arity and the index of the last token in the call. +func countNoParenCallArity(tokens []parser.Token, firstArg int) (int, int) { + depth := 0 + commas := 0 + lastIdx := firstArg + + for i := firstArg; i < len(tokens); i++ { + tok := tokens[i] + switch tok.Kind { + case parser.TokOpenParen, parser.TokOpenBracket, parser.TokOpenBrace: + depth++ + lastIdx = i + case parser.TokCloseParen, parser.TokCloseBracket, parser.TokCloseBrace: + if depth == 0 { + // Hit an outer closing delimiter — end of call + return commas + 1, lastIdx + } + depth-- + lastIdx = i + case parser.TokComma: + if depth == 0 { + commas++ + } + lastIdx = i + case parser.TokEOL: + if depth == 0 { + return commas + 1, lastIdx + } + case parser.TokEOF: + return commas + 1, lastIdx + case parser.TokDo: + if depth == 0 { + return commas + 1, lastIdx + } + default: + lastIdx = i + } + } + return commas + 1, lastIdx +} + +// countCallArity counts the number of arguments in a function call starting +// at the open paren. Returns 0 for empty parens, 1+ for calls with arguments. +func countCallArity(tokens []parser.Token, source []byte, openParen int) int { + depth := 1 + hasContent := false + commas := 0 + for i := openParen + 1; i < len(tokens); i++ { + switch tokens[i].Kind { + case parser.TokOpenParen, parser.TokOpenBracket, parser.TokOpenBrace: + depth++ + hasContent = true + case parser.TokCloseParen: + depth-- + if depth == 0 { + if hasContent { + return commas + 1 + } + return 0 + } + case parser.TokCloseBracket, parser.TokCloseBrace: + depth-- + case parser.TokComma: + if depth == 1 { + commas++ + } + hasContent = true + case parser.TokEOF: + return 0 + default: + if !isWhitespaceToken(tokens[i].Kind) { + hasContent = true + } + } + } + return 0 +} + +func isWhitespaceToken(k parser.TokenKind) bool { + return k == parser.TokEOL || k == parser.TokComment +} + +// findMatchingCloseParen finds the matching ) for the ( at tokens[openIdx]. +func findMatchingCloseParen(tokens []parser.Token, openIdx int) int { + depth := 1 + for i := openIdx + 1; i < len(tokens); i++ { + switch tokens[i].Kind { + case parser.TokOpenParen: + depth++ + case parser.TokCloseParen: + depth-- + if depth == 0 { + return i + } + case parser.TokEOF: + return -1 + } + } + return -1 +} + // findMatchingCloseBrace finds the matching } for the { at tokens[openIdx]. // Returns -1 if not found. func findMatchingCloseBrace(tokens []parser.Token, openIdx int) int { diff --git a/internal/lsp/elixir_test.go b/internal/lsp/elixir_test.go index c569632..64995b9 100644 --- a/internal/lsp/elixir_test.go +++ b/internal/lsp/elixir_test.go @@ -984,6 +984,112 @@ end`, col: len(" user"), want: map[string]string{"user": "User"}, }, + // Typespec-based parameter inference + { + name: "typespec t() infers __MODULE__", + code: `@spec changeset(t(), map()) :: Ecto.Changeset.t() +def changeset(leave_type, attrs) do + leave_type +end`, + line: 2, + col: len(" leave_type"), + want: map[string]string{"leave_type": "__MODULE__"}, + }, + { + name: "typespec Module.t() infers remote module", + code: `@spec process(User.t(), map()) :: :ok +def process(user, attrs) do + user +end`, + line: 2, + col: len(" user"), + want: map[string]string{"user": "User"}, + }, + { + name: "typespec qualified Module.t()", + code: `@spec process(MyApp.Accounts.User.t(), map()) :: :ok +def process(user, attrs) do + user +end`, + line: 2, + col: len(" user"), + want: map[string]string{"user": "MyApp.Accounts.User"}, + }, + { + name: "typespec second param t()", + code: `@spec update(map(), t()) :: t() +def update(attrs, schema) do + schema +end`, + line: 2, + col: len(" schema"), + want: map[string]string{"schema": "__MODULE__"}, + }, + { + name: "typespec does not override pattern match", + code: `@spec changeset(t(), map()) :: Ecto.Changeset.t() +def changeset(%LeaveType{} = leave_type, attrs) do + leave_type +end`, + line: 2, + col: len(" leave_type"), + want: map[string]string{"leave_type": "LeaveType"}, + }, + { + name: "typespec with defp", + code: `@spec do_work(t()) :: :ok +defp do_work(item) do + item +end`, + line: 2, + col: len(" item"), + want: map[string]string{"item": "__MODULE__"}, + }, + { + name: "typespec non-struct types ignored", + code: `@spec run(String.t(), integer(), atom()) :: :ok +def run(name, count, label) do + name +end`, + line: 2, + col: len(" name"), + want: map[string]string{}, + }, + { + name: "typespec with no params", + code: `@spec run() :: :ok +def run do + :ok +end`, + line: 2, + col: len(" :ok"), + want: map[string]string{}, + }, + { + name: "typespec multiple struct params", + code: `@spec merge(t(), User.t()) :: t() +def merge(schema, user) do + schema +end`, + line: 2, + col: len(" schema"), + want: map[string]string{"schema": "__MODULE__", "user": "User"}, + }, + { + name: "typespec only matches preceding spec", + code: `@spec unrelated(t()) :: :ok +def unrelated(x) do + x +end + +@spec actual(map()) :: :ok +def actual(data) do + data +end`, + line: 7, + col: len(" data"), + want: map[string]string{}, + }, } for _, tt := range tests { @@ -1004,6 +1110,210 @@ end`, } } +func TestVariableFunctionCalls(t *testing.T) { + tests := []struct { + name string + code string + line int + col int + want []VariableFunctionCall + }{ + { + name: "simple module function call", + code: `def run do + user = Accounts.get_user(id) + user +end`, + line: 2, + col: len(" user"), + want: []VariableFunctionCall{ + {VarName: "user", Module: "Accounts", Function: "get_user", Arity: 1}, + }, + }, + { + name: "qualified module call", + code: `def run do + user = MyApp.Accounts.get_user(id) + user +end`, + line: 2, + col: len(" user"), + want: []VariableFunctionCall{ + {VarName: "user", Module: "MyApp.Accounts", Function: "get_user", Arity: 1}, + }, + }, + { + name: "zero arity", + code: `def run do + users = Repo.all() + users +end`, + line: 2, + col: len(" users"), + want: []VariableFunctionCall{ + {VarName: "users", Module: "Repo", Function: "all", Arity: 0}, + }, + }, + { + name: "multiple arguments", + code: `def run do + user = Repo.get(User, id) + user +end`, + line: 2, + col: len(" user"), + want: []VariableFunctionCall{ + {VarName: "user", Module: "Repo", Function: "get", Arity: 2}, + }, + }, + { + name: "multiple calls", + code: `def run do + user = Accounts.get_user(id) + org = Organizations.get_org(slug) + org +end`, + line: 3, + col: len(" org"), + want: []VariableFunctionCall{ + {VarName: "user", Module: "Accounts", Function: "get_user", Arity: 1}, + {VarName: "org", Module: "Organizations", Function: "get_org", Arity: 1}, + }, + }, + { + name: "reassignment keeps last call", + code: `def run do + user = Accounts.get_user(id) + user = Accounts.get_admin(id) + user +end`, + line: 3, + col: len(" user"), + want: []VariableFunctionCall{ + {VarName: "user", Module: "Accounts", Function: "get_admin", Arity: 1}, + }, + }, + { + name: "bare function call ignored", + code: `def run do + user = get_user(id) + user +end`, + line: 2, + col: len(" user"), + want: nil, + }, + { + name: "after cursor excluded", + code: `def run do + cursor + user = Accounts.get_user(id) +end`, + line: 1, + col: len(" cursor"), + want: nil, + }, + { + name: "scoped to current function", + code: `def first do + user = Accounts.get_user(id) +end + +def second do + user +end`, + line: 5, + col: len(" user"), + want: nil, + }, + { + name: "nested call in args counted correctly", + code: `def run do + user = Repo.get(User, String.to_integer(id)) + user +end`, + line: 2, + col: len(" user"), + want: []VariableFunctionCall{ + {VarName: "user", Module: "Repo", Function: "get", Arity: 2}, + }, + }, + { + name: "pipeline not detected", + code: `def run do + result = id |> Accounts.get_user() + result +end`, + line: 2, + col: len(" result"), + want: nil, + }, + { + name: "no-paren call single arg", + code: `def run do + user = Repo.get! User + user +end`, + line: 2, + col: len(" user"), + want: []VariableFunctionCall{ + {VarName: "user", Module: "Repo", Function: "get!", Arity: 1}, + }, + }, + { + name: "no-paren call multiple args", + code: `def run do + user = Repo.get User, id + user +end`, + line: 2, + col: len(" user"), + want: []VariableFunctionCall{ + {VarName: "user", Module: "Repo", Function: "get", Arity: 2}, + }, + }, + { + name: "no-paren call with do block", + code: `def run do + changeset = Ecto.Changeset.change user do + :ok + end + changeset +end`, + line: 4, + col: len(" changeset"), + want: []VariableFunctionCall{ + {VarName: "changeset", Module: "Ecto.Changeset", Function: "change", Arity: 1}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tokens, source, lineStarts := tokenize(tt.code) + got := VariableFunctionCalls(tokens, source, lineStarts, tt.line, tt.col) + if len(got) != len(tt.want) { + t.Fatalf("got %d calls %v, want %d", len(got), got, len(tt.want)) + } + for i, want := range tt.want { + g := got[i] + if g.VarName != want.VarName { + t.Errorf("[%d] VarName = %q, want %q", i, g.VarName, want.VarName) + } + if g.Module != want.Module { + t.Errorf("[%d] Module = %q, want %q", i, g.Module, want.Module) + } + if g.Function != want.Function { + t.Errorf("[%d] Function = %q, want %q", i, g.Function, want.Function) + } + if g.Arity != want.Arity { + t.Errorf("[%d] Arity = %d, want %d", i, g.Arity, want.Arity) + } + } + }) + } +} + func TestVariableFieldAccessAtCursor(t *testing.T) { tests := []struct { name string diff --git a/internal/lsp/formatter.go b/internal/lsp/formatter.go index a7b1b17..03cefe6 100644 --- a/internal/lsp/formatter.go +++ b/internal/lsp/formatter.go @@ -45,12 +45,13 @@ const ( formatterOpFormat byte = 0x00 - codeIntelOpErlangSource byte = 0x00 - codeIntelOpErlangDocs byte = 0x01 - codeIntelOpWarmOTPModules byte = 0x02 - codeIntelOpErlangExports byte = 0x03 - codeIntelOpRuntimeInfo byte = 0x04 - codeIntelOpStructFields byte = 0x05 + codeIntelOpErlangSource byte = 0x00 + codeIntelOpErlangDocs byte = 0x01 + codeIntelOpWarmOTPModules byte = 0x02 + codeIntelOpErlangExports byte = 0x03 + codeIntelOpRuntimeInfo byte = 0x04 + codeIntelOpStructFields byte = 0x05 + codeIntelOpReturnTypeStruct byte = 0x06 beamNotificationOTPModulesReady byte = 0x00 beamNotificationOTPModulesFailed byte = 0x01 @@ -696,6 +697,44 @@ func (bp *beamProcess) StructFields(ctx context.Context, module string) ([]strin return fields, err } +// ReturnTypeStruct queries the ExCk chunk of a compiled module to determine if +// a function's return type is a struct. Returns the struct module name (e.g. +// "MyApp.User") or empty string if the return type is not a single struct type. +// Gracefully returns empty string if the ExCk chunk is not available. +func (bp *beamProcess) ReturnTypeStruct(ctx context.Context, module, function string, arity int) (string, error) { + var payload bytes.Buffer + _ = binary.Write(&payload, binary.BigEndian, uint16(len(module))) + payload.WriteString(module) + _ = binary.Write(&payload, binary.BigEndian, uint16(len(function))) + payload.WriteString(function) + payload.WriteByte(byte(arity)) + + var structModule string + err := bp.doRequest(ctx, serviceCodeIntel, codeIntelOpReturnTypeStruct, payload.Bytes(), func(status byte, respPayload []byte) error { + if status != 0 { + // Non-zero status means lookup failed (e.g. old Elixir without ExCk). + // Treat as graceful "not a struct". + return nil + } + + reader := bytes.NewReader(respPayload) + var nameLen uint16 + if err := binary.Read(reader, binary.BigEndian, &nameLen); err != nil { + return fmt.Errorf("read struct module name length: %w", err) + } + if nameLen == 0 { + return nil + } + nameBuf := make([]byte, nameLen) + if _, err := io.ReadFull(reader, nameBuf); err != nil { + return fmt.Errorf("read struct module name: %w", err) + } + structModule = string(nameBuf) + return nil + }) + return structModule, err +} + // FormatError represents a formatting failure (e.g. syntax error in the source). // The persistent process is still alive — this is not a protocol/crash error. type FormatError struct { diff --git a/internal/lsp/formatter_test.go b/internal/lsp/formatter_test.go index 6d00ad2..3504c23 100644 --- a/internal/lsp/formatter_test.go +++ b/internal/lsp/formatter_test.go @@ -713,6 +713,149 @@ func TestFindFormatterConfig_PerAppOverridesRoot(t *testing.T) { } } +func TestBeamProcess_ReturnTypeStruct(t *testing.T) { + reqReader, reqWriter := io.Pipe() + respReader, respWriter := io.Pipe() + + bp := newTestBeamProcess(reqWriter, respReader, nil) + + readLoopDone := make(chan struct{}) + go func() { + bp.readLoop() + close(readLoopDone) + }() + + tests := []struct { + name string + module string + function string + arity int + respStatus byte + respName string // struct module name in response ("" = not a struct) + wantResult string + }{ + { + name: "returns struct module name", + module: "URI", + function: "parse", + arity: 1, + respStatus: 0, + respName: "URI", + wantResult: "URI", + }, + { + name: "returns empty for non-struct", + module: "Map", + function: "new", + arity: 0, + respStatus: 0, + respName: "", + wantResult: "", + }, + { + name: "handles error status gracefully", + module: "NoModule", + function: "nope", + arity: 0, + respStatus: 1, + respName: "", + wantResult: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + serverDone := make(chan struct{}) + go func() { + defer close(serverDone) + + frameType, err := readByte(reqReader) + if err != nil { + t.Error(err) + return + } + if frameType != frameRequest { + t.Errorf("expected request frame, got %d", frameType) + return + } + + reqID, err := readUint32(reqReader) + if err != nil { + t.Error(err) + return + } + + header := make([]byte, 6) + if _, err := io.ReadFull(reqReader, header); err != nil { + t.Error(err) + return + } + if header[0] != serviceCodeIntel || header[1] != codeIntelOpReturnTypeStruct { + t.Errorf("unexpected service=%d op=%d", header[0], header[1]) + return + } + payloadLen := binary.BigEndian.Uint32(header[2:]) + payloadBuf := make([]byte, payloadLen) + if _, err := io.ReadFull(reqReader, payloadBuf); err != nil { + t.Error(err) + return + } + + // Decode and verify the request payload + r := bytes.NewReader(payloadBuf) + var modLen uint16 + _ = binary.Read(r, binary.BigEndian, &modLen) + modBuf := make([]byte, modLen) + _, _ = io.ReadFull(r, modBuf) + if string(modBuf) != tt.module { + t.Errorf("module = %q, want %q", string(modBuf), tt.module) + } + + var fnLen uint16 + _ = binary.Read(r, binary.BigEndian, &fnLen) + fnBuf := make([]byte, fnLen) + _, _ = io.ReadFull(r, fnBuf) + if string(fnBuf) != tt.function { + t.Errorf("function = %q, want %q", string(fnBuf), tt.function) + } + + arityByte, _ := r.ReadByte() + if int(arityByte) != tt.arity { + t.Errorf("arity = %d, want %d", arityByte, tt.arity) + } + + // Build response + var respPayload bytes.Buffer + if tt.respStatus == 0 { + _ = binary.Write(&respPayload, binary.BigEndian, uint16(len(tt.respName))) + respPayload.WriteString(tt.respName) + } else { + respPayload.WriteString("lookup failed") + } + writeTestResponseFrame(t, respWriter, reqID, tt.respStatus, respPayload.Bytes()) + }() + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + result, err := bp.ReturnTypeStruct(ctx, tt.module, tt.function, tt.arity) + if err != nil { + t.Fatalf("ReturnTypeStruct error: %v", err) + } + if result != tt.wantResult { + t.Errorf("ReturnTypeStruct = %q, want %q", result, tt.wantResult) + } + + <-serverDone + }) + } + + _ = reqWriter.Close() + _ = reqReader.Close() + _ = respWriter.Close() + <-readLoopDone +} + func TestBeamProcess_DoRequestHandlesNotificationBeforeResponse(t *testing.T) { reqReader, reqWriter := io.Pipe() respReader, respWriter := io.Pipe() diff --git a/internal/lsp/server.go b/internal/lsp/server.go index edf8389..23761c8 100644 --- a/internal/lsp/server.go +++ b/internal/lsp/server.go @@ -1209,6 +1209,61 @@ func (s *Server) structFieldBuildRoot(filePath string) string { return s.findBuildRoot(s.projectRoot) } +// resolveVariableTypeFromExCk checks if a variable was assigned from a module +// function call and queries the ExCk chunk to determine if the function's return +// type is a struct. Returns the fully-qualified struct module name or "". +func (s *Server) resolveVariableTypeFromExCk(ctx context.Context, tf *TokenizedFile, filePath, varName string, lineNum, col int) string { + calls := tf.VariableFunctionCalls(lineNum, col) + if len(calls) == 0 { + return "" + } + + // Find the call for this variable. + var call *VariableFunctionCall + for i := range calls { + if calls[i].VarName == varName { + call = &calls[i] + break + } + } + if call == nil { + return "" + } + + // Resolve the module reference through aliases. + aliases := tf.ExtractAliasesInScope(call.Line) + s.mergeAliasesFromUseTokenized(tf, aliases) + resolvedModule := tf.ResolveModuleExpr(call.Module, call.Line) + fullCallModule := s.resolveModuleWithNesting(resolvedModule, aliases, filePath, call.Line) + if fullCallModule == "" { + return "" + } + + buildRoot := s.structFieldBuildRoot(filePath) + if buildRoot == "" { + return "" + } + + bp := s.getBeamProcess(ctx, buildRoot) + if bp == nil { + return "" + } + if err := bp.Ready(ctx); err != nil { + return "" + } + + structModule, err := bp.ReturnTypeStruct(ctx, fullCallModule, call.Function, call.Arity) + if err != nil || structModule == "" { + return "" + } + + if s.debug { + s.debugf("ExCk return type struct: %s.%s/%d -> %s", fullCallModule, call.Function, call.Arity, structModule) + } + + return structModule +} + func (s *Server) cachedStructFieldsOrWarm(filePath, module string) ([]string, bool) { return s.cachedStructFieldsOrWarmWithLogging(filePath, module, true) } @@ -1410,14 +1465,6 @@ func (s *Server) invalidateStructFieldCacheForFile(filePath string, newDefs []pa } } -func (s *Server) moduleKnown(module string) bool { - if module == "" { - return false - } - results, err := s.store.LookupModule(module) - return err == nil && len(results) > 0 -} - func (s *Server) maybePrewarmStructFields(docURI, filePath, text string) { if filePath == "" || !parser.IsElixirFile(filePath) || !s.isProjectFile(filePath) || s.isDepsFile(filePath) { return @@ -1436,14 +1483,12 @@ func (s *Server) maybePrewarmStructFields(docURI, filePath, text string) { func (s *Server) prewarmStructFieldsFromText(_ string, filePath, text string) { tPrewarm := s.debugNow() tf := NewTokenizedFile(text) - refs := tf.StructModuleRefs() - if len(refs) == 0 { - return - } - seen := make(map[string]bool, len(refs)) + seen := make(map[string]bool) warmed := 0 - skippedUnknown := 0 + + // Pre-warm from struct literal references (%Module{}). + refs := tf.StructModuleRefs() for _, ref := range refs { aliases := tf.ExtractAliasesInScope(ref.Line) s.mergeAliasesFromUseTokenized(tf, aliases) @@ -1453,16 +1498,52 @@ func (s *Server) prewarmStructFieldsFromText(_ string, filePath, text string) { continue } seen[fullModule] = true - if !s.moduleKnown(fullModule) { - skippedUnknown++ + s.prewarmStructFields(filePath, fullModule) + warmed++ + } + + // Pre-warm from function call return types via ExCk. + // Scan the entire file for var = Module.func(...) patterns. + lines := strings.Count(text, "\n") + endCol := len(text) + calls := tf.VariableFunctionCalls(lines, endCol) + for _, call := range calls { + aliases := tf.ExtractAliasesInScope(call.Line) + s.mergeAliasesFromUseTokenized(tf, aliases) + resolvedModule := tf.ResolveModuleExpr(call.Module, call.Line) + fullCallModule := s.resolveModuleWithNesting(resolvedModule, aliases, filePath, call.Line) + if fullCallModule == "" { continue } - s.prewarmStructFields(filePath, fullModule) + + buildRoot := s.structFieldBuildRoot(filePath) + if buildRoot == "" { + continue + } + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + bp := s.getBeamProcess(ctx, buildRoot) + if bp == nil { + cancel() + continue + } + if err := bp.Ready(ctx); err != nil { + cancel() + continue + } + + structModule, err := bp.ReturnTypeStruct(ctx, fullCallModule, call.Function, call.Arity) + cancel() + if err != nil || structModule == "" || seen[structModule] { + continue + } + seen[structModule] = true + s.prewarmStructFields(filePath, structModule) warmed++ } if s.debug && warmed > 0 { - s.debugf("StructFields prewarm refs=%d queued=%d skipped_unknown=%d (%s)", len(refs), warmed, skippedUnknown, time.Since(tPrewarm).Round(time.Microsecond)) + s.debugf("StructFields prewarm queued=%d (%s)", warmed, time.Since(tPrewarm).Round(time.Microsecond)) } } @@ -1782,17 +1863,27 @@ func (s *Server) Completion(ctx context.Context, params *protocol.CompletionPara // "variable." or "variable.field_prefix" — struct field access on a typed variable. if fieldAccess, ok := tf.VariableFieldAccessAtCursor(lineNum, col); ok { + fullModule := "" + + // Tier 1: check pattern-match inference (pure tokens, no BEAM needed). varStructTypes := tf.VariableStructTypes(lineNum, col) if structModule, ok := varStructTypes[fieldAccess.VariableName]; ok { - tStruct := s.debugNow() aliases := tf.ExtractAliasesInScope(lineNum) s.mergeAliasesFromUseTokenized(tf, aliases) resolvedModule := tf.ResolveModuleExpr(structModule, lineNum) - fullModule := s.resolveModuleWithNesting(resolvedModule, aliases, filePath, lineNum) + fullModule = s.resolveModuleWithNesting(resolvedModule, aliases, filePath, lineNum) + } + + // Tier 2: check ExCk return type inference for function call assignments. + if fullModule == "" { + fullModule = s.resolveVariableTypeFromExCk(ctx, tf, filePath, fieldAccess.VariableName, lineNum, col) + } + + if fullModule != "" { + tStruct := s.debugNow() if s.debug { s.debugf("Completion variable struct type") s.debugf(" variable=%s", fieldAccess.VariableName) - s.debugf(" structModule=%s", structModule) s.debugf(" resolved=%s", fullModule) s.debugf(" fieldPrefix=%q", fieldAccess.FieldPrefix) } diff --git a/internal/lsp/server_test.go b/internal/lsp/server_test.go index 492e090..0e437fc 100644 --- a/internal/lsp/server_test.go +++ b/internal/lsp/server_test.go @@ -822,6 +822,128 @@ end`) } } +func TestCompletion_VariableDotFromPatternMatch(t *testing.T) { + server, cleanup := setupTestServer(t) + defer cleanup() + + path := filepath.Join(server.projectRoot, "lib", "controller.ex") + uriStr := string(uri.File(path)) + server.docs.Set(uriStr, `defmodule MyApp.Controller do + alias MyApp.Accounts.User + + def run do + user = %User{name: "test"} + user. + end +end`) + + key := structFieldCacheKey{ + buildRoot: server.structFieldBuildRoot(path), + module: "MyApp.Accounts.User", + } + server.structFieldMu.Lock() + server.structFieldCache[key] = &structFieldCacheEntry{ + fields: []string{"active?", "email", "name"}, + loaded: true, + } + server.structFieldMu.Unlock() + + items := completionAt(t, server, uriStr, 5, uint32(len(" user."))) + if !hasCompletionItem(items, "email") { + t.Fatal("expected 'email' struct field completion on variable.dot") + } + if !hasCompletionItem(items, "name") { + t.Fatal("expected 'name' struct field completion on variable.dot") + } + for _, item := range items { + if item.Kind != protocol.CompletionItemKindField { + t.Errorf("item %q Kind = %v, want Field", item.Label, item.Kind) + } + } +} + +func TestCompletion_VariableDotFromPatternMatchWithPrefix(t *testing.T) { + server, cleanup := setupTestServer(t) + defer cleanup() + + path := filepath.Join(server.projectRoot, "lib", "controller.ex") + uriStr := string(uri.File(path)) + server.docs.Set(uriStr, `defmodule MyApp.Controller do + alias MyApp.Accounts.User + + def run do + user = %User{name: "test"} + user.na + end +end`) + + key := structFieldCacheKey{ + buildRoot: server.structFieldBuildRoot(path), + module: "MyApp.Accounts.User", + } + server.structFieldMu.Lock() + server.structFieldCache[key] = &structFieldCacheEntry{ + fields: []string{"active?", "email", "name"}, + loaded: true, + } + server.structFieldMu.Unlock() + + items := completionAt(t, server, uriStr, 5, uint32(len(" user.na"))) + if !hasCompletionItem(items, "name") { + t.Fatal("expected 'name' completion for user.na prefix") + } + if hasCompletionItem(items, "email") { + t.Fatal("did not expect 'email' for user.na prefix") + } +} + +func TestCompletion_VariableDotFromExCk(t *testing.T) { + server, cleanup := setupTestServer(t) + defer cleanup() + + path := filepath.Join(server.projectRoot, "lib", "controller.ex") + uriStr := string(uri.File(path)) + server.docs.Set(uriStr, `defmodule MyApp.Controller do + def run(url) do + parsed = URI.parse(url) + parsed. + end +end`) + + // This test requires a real BEAM process to resolve URI.parse/1 -> URI struct + buildRoot := server.structFieldBuildRoot(path) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + bp := server.getBeamProcess(ctx, buildRoot) + if bp == nil { + t.Skip("BEAM process not available") + } + if err := bp.Ready(ctx); err != nil { + t.Skipf("BEAM process not ready: %v", err) + } + + // Pre-warm the URI struct fields cache so completion doesn't have to wait + key := structFieldCacheKey{buildRoot: buildRoot, module: "URI"} + fields, err := bp.StructFields(ctx, "URI") + if err != nil || len(fields) == 0 { + t.Skipf("could not fetch URI struct fields: %v", err) + } + server.structFieldMu.Lock() + server.structFieldCache[key] = &structFieldCacheEntry{fields: fields, loaded: true} + server.structFieldMu.Unlock() + + items := completionAt(t, server, uriStr, 3, uint32(len(" parsed."))) + if !hasCompletionItem(items, "host") { + t.Fatalf("expected 'host' struct field from ExCk return type, got %v", items) + } + if !hasCompletionItem(items, "scheme") { + t.Fatal("expected 'scheme' struct field from ExCk return type") + } + if !hasCompletionItem(items, "path") { + t.Fatal("expected 'path' struct field from ExCk return type") + } +} + func TestStructFieldPrewarmFromDocument(t *testing.T) { server, cleanup := setupTestServer(t) defer cleanup() @@ -5570,3 +5692,69 @@ end`) t.Errorf("expected doc content, got %q", hover.Contents.Value) } } + +func TestReturnTypeStruct_Integration(t *testing.T) { + server, cleanup := setupTestServer(t) + defer cleanup() + + buildRoot := server.projectRoot + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + bp := server.getBeamProcess(ctx, buildRoot) + if bp == nil { + t.Skip("BEAM process not available (mix not in PATH)") + } + if err := bp.Ready(ctx); err != nil { + t.Skipf("BEAM process not ready: %v", err) + } + + tests := []struct { + name string + module string + function string + arity int + want string + }{ + { + name: "URI.parse returns URI struct", + module: "URI", + function: "parse", + arity: 1, + want: "URI", + }, + { + name: "Map.new returns no struct", + module: "Map", + function: "new", + arity: 0, + want: "", + }, + { + name: "nonexistent module returns empty", + module: "NonExistentModule12345", + function: "foo", + arity: 0, + want: "", + }, + { + name: "nonexistent function returns empty", + module: "URI", + function: "nonexistent_function_xyz", + arity: 0, + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := bp.ReturnTypeStruct(ctx, tt.module, tt.function, tt.arity) + if err != nil { + t.Fatalf("ReturnTypeStruct error: %v", err) + } + if result != tt.want { + t.Errorf("ReturnTypeStruct(%s.%s/%d) = %q, want %q", tt.module, tt.function, tt.arity, result, tt.want) + } + }) + } +}