Skip to content

Commit

Permalink
Sort with wildcard and group first, close #47 (#49)
Browse files Browse the repository at this point in the history
* Refractoring Sortimports

- ImportGroup abstract List[Import] and operations on it
- Use List instead of ListBuffer when possible
- Rename variables for better comprehension

* Allow sorting with priority on wildcard and groups

- Add asciiSort boolean  on configuration (default true)
- Add WildcardAndGroupFirstSort SortWith implementation
  • Loading branch information
tpetillot authored Apr 23, 2020
1 parent 6c92d10 commit be3b4ba
Show file tree
Hide file tree
Showing 9 changed files with 231 additions and 71 deletions.
19 changes: 19 additions & 0 deletions input/src/main/scala/fix/asciisort.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
/*
rule = SortImports
SortImports.blocks = [
"java",
"scala",
"*"
]
*/
package fix

import scala._
import scala.Console._

import java.util.{ Base64, HashMap }
import java.util.regex.Matcher

object AsciiSort {
// Add code that needs fixing here.
}
6 changes: 3 additions & 3 deletions input/src/main/scala/fix/commented.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ rule = SortImports
"com.sun"
]
*/
import scala.util._ // foobar
import scala.util._
import scala.collection._
import java.util.Map
import java.util.Map // foo1
import com.oracle.net._
import com.sun.awt._
import com.sun.awt._ // foo2
import java.math.BigInteger

/**
Expand Down
20 changes: 20 additions & 0 deletions input/src/main/scala/fix/nonasciisort.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
/*
rule = SortImports
SortImports.blocks = [
"java",
"scala",
"*"
]
SortImports.asciiSort = false
*/
package fix

import scala.Console._
import scala._

import java.util.regex.Matcher
import java.util.{ Base64, HashMap }

object NonAsciiSort {
// Add code that needs fixing here.
}
11 changes: 11 additions & 0 deletions output/src/main/scala/fix/asciisort.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package fix

import java.util.regex.Matcher
import java.util.{ Base64, HashMap }

import scala.Console._
import scala._

object AsciiSort {
// Add code that needs fixing here.
}
6 changes: 3 additions & 3 deletions output/src/main/scala/fix/commented.scala
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import java.math.BigInteger
import java.util.Map
import java.util.Map // foo1

import scala.collection._
import scala.util._ // foobar
import scala.util._

import com.oracle.net._

import com.sun.awt._
import com.sun.awt._ // foo2

/**
* Bla
Expand Down
11 changes: 11 additions & 0 deletions output/src/main/scala/fix/nonasciisort.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package fix

import java.util.{ Base64, HashMap }
import java.util.regex.Matcher

import scala._
import scala.Console._

object NonAsciiSort {
// Add code that needs fixing here.
}
57 changes: 57 additions & 0 deletions rules/src/main/scala/fix/ImportGroup.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package fix

import scala.collection.mutable.ListBuffer
import scala.meta.contrib.AssociatedComments
import scala.meta.inputs.Position
import scala.meta.tokens.Token
import scala.meta.{ Import, Traverser, Tree }

object ImportGroupTraverser {
def retrieveImportGroups(tree: Tree): List[ImportGroup] = {
val importGroupsBuffer = ListBuffer[ListBuffer[Import]](ListBuffer.empty)
val importTraverser = new ImportGroupTraverser(importGroupsBuffer)
importTraverser(tree)
importGroupsBuffer.map(importGroupBuffer => ImportGroup(importGroupBuffer.toList)).toList
}
}

private class ImportGroupTraverser(listBuffer: ListBuffer[ListBuffer[Import]]) extends Traverser {
override def apply(tree: Tree): Unit = tree match {
case x: Import => listBuffer.last.append(x)
case node =>
listBuffer.append(ListBuffer.empty)
super.apply(node)
}
}

object ImportGroup {

val empty: ImportGroup = ImportGroup(Nil)
}

case class ImportGroup(value: List[Import]) extends Traversable[Import] {

def sortWith(ordering: Ordering[Import]): ImportGroup = ImportGroup(value.sortWith(ordering.lt))

def groupByBlock(blocks: List[String], defaultBlock: String): Map[String, ImportGroup] =
value.groupBy { imp =>
blocks
.find(block => imp.children.head.syntax.startsWith(block))
.getOrElse(defaultBlock)
}.mapValues(ImportGroup(_))

def containPosition(pos: Position): Boolean =
pos.start > value.head.pos.start && pos.end < value.last.pos.end

def trailingComment(comments: AssociatedComments): Map[Import, Token.Comment] =
value
.map(currentImport => currentImport -> comments.trailing(currentImport).headOption)
.collect {
case (imp, comment) if comment.nonEmpty => (imp, comment.get)
}
.toMap

override def nonEmpty: Boolean = value.nonEmpty

override def foreach[U](f: Import => U): Unit = value.foreach(f)
}
37 changes: 37 additions & 0 deletions rules/src/main/scala/fix/ImportOrdering.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package fix

import scala.meta.Import
import WildcardAndGroupFirstSort._

sealed trait ImportOrdering extends Ordering[Import] {

protected def strFirstImport(imp: Import): String =
imp.children.head.syntax
}

class DefaultSort extends ImportOrdering {

override def compare(x: Import, y: Import): Int =
strFirstImport(x).compareTo(strFirstImport(y))
}

object WildcardAndGroupFirstSort {

private val wildcardRegex = "_".r
private val groupRegex = "\\{.+\\}".r
}

class WildcardAndGroupFirstSort extends ImportOrdering {

private def transformForSorting(imp: Import): (String, String) = {
val strImp = strFirstImport(imp)
(strImp, groupRegex.replaceAllIn(wildcardRegex.replaceAllIn(strImp, "\0"), "\1"))
}

override def compare(x: Import, y: Import): Int =
(transformForSorting(x), transformForSorting(y)) match {
case ((strImp1, tranformedStrImp1), (strImp2, tranformedStrImp2)) =>
val transformComparison = tranformedStrImp1.compareTo(tranformedStrImp2)
if (transformComparison != 0) transformComparison else strImp1.compareTo(strImp2)
}
}
135 changes: 70 additions & 65 deletions rules/src/main/scala/fix/Sortimports.scala
Original file line number Diff line number Diff line change
@@ -1,21 +1,33 @@
package fix

import scala.collection.mutable.ListBuffer
import scala.meta._
import scala.meta.tokens.Token.Comment

import metaconfig.Configured
import metaconfig.generic
import scala.meta.tokens.Token.{ Comment, LF }

import fix.SortImportsConfig.Swap
import metaconfig.generic.Surface
import metaconfig.{ generic, ConfDecoder, ConfEncoder, Configured }
import scalafix.v1._

final case class SortImportsConfig(blocks: List[String] = List("*"))
final case class SortImportsConfig(
blocks: List[String] = List(SortImportsConfig.Blocks.Asterisk),
asciiSort: Boolean = true
)

object SortImportsConfig {
val default = SortImportsConfig()
implicit val surface = generic.deriveSurface[SortImportsConfig]
implicit val decoder = generic.deriveDecoder[SortImportsConfig](default)
implicit val encoder = generic.deriveEncoder[SortImportsConfig]

object Blocks {
val Asterisk: String = "*"
}

val default: SortImportsConfig = SortImportsConfig()
implicit val surface: Surface[SortImportsConfig] = generic.deriveSurface[SortImportsConfig]
implicit val decoder: ConfDecoder[SortImportsConfig] = generic.deriveDecoder[SortImportsConfig](default)
implicit val encoder: ConfEncoder[SortImportsConfig] = generic.deriveEncoder[SortImportsConfig]

final class Swap(val value: (Import, String)) extends AnyVal {
def from: Import = value._1
def to: String = value._2
}
}

class SortImports(config: SortImportsConfig) extends SyntacticRule("SortImports") {
Expand All @@ -30,95 +42,88 @@ class SortImports(config: SortImportsConfig) extends SyntacticRule("SortImports"
.getOrElse("SortImports")(this.config)
.map(new SortImports(_))

private val importOrdering: ImportOrdering =
if (config.asciiSort) new DefaultSort else new WildcardAndGroupFirstSort

override def fix(implicit doc: SyntacticDocument): Patch = {

// Traverse full code tree. Stop when import branches are found and add them to last list in buf
// If an empty line is found add an empty list to buf
val buf: ListBuffer[ListBuffer[Import]] = ListBuffer(ListBuffer.empty)
val traverser: Traverser = new Traverser {
override def apply(tree: Tree): Unit = tree match {
case x: Import =>
buf.last.append(x)
case node =>
buf.append(ListBuffer.empty)
super.apply(node)
}
}

traverser(doc.tree)
val importGroupsWithEmptyLines: List[ImportGroup] = ImportGroupTraverser.retrieveImportGroups(doc.tree)

// Contains groups of imports
val unsorted: ListBuffer[ListBuffer[Import]] = buf
.filter(_.length > 0)
val importGroups: List[ImportGroup] = importGroupsWithEmptyLines.filter(_.nonEmpty)

// Trailing comments
val comments: Map[Import, Option[Comment]] =
unsorted.flatten.map(x => (x -> doc.comments.trailing(x).headOption)).filterNot(_._2.isEmpty).toMap
val comments: Map[Import, Comment] = ImportGroup(importGroups.flatten).trailingComment(doc.comments)

// Remove all newlines within import groups
val removeLinesPatch: ListBuffer[Patch] = unsorted.map { i =>
val removeLinesPatch: List[Patch] = importGroups.flatMap { importGroup =>
doc.tokens.collect {
case e
if e.productPrefix == "LF"
&& e.pos.start > i.head.pos.start
&& e.pos.end < i.last.pos.end =>
e
case token: LF if importGroup.containPosition(token.pos) => token
}
}.flatten
.map(Patch.removeToken(_))
}.map(Patch.removeToken)

// Remove comments and whitespace between imports and comments
val removeCommentsPatch: Iterable[Patch] = comments.values.flatten.map(Patch.removeToken _)
val removeCommentsPatch: Iterable[Patch] = comments.values.map(Patch.removeToken)
val removeCommentSpacesPatch: Iterable[Patch] = comments.flatMap {
case (k, v) =>
v.map { v =>
val num = v.pos.start - k.pos.end
((0 to num).map { diff => new Token.Space(Input.None, v.dialect, k.pos.end + diff) }).toList
}.getOrElse(List.empty)
}.map(Patch.removeToken _)
case (imp, comment) =>
(0 to comment.pos.start - imp.pos.end).map { diff =>
new Token.Space(Input.None, comment.dialect, imp.pos.end + diff)
}
}.map(Patch.removeToken)

// Sort each group of imports
val sorted: ListBuffer[ListBuffer[String]] = unsorted.map { importLines =>
val sorted: Seq[Seq[String]] = importGroups.map { importGroup =>
val configBlocksByLengthDesc = config.blocks.sortBy(-_.length)

// Sort all imports then group based on SortImports rule
// In case of import list, the first element in the list is significant
val importsGrouped = importLines.sortWith { (line1, line2) =>
line1.children.head.toString.compareTo(line2.children.head.toString) < 0
}.groupBy(line => configBlocksByLengthDesc.find(block => line.children.head.toString.startsWith(block)))
val importsGrouped: Map[String, ImportGroup] =
importGroup
.sortWith(importOrdering)
.groupByBlock(configBlocksByLengthDesc, SortImportsConfig.Blocks.Asterisk)

// If a start is not found in the SortImports rule, add it to the end
val fixedList: List[String] = config.blocks
.find(_ == "*")
.fold(config.blocks :+ "*")(_ => config.blocks)
val configBlocks: List[String] =
config.blocks
.find(_ == SortImportsConfig.Blocks.Asterisk)
.fold(config.blocks :+ SortImportsConfig.Blocks.Asterisk)(_ => config.blocks)

// Sort grouped imports and convert to strings
val importsSorted = fixedList
.foldLeft(ListBuffer[ListBuffer[String]]()) { (acc, i) =>
importsGrouped
.find(_._1.getOrElse("*") == i) // If key is None, make key *
.fold(acc) { found =>
val commentOrNot = comments.get(found._2.last).map(" " + _.mkString)
acc += (found._2.map(_.toString).init += (found._2.last.toString + commentOrNot.getOrElse("") + "\n"))
}
val strImportsSorted = configBlocks
.foldLeft(Seq[Seq[String]]()) { (acc, configBlock) =>
importsGrouped.find {
case (block, _) => block == configBlock
}.fold(acc) {
case (_, imports) =>
val strImports = imports.map { imp =>
comments.get(imp).fold(imp.syntax)(comment => s"${imp.syntax} ${comment.syntax}")
}.toSeq

acc :+ (strImports.init :+ (strImports.last + '\n'))
}
}
.flatten

// Remove extra newline on end of imports
importsSorted.init :+ importsSorted.last.dropRight(1)
strImportsSorted.init :+ strImportsSorted.last.dropRight(1)
}

val combined: ListBuffer[ListBuffer[(Import, String)]] = unsorted
.zip(sorted)
.map(i => i._1.zip(i._2))
val combined: List[List[Swap]] =
importGroups
.zip(sorted)
.map {
case (importGroup, strImportGroupSorted) => importGroup.value.zip(strImportGroupSorted).map(new Swap(_))
}

// Create patches using sorted - unsorted pairs
// Essentially imports are playing musical chairs
val patches: ListBuffer[Patch] = combined.map { el =>
el.init.map { i =>
Patch.replaceTree(i._1, i._2 + "\n")
} :+ Patch.replaceTree(el.last._1, el.last._2)
}.flatten
val patches: List[Patch] =
combined.flatMap(importSwaps =>
importSwaps.init.map(trade => Patch.replaceTree(trade.from, s"${trade.to}\n")) :+ Patch
.replaceTree(importSwaps.last.from, importSwaps.last.to)
)

List(patches, removeLinesPatch, removeCommentsPatch, removeCommentSpacesPatch).flatten.asPatch
}
Expand Down

0 comments on commit be3b4ba

Please sign in to comment.