Skip to content

Commit 1df148d

Browse files
committed
Rewrite imports collection, avoid deprecated code
1 parent 536101a commit 1df148d

3 files changed

Lines changed: 41 additions & 42 deletions

File tree

scalafix-core/src/main/scala/scalafix/internal/patch/ImportPatchOps.scala

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,11 @@ object ImportPatchOps {
4040
}
4141
}
4242
private def fallbackToken(ctx: RuleCtx): Token = {
43+
@tailrec
4344
def loop(tree: Tree): Token = tree match {
4445
case Source((stat: Pkg) :: _) => loop(stat)
45-
case Source(_) => tree.tokens.head
46-
case Pkg(_, stat :: _) => loop(stat)
46+
case _: Source => tree.tokens.head
47+
case Pkg.Initial(_, stat :: _) => loop(stat)
4748
case els =>
4849
ctx.tokenList.prev(ctx.tokenList.prev(els.tokens.head)) match {
4950
case comment @ Token.Comment(_) =>
@@ -53,19 +54,12 @@ object ImportPatchOps {
5354
}
5455
loop(ctx.tree)
5556
}
56-
private def extractImports(stats: Seq[Stat]): Seq[Import] = {
57-
stats
58-
.takeWhile(_.is[Import])
59-
.collect { case i: Import => i }
60-
}
6157

62-
@tailrec private final def getGlobalImports(ast: Tree): Seq[Import] =
63-
ast match {
64-
case Pkg(_, Seq(pkg: Pkg)) => getGlobalImports(pkg)
65-
case Source(Seq(pkg: Pkg)) => getGlobalImports(pkg)
66-
case Pkg(_, stats) => extractImports(stats)
67-
case Source(stats) => extractImports(stats)
68-
case _ => Nil
58+
@tailrec
59+
private def extractImports(stats: List[Stat]): Seq[Import] =
60+
stats match {
61+
case (pkg: Pkg) :: Nil => extractImports(pkg.body.stats)
62+
case _ => stats.takeWhile(_.is[Import]).collect { case i: Import => i }
6963
}
7064

7165
// NOTE(olafur): This method is the simplest/dummest thing I can think of
@@ -88,7 +82,11 @@ object ImportPatchOps {
8882
lazy val allImporteeSymbols = allImportees.flatMap(importee =>
8983
importee.symbol.map(_.normalized -> importee)
9084
)
91-
val globalImports = getGlobalImports(ctx.tree)
85+
val globalImports = ctx.tree match {
86+
case t: Source => extractImports(t.stats)
87+
case t: Pkg => extractImports(t.body.stats)
88+
case _ => Nil
89+
}
9290
val editToken: Token = {
9391
if (globalImports.isEmpty) fallbackToken(ctx)
9492
else globalImports.last.tokens.last

scalafix-core/src/main/scala/scalafix/internal/patch/ReplaceSymbolOps.scala

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,25 +22,22 @@ object ReplaceSymbolOps {
2222
}
2323
}
2424

25-
private def extractImports(stats: Seq[Stat]): Seq[Import] = {
26-
stats.collect { case i: Import => i }
25+
@tailrec
26+
private def extractImports(stats: List[Stat]): Seq[Import] = stats match {
27+
case (p: Pkg) :: Nil => extractImports(p.body.stats)
28+
case _ => stats.collect { case i: Import => i }
2729
}
2830

2931
private def getNamesOfExplicitlyImportedSymbols(
3032
tree: Tree,
3133
isMoved: Name => Boolean
3234
): Set[String] = {
33-
@tailrec
34-
def getGlobalImports(ast: Tree): Seq[Import] = ast match {
35-
case Pkg(_, Seq(pkg: Pkg)) => getGlobalImports(pkg)
36-
case Source(Seq(pkg: Pkg)) => getGlobalImports(pkg)
37-
case Pkg(_, stats) => extractImports(stats)
38-
case Source(stats) => extractImports(stats)
35+
val globalImports = tree match {
36+
case t: Pkg => extractImports(t.body.stats)
37+
case t: Source => extractImports(t.stats)
3938
case _ => Nil
4039
}
4140

42-
val globalImports = getGlobalImports(tree)
43-
4441
// pre-compute global imported symbols for O(1) collision detection
4542
// since ctx.addGlobalImport adds imports at global scope
4643
// exclude names whose symbols are being moved, as those imports

scalafix-rules/src/main/scala/scalafix/internal/rule/OrganizeImports.scala

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,11 @@ class OrganizeImports(
7676

7777
val unusedImporteePositions = new UnusedImporteePositions
7878

79-
val (globalImports, localImports) = collectImports(doc.tree)
79+
val (globalImports, localImports) = doc.tree match {
80+
case t: Source => collectImports(t.stats)
81+
case t: Pkg => collectImports(t.body.stats)
82+
case _ => (Nil, Nil)
83+
}
8084

8185
val globalImportsPatch =
8286
if (globalImports.isEmpty) Patch.empty
@@ -878,22 +882,22 @@ object OrganizeImports {
878882
}
879883

880884
@tailrec private def collectImports(
881-
tree: Tree
882-
): (Seq[Import], Seq[Import]) = {
883-
def extractImports(stats: Seq[Stat]): (Seq[Import], Seq[Import]) = {
884-
val (importStats, otherStats) = stats.span(_.is[Import])
885-
val globalImports = importStats.map { case i: Import => i }
886-
val localImports = otherStats.flatMap(_.collect { case i: Import => i })
887-
(globalImports, localImports)
888-
}
889-
890-
tree match {
891-
case Source(Seq(p: Pkg)) => collectImports(p)
892-
case Pkg(_, Seq(p: Pkg)) => collectImports(p)
893-
case Source(stats) => extractImports(stats)
894-
case Pkg(_, stats) => extractImports(stats)
895-
case _ => (Nil, Nil)
896-
}
885+
stats: List[Stat]
886+
): (Seq[Import], Seq[Import]) = stats match {
887+
case (p: Pkg) :: Nil => collectImports(p.body.stats)
888+
case _ =>
889+
val globalImports = Seq.newBuilder[Import]
890+
val localImports = Seq.newBuilder[Import]
891+
def collectLocalImports(tree: Tree): Unit =
892+
tree.traverse { case i: Import => localImports += i }
893+
val statsiter = stats.iterator
894+
while (statsiter.hasNext) statsiter.next() match {
895+
case i: Import => globalImports += i
896+
case i =>
897+
collectLocalImports(i)
898+
while (statsiter.hasNext) collectLocalImports(statsiter.next())
899+
}
900+
(globalImports.result(), localImports.result())
897901
}
898902

899903
@tailrec private def topQualifierOf(term: Term): Term.Name =

0 commit comments

Comments
 (0)