Skip to content

Scala 迭代器

迭代器(Iterator)是 Scala 中用于遍历集合元素的重要工具。它提供了一种惰性(lazy)的方式来处理数据,特别适合处理大型数据集或无限序列。

迭代器基础

创建迭代器

scala
object IteratorCreation {
  def main(args: Array[String]): Unit = {
    // 从集合创建迭代器
    val list = List(1, 2, 3, 4, 5)
    val iterator1 = list.iterator
    
    // 直接创建迭代器
    val iterator2 = Iterator(1, 2, 3, 4, 5)
    val iterator3 = Iterator.range(1, 6)
    val iterator4 = Iterator.fill(5)(0)
    val iterator5 = Iterator.tabulate(5)(i => i * i)
    
    println("Iterators created from different sources:")
    println(s"From list: ${iterator1.toList}")
    println(s"Direct creation: ${iterator2.toList}")
    println(s"Range: ${iterator3.toList}")
    println(s"Fill: ${iterator4.toList}")
    println(s"Tabulate: ${iterator5.toList}")
    
    // 空迭代器
    val emptyIterator = Iterator.empty[Int]
    println(s"Empty iterator: ${emptyIterator.toList}")
    
    // 单元素迭代器
    val singleIterator = Iterator.single(42)
    println(s"Single element: ${singleIterator.toList}")
    
    // 无限迭代器
    val infiniteIterator = Iterator.from(1)  // 从1开始的无限序列
    println(s"First 10 from infinite: ${infiniteIterator.take(10).toList}")
    
    // 重复元素的迭代器
    val repeatedIterator = Iterator.continually("hello")
    println(s"First 5 repeated: ${repeatedIterator.take(5).toList}")
  }
}

基本迭代器操作

scala
object BasicIteratorOperations {
  def main(args: Array[String]): Unit = {
    val numbers = Iterator(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
    
    // 检查迭代器状态
    println(s"Has next: ${numbers.hasNext}")
    
    // 获取下一个元素
    if (numbers.hasNext) {
      println(s"Next element: ${numbers.next()}")
    }
    
    // 注意:迭代器是一次性的
    println(s"Remaining elements: ${numbers.toList}")
    
    // 重新创建迭代器进行其他操作
    val numbers2 = Iterator(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
    
    // 使用 foreach 遍历
    print("Foreach: ")
    numbers2.foreach(x => print(s"$x "))
    println()
    
    // 使用 while 循环遍历
    val numbers3 = Iterator(1, 2, 3, 4, 5)
    print("While loop: ")
    while (numbers3.hasNext) {
      print(s"${numbers3.next()} ")
    }
    println()
    
    // 使用 for 循环遍历
    val numbers4 = Iterator(1, 2, 3, 4, 5)
    print("For loop: ")
    for (num <- numbers4) {
      print(s"$num ")
    }
    println()
  }
}

迭代器变换操作

映射和过滤

scala
object IteratorTransformations {
  def main(args: Array[String]): Unit = {
    // map - 变换每个元素
    val numbers = Iterator(1, 2, 3, 4, 5)
    val doubled = numbers.map(_ * 2)
    println(s"Doubled: ${doubled.toList}")
    
    // filter - 过滤元素
    val numbers2 = Iterator(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
    val evens = numbers2.filter(_ % 2 == 0)
    println(s"Even numbers: ${evens.toList}")
    
    // flatMap - 扁平化映射
    val words = Iterator("hello", "world")
    val chars = words.flatMap(_.iterator)
    println(s"All characters: ${chars.toList}")
    
    // collect - 部分函数映射
    val mixed = Iterator(1, "hello", 2, "world", 3)
    val numbersOnly = mixed.collect { case x: Int => x * 2 }
    println(s"Numbers only (doubled): ${numbersOnly.toList}")
    
    // take 和 drop
    val range = Iterator.range(1, 21)
    val first5 = range.take(5)
    println(s"First 5: ${first5.toList}")
    
    val range2 = Iterator.range(1, 21)
    val after5 = range2.drop(5).take(5)
    println(s"Elements 6-10: ${after5.toList}")
    
    // takeWhile 和 dropWhile
    val range3 = Iterator.range(1, 21)
    val lessThan8 = range3.takeWhile(_ < 8)
    println(s"Less than 8: ${lessThan8.toList}")
    
    val range4 = Iterator.range(1, 21)
    val afterLessThan8 = range4.dropWhile(_ < 8).take(5)
    println(s"After dropping < 8, take 5: ${afterLessThan8.toList}")
  }
}

聚合操作

scala
object IteratorAggregations {
  def main(args: Array[String]): Unit = {
    // reduce 操作
    val numbers = Iterator(1, 2, 3, 4, 5)
    val sum = numbers.reduce(_ + _)
    println(s"Sum: $sum")
    
    // fold 操作
    val numbers2 = Iterator(1, 2, 3, 4, 5)
    val product = numbers2.fold(1)(_ * _)
    println(s"Product: $product")
    
    // 查找操作
    val numbers3 = Iterator(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
    val firstEven = numbers3.find(_ % 2 == 0)
    println(s"First even: $firstEven")
    
    // exists 和 forall
    val numbers4 = Iterator(2, 4, 6, 8, 10)
    val hasEven = numbers4.exists(_ % 2 == 0)
    println(s"Has even numbers: $hasEven")
    
    val numbers5 = Iterator(2, 4, 6, 8, 10)
    val allEven = numbers5.forall(_ % 2 == 0)
    println(s"All even: $allEven")
    
    // count 操作
    val numbers6 = Iterator(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
    val evenCount = numbers6.count(_ % 2 == 0)
    println(s"Even count: $evenCount")
    
    // min 和 max
    val numbers7 = Iterator(5, 2, 8, 1, 9, 3)
    val min = numbers7.min
    println(s"Min: $min")
    
    val numbers8 = Iterator(5, 2, 8, 1, 9, 3)
    val max = numbers8.max
    println(s"Max: $max")
    
    // size 和 length
    val numbers9 = Iterator(1, 2, 3, 4, 5)
    val size = numbers9.size  // 注意:这会消耗迭代器
    println(s"Size: $size")
  }
}

迭代器组合

连接和分组

scala
object IteratorCombination {
  def main(args: Array[String]): Unit = {
    // 连接迭代器
    val iter1 = Iterator(1, 2, 3)
    val iter2 = Iterator(4, 5, 6)
    val concatenated = iter1 ++ iter2
    println(s"Concatenated: ${concatenated.toList}")
    
    // zip 操作
    val letters = Iterator('a', 'b', 'c', 'd')
    val numbers = Iterator(1, 2, 3, 4, 5)
    val zipped = letters.zip(numbers)
    println(s"Zipped: ${zipped.toList}")
    
    // zipWithIndex
    val words = Iterator("hello", "world", "scala")
    val indexed = words.zipWithIndex
    println(s"With index: ${indexed.toList}")
    
    // partition
    val numbers2 = Iterator(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
    val (evens, odds) = numbers2.partition(_ % 2 == 0)
    println(s"Evens: ${evens.toList}")
    println(s"Odds: ${odds.toList}")
    
    // grouped - 分组
    val numbers3 = Iterator(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
    val groups = numbers3.grouped(3)
    println("Grouped by 3:")
    groups.foreach(group => println(s"  ${group.toList}"))
    
    // sliding - 滑动窗口
    val numbers4 = Iterator(1, 2, 3, 4, 5, 6)
    val sliding = numbers4.sliding(3)
    println("Sliding window of 3:")
    sliding.foreach(window => println(s"  ${window.toList}"))
  }
}

惰性求值和性能

惰性求值示例

scala
object LazyEvaluation {
  def main(args: Array[String]): Unit = {
    println("Demonstrating lazy evaluation:")
    
    // 创建一个大的迭代器,但不立即计算
    val largeRange = Iterator.range(1, 1000000)
    
    // 链式操作 - 这些操作都是惰性的
    val processed = largeRange
      .filter { x =>
        println(s"Filtering $x")  // 这个打印语句帮助我们看到何时执行
        x % 1000 == 0
      }
      .map { x =>
        println(s"Mapping $x")
        x * 2
      }
      .take(5)  // 只取前5个
    
    println("Operations defined, but not executed yet")
    
    // 只有当我们实际需要结果时,计算才会执行
    println("Now executing:")
    val result = processed.toList
    println(s"Result: $result")
    
    // 对比:如果使用 List,所有操作都会立即执行
    println("\nCompare with List (eager evaluation):")
    val eagerResult = (1 until 1000000).toList
      .filter { x =>
        if (x <= 5000) println(s"Eagerly filtering $x")  // 限制输出
        x % 1000 == 0
      }
      .map { x =>
        if (x <= 5000) println(s"Eagerly mapping $x")
        x * 2
      }
      .take(5)
    
    println(s"Eager result: $eagerResult")
  }
}

性能比较

scala
object IteratorPerformance {
  def main(args: Array[String]): Unit = {
    val size = 1000000
    
    def timeOperation[T](name: String)(operation: => T): T = {
      val start = System.nanoTime()
      val result = operation
      val end = System.nanoTime()
      println(f"$name%25s: ${(end - start) / 1000000}%6d ms")
      result
    }
    
    println("Performance Comparison: Iterator vs List")
    println("=" * 50)
    
    // 创建性能
    val list = timeOperation("List creation") {
      (1 to size).toList
    }
    
    val iterator = timeOperation("Iterator creation") {
      Iterator.range(1, size + 1)
    }
    
    // 过滤和映射性能(惰性 vs 急切)
    timeOperation("List filter+map+take") {
      list.filter(_ % 2 == 0).map(_ * 2).take(100)
    }
    
    timeOperation("Iterator filter+map+take") {
      Iterator.range(1, size + 1).filter(_ % 2 == 0).map(_ * 2).take(100).toList
    }
    
    // 内存使用
    println("\nMemory Usage:")
    println("List: Stores all elements in memory")
    println("Iterator: Generates elements on demand")
    
    // 演示内存效率
    def processLargeDataset(): Unit = {
      // 使用迭代器处理大数据集,内存使用恒定
      val result = Iterator.range(1, 10000000)
        .filter(_ % 1000 == 0)
        .map(x => x * x)
        .take(100)
        .sum
      
      println(s"Processed large dataset result: $result")
    }
    
    timeOperation("Large dataset processing") {
      processLargeDataset()
    }
  }
}

自定义迭代器

创建自定义迭代器

scala
object CustomIterators {
  // 斐波那契数列迭代器
  class FibonacciIterator extends Iterator[Long] {
    private var current = 0L
    private var next = 1L
    
    def hasNext: Boolean = true  // 无限序列
    
    def next(): Long = {
      val result = current
      val temp = current + next
      current = next
      next = temp
      result
    }
  }
  
  // 素数迭代器
  class PrimeIterator extends Iterator[Int] {
    private var current = 2
    
    def hasNext: Boolean = true  // 无限序列
    
    def next(): Int = {
      while (!isPrime(current)) {
        current += 1
      }
      val result = current
      current += 1
      result
    }
    
    private def isPrime(n: Int): Boolean = {
      if (n < 2) false
      else if (n == 2) true
      else if (n % 2 == 0) false
      else {
        val sqrt = math.sqrt(n).toInt
        !(3 to sqrt by 2).exists(n % _ == 0)
      }
    }
  }
  
  // 使用伴生对象创建工厂方法
  object FibonacciIterator {
    def apply(): FibonacciIterator = new FibonacciIterator()
  }
  
  object PrimeIterator {
    def apply(): PrimeIterator = new PrimeIterator()
  }
  
  def main(args: Array[String]): Unit = {
    // 使用斐波那契迭代器
    val fibonacci = FibonacciIterator()
    println(s"First 15 Fibonacci numbers: ${fibonacci.take(15).toList}")
    
    // 使用素数迭代器
    val primes = PrimeIterator()
    println(s"First 20 prime numbers: ${primes.take(20).toList}")
    
    // 组合自定义迭代器
    val fibPrimes = FibonacciIterator()
      .take(100)
      .filter(fib => PrimeIterator().take(1000).contains(fib.toInt))
    
    println(s"Fibonacci numbers that are also prime: ${fibPrimes.toList}")
  }
}

迭代器工厂方法

scala
object IteratorFactories {
  // 创建几何级数迭代器
  def geometricSeries(start: Double, ratio: Double): Iterator[Double] = {
    Iterator.iterate(start)(_ * ratio)
  }
  
  // 创建随机数迭代器
  def randomNumbers(seed: Long = System.currentTimeMillis()): Iterator[Int] = {
    val random = new scala.util.Random(seed)
    Iterator.continually(random.nextInt(100))
  }
  
  // 创建文件行迭代器
  def fileLines(filename: String): Iterator[String] = {
    val source = scala.io.Source.fromFile(filename)
    source.getLines()
  }
  
  // 创建树遍历迭代器
  case class TreeNode[T](value: T, children: List[TreeNode[T]] = Nil)
  
  def depthFirstTraversal[T](root: TreeNode[T]): Iterator[T] = {
    def traverse(nodes: List[TreeNode[T]]): Iterator[T] = {
      nodes match {
        case Nil => Iterator.empty
        case head :: tail =>
          Iterator.single(head.value) ++ traverse(head.children) ++ traverse(tail)
      }
    }
    traverse(List(root))
  }
  
  def breadthFirstTraversal[T](root: TreeNode[T]): Iterator[T] = {
    def traverse(queue: List[TreeNode[T]]): Iterator[T] = {
      queue match {
        case Nil => Iterator.empty
        case head :: tail =>
          Iterator.single(head.value) ++ traverse(tail ++ head.children)
      }
    }
    traverse(List(root))
  }
  
  def main(args: Array[String]): Unit = {
    // 几何级数
    val geometric = geometricSeries(1.0, 2.0)
    println(s"Geometric series (1, 2, 4, 8, ...): ${geometric.take(10).toList}")
    
    // 随机数
    val random = randomNumbers(42)  // 固定种子以获得可重现的结果
    println(s"Random numbers: ${random.take(10).toList}")
    
    // 树遍历
    val tree = TreeNode(1, List(
      TreeNode(2, List(TreeNode(4), TreeNode(5))),
      TreeNode(3, List(TreeNode(6), TreeNode(7)))
    ))
    
    println(s"Depth-first traversal: ${depthFirstTraversal(tree).toList}")
    println(s"Breadth-first traversal: ${breadthFirstTraversal(tree).toList}")
  }
}

实际应用示例

数据流处理

scala
object DataStreamProcessing {
  // 模拟数据流
  case class LogEntry(timestamp: Long, level: String, message: String)
  
  def generateLogStream(): Iterator[LogEntry] = {
    val levels = Array("INFO", "WARN", "ERROR", "DEBUG")
    val messages = Array("User login", "Database query", "Cache miss", "Network timeout")
    val random = new scala.util.Random()
    
    Iterator.continually {
      LogEntry(
        System.currentTimeMillis() + random.nextInt(1000),
        levels(random.nextInt(levels.length)),
        messages(random.nextInt(messages.length))
      )
    }
  }
  
  def processLogStream(logs: Iterator[LogEntry]): Unit = {
    // 实时处理日志流
    val errorLogs = logs
      .filter(_.level == "ERROR")
      .take(5)  // 只处理前5个错误
    
    println("Processing error logs:")
    errorLogs.foreach { log =>
      println(s"[${log.timestamp}] ERROR: ${log.message}")
      Thread.sleep(100)  // 模拟处理时间
    }
  }
  
  // 批处理
  def batchProcess(logs: Iterator[LogEntry], batchSize: Int): Iterator[List[LogEntry]] = {
    logs.grouped(batchSize)
  }
  
  // 窗口处理
  def slidingWindowProcess(logs: Iterator[LogEntry], windowSize: Int): Iterator[List[LogEntry]] = {
    logs.sliding(windowSize)
  }
  
  def main(args: Array[String]): Unit = {
    val logStream = generateLogStream()
    
    // 实时处理
    println("Real-time processing:")
    processLogStream(logStream.take(20))
    
    // 批处理示例
    println("\nBatch processing:")
    val batches = batchProcess(generateLogStream().take(10), 3)
    batches.zipWithIndex.foreach { case (batch, index) =>
      println(s"Batch $index: ${batch.size} logs")
      batch.foreach(log => println(s"  ${log.level}: ${log.message}"))
    }
    
    // 滑动窗口处理
    println("\nSliding window processing:")
    val windows = slidingWindowProcess(generateLogStream().take(8), 3)
    windows.zipWithIndex.foreach { case (window, index) =>
      println(s"Window $index: ${window.map(_.level).mkString(", ")}")
    }
  }
}

文件处理

scala
import java.io.{File, PrintWriter}

object FileProcessing {
  // 创建测试文件
  def createTestFile(filename: String): Unit = {
    val writer = new PrintWriter(new File(filename))
    try {
      (1 to 1000).foreach { i =>
        writer.println(s"Line $i: This is line number $i with some random data ${scala.util.Random.nextInt(100)}")
      }
    } finally {
      writer.close()
    }
  }
  
  // 使用迭代器处理大文件
  def processLargeFile(filename: String): Unit = {
    val source = scala.io.Source.fromFile(filename)
    try {
      val lines = source.getLines()
      
      // 统计包含特定词的行数
      val wordCount = lines
        .filter(_.contains("random"))
        .map(_.split("\\s+").length)
        .sum
      
      println(s"Total words in lines containing 'random': $wordCount")
    } finally {
      source.close()
    }
  }
  
  // 分块处理文件
  def processFileInChunks(filename: String, chunkSize: Int): Unit = {
    val source = scala.io.Source.fromFile(filename)
    try {
      val lines = source.getLines()
      val chunks = lines.grouped(chunkSize)
      
      chunks.zipWithIndex.foreach { case (chunk, index) =>
        val chunkList = chunk.toList
        val avgLength = chunkList.map(_.length).sum.toDouble / chunkList.size
        println(f"Chunk $index: ${chunkList.size} lines, avg length: $avgLength%.2f")
      }
    } finally {
      source.close()
    }
  }
  
  // 过滤和转换文件内容
  def filterAndTransform(inputFile: String, outputFile: String): Unit = {
    val source = scala.io.Source.fromFile(inputFile)
    val writer = new PrintWriter(new File(outputFile))
    
    try {
      val lines = source.getLines()
      
      // 过滤包含数字的行,并转换为大写
      val processed = lines
        .filter(_.matches(".*\\d+.*"))
        .map(_.toUpperCase)
        .take(10)  // 只处理前10行
      
      processed.foreach(writer.println)
      
      println(s"Processed lines written to $outputFile")
    } finally {
      source.close()
      writer.close()
    }
  }
  
  def main(args: Array[String]): Unit = {
    val testFile = "test_data.txt"
    val outputFile = "processed_data.txt"
    
    // 创建测试文件
    createTestFile(testFile)
    println(s"Created test file: $testFile")
    
    // 处理大文件
    processLargeFile(testFile)
    
    // 分块处理
    println("\nProcessing in chunks:")
    processFileInChunks(testFile, 100)
    
    // 过滤和转换
    filterAndTransform(testFile, outputFile)
    
    // 清理文件
    new File(testFile).delete()
    new File(outputFile).delete()
  }
}

最佳实践

  1. 何时使用迭代器

    • 处理大型数据集
    • 需要惰性求值
    • 内存使用是关键考虑因素
    • 处理无限序列
  2. 性能考虑

    • 迭代器是一次性的,不能重复使用
    • 惰性求值可以提高性能和内存效率
    • 避免在迭代器上调用 sizelength
  3. 内存管理

    • 迭代器不会将所有元素存储在内存中
    • 适合处理流式数据
    • 注意资源管理(如文件句柄)
  4. 函数式编程

    • 使用 mapfilterflatMap 等操作
    • 避免副作用
    • 利用链式操作的惰性特性
  5. 错误处理

    • 检查 hasNext 避免 NoSuchElementException
    • 使用 Option 类型处理可能的空值
    • 正确管理资源(使用 try-finally 或 using 模式)

迭代器是 Scala 中处理数据流和大型数据集的强大工具,掌握其使用方法对于编写高效的程序至关重要。

本站内容仅供学习和研究使用。