Skip to content

Scala 模式匹配

模式匹配是 Scala 最强大的特性之一,它提供了一种优雅的方式来检查和解构数据。相比传统的 if-else 或 switch 语句,模式匹配更加灵活和表达力强。

基本模式匹配

值匹配

scala
object BasicPatternMatching {
  def main(args: Array[String]): Unit = {
    val number = 42
    
    val result = number match {
      case 1 => "one"
      case 2 => "two"
      case 42 => "the answer"
      case _ => "something else"  // 默认情况
    }
    
    println(result)  // "the answer"
    
    // 匹配多个值
    val day = 3
    val dayType = day match {
      case 1 | 2 | 3 | 4 | 5 => "weekday"
      case 6 | 7 => "weekend"
      case _ => "invalid day"
    }
    
    println(dayType)  // "weekday"
    
    // 带条件的匹配
    val x = 15
    val description = x match {
      case n if n < 0 => "negative"
      case n if n == 0 => "zero"
      case n if n > 0 && n < 10 => "single digit positive"
      case n if n >= 10 && n < 100 => "double digit positive"
      case _ => "large number"
    }
    
    println(description)  // "double digit positive"
  }
}

类型匹配

scala
object TypeMatching {
  def processValue(value: Any): String = value match {
    case s: String => s"String: $s"
    case i: Int => s"Integer: $i"
    case d: Double => s"Double: $d"
    case b: Boolean => s"Boolean: $b"
    case list: List[_] => s"List with ${list.length} elements"
    case map: Map[_, _] => s"Map with ${map.size} entries"
    case _ => "Unknown type"
  }
  
  def main(args: Array[String]): Unit = {
    val values = List(
      "Hello",
      42,
      3.14,
      true,
      List(1, 2, 3),
      Map("a" -> 1, "b" -> 2),
      Array(1, 2, 3)
    )
    
    values.foreach(value => println(processValue(value)))
  }
}
```## 集合模式匹


### List 模式匹配

```scala
object ListPatternMatching {
  def analyzeList(list: List[Int]): String = list match {
    case Nil => "Empty list"
    case head :: Nil => s"Single element: $head"
    case head :: tail => s"Head: $head, Tail: ${tail.mkString(", ")}"
  }
  
  def processSpecificPatterns(list: List[Int]): String = list match {
    case List() => "Empty list"
    case List(x) => s"Single element: $x"
    case List(x, y) => s"Two elements: $x, $y"
    case List(1, 2, 3) => "Exactly 1, 2, 3"
    case List(1, _*) => "Starts with 1"
    case List(_, _, third, _*) => s"Third element is $third"
    case x :: y :: _ if x > y => "First two elements are decreasing"
    case _ => "Other pattern"
  }
  
  def main(args: Array[String]): Unit = {
    val lists = List(
      List(),
      List(1),
      List(1, 2),
      List(1, 2, 3),
      List(1, 4, 5, 6),
      List(2, 1, 7, 8),
      List(5, 3, 9)
    )
    
    println("List analysis:")
    lists.foreach(list => println(s"$list -> ${analyzeList(list)}"))
    
    println("\nSpecific patterns:")
    lists.foreach(list => println(s"$list -> ${processSpecificPatterns(list)}"))
  }
}

Array 和 Vector 模式匹配

scala
object ArrayVectorMatching {
  def matchArray(arr: Array[Int]): String = arr match {
    case Array() => "Empty array"
    case Array(x) => s"Single element: $x"
    case Array(x, y) => s"Two elements: $x, $y"
    case Array(1, 2, _*) => "Starts with 1, 2"
    case _ => s"Array with ${arr.length} elements"
  }
  
  def matchVector(vec: Vector[String]): String = vec match {
    case Vector() => "Empty vector"
    case Vector(single) => s"Single element: $single"
    case Vector("start", _*) => "Starts with 'start'"
    case Vector(_, "end") => "Ends with 'end'"
    case _ => s"Vector with ${vec.length} elements"
  }
  
  def main(args: Array[String]): Unit = {
    val arrays = List(
      Array(),
      Array(1),
      Array(1, 2),
      Array(1, 2, 3, 4),
      Array(5, 6, 7)
    )
    
    arrays.foreach(arr => println(s"${arr.mkString("[", ", ", "]")} -> ${matchArray(arr)}"))
    
    val vectors = List(
      Vector(),
      Vector("hello"),
      Vector("start", "middle", "end"),
      Vector("begin", "end"),
      Vector("a", "b", "c")
    )
    
    vectors.foreach(vec => println(s"$vec -> ${matchVector(vec)}"))
  }
}

Case Class 模式匹配

基本 Case Class 匹配

scala
case class Person(name: String, age: Int)
case class Employee(name: String, age: Int, department: String, salary: Double)
case class Student(name: String, age: Int, grade: String)

object CaseClassMatching {
  def describePerson(person: Any): String = person match {
    case Person(name, age) => s"Person: $name, $age years old"
    case Employee(name, age, dept, salary) => s"Employee: $name, $age years old, works in $dept, earns $salary"
    case Student(name, age, grade) => s"Student: $name, $age years old, in grade $grade"
    case _ => "Unknown person type"
  }
  
  def categorizeByAge(person: Person): String = person match {
    case Person(_, age) if age < 18 => "Minor"
    case Person(_, age) if age >= 18 && age < 65 => "Adult"
    case Person(_, age) if age >= 65 => "Senior"
  }
  
  def extractName(person: Any): Option[String] = person match {
    case Person(name, _) => Some(name)
    case Employee(name, _, _, _) => Some(name)
    case Student(name, _, _) => Some(name)
    case _ => None
  }
  
  def main(args: Array[String]): Unit = {
    val people = List(
      Person("Alice", 25),
      Employee("Bob", 35, "Engineering", 75000),
      Student("Charlie", 16, "10th"),
      "Not a person"
    )
    
    people.foreach(person => println(describePerson(person)))
    
    println("\nAge categories:")
    val persons = List(Person("Child", 12), Person("Adult", 30), Person("Elder", 70))
    persons.foreach(person => println(s"${person.name}: ${categorizeByAge(person)}"))
    
    println("\nExtracted names:")
    people.foreach(person => println(s"Name: ${extractName(person).getOrElse("Unknown")}"))
  }
}
```###
 嵌套模式匹配

```scala
case class Address(street: String, city: String, country: String)
case class Company(name: String, address: Address)
case class EmployeeWithAddress(name: String, age: Int, company: Company)

object NestedPatternMatching {
  def analyzeEmployee(emp: EmployeeWithAddress): String = emp match {
    case EmployeeWithAddress(name, age, Company(companyName, Address(_, city, "USA"))) =>
      s"$name ($age) works at $companyName in $city, USA"
    
    case EmployeeWithAddress(name, age, Company(companyName, Address(_, "San Francisco", _))) =>
      s"$name ($age) works at $companyName in San Francisco"
    
    case EmployeeWithAddress(name, _, Company("Google", _)) =>
      s"$name works at Google"
    
    case EmployeeWithAddress(name, age, _) if age > 50 =>
      s"$name is a senior employee ($age years old)"
    
    case EmployeeWithAddress(name, _, _) =>
      s"$name is an employee"
  }
  
  def main(args: Array[String]): Unit = {
    val employees = List(
      EmployeeWithAddress("Alice", 30, Company("Google", Address("1600 Amphitheatre", "Mountain View", "USA"))),
      EmployeeWithAddress("Bob", 25, Company("Twitter", Address("1355 Market St", "San Francisco", "USA"))),
      EmployeeWithAddress("Charlie", 55, Company("Microsoft", Address("One Microsoft Way", "Redmond", "USA"))),
      EmployeeWithAddress("Diana", 28, Company("Google", Address("Googleplex", "London", "UK")))
    )
    
    employees.foreach(emp => println(analyzeEmployee(emp)))
  }
}

Option 和 Either 模式匹配

Option 模式匹配

scala
object OptionPatternMatching {
  def processOption(opt: Option[String]): String = opt match {
    case Some(value) => s"Got value: $value"
    case None => "No value"
  }
  
  def processNestedOption(opt: Option[Option[Int]]): String = opt match {
    case Some(Some(value)) => s"Nested value: $value"
    case Some(None) => "Outer Some, inner None"
    case None => "Outer None"
  }
  
  // 实际应用:安全的字典查找
  def safeDivide(x: Double, y: Double): Option[Double] = {
    if (y != 0) Some(x / y) else None
  }
  
  def calculateAndDescribe(x: Double, y: Double): String = {
    safeDivide(x, y) match {
      case Some(result) if result > 1 => s"$x / $y = $result (greater than 1)"
      case Some(result) if result == 1 => s"$x / $y = $result (exactly 1)"
      case Some(result) => s"$x / $y = $result (less than 1)"
      case None => s"Cannot divide $x by $y (division by zero)"
    }
  }
  
  def main(args: Array[String]): Unit = {
    val options = List(Some("hello"), None, Some("world"))
    options.foreach(opt => println(processOption(opt)))
    
    val nestedOptions = List(Some(Some(42)), Some(None), None)
    nestedOptions.foreach(opt => println(processNestedOption(opt)))
    
    val calculations = List((10.0, 2.0), (5.0, 5.0), (3.0, 4.0), (1.0, 0.0))
    calculations.foreach { case (x, y) => println(calculateAndDescribe(x, y)) }
  }
}

Either 模式匹配

scala
object EitherPatternMatching {
  def parseInteger(s: String): Either[String, Int] = {
    try {
      Right(s.toInt)
    } catch {
      case _: NumberFormatException => Left(s"'$s' is not a valid integer")
    }
  }
  
  def processEither(either: Either[String, Int]): String = either match {
    case Left(error) => s"Error: $error"
    case Right(value) if value > 0 => s"Positive number: $value"
    case Right(value) if value < 0 => s"Negative number: $value"
    case Right(0) => "Zero"
  }
  
  // 链式操作
  def processNumbers(inputs: List[String]): List[String] = {
    inputs.map { input =>
      parseInteger(input) match {
        case Left(error) => error
        case Right(num) => s"$input -> ${num * 2}"
      }
    }
  }
  
  def main(args: Array[String]): Unit = {
    val inputs = List("42", "-10", "0", "abc", "3.14")
    
    println("Parsing results:")
    inputs.foreach { input =>
      val result = parseInteger(input)
      println(s"$input -> ${processEither(result)}")
    }
    
    println("\nProcessed numbers:")
    processNumbers(inputs).foreach(println)
  }
}

高级模式匹配

变量绑定和提取器

scala
object AdvancedPatternMatching {
  // 自定义提取器
  object Even {
    def unapply(n: Int): Option[Int] = {
      if (n % 2 == 0) Some(n) else None
    }
  }
  
  object Odd {
    def unapply(n: Int): Option[Int] = {
      if (n % 2 != 0) Some(n) else None
    }
  }
  
  // 多值提取器
  object FirstLast {
    def unapply[T](list: List[T]): Option[(T, T)] = {
      if (list.length >= 2) Some((list.head, list.last))
      else None
    }
  }
  
  def analyzeNumber(n: Int): String = n match {
    case Even(x) => s"$x is even"
    case Odd(x) => s"$x is odd"
  }
  
  def analyzeList[T](list: List[T]): String = list match {
    case FirstLast(first, last) => s"First: $first, Last: $last"
    case head :: Nil => s"Single element: $head"
    case Nil => "Empty list"
  }
  
  // 变量绑定 (@)
  def processComplexPattern(data: Any): String = data match {
    case list @ List(x, y, z) if x == z => s"List $list has equal first and last elements"
    case person @ Person(name, age) if age >= 18 => s"Adult person: $person"
    case opt @ Some(value) if value.toString.length > 5 => s"Long value in option: $opt"
    case _ => "No special pattern matched"
  }
  
  def main(args: Array[String]): Unit = {
    // 测试自定义提取器
    val numbers = List(1, 2, 3, 4, 5, 6)
    numbers.foreach(n => println(analyzeNumber(n)))
    
    // 测试列表提取器
    val lists = List(
      List(1, 2, 3, 4),
      List("hello"),
      List(),
      List('a', 'b')
    )
    lists.foreach(list => println(analyzeList(list)))
    
    // 测试变量绑定
    val testData = List(
      List(1, 2, 1),
      Person("Alice", 25),
      Some("very long string"),
      "other"
    )
    testData.foreach(data => println(processComplexPattern(data)))
  }
}
```###
 正则表达式模式匹配

```scala
import scala.util.matching.Regex

object RegexPatternMatching {
  val EmailPattern: Regex = """(\w+)@(\w+\.\w+)""".r
  val PhonePattern: Regex = """(\d{3})-(\d{3})-(\d{4})""".r
  val DatePattern: Regex = """(\d{4})-(\d{2})-(\d{2})""".r
  
  def validateInput(input: String): String = input match {
    case EmailPattern(username, domain) => s"Valid email: $username at $domain"
    case PhonePattern(area, exchange, number) => s"Valid phone: ($area) $exchange-$number"
    case DatePattern(year, month, day) => s"Valid date: $day/$month/$year"
    case _ => s"Invalid format: $input"
  }
  
  // 更复杂的正则表达式匹配
  val LogPattern: Regex = """(\d{4}-\d{2}-\d{2}) (\d{2}:\d{2}:\d{2}) \[(\w+)\] (.+)""".r
  
  def parseLogEntry(log: String): String = log match {
    case LogPattern(date, time, level, message) =>
      s"Log entry: $date $time [$level] $message"
    case _ =>
      s"Invalid log format: $log"
  }
  
  def main(args: Array[String]): Unit = {
    val inputs = List(
      "john@example.com",
      "555-123-4567",
      "2023-12-25",
      "invalid-input",
      "alice@company.org"
    )
    
    inputs.foreach(input => println(validateInput(input)))
    
    val logEntries = List(
      "2023-12-25 10:30:45 [INFO] Application started",
      "2023-12-25 10:31:02 [ERROR] Database connection failed",
      "Invalid log entry"
    )
    
    logEntries.foreach(log => println(parseLogEntry(log)))
  }
}

实际应用示例

JSON 解析器

scala
// 简化的 JSON AST
sealed trait Json
case object JsonNull extends Json
case class JsonBool(value: Boolean) extends Json
case class JsonNumber(value: Double) extends Json
case class JsonString(value: String) extends Json
case class JsonArray(elements: List[Json]) extends Json
case class JsonObject(fields: Map[String, Json]) extends Json

object JsonProcessor {
  def prettyPrint(json: Json, indent: Int = 0): String = {
    val spaces = "  " * indent
    
    json match {
      case JsonNull => "null"
      case JsonBool(value) => value.toString
      case JsonNumber(value) => value.toString
      case JsonString(value) => s""""$value""""
      
      case JsonArray(Nil) => "[]"
      case JsonArray(elements) =>
        val elementsStr = elements.map(prettyPrint(_, indent + 1)).mkString(",\n" + "  " * (indent + 1))
        s"[\n${"  " * (indent + 1)}$elementsStr\n$spaces]"
      
      case JsonObject(fields) if fields.isEmpty => "{}"
      case JsonObject(fields) =>
        val fieldsStr = fields.map { case (key, value) =>
          s""""$key": ${prettyPrint(value, indent + 1)}"""
        }.mkString(",\n" + "  " * (indent + 1))
        s"{\n${"  " * (indent + 1)}$fieldsStr\n$spaces}"
    }
  }
  
  def extractStrings(json: Json): List[String] = json match {
    case JsonString(value) => List(value)
    case JsonArray(elements) => elements.flatMap(extractStrings)
    case JsonObject(fields) => fields.values.toList.flatMap(extractStrings)
    case _ => List.empty
  }
  
  def findValue(json: Json, key: String): Option[Json] = json match {
    case JsonObject(fields) => fields.get(key)
    case JsonArray(elements) => elements.collectFirst {
      case obj @ JsonObject(_) => findValue(obj, key)
    }.flatten
    case _ => None
  }
  
  def main(args: Array[String]): Unit = {
    val json = JsonObject(Map(
      "name" -> JsonString("Alice"),
      "age" -> JsonNumber(25),
      "active" -> JsonBool(true),
      "address" -> JsonObject(Map(
        "street" -> JsonString("123 Main St"),
        "city" -> JsonString("New York")
      )),
      "hobbies" -> JsonArray(List(
        JsonString("reading"),
        JsonString("swimming"),
        JsonString("coding")
      )),
      "spouse" -> JsonNull
    ))
    
    println("Pretty printed JSON:")
    println(prettyPrint(json))
    
    println("\nExtracted strings:")
    extractStrings(json).foreach(println)
    
    println("\nFind values:")
    println(s"Name: ${findValue(json, "name")}")
    println(s"City: ${findValue(json, "city")}")
    println(s"Unknown: ${findValue(json, "unknown")}")
  }
}

表达式求值器

scala
// 数学表达式 AST
sealed trait Expr
case class Num(value: Double) extends Expr
case class Add(left: Expr, right: Expr) extends Expr
case class Sub(left: Expr, right: Expr) extends Expr
case class Mul(left: Expr, right: Expr) extends Expr
case class Div(left: Expr, right: Expr) extends Expr
case class Var(name: String) extends Expr

object ExpressionEvaluator {
  def eval(expr: Expr, env: Map[String, Double] = Map.empty): Double = expr match {
    case Num(value) => value
    case Var(name) => env.getOrElse(name, throw new RuntimeException(s"Undefined variable: $name"))
    case Add(left, right) => eval(left, env) + eval(right, env)
    case Sub(left, right) => eval(left, env) - eval(right, env)
    case Mul(left, right) => eval(left, env) * eval(right, env)
    case Div(left, right) => 
      val rightVal = eval(right, env)
      if (rightVal == 0) throw new RuntimeException("Division by zero")
      else eval(left, env) / rightVal
  }
  
  def simplify(expr: Expr): Expr = expr match {
    case Add(Num(0), right) => simplify(right)
    case Add(left, Num(0)) => simplify(left)
    case Add(Num(a), Num(b)) => Num(a + b)
    case Add(left, right) => Add(simplify(left), simplify(right))
    
    case Sub(left, Num(0)) => simplify(left)
    case Sub(Num(a), Num(b)) => Num(a - b)
    case Sub(left, right) => Sub(simplify(left), simplify(right))
    
    case Mul(Num(0), _) | Mul(_, Num(0)) => Num(0)
    case Mul(Num(1), right) => simplify(right)
    case Mul(left, Num(1)) => simplify(left)
    case Mul(Num(a), Num(b)) => Num(a * b)
    case Mul(left, right) => Mul(simplify(left), simplify(right))
    
    case Div(left, Num(1)) => simplify(left)
    case Div(Num(a), Num(b)) if b != 0 => Num(a / b)
    case Div(left, right) => Div(simplify(left), simplify(right))
    
    case other => other
  }
  
  def toString(expr: Expr): String = expr match {
    case Num(value) => value.toString
    case Var(name) => name
    case Add(left, right) => s"(${toString(left)} + ${toString(right)})"
    case Sub(left, right) => s"(${toString(left)} - ${toString(right)})"
    case Mul(left, right) => s"(${toString(left)} * ${toString(right)})"
    case Div(left, right) => s"(${toString(left)} / ${toString(right)})"
  }
  
  def main(args: Array[String]): Unit = {
    // 表达式: (x + 1) * (y - 2) / 3
    val expr = Div(
      Mul(
        Add(Var("x"), Num(1)),
        Sub(Var("y"), Num(2))
      ),
      Num(3)
    )
    
    println(s"Expression: ${toString(expr)}")
    
    val env = Map("x" -> 5.0, "y" -> 8.0)
    println(s"Evaluation with x=5, y=8: ${eval(expr, env)}")
    
    // 简化表达式: (x + 0) * (1 * y) - (0 + z)
    val complexExpr = Sub(
      Mul(
        Add(Var("x"), Num(0)),
        Mul(Num(1), Var("y"))
      ),
      Add(Num(0), Var("z"))
    )
    
    println(s"Complex expression: ${toString(complexExpr)}")
    println(s"Simplified: ${toString(simplify(complexExpr))}")
  }
}

最佳实践

  1. 优先使用模式匹配而不是 if-else

    • 更清晰的代码结构
    • 编译器检查完整性
    • 更好的类型安全
  2. 使用 sealed trait 确保完整性

    • 编译器会检查所有情况
    • 避免遗漏分支
    • 更安全的代码
  3. 合理使用守卫条件

    • 添加额外的条件检查
    • 保持模式简洁
    • 避免过于复杂的条件
  4. 利用变量绑定

    • 使用 @ 绑定整个匹配的值
    • 在需要时访问原始数据
    • 提高代码可读性
  5. 创建自定义提取器

    • 封装复杂的匹配逻辑
    • 提高代码复用性
    • 保持模式匹配的简洁

模式匹配是 Scala 函数式编程的核心特性,掌握它对于编写优雅、安全的代码至关重要。

本站内容仅供学习和研究使用。