@@ -23,6 +23,7 @@ import org.jetbrains.plugins.scala.lang.resolve.ScalaResolveResult
2323import org .jetbrains .plugins .scala .overrideImplement .{ScMethodMember , ScalaOIUtil , ScalaTypedMember }
2424import org .jetbrains .plugins .scala .{NlsString , ScalaBundle , overrideImplement }
2525
26+ import scala .annotation .tailrec
2627import scala .util .chaining ._
2728
2829object ScTemplateDefinitionAnnotator extends ElementAnnotator [ScTemplateDefinition ] {
@@ -37,6 +38,7 @@ object ScTemplateDefinitionAnnotator extends ElementAnnotator[ScTemplateDefiniti
3738 annotateEnumClassInheritance(element)
3839 annotateTraitPassingConstructorParameters(element)
3940 annotateParentTraitConstructorParameters(element)
41+ annotateCaseToCaseInheritance(element)
4042
4143 if (typeAware) {
4244 annotateNeedsToBeMixin(element)
@@ -235,8 +237,6 @@ object ScTemplateDefinitionAnnotator extends ElementAnnotator[ScTemplateDefiniti
235237 superRefs(element).collect {
236238 case (range, clazz) if clazz.hasFinalModifier =>
237239 (range, NlsString (ScalaBundle .message(" illegal.inheritance.from.final.kind" , kindOf(clazz, toLowerCase = true ), clazz.name)))
238- case (range, clazz : ScClass ) if clazz.isCase && element.asOptionOf[ScClass ].exists(_.isCase) =>
239- (range, NlsString (ScalaBundle .message(" illegal.inheritance.from.case.class" , element.name, clazz.name)))
240240 case (range, clazz) if ValueClassType .extendsAnyVal(clazz) =>
241241 (range, NlsString (ScalaBundle .message(" illegal.inheritance.from.value.class" , clazz.name)))
242242 }.foreach {
@@ -245,6 +245,41 @@ object ScTemplateDefinitionAnnotator extends ElementAnnotator[ScTemplateDefiniti
245245 }
246246 }
247247
248+ def annotateCaseToCaseInheritance (element : ScTemplateDefinition )
249+ (implicit holder : ScalaAnnotationHolder ): Unit = {
250+ @ tailrec
251+ def findCaseAncestor (queue : List [ScClass ]): Option [ScClass ] = queue match {
252+ case Nil => None
253+ case head :: tail =>
254+ if (head.isCase) Some (head)
255+ else {
256+ val nextLevel = superRefs(head)
257+ .map(_._2)
258+ .collect { case c : ScClass => c }
259+ findCaseAncestor(tail ++ nextLevel)
260+ }
261+ }
262+
263+ element.asOptionOf[ScClass ]
264+ .filter(_.isCase)
265+ .flatMap { clazz =>
266+ for {
267+ (range, firstAncestor) <- superRefs(clazz).collectFirst {
268+ case (rng, sc : ScClass ) => (rng, sc)
269+ }
270+ ancestorCaseClass <- findCaseAncestor(List (firstAncestor))
271+ } yield (range, ancestorCaseClass)
272+ }
273+ .foreach { case (range, ancestor) =>
274+ val msg = ScalaBundle .message(
275+ " illegal.inheritance.from.case.class" ,
276+ element.name,
277+ ancestor.name
278+ )
279+ holder.createErrorAnnotation(range, msg)
280+ }
281+ }
282+
248283 def annotateIllegalInheritance (element : ScTemplateDefinition )
249284 (implicit holder : ScalaAnnotationHolder ): Unit = {
250285 implicit val tpc : TypePresentationContext = TypePresentationContext (element)
0 commit comments