Skip to content

Commit e65ea35

Browse files
committed
SCL-6315: indirect case-to-case inheritance is prohibited
1 parent 3fd3e68 commit e65ea35

3 files changed

Lines changed: 77 additions & 12 deletions

File tree

scala/scala-impl/src/org/jetbrains/plugins/scala/annotator/element/ScTemplateDefinitionAnnotator.scala

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import org.jetbrains.plugins.scala.lang.resolve.ScalaResolveResult
2323
import org.jetbrains.plugins.scala.overrideImplement.{ScMethodMember, ScalaOIUtil, ScalaTypedMember}
2424
import org.jetbrains.plugins.scala.{NlsString, ScalaBundle, overrideImplement}
2525

26+
import scala.annotation.tailrec
2627
import scala.util.chaining._
2728

2829
object ScTemplateDefinitionAnnotator extends ElementAnnotator[ScTemplateDefinition] {
@@ -37,6 +38,7 @@ object ScTemplateDefinitionAnnotator extends ElementAnnotator[ScTemplateDefiniti
3738
annotateEnumClassInheritance(element)
3839
annotateTraitPassingConstructorParameters(element)
3940
annotateParentTraitConstructorParameters(element)
41+
annotateCaseToCaseInheritance(element)
4042

4143
if (typeAware) {
4244
annotateNeedsToBeMixin(element)
@@ -235,8 +237,6 @@ object ScTemplateDefinitionAnnotator extends ElementAnnotator[ScTemplateDefiniti
235237
superRefs(element).collect {
236238
case (range, clazz) if clazz.hasFinalModifier =>
237239
(range, NlsString(ScalaBundle.message("illegal.inheritance.from.final.kind", kindOf(clazz, toLowerCase = true), clazz.name)))
238-
case (range, clazz: ScClass) if clazz.isCase && element.asOptionOf[ScClass].exists(_.isCase) =>
239-
(range, NlsString(ScalaBundle.message("illegal.inheritance.from.case.class", element.name, clazz.name)))
240240
case (range, clazz) if ValueClassType.extendsAnyVal(clazz) =>
241241
(range, NlsString(ScalaBundle.message("illegal.inheritance.from.value.class", clazz.name)))
242242
}.foreach {
@@ -245,6 +245,41 @@ object ScTemplateDefinitionAnnotator extends ElementAnnotator[ScTemplateDefiniti
245245
}
246246
}
247247

248+
def annotateCaseToCaseInheritance(element: ScTemplateDefinition)
249+
(implicit holder: ScalaAnnotationHolder): Unit = {
250+
@tailrec
251+
def findCaseAncestor(queue: List[ScClass]): Option[ScClass] = queue match {
252+
case Nil => None
253+
case head :: tail =>
254+
if (head.isCase) Some(head)
255+
else {
256+
val nextLevel = superRefs(head)
257+
.map(_._2)
258+
.collect { case c: ScClass => c }
259+
findCaseAncestor(tail ++ nextLevel)
260+
}
261+
}
262+
263+
element.asOptionOf[ScClass]
264+
.filter(_.isCase)
265+
.flatMap { clazz =>
266+
for {
267+
(range, firstAncestor) <- superRefs(clazz).collectFirst {
268+
case (rng, sc: ScClass) => (rng, sc)
269+
}
270+
ancestorCaseClass <- findCaseAncestor(List(firstAncestor))
271+
} yield (range, ancestorCaseClass)
272+
}
273+
.foreach { case (range, ancestor) =>
274+
val msg = ScalaBundle.message(
275+
"illegal.inheritance.from.case.class",
276+
element.name,
277+
ancestor.name
278+
)
279+
holder.createErrorAnnotation(range, msg)
280+
}
281+
}
282+
248283
def annotateIllegalInheritance(element: ScTemplateDefinition)
249284
(implicit holder: ScalaAnnotationHolder): Unit = {
250285
implicit val tpc: TypePresentationContext = TypePresentationContext(element)
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
package org.jetbrains.plugins.scala.annotator.template
2+
3+
import org.jetbrains.plugins.scala.ScalaBundle
4+
import org.jetbrains.plugins.scala.annotator.{AnnotatorTestBase, Message, ScalaAnnotationHolder}
5+
import org.jetbrains.plugins.scala.annotator.element.ScTemplateDefinitionAnnotator
6+
import org.jetbrains.plugins.scala.lang.psi.api.toplevel.typedef.ScTemplateDefinition
7+
8+
class CaseToCaseInheritanceTest extends AnnotatorTestBase[ScTemplateDefinition] {
9+
import Message._
10+
11+
def testCaseToCase(): Unit = {
12+
val message = ScalaBundle.message("illegal.inheritance.from.case.class", "B", "A")
13+
14+
val expectation: PartialFunction[List[Message], Unit] = {
15+
case Error("A", `message`) :: Nil =>
16+
}
17+
18+
assertMatches(messages("case class A(); case class B() extends A(); B()"))(expectation)
19+
}
20+
21+
def testIndirectCaseToCase(): Unit = {
22+
val message = ScalaBundle.message("illegal.inheritance.from.case.class", "C", "A")
23+
24+
val expectation: PartialFunction[List[Message], Unit] = {
25+
case Error("B", `message`) :: Nil =>
26+
}
27+
28+
assertMatches(messages(
29+
"""
30+
|case class A(a : Int)
31+
|class B extends A(2)
32+
|case class C(z: Int) extends B
33+
""".stripMargin
34+
))(expectation)
35+
}
36+
37+
override protected def annotate(element: ScTemplateDefinition)
38+
(implicit holder: ScalaAnnotationHolder): Unit =
39+
ScTemplateDefinitionAnnotator.annotateCaseToCaseInheritance(element)
40+
}

scala/scala-impl/test/org/jetbrains/plugins/scala/annotator/template/FinalClassInheritanceTest.scala

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -55,16 +55,6 @@ class FinalClassInheritanceTest extends AnnotatorTestBase[ScTemplateDefinition]
5555
assertMatches(messages("class C(val x: Int) extends AnyVal; class X extends C(2)"))(expectation)
5656
}
5757

58-
def testCaseToCase(): Unit = {
59-
val message = ScalaBundle.message("illegal.inheritance.from.case.class", "B", "A")
60-
61-
val expectation: PartialFunction[List[Message], Unit] = {
62-
case Error("A", `message`) :: Nil =>
63-
}
64-
65-
assertMatches(messages("case class A(); case class B() extends A(); B()"))(expectation)
66-
}
67-
6858
override protected def annotate(element: ScTemplateDefinition)
6959
(implicit holder: ScalaAnnotationHolder): Unit =
7060
ScTemplateDefinitionAnnotator.annotateFinalClassInheritance(element)

0 commit comments

Comments
 (0)