Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix incorrect function params' type infer when there is only @overload #2838

Merged
merged 4 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
* `FIX` Fix `VM.OnCompileFunctionParam` function in plugins
* `FIX` Lua 5.1: fix incorrect warning when using setfenv with an int as first parameter
* `FIX` Improve type narrow by checking exact match on literal type params
* `FIX` Incorrect function params' type infer when there is only `@overload` [#2509](https://github.com/LuaLS/lua-language-server/issues/2509) [#2708](https://github.com/LuaLS/lua-language-server/issues/2708) [#2709](https://github.com/LuaLS/lua-language-server/issues/2709)

## 3.10.5
`2024-8-19`
Expand Down
10 changes: 9 additions & 1 deletion script/vm/compiler.lua
Original file line number Diff line number Diff line change
Expand Up @@ -1099,6 +1099,7 @@ local function compileFunctionParam(func, source)

-- local call ---@type fun(f: fun(x: number));call(function (x) end) --> x -> number
local funcNode = vm.compileNode(func)
local found = false
for n in funcNode:eachObject() do
if n.type == 'doc.type.function' and n.args[aindex] then
local argNode = vm.compileNode(n.args[aindex])
Expand All @@ -1107,9 +1108,16 @@ local function compileFunctionParam(func, source)
vm.setNode(source, an)
end
end
return true
-- NOTE: keep existing behavior for local call which only set type based on the 1st match
if func.parent.type == 'callargs' then
return true
end
found = true
end
end
if found then
return true
end

local derviationParam = config.get(guide.getUri(func), 'Lua.type.inferParamType')
if derviationParam and func.parent.type == 'local' and func.parent.ref then
Expand Down
31 changes: 31 additions & 0 deletions script/vm/function.lua
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,7 @@ end
---@return number
local function calcFunctionMatchScore(uri, args, func)
if vm.isVarargFunctionWithOverloads(func)
or vm.isFunctionWithOnlyOverloads(func)
or not isAllParamMatched(uri, args, func.args)
then
return -1
Expand Down Expand Up @@ -490,6 +491,36 @@ function vm.isVarargFunctionWithOverloads(func)
return false
end

---@param func table
---@return boolean
function vm.isFunctionWithOnlyOverloads(func)
if func.type ~= 'function' then
return false
end
if func._onlyOverloadFunction ~= nil then
return func._onlyOverloadFunction
end

if not func.bindDocs then
func._onlyOverloadFunction = false
return false
end
local hasOverload = false
for _, doc in ipairs(func.bindDocs) do
if doc.type == 'doc.overload' then
hasOverload = true
elseif doc.type == 'doc.param'
or doc.type == 'doc.return'
then
-- has specified @param or @return, thus not only @overload
func._onlyOverloadFunction = false
return false
end
end
func._onlyOverloadFunction = hasOverload
return true
end

---@param func parser.object
---@return boolean
function vm.isEmptyFunction(func)
Expand Down
38 changes: 38 additions & 0 deletions test/type_inference/param_match.lua
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,44 @@ local v = 'y'
local <?r?> = f(v)
]]

TEST 'string|number' [[
---@overload fun(a: string)
---@overload fun(a: number)
local function f(<?a?>) end
]]

TEST '1|2' [[
---@overload fun(a: 1)
---@overload fun(a: 2)
local function f(<?a?>) end
]]

TEST 'string' [[
---@overload fun(a: 1): string
---@overload fun(a: 2): number
local function f(a) end

local <?r?> = f(1)
]]

TEST 'number' [[
---@overload fun(a: 1): string
---@overload fun(a: 2): number
local function f(a) end

local <?r?> = f(2)
]]

TEST 'string|number' [[
---@overload fun(a: 1): string
---@overload fun(a: 2): number
local function f(a) end

---@type number
local v
local <?r?> = f(v)
]]

TEST 'number' [[
---@overload fun(a: 1, c: fun(x: number))
---@overload fun(a: 2, c: fun(x: string))
Expand Down
Loading