Skip to content

fix #SCL-23707 #677

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: idea251.x
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@ import org.jetbrains.plugins.scala.lang.lexer.ScalaTokenTypes
import org.jetbrains.plugins.scala.lang.parameterInfo.ScalaFunctionParameterInfoHandler.{AnnotationParameters, UniversalApplyCall, UniversalApplyCallContext}
import org.jetbrains.plugins.scala.lang.psi.api.ScalaPsiElement
import org.jetbrains.plugins.scala.lang.psi.api.base.types.{ScParameterizedTypeElement, ScTypeArgs, ScTypeElementExt}
import org.jetbrains.plugins.scala.lang.psi.api.base.{ScConstructorInvocation, ScPrimaryConstructor}
import org.jetbrains.plugins.scala.lang.psi.api.base.{ScConstructorInvocation, ScInterpolatedStringLiteral, ScPrimaryConstructor}
import org.jetbrains.plugins.scala.lang.psi.api.expr._
import org.jetbrains.plugins.scala.lang.psi.api.statements.params.{ScParameter, ScParameterClause, ScTypeParam}
import org.jetbrains.plugins.scala.lang.psi.api.statements.{ScFunction, ScFunctionDefinition}
import org.jetbrains.plugins.scala.lang.psi.api.toplevel.typedef.{ScClass, ScConstructorOwner, ScTypeDefinition}
import org.jetbrains.plugins.scala.lang.psi.api.toplevel.{ScTypeParametersOwner, ScTypedDefinition}
import org.jetbrains.plugins.scala.lang.psi.fake.FakePsiMethod
import org.jetbrains.plugins.scala.lang.psi.impl.base.ScInterpolatedStringLiteralImpl
import org.jetbrains.plugins.scala.lang.psi.light.ScFunctionWrapper
import org.jetbrains.plugins.scala.lang.psi.types._
import org.jetbrains.plugins.scala.lang.psi.types.api.presentation.TypeAnnotationRenderer.ParameterTypeDecorator
Expand Down Expand Up @@ -548,6 +549,114 @@ class ScalaFunctionParameterInfoHandler extends ScalaParameterInfoHandler[PsiEle
private def elementsForParameterInfo(args: Invocation): Seq[Object] = {
implicit val project: ProjectContext = args.element.projectContext

def methodInvocationParameterInfo(call: PsiElement): ArraySeq[Object] = {
val resultBuilder = ArraySeq.newBuilder[Object]

def collectResult(): Unit = {
val canBeUpdate = call.getParent match {
case assignStmt: ScAssignment if call == assignStmt.leftExpression => true
case notExpr if !notExpr.is[ScExpression] || notExpr.is[ScBlockExpr] => true
case _ => false
}
val count = args.invocationCount
val gen = args.callGeneric.getOrElse(null: ScGenericCall)

def collectSubstitutor(element: PsiElement): ScSubstitutor = {
if (gen == null) return ScSubstitutor.empty
val typeParams = element match {
case tpo: ScTypeParametersOwner => tpo.typeParameters.toArray
case ptpo: PsiTypeParameterListOwner => ptpo.getTypeParameters
case _ => return ScSubstitutor.empty
}
ScSubstitutor.bind(typeParams, gen.arguments)(_.calcType)
}

def collectForType(typez: ScType): Unit = {
def process(functionName: String): Unit = {
val i = if (functionName == "update") -1 else 0
val processor: CompletionProcessor = new CompletionProcessor(StdKinds.refExprQualRef, call, withImplicitConversions = true) {

override protected val forName: Option[String] = Some(functionName)
}
processor.processType(typez, call)
val variants: Array[ScalaResolveResult] = processor.candidates
for {
variant <- variants
if !variant.getElement.isInstanceOf[PsiMember] ||
ResolveUtils.isAccessible(variant.getElement.asInstanceOf[PsiMember], call)
} {
variant match {
case ScalaResolveResult(method: ScFunction, subst: ScSubstitutor) =>
val signature: PhysicalMethodSignature = new PhysicalMethodSignature(method, subst.followed(collectSubstitutor(method)))
resultBuilder += ((signature, i))
resultBuilder ++= ScalaParameterInfoEnhancer.enhance(signature, args.arguments).map((_, i))
case _ =>
}
}
}

process("apply")
if (canBeUpdate) process("update")
}

args.callReference match {
case Some(ref: ScReferenceExpression) =>
if (count > 1) {
//todo: missed case with last implicit call
ref.bind() match {
case Some(ScalaResolveResult(function: ScFunction, subst: ScSubstitutor)) if function.
effectiveParameterClauses.length >= count =>
resultBuilder += ((new PhysicalMethodSignature(function, subst.followed(collectSubstitutor(function))), count - 1))
case Some(ScalaResolveResult(function: ScFunction, _)) if function.effectiveParameterClauses.isEmpty =>
function.`type`().foreach(collectForType)
case _ =>
call match {
case invocation: MethodInvocation =>
for (typez <- invocation.getEffectiveInvokedExpr.`type`()) //todo: implicit conversions
{
collectForType(typez)
}
case _ =>
}
}
} else {
val variants = {
val sameName = ref.getSameNameVariants
if (sameName.isEmpty) ref.multiResolveScala(false)
else sameName
}
for {
variant <- variants
if !variant.getElement.isInstanceOf[PsiMember] ||
ResolveUtils.isAccessible(variant.getElement.asInstanceOf[PsiMember], ref)
} {
variant match {
//todo: Synthetic function
case ScalaResolveResult(method: PsiMethod, subst: ScSubstitutor) =>
val signature: PhysicalMethodSignature = new PhysicalMethodSignature(method, subst.followed(collectSubstitutor(method)))
resultBuilder += ((signature, 0))
resultBuilder ++= ScalaParameterInfoEnhancer.enhance(signature, args.arguments).map((_, 0))
case ScalaResolveResult(typed: ScTypedDefinition, subst: ScSubstitutor) =>
val typez = subst(typed.`type`().getOrNothing) //todo: implicit conversions
collectForType(typez)
case _ =>
}
}
}
case None =>
call match {
case call: ScMethodCall =>
for (typez <- call.getEffectiveInvokedExpr.`type`()) { //todo: implicit conversions
collectForType(typez)
}
}
}
}

collectResult()
resultBuilder.result()
}

def elementsForConstructorInvocationParameterInfo(clazz: PsiClass,
subst: ScSubstitutor,
argumentLists: Seq[ScalaPsiElement],
Expand Down Expand Up @@ -615,104 +724,10 @@ class ScalaFunctionParameterInfoHandler extends ScalaParameterInfoHandler[PsiEle
case _ => Seq.empty
}
case call@(_: MethodInvocation | _: ScReferenceExpression) =>
val resultBuilder = ArraySeq.newBuilder[Object]
def collectResult(): Unit = {
val canBeUpdate = call.getParent match {
case assignStmt: ScAssignment if call == assignStmt.leftExpression => true
case notExpr if !notExpr.is[ScExpression] || notExpr.is[ScBlockExpr] => true
case _ => false
}
val count = args.invocationCount
val gen = args.callGeneric.getOrElse(null: ScGenericCall)
def collectSubstitutor(element: PsiElement): ScSubstitutor = {
if (gen == null) return ScSubstitutor.empty
val typeParams = element match {
case tpo: ScTypeParametersOwner => tpo.typeParameters.toArray
case ptpo: PsiTypeParameterListOwner => ptpo.getTypeParameters
case _ => return ScSubstitutor.empty
}
ScSubstitutor.bind(typeParams, gen.arguments)(_.calcType)
}
def collectForType(typez: ScType): Unit = {
def process(functionName: String): Unit = {
val i = if (functionName == "update") -1 else 0
val processor: CompletionProcessor = new CompletionProcessor(StdKinds.refExprQualRef, call, withImplicitConversions = true) {

override protected val forName: Option[String] = Some(functionName)
}
processor.processType(typez, call)
val variants: Array[ScalaResolveResult] = processor.candidates
for {
variant <- variants
if !variant.getElement.isInstanceOf[PsiMember] ||
ResolveUtils.isAccessible(variant.getElement.asInstanceOf[PsiMember], call)
} {
variant match {
case ScalaResolveResult(method: ScFunction, subst: ScSubstitutor) =>
val signature: PhysicalMethodSignature = new PhysicalMethodSignature(method, subst.followed(collectSubstitutor(method)))
resultBuilder += ((signature, i))
resultBuilder ++= ScalaParameterInfoEnhancer.enhance(signature, args.arguments).map((_, i))
case _ =>
}
}
}

process("apply")
if (canBeUpdate) process("update")
}
args.callReference match {
case Some(ref: ScReferenceExpression) =>
if (count > 1) {
//todo: missed case with last implicit call
ref.bind() match {
case Some(ScalaResolveResult(function: ScFunction, subst: ScSubstitutor)) if function.
effectiveParameterClauses.length >= count =>
resultBuilder += ((new PhysicalMethodSignature(function, subst.followed(collectSubstitutor(function))), count - 1))
case Some(ScalaResolveResult(function: ScFunction, _)) if function.effectiveParameterClauses.isEmpty =>
function.`type`().foreach(collectForType)
case _ =>
call match {
case invocation: MethodInvocation =>
for (typez <- invocation.getEffectiveInvokedExpr.`type`()) //todo: implicit conversions
{collectForType(typez)}
case _ =>
}
}
} else {
val variants = {
val sameName = ref.getSameNameVariants
if (sameName.isEmpty) ref.multiResolveScala(false)
else sameName
}
for {
variant <- variants
if !variant.getElement.isInstanceOf[PsiMember] ||
ResolveUtils.isAccessible(variant.getElement.asInstanceOf[PsiMember], ref)
} {
variant match {
//todo: Synthetic function
case ScalaResolveResult(method: PsiMethod, subst: ScSubstitutor) =>
val signature: PhysicalMethodSignature = new PhysicalMethodSignature(method, subst.followed(collectSubstitutor(method)))
resultBuilder += ((signature, 0))
resultBuilder ++= ScalaParameterInfoEnhancer.enhance(signature, args.arguments).map((_, 0))
case ScalaResolveResult(typed: ScTypedDefinition, subst: ScSubstitutor) =>
val typez = subst(typed.`type`().getOrNothing) //todo: implicit conversions
collectForType(typez)
case _ =>
}
}
}
case None =>
call match {
case call: ScMethodCall =>
for (typez <- call.getEffectiveInvokedExpr.`type`()) { //todo: implicit conversions
collectForType(typez)
}
}
}
}
collectResult()
resultBuilder.result()
methodInvocationParameterInfo(call)
case isl: ScInterpolatedStringLiteral if isl.desugaredExpression.nonEmpty =>
val (_, call) = isl.desugaredExpression.get
methodInvocationParameterInfo(call)
case self: ScSelfInvocation =>
val resultBuilder = ArraySeq.newBuilder[Object]
val i = self.arguments.indexOf(args.element)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import org.jetbrains.plugins.scala.ScalaBundle
import org.jetbrains.plugins.scala.lang.lexer.ScalaTokenTypes
import org.jetbrains.plugins.scala.lang.parser.ScalaElementType
import org.jetbrains.plugins.scala.lang.parser.parsing.builder.ScalaPsiBuilder
import org.jetbrains.plugins.scala.lang.parser.parsing.expressions.BlockExpr
import org.jetbrains.plugins.scala.lang.parser.parsing.expressions.{ArgumentExprs, BlockExpr}
import org.jetbrains.plugins.scala.lang.parser.parsing.patterns.Pattern
import org.jetbrains.plugins.scala.lang.parser.util.ParserUtils

Expand Down Expand Up @@ -86,6 +86,8 @@ object CommonUtils {

if (!builder.eof())
builder.advanceLexer()

ArgumentExprs()
}

/** see comments to [[ScalaTokenTypes.tINTERPOLATED_RAW_STRING]] and [[ScalaTokenTypes.tINTERPOLATED_MULTILINE_RAW_STRING]] */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ final class ScInterpolatedStringLiteralImpl(node: ASTNode,
override def referenceName: String = reference.refName

override def hasValidClosingQuotes: Boolean =
getNode.getLastChildNode.getElementType == tINTERPOLATED_STRING_END
getNode.findChildByType(tINTERPOLATED_STRING_END) != null

override def isMultiLineString: Boolean = hasValidClosingQuotes && {
val next = firstNode.getTreeNext
Expand Down Expand Up @@ -77,16 +77,23 @@ final class ScInterpolatedStringLiteralImpl(node: ASTNode,
}
val methodParameters = injectionsValues.commaSeparated(Model.Parentheses)

val closeQuotes = getNode.findChildByType(tINTERPOLATED_STRING_END)
val argumentExprs = if(closeQuotes.getTreeNext != null) {
closeQuotes.getTreeNext.getText
} else {
""
}

val expression =
try {
// FIXME: fails on s"aaa /* ${s"ccc s${s"/*"} ddd"} bbb" (SCL-17625, SCL-18706)
val text = s"$StringContextCanonical$constructorParameters.$methodName$methodParameters"
val text = s"$StringContextCanonical$constructorParameters.$methodName$methodParameters$argumentExprs"
ScalaPsiElementFactory.createExpressionWithContextFromText(text, context, this).asInstanceOf[ScMethodCall]
} catch {
case e: IncorrectOperationException =>
throw new IncorrectOperationException(s"Couldn't desugar interpolated string ${this.getText}", e: Throwable)
}
Some(expression.getInvokedExpr.asInstanceOf[ScReferenceExpression], expression)
Some(expression.deepestInvokedExpr.asInstanceOf[ScReferenceExpression], expression)
case _ => None
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ class InterpolatedStringsAnnotatorTest extends base.ScalaLightCodeInsightFixture
emptyMessages("c\"blah blah $s1 ${s2}\"")
}

def testCorrectArgumentInt(): Unit = {
emptyMessages("e\"\"(i1)")
}

def testMultiResolve(): Unit = {
messageExists("_\"blah $s1 blah $s2 blah\"", "Error(_,Value '_' is not a member of StringContext)")
}
Expand All @@ -56,6 +60,14 @@ class InterpolatedStringsAnnotatorTest extends base.ScalaLightCodeInsightFixture
def testMultipleResolve2(): Unit = {
messageExists("c\"blah $i1 blah $s1 $i2\"", "Error(i2,Too many arguments for method c(String, String))")
}

def testInvalidArgumentString(): Unit = {
messageExists("e\"\"(s1)", "Error(e,Type mismatch, expected: Int, actual: String)")
}

def testTooManyArguments(): Unit = {
messageExists("e\"\"(i1, i2)", "Error(e,Too many arguments)")
}
}

object InterpolatedStringsAnnotatorTest {
Expand All @@ -68,6 +80,7 @@ object InterpolatedStringsAnnotatorTest {
| def c(s1: String, s2: String) = s1 + s2
| def d(i1: Int, s1: String) = i1 + s1.length
| def d(i1: Int, i2: Int) = i1 + i2
| def e(args: Any*)(i: Int) = i
|}
|
|implicit def extendStrContext(ctx: StringContext) = new ExtendedContext
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ class FunctionParameterInfoFunctionTypeTest extends FunctionParameterInfoTestBas
override def getTestDataPath: String =
s"${super.getTestDataPath}functionType/"

def testCustomInterpolatorFunctionType(): Unit = doTest()

def testFunctionType(): Unit = doTest()

def testFunctionTypeTwo(): Unit = doTest()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
implicit class CustomInterpolator(private val stringContext: StringContext) {
def test(values: Any*)(i: Int): Unit = {}
}

test""(<caret>)
//TEXT: v1: Int, STRIKEOUT: false
Loading