Skip to content

Commit

Permalink
Performance improvements, particularly for deeply nested shapes.
Browse files Browse the repository at this point in the history
  • Loading branch information
Benjamin-Dobell committed Jan 3, 2023
1 parent f01e7eb commit b8c20ac
Show file tree
Hide file tree
Showing 20 changed files with 218 additions and 130 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class AssignTypeInspection : StrictInspection() {
// Get owner class
val assigneeOwnerType = assignee.guessParentType(context)

if (assigneeOwnerType is TyTable && resolvedValue is TyTable && assigneeOwnerType.table == resolvedValue.table) {
if (assigneeOwnerType is TyTable && resolvedValue is TyTable && assigneeOwnerType.psi == resolvedValue.psi) {
return
}

Expand Down Expand Up @@ -87,7 +87,7 @@ class AssignTypeInspection : StrictInspection() {

val variableType = assignee.guessType(context)

if (variableType == null || (variableType is TyTable && resolvedValue is TyTable && variableType.table == resolvedValue.table)) {
if (variableType == null || (variableType is TyTable && resolvedValue is TyTable && variableType.psi == resolvedValue.psi)) {
return
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ class LuaCommentImpl(node: ASTNode) : ASTWrapperPsiElement(node), LuaComment {
val map = list.associateBy { it.varName }

object : TySubstitutor() {
override val name = "name substitutor"

override fun substitute(context: SearchContext, clazz: ITyClass): ITy {
return map[clazz.className] ?: super.substitute(context, clazz)
}
Expand Down
9 changes: 6 additions & 3 deletions src/main/java/com/tang/intellij/lua/psi/LuaParamInfo.kt
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ import com.tang.intellij.lua.Constants
import com.tang.intellij.lua.search.SearchContext
import com.tang.intellij.lua.stubs.readTyNullable
import com.tang.intellij.lua.stubs.writeTyNullable
import com.tang.intellij.lua.ty.*
import com.tang.intellij.lua.ty.ITy
import com.tang.intellij.lua.ty.ITySubstitutor
import com.tang.intellij.lua.ty.Primitives
import com.tang.intellij.lua.ty.TyMultipleResults

/**
* parameter info
Expand All @@ -36,14 +39,14 @@ class LuaParamInfo(val name: String, val ty: ITy?) {
return other is LuaParamInfo && other.ty == ty
}

fun equals(context: SearchContext, other: LuaParamInfo): Boolean {
fun equals(context: SearchContext, other: LuaParamInfo, equalityFlags: Int): Boolean {
if (ty == null) {
return other.ty == null
} else if (other.ty == null) {
return false
}

return ty.equals(context, other.ty)
return ty.equals(context, other.ty, equalityFlags)
}

override fun hashCode(): Int {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ private class ScopedTypeTreeScope(override val psi: LuaTypeScope, override val t
val classTag = if (cls is TySerializedClass) {
LuaClassIndex.find(context, cls.className)
} else if (cls is TyPsiDocClass) {
cls.tagClass
cls.psi
} else null

// Need to ensure we don't check the same scope *without* beforeIndex
Expand Down
5 changes: 2 additions & 3 deletions src/main/java/com/tang/intellij/lua/psi/PsiExtension.kt
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,8 @@ private fun LuaExpression<*>.shouldBeInternal(context: SearchContext): ITy? {
var ret: ITy = Primitives.VOID
Ty.eachResolved(context, fTy) {
if (it is ITyFunction) {
var sig = it.matchSignature(context, p2)?.substitutedSignature ?: it.mainSignature
val substitutor = p2.createSubstitutor(context, sig)
sig = sig.substitute(context, substitutor)
val sig = it.matchSignature(context, p2)?.substitutedSignature
?: it.mainSignature.substitute(context, p2.createSubstitutor(context, it.mainSignature))
ret = ret.union(context, sig.getArgTy(idx))
}
}
Expand Down
14 changes: 10 additions & 4 deletions src/main/java/com/tang/intellij/lua/ty/ProblemUtil.kt
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,12 @@ object ProblemUtil {
return false
}

// We perform a non-structural (i.e. inheritance) check first as a happy path optimization. This is a *very*
// significant optimization when you've deeply nested shapes whose members are other (non-anonymous) shapes.
if (target.contravariantOf(context, source, varianceFlags or TyVarianceFlags.NON_STRUCTURAL)) {
return true;
}

val sourceSubstitutor = source.getMemberSubstitutor(context)
val targetSubstitutor = target.getMemberSubstitutor(context)

Expand Down Expand Up @@ -204,7 +210,7 @@ object ProblemUtil {
val targetMemberTy = (if (indexTy != null) {
val targetMember = target.findIndexer(context, indexTy)

if (targetMember?.guessIndexType(context)?.equals(context, indexTy) == true) {
if (targetMember?.guessIndexType(context)?.equals(context, indexTy, 0) == true) {
// If the target index type == source index type, then we have already checked compatibility of this member above.
return@processMembers true
}
Expand Down Expand Up @@ -232,7 +238,7 @@ object ProblemUtil {
// TODO: DRY
if (varianceFlags and TyVarianceFlags.STRICT_UNKNOWN != 0 || !sourceMemberTy.isUnknown) {
if (varianceFlags and TyVarianceFlags.WIDEN_TABLES == 0) {
if (!targetMemberTy.equals(context, sourceMemberTy)) {
if (!targetMemberTy.equals(context, sourceMemberTy, 0)) {
isContravariant = false

if (processProblem != null && sourceElement != null) {
Expand Down Expand Up @@ -311,7 +317,7 @@ object ProblemUtil {
resolvedSourceTy.lazyInit(context)

if ((varianceFlags and TyVarianceFlags.NON_STRUCTURAL == 0 || resolvedSourceTy.isAnonymousTable) && resolvedSourceTy.isShape(context)) {
val sourceIsInline = resolvedSourceTy is TyTable && resolvedSourceTy.table == sourceElement
val sourceIsInline = resolvedSourceTy is TyTable && resolvedSourceTy.psi == sourceElement
val indexes = sortedMapOf<Int, PsiElement>()
var foundNumberIndexer = false

Expand Down Expand Up @@ -453,7 +459,7 @@ object ProblemUtil {
base.lazyInit(context)
}

if (varianceFlags and TyVarianceFlags.NON_STRUCTURAL == 0 && resolvedSourceTy is TyTable && resolvedSourceTy.table == sourceElement && base.isShape(context)) {
if (varianceFlags and TyVarianceFlags.NON_STRUCTURAL == 0 && resolvedSourceTy is TyTable && resolvedSourceTy.psi == sourceElement && base.isShape(context)) {
isContravariant = contravariantOfShape(
context,
resolvedTargetTy,
Expand Down
58 changes: 41 additions & 17 deletions src/main/java/com/tang/intellij/lua/ty/Ty.kt
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,11 @@ import com.tang.intellij.lua.codeInsight.inspection.MatchFunctionSignatureInspec
import com.tang.intellij.lua.ext.recursionGuard
import com.tang.intellij.lua.project.LuaSettings
import com.tang.intellij.lua.psi.LuaCallExpr
import com.tang.intellij.lua.psi.LuaPsiElement
import com.tang.intellij.lua.psi.LuaTableExpr
import com.tang.intellij.lua.psi.argList
import com.tang.intellij.lua.search.SearchContext
import conditionallyCached
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.contract

Expand Down Expand Up @@ -68,6 +70,7 @@ class TyFlags {
const val UNKNOWN = 0x20 // Unless STRICT_UNKNOWN is enabled, this type is covariant of all other types.
}
}

class TyVarianceFlags {
companion object {
const val STRICT_UNKNOWN = 0x1 // When enabled UNKNOWN types are no longer treated as covariant of all types.
Expand All @@ -78,10 +81,28 @@ class TyVarianceFlags {
}
}

class TyEqualityFlags {
companion object {
const val NON_STRUCTURAL = 0x1 // Treat shapes as classes i.e. a shape is only covariant of another shape if it explicitly inherits from it.

fun fromVarianceFlags(varianceFlags: Int): Int {
return if (varianceFlags and TyVarianceFlags.NON_STRUCTURAL != 0) {
TyEqualityFlags.NON_STRUCTURAL
} else {
0
}
}
}
}

data class SignatureMatchResult(val signature: IFunSignature?, val substitutedSignature: IFunSignature?, val returnTy: ITy)

typealias ProcessTypeMember = (ownerTy: ITy, member: TypeMember) -> Boolean

interface IPsiTy<T : LuaPsiElement> {
val psi: T
}

interface ITy : Comparable<ITy> {
val kind: TyKind

Expand All @@ -91,7 +112,7 @@ interface ITy : Comparable<ITy> {

val booleanType: ITy

fun equals(context: SearchContext, other: ITy): Boolean
fun equals(context: SearchContext, other: ITy, equalityFlags: Int): Boolean

fun union(context: SearchContext, ty: ITy): ITy

Expand Down Expand Up @@ -501,8 +522,9 @@ abstract class Ty(override val kind: TyKind) : ITy {

final override var flags: Int = 0

override val displayName: String
get() = TyRenderer.SIMPLE.render(this)
override val displayName by conditionallyCached({ !(this is IPsiTy<*>) || !SearchContext.get(psi.project).isDumb }) {
TyRenderer.SIMPLE.render(this)
}

// Lazy initialization because Primitives.TRUE is itself a Ty that needs to be instantiated and refers to itself.
override val booleanType: ITy by lazy { Primitives.TRUE }
Expand Down Expand Up @@ -548,7 +570,19 @@ abstract class Ty(override val kind: TyKind) : ITy {

val resolvedOther = resolve(context, other)

if (this.equals(context, resolvedOther)) {
if (varianceFlags and TyVarianceFlags.NON_STRUCTURAL == 0 && isShape(context)) {
val isContravariant: Boolean? = recursionGuard(resolvedOther, {
// Note: ProblemUtil.contravariantOfShape will call back into this method with
// TyVarianceFlags.NON_STRUCTURAL set as a fast nominal check, before checking structurally.
ProblemUtil.contravariantOfShape(context, this, resolvedOther, varianceFlags)
})

if (isContravariant != null) {
return isContravariant
}
}

if (this.equals(context, resolvedOther, TyEqualityFlags.fromVarianceFlags(varianceFlags))) {
return true
}

Expand All @@ -566,16 +600,6 @@ abstract class Ty(override val kind: TyKind) : ITy {
return true
}

if ((varianceFlags and TyVarianceFlags.NON_STRUCTURAL == 0 || other.isAnonymousTable) && isShape(context)) {
val isContravariant: Boolean? = recursionGuard(resolvedOther, {
ProblemUtil.contravariantOfShape(context, this, resolvedOther, varianceFlags)
})

if (isContravariant != null) {
return isContravariant
}
}

val otherSuper = other.getSuperType(context)
return otherSuper != null && contravariantOf(context, otherSuper, varianceFlags)
}
Expand Down Expand Up @@ -871,7 +895,7 @@ class TyUnknown : Ty(TyKind.Unknown) {
this.flags = this.flags or TyFlags.UNKNOWN
}

override fun equals(context: SearchContext, other: ITy): Boolean {
override fun equals(context: SearchContext, other: ITy, equalityFlags: Int): Boolean {
if (other === Primitives.UNKNOWN) {
return true
}
Expand Down Expand Up @@ -904,7 +928,7 @@ class TyNil : Ty(TyKind.Nil) {

override val booleanType = Primitives.FALSE

override fun equals(context: SearchContext, other: ITy): Boolean {
override fun equals(context: SearchContext, other: ITy, equalityFlags: Int): Boolean {
if (other === Primitives.NIL) {
return true
}
Expand All @@ -931,7 +955,7 @@ class TyNil : Ty(TyKind.Nil) {

class TyVoid : Ty(TyKind.Void) {

override fun equals(context: SearchContext, other: ITy): Boolean {
override fun equals(context: SearchContext, other: ITy, equalityFlags: Int): Boolean {
if (other === Primitives.VOID) {
return true
}
Expand Down
4 changes: 2 additions & 2 deletions src/main/java/com/tang/intellij/lua/ty/TyAlias.kt
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,12 @@ class TyAlias(override val name: String,
return other is ITyAlias && other.name == name && other.flags == flags
}

override fun equals(context: SearchContext, other: ITy): Boolean {
override fun equals(context: SearchContext, other: ITy, equalityFlags: Int): Boolean {
if (this === other) {
return true
}

return ty.equals(context, other)
return ty.equals(context, other, equalityFlags)
}

override fun hashCode(): Int {
Expand Down
18 changes: 9 additions & 9 deletions src/main/java/com/tang/intellij/lua/ty/TyArray.kt
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ open class TyArray(override val base: ITy) : Ty(TyKind.Array), ITyArray {
return other is ITyArray && base == other.base
}

override fun equals(context: SearchContext, other: ITy): Boolean {
override fun equals(context: SearchContext, other: ITy, equalityFlags: Int): Boolean {
if (this === other) {
return true
}

val resolvedOther = Ty.resolve(context, other)
return resolvedOther is ITyArray && base.equals(context, resolvedOther.base)
return resolvedOther is ITyArray && base.equals(context, resolvedOther.base, equalityFlags)
}

override fun hashCode(): Int {
Expand All @@ -56,7 +56,7 @@ open class TyArray(override val base: ITy) : Ty(TyKind.Array), ITyArray {
val resolvedBase = Ty.resolve(context, base)

if (other is ITyArray) {
return resolvedBase.equals(context, other.base)
return resolvedBase.equals(context, other.base, TyEqualityFlags.fromVarianceFlags(varianceFlags))
|| (varianceFlags and TyVarianceFlags.WIDEN_TABLES != 0 && resolvedBase.contravariantOf(context, other.base, varianceFlags))
}

Expand Down Expand Up @@ -88,7 +88,7 @@ open class TyArray(override val base: ITy) : Ty(TyKind.Array), ITyArray {
}

return varianceFlags and TyVarianceFlags.WIDEN_TABLES != 0
|| Ty.resolve(context, resolvedBase).equals(context, indexedMemberType)
|| Ty.resolve(context, resolvedBase).equals(context, indexedMemberType, TyEqualityFlags.fromVarianceFlags(varianceFlags))
|| (resolvedBase.isUnknown && varianceFlags and TyVarianceFlags.STRICT_UNKNOWN == 0)
}

Expand Down Expand Up @@ -185,14 +185,14 @@ object TyArraySerializer : TySerializer<ITyArray>() {
}
}

class TyDocArray(val luaDocArrTy: LuaDocArrTy, base: ITy = luaDocArrTy.ty.getType()) : TyArray(base) {
class TyDocArray(override val psi: LuaDocArrTy, base: ITy = psi.ty.getType()) : TyArray(base), IPsiTy<LuaDocArrTy> {
override fun processIndexer(context: SearchContext, indexTy: ITy, exact: Boolean, deep: Boolean, process: ProcessTypeMember): Boolean {
if (exact) {
if (Primitives.NUMBER.equals(context, indexTy)) {
return process(this, luaDocArrTy)
if (Primitives.NUMBER.equals(context, indexTy, 0)) {
return process(this, psi)
}
} else if (Primitives.NUMBER.contravariantOf(context, indexTy, TyVarianceFlags.STRICT_UNKNOWN)) {
return process(this, luaDocArrTy)
return process(this, psi)
}

return true
Expand All @@ -202,7 +202,7 @@ class TyDocArray(val luaDocArrTy: LuaDocArrTy, base: ITy = luaDocArrTy.ty.getTyp
val substitutedBase = TyMultipleResults.getResult(context, base.substitute(context, substitutor))

return if (substitutedBase !== base) {
TyDocArray(luaDocArrTy, substitutedBase)
TyDocArray(psi, substitutedBase)
} else {
this
}
Expand Down
Loading

0 comments on commit b8c20ac

Please sign in to comment.