Skip to content

Scala 提取器

提取器(Extractor)是 Scala 中一个强大的特性,它允许我们定义如何从对象中提取值。提取器通过 unapply 方法实现,是模式匹配的基础。

基本提取器

unapply 方法

scala
object BasicExtractors {
  // 简单的提取器
  object Even {
    def unapply(n: Int): Option[Int] = {
      if (n % 2 == 0) Some(n) else None
    }
  }
  
  object Odd {
    def unapply(n: Int): Option[Int] = {
      if (n % 2 != 0) Some(n) else None
    }
  }
  
  // 提取器可以返回不同的值
  object Square {
    def unapply(n: Int): Option[Int] = {
      val sqrt = math.sqrt(n).toInt
      if (sqrt * sqrt == n) Some(sqrt) else None
    }
  }
  
  // 布尔提取器
  object Positive {
    def unapply(n: Int): Boolean = n > 0
  }
  
  object Negative {
    def unapply(n: Int): Boolean = n < 0
  }
  
  def analyzeNumber(n: Int): String = n match {
    case Even(x) => s"$x 是偶数"
    case Odd(x) => s"$x 是奇数"
  }
  
  def analyzeSpecialNumbers(n: Int): String = n match {
    case Square(root) => s"$n 是完全平方数,平方根是 $root"
    case Positive() => s"$n 是正数"
    case Negative() => s"$n 是负数"
    case _ => s"$n 是零"
  }
  
  def main(args: Array[String]): Unit = {
    val numbers = List(1, 2, 3, 4, 9, 16, -5, 0, 25)
    
    println("数字分析:")
    numbers.foreach(n => println(s"$n -> ${analyzeNumber(n)}"))
    
    println("\n特殊数字分析:")
    numbers.foreach(n => println(s"$n -> ${analyzeSpecialNumbers(n)}"))
  }
}

多值提取器

scala
object MultiValueExtractors {
  // 提取多个值的提取器
  object FullName {
    def unapply(fullName: String): Option[(String, String)] = {
      val parts = fullName.split(" ")
      if (parts.length == 2) Some((parts(0), parts(1)))
      else None
    }
  }
  
  // 提取可变数量值的提取器
  object Words {
    def unapplySeq(sentence: String): Option[Seq[String]] = {
      val words = sentence.split("\\s+").filter(_.nonEmpty)
      if (words.nonEmpty) Some(words.toSeq) else None
    }
  }
  
  // 数字范围提取器
  object Range {
    def unapply(input: String): Option[(Int, Int)] = {
      val rangePattern = """(\d+)-(\d+)""".r
      input match {
        case rangePattern(start, end) => Some((start.toInt, end.toInt))
        case _ => None
      }
    }
  }
  
  // 坐标提取器
  object Coordinate {
    def unapply(input: String): Option[(Double, Double)] = {
      val coordPattern = """\((-?\d+\.?\d*),\s*(-?\d+\.?\d*)\)""".r
      input match {
        case coordPattern(x, y) => Some((x.toDouble, y.toDouble))
        case _ => None
      }
    }
  }
  
  def processInput(input: String): String = input match {
    case FullName(first, last) => 
      s"姓名: $first $last"
    
    case Words(first, second, rest @ _*) => 
      s"句子: 第一个词='$first', 第二个词='$second', 其余=${rest.mkString(", ")}"
    
    case Range(start, end) => 
      s"范围: 从 $start$end"
    
    case Coordinate(x, y) => 
      s"坐标: ($x, $y)"
    
    case _ => 
      s"未识别的格式: $input"
  }
  
  def main(args: Array[String]): Unit = {
    val inputs = List(
      "John Doe",
      "The quick brown fox jumps",
      "1-100",
      "(3.14, 2.71)",
      "single",
      "invalid format"
    )
    
    inputs.foreach(input => println(s"'$input' -> ${processInput(input)}"))
  }
}

高级提取器模式

嵌套提取器

scala
object NestedExtractors {
  case class Address(street: String, city: String, zipCode: String)
  case class Person(name: String, age: Int, address: Address)
  
  // 年龄组提取器
  object AgeGroup {
    def unapply(age: Int): Option[String] = age match {
      case a if a < 18 => Some("未成年")
      case a if a < 65 => Some("成年人")
      case _ => Some("老年人")
    }
  }
  
  // 城市类型提取器
  object CityType {
    val majorCities = Set("北京", "上海", "广州", "深圳")
    
    def unapply(city: String): Option[String] = {
      if (majorCities.contains(city)) Some("一线城市")
      else Some("其他城市")
    }
  }
  
  // 邮编区域提取器
  object ZipRegion {
    def unapply(zipCode: String): Option[String] = {
      zipCode.take(2) match {
        case "10" | "11" => Some("华北")
        case "20" | "21" => Some("华东")
        case "30" | "31" => Some("华南")
        case _ => Some("其他")
      }
    }
  }
  
  def analyzePerson(person: Person): String = person match {
    // 嵌套模式匹配
    case Person(name, AgeGroup(ageGroup), Address(_, CityType(cityType), ZipRegion(region))) =>
      s"$name$ageGroup,住在$cityType$region 地区)"
    
    case Person(name, age, Address(street, city, _)) if age > 60 =>
      s"$name 是老年人,住在 $city$street"
    
    case Person(name, _, Address(_, "北京", _)) =>
      s"$name 住在首都北京"
    
    case Person(name, age, _) =>
      s"$name$age 岁"
  }
  
  def main(args: Array[String]): Unit = {
    val people = List(
      Person("张三", 25, Address("中关村大街", "北京", "100080")),
      Person("李四", 16, Address("南京路", "上海", "200000")),
      Person("王五", 70, Address("天河路", "广州", "310000")),
      Person("赵六", 35, Address("解放路", "武汉", "430000"))
    )
    
    people.foreach(person => println(analyzePerson(person)))
  }
}

条件提取器

scala
object ConditionalExtractors {
  // 带条件的提取器
  object ValidEmail {
    def unapply(email: String): Option[String] = {
      if (email.contains("@") && email.contains(".")) Some(email.toLowerCase)
      else None
    }
  }
  
  object StrongPassword {
    def unapply(password: String): Option[String] = {
      val hasUpper = password.exists(_.isUpper)
      val hasLower = password.exists(_.isLower)
      val hasDigit = password.exists(_.isDigit)
      val hasSpecial = password.exists("!@#$%^&*".contains(_))
      val isLongEnough = password.length >= 8
      
      if (hasUpper && hasLower && hasDigit && hasSpecial && isLongEnough) {
        Some(password)
      } else None
    }
  }
  
  // 数值范围提取器
  object InRange {
    def unapply(value: Int): Option[String] = value match {
      case v if v >= 0 && v <= 100 => Some("正常范围")
      case v if v > 100 => Some("超出上限")
      case _ => Some("低于下限")
    }
  }
  
  // 文件类型提取器
  object FileType {
    def unapply(filename: String): Option[String] = {
      val extension = filename.split("\\.").lastOption.map(_.toLowerCase)
      extension match {
        case Some("jpg" | "jpeg" | "png" | "gif") => Some("图片")
        case Some("mp4" | "avi" | "mov") => Some("视频")
        case Some("txt" | "doc" | "pdf") => Some("文档")
        case Some("mp3" | "wav" | "flac") => Some("音频")
        case _ => Some("其他")
      }
    }
  }
  
  def validateUser(email: String, password: String): String = (email, password) match {
    case (ValidEmail(validEmail), StrongPassword(strongPass)) =>
      s"用户验证成功: $validEmail"
    
    case (ValidEmail(_), _) =>
      "邮箱有效,但密码不够强"
    
    case (_, StrongPassword(_)) =>
      "密码强度足够,但邮箱无效"
    
    case _ =>
      "邮箱和密码都不符合要求"
  }
  
  def analyzeValue(value: Int): String = value match {
    case InRange(status) => s"值 $value: $status"
  }
  
  def classifyFile(filename: String): String = filename match {
    case FileType(fileType) => s"文件 '$filename' 是 $fileType 类型"
  }
  
  def main(args: Array[String]): Unit = {
    // 用户验证测试
    val userTests = List(
      ("user@example.com", "StrongPass123!"),
      ("invalid-email", "StrongPass123!"),
      ("user@example.com", "weak"),
      ("invalid", "weak")
    )
    
    println("用户验证测试:")
    userTests.foreach { case (email, password) =>
      println(s"$email, $password -> ${validateUser(email, password)}")
    }
    
    // 数值范围测试
    println("\n数值范围测试:")
    List(-10, 50, 150).foreach(value => println(analyzeValue(value)))
    
    // 文件类型测试
    println("\n文件类型测试:")
    List("photo.jpg", "video.mp4", "document.pdf", "music.mp3", "data.csv")
      .foreach(filename => println(classifyFile(filename)))
  }
}

自定义数据结构的提取器

链表提取器

scala
object CustomDataStructureExtractors {
  // 自定义链表
  sealed trait MyList[+T]
  case object MyNil extends MyList[Nothing]
  case class MyCons[T](head: T, tail: MyList[T]) extends MyList[T]
  
  // 链表的提取器
  object MyList {
    def apply[T](elements: T*): MyList[T] = {
      elements.foldRight(MyNil: MyList[T])(MyCons(_, _))
    }
    
    // 提取头和尾
    def unapply[T](list: MyList[T]): Option[(T, MyList[T])] = list match {
      case MyCons(head, tail) => Some((head, tail))
      case MyNil => None
    }
  }
  
  // 特殊模式提取器
  object SingleElement {
    def unapply[T](list: MyList[T]): Option[T] = list match {
      case MyCons(head, MyNil) => Some(head)
      case _ => None
    }
  }
  
  object FirstTwo {
    def unapply[T](list: MyList[T]): Option[(T, T)] = list match {
      case MyCons(first, MyCons(second, _)) => Some((first, second))
      case _ => None
    }
  }
  
  // 二叉树
  sealed trait BinaryTree[+T]
  case object Empty extends BinaryTree[Nothing]
  case class Node[T](value: T, left: BinaryTree[T], right: BinaryTree[T]) extends BinaryTree[T]
  
  // 二叉树提取器
  object Leaf {
    def unapply[T](tree: BinaryTree[T]): Option[T] = tree match {
      case Node(value, Empty, Empty) => Some(value)
      case _ => None
    }
  }
  
  object LeftChild {
    def unapply[T](tree: BinaryTree[T]): Option[(T, BinaryTree[T])] = tree match {
      case Node(value, left, Empty) => Some((value, left))
      case _ => None
    }
  }
  
  object RightChild {
    def unapply[T](tree: BinaryTree[T]): Option[(T, BinaryTree[T])] = tree match {
      case Node(value, Empty, right) => Some((value, right))
      case _ => None
    }
  }
  
  def analyzeList[T](list: MyList[T]): String = list match {
    case MyNil => "空列表"
    case SingleElement(element) => s"单元素列表: $element"
    case FirstTwo(first, second) => s"前两个元素: $first, $second"
    case MyList(head, tail) => s"头元素: $head, 尾部: ${analyzeList(tail)}"
  }
  
  def analyzeTree[T](tree: BinaryTree[T]): String = tree match {
    case Empty => "空树"
    case Leaf(value) => s"叶子节点: $value"
    case LeftChild(value, left) => s"只有左子树的节点: $value, 左子树: ${analyzeTree(left)}"
    case RightChild(value, right) => s"只有右子树的节点: $value, 右子树: ${analyzeTree(right)}"
    case Node(value, left, right) => s"完整节点: $value, 左: ${analyzeTree(left)}, 右: ${analyzeTree(right)}"
  }
  
  def main(args: Array[String]): Unit = {
    // 测试链表
    val lists = List(
      MyNil,
      MyList(1),
      MyList(1, 2),
      MyList(1, 2, 3, 4)
    )
    
    println("链表分析:")
    lists.foreach(list => println(analyzeList(list)))
    
    // 测试二叉树
    val trees = List(
      Empty,
      Node(1, Empty, Empty),  // 叶子
      Node(1, Node(2, Empty, Empty), Empty),  // 只有左子树
      Node(1, Empty, Node(3, Empty, Empty)),  // 只有右子树
      Node(1, Node(2, Empty, Empty), Node(3, Empty, Empty))  // 完整树
    )
    
    println("\n二叉树分析:")
    trees.foreach(tree => println(analyzeTree(tree)))
  }
}

实际应用示例

URL 解析器

scala
object URLParser {
  case class URL(protocol: String, host: String, port: Option[Int], path: String, query: Map[String, String])
  
  object URL {
    def unapply(urlString: String): Option[URL] = {
      val urlPattern = """^(https?):\/\/([^:\/\s]+)(?::(\d+))?([^?\s]*)(?:\?(.*))?$""".r
      
      urlString match {
        case urlPattern(protocol, host, portStr, path, queryStr) =>
          val port = Option(portStr).map(_.toInt)
          val query = parseQuery(Option(queryStr).getOrElse(""))
          Some(URL(protocol, host, port, if (path.isEmpty) "/" else path, query))
        case _ => None
      }
    }
    
    private def parseQuery(queryString: String): Map[String, String] = {
      if (queryString.isEmpty) Map.empty
      else {
        queryString.split("&").map { param =>
          val parts = param.split("=", 2)
          parts(0) -> (if (parts.length > 1) parts(1) else "")
        }.toMap
      }
    }
  }
  
  // 特定协议提取器
  object HttpsURL {
    def unapply(url: URL): Option[URL] = {
      if (url.protocol == "https") Some(url) else None
    }
  }
  
  object LocalURL {
    def unapply(url: URL): Option[URL] = {
      if (url.host == "localhost" || url.host == "127.0.0.1") Some(url) else None
    }
  }
  
  def analyzeURL(urlString: String): String = urlString match {
    case URL(HttpsURL(url)) => 
      s"安全HTTPS连接: ${url.host}${url.path}"
    
    case URL(LocalURL(url)) => 
      s"本地连接: ${url.protocol}://${url.host}:${url.port.getOrElse("默认端口")}"
    
    case URL(url) => 
      s"普通URL: ${url.protocol}://${url.host}${url.path}" +
      (if (url.query.nonEmpty) s", 查询参数: ${url.query}" else "")
    
    case _ => 
      s"无效URL: $urlString"
  }
  
  def main(args: Array[String]): Unit = {
    val urls = List(
      "https://www.example.com/path?param=value",
      "http://localhost:8080/api/users",
      "https://api.github.com/repos/owner/repo",
      "http://127.0.0.1:3000/",
      "invalid-url"
    )
    
    urls.foreach(url => println(s"$url -> ${analyzeURL(url)}"))
  }
}

日志解析器

scala
import java.time.LocalDateTime
import java.time.format.DateTimeFormatter

object LogParser {
  case class LogEntry(
    timestamp: LocalDateTime,
    level: String,
    logger: String,
    message: String,
    thread: Option[String] = None
  )
  
  // 标准日志格式提取器
  object StandardLog {
    private val pattern = """(\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}) \[(\w+)\] (\w+): (.+)""".r
    private val formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss")
    
    def unapply(logLine: String): Option[LogEntry] = logLine match {
      case pattern(timestampStr, level, logger, message) =>
        try {
          val timestamp = LocalDateTime.parse(timestampStr, formatter)
          Some(LogEntry(timestamp, level, logger, message))
        } catch {
          case _: Exception => None
        }
      case _ => None
    }
  }
  
  // 带线程信息的日志格式
  object ThreadedLog {
    private val pattern = """(\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}) \[(\w+)\] \[([^\]]+)\] (\w+): (.+)""".r
    private val formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss")
    
    def unapply(logLine: String): Option[LogEntry] = logLine match {
      case pattern(timestampStr, level, thread, logger, message) =>
        try {
          val timestamp = LocalDateTime.parse(timestampStr, formatter)
          Some(LogEntry(timestamp, level, logger, message, Some(thread)))
        } catch {
          case _: Exception => None
        }
      case _ => None
    }
  }
  
  // 错误级别提取器
  object ErrorLog {
    def unapply(entry: LogEntry): Option[LogEntry] = {
      if (entry.level == "ERROR") Some(entry) else None
    }
  }
  
  object WarningLog {
    def unapply(entry: LogEntry): Option[LogEntry] = {
      if (entry.level == "WARN") Some(entry) else None
    }
  }
  
  // 特定时间范围提取器
  object RecentLog {
    def unapply(entry: LogEntry): Option[LogEntry] = {
      val now = LocalDateTime.now()
      val oneHourAgo = now.minusHours(1)
      if (entry.timestamp.isAfter(oneHourAgo)) Some(entry) else None
    }
  }
  
  def analyzeLogEntry(logLine: String): String = logLine match {
    case ThreadedLog(ErrorLog(entry)) =>
      s"🔴 线程错误: [${entry.thread.get}] ${entry.logger} - ${entry.message}"
    
    case StandardLog(ErrorLog(entry)) =>
      s"🔴 错误: ${entry.logger} - ${entry.message}"
    
    case ThreadedLog(WarningLog(entry)) =>
      s"🟡 线程警告: [${entry.thread.get}] ${entry.logger} - ${entry.message}"
    
    case StandardLog(WarningLog(entry)) =>
      s"🟡 警告: ${entry.logger} - ${entry.message}"
    
    case ThreadedLog(RecentLog(entry)) =>
      s"🕐 最近线程日志: [${entry.thread.get}] ${entry.level} - ${entry.message}"
    
    case StandardLog(RecentLog(entry)) =>
      s"🕐 最近日志: ${entry.level} - ${entry.message}"
    
    case ThreadedLog(entry) =>
      s"📝 线程日志: [${entry.thread.get}] ${entry.level} - ${entry.logger}"
    
    case StandardLog(entry) =>
      s"📝 标准日志: ${entry.level} - ${entry.logger}"
    
    case _ =>
      s"❓ 无法解析的日志: $logLine"
  }
  
  def main(args: Array[String]): Unit = {
    val logLines = List(
      "2023-12-25 10:30:45 [INFO] UserService: User login successful",
      "2023-12-25 10:31:02 [ERROR] DatabaseService: Connection timeout",
      "2023-12-25 10:31:15 [WARN] [main-thread] CacheService: Cache miss for key: user_123",
      "2023-12-25 10:31:30 [ERROR] [worker-1] PaymentService: Payment processing failed",
      "Invalid log format",
      "2023-12-25 10:32:00 [DEBUG] SecurityService: Token validation passed"
    )
    
    println("日志分析结果:")
    logLines.foreach(line => println(analyzeLogEntry(line)))
  }
}

最佳实践

提取器设计原则

scala
object ExtractorBestPractices {
  // 1. 保持提取器简单和专一
  object EmailDomain {
    def unapply(email: String): Option[String] = {
      val atIndex = email.indexOf('@')
      if (atIndex > 0 && atIndex < email.length - 1) {
        Some(email.substring(atIndex + 1))
      } else None
    }
  }
  
  // 2. 提供有意义的返回值
  object Temperature {
    def unapply(celsius: Double): Option[String] = celsius match {
      case c if c < 0 => Some("冰点以下")
      case c if c < 10 => Some("寒冷")
      case c if c < 25 => Some("凉爽")
      case c if c < 35 => Some("温暖")
      case _ => Some("炎热")
    }
  }
  
  // 3. 考虑性能,避免复杂计算
  object FastPrime {
    private val knownPrimes = Set(2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31)
    
    def unapply(n: Int): Boolean = {
      if (n <= 31) knownPrimes.contains(n)
      else isPrime(n)  // 只对大数进行复杂计算
    }
    
    private def isPrime(n: Int): Boolean = {
      if (n < 2) false
      else !(2 to math.sqrt(n).toInt).exists(n % _ == 0)
    }
  }
  
  // 4. 组合提取器
  object ValidUser {
    def unapply(input: (String, String, Int)): Option[(String, String, Int)] = {
      val (name, email, age) = input
      
      val validName = name.trim.nonEmpty && name.length >= 2
      val validEmail = email.contains("@") && email.contains(".")
      val validAge = age >= 0 && age <= 150
      
      if (validName && validEmail && validAge) Some((name, email, age))
      else None
    }
  }
  
  // 5. 错误处理
  object SafeInt {
    def unapply(s: String): Option[Int] = {
      try {
        Some(s.toInt)
      } catch {
        case _: NumberFormatException => None
      }
    }
  }
  
  def demonstrateBestPractices(): Unit = {
    // 邮箱域名提取
    val emails = List("user@gmail.com", "admin@company.org", "invalid-email")
    emails.foreach {
      case email @ EmailDomain(domain) => println(s"$email 的域名是 $domain")
      case email => println(s"$email 不是有效邮箱")
    }
    
    // 温度分类
    val temperatures = List(-5.0, 5.0, 20.0, 30.0, 40.0)
    temperatures.foreach {
      case temp @ Temperature(category) => println(s"${temp}°C 是 $category")
    }
    
    // 素数检测
    val numbers = List(2, 4, 17, 25, 29)
    numbers.foreach {
      case n @ FastPrime() => println(s"$n 是素数")
      case n => println(s"$n 不是素数")
    }
    
    // 用户验证
    val users = List(
      ("Alice", "alice@example.com", 25),
      ("", "invalid", -5),
      ("Bob", "bob@test.com", 30)
    )
    
    users.foreach {
      case ValidUser(name, email, age) => println(s"有效用户: $name, $email, $age")
      case (name, email, age) => println(s"无效用户: $name, $email, $age")
    }
    
    // 安全整数解析
    val numberStrings = List("123", "abc", "456")
    numberStrings.foreach {
      case SafeInt(number) => println(s"解析成功: $number")
      case str => println(s"解析失败: $str")
    }
  }
  
  def main(args: Array[String]): Unit = {
    demonstrateBestPractices()
  }
}

总结

提取器是 Scala 模式匹配的核心机制:

  1. 基本概念

    • unapply 方法定义如何提取值
    • 返回 Option[T]Boolean
    • unapplySeq 用于可变数量的值
  2. 设计原则

    • 保持简单和专一
    • 提供有意义的返回值
    • 考虑性能影响
    • 合理的错误处理
  3. 应用场景

    • 数据验证和解析
    • 模式匹配增强
    • 领域特定语言
    • API 设计
  4. 最佳实践

    • 组合简单提取器
    • 避免副作用
    • 考虑类型安全
    • 提供清晰的文档

提取器让模式匹配更加强大和灵活,是 Scala 函数式编程的重要工具。

本站内容仅供学习和研究使用。