Skip to content

Scala 方法与函数

方法和函数是 Scala 编程的核心概念。本章将详细介绍方法定义、函数字面量、高阶函数以及函数式编程的基本概念。

方法定义

基本方法语法

scala
// 基本方法定义
def methodName(parameter1: Type1, parameter2: Type2): ReturnType = {
  // 方法体
  returnValue
}

// 示例
def add(x: Int, y: Int): Int = {
  x + y
}

def greet(name: String): String = {
  s"Hello, $name!"
}

// 单表达式方法可以省略大括号
def multiply(x: Int, y: Int): Int = x * y

def square(x: Int): Int = x * x

无参数方法

scala
// 无参数方法
def getCurrentTime(): Long = System.currentTimeMillis()

def pi(): Double = 3.14159

// 可以省略括号(但调用时也必须省略)
def randomNumber: Int = scala.util.Random.nextInt(100)

// 调用方式
val time1 = getCurrentTime()  // 带括号
val time2 = getCurrentTime    // 不带括号(如果定义时有括号)
val number = randomNumber     // 必须不带括号

返回类型推断

scala
// 返回类型可以推断
def add(x: Int, y: Int) = x + y  // 推断为 Int

def createList() = List(1, 2, 3)  // 推断为 List[Int]

// 递归方法必须显式声明返回类型
def factorial(n: Int): Int = {
  if (n <= 1) 1
  else n * factorial(n - 1)
}

// 复杂方法建议显式声明返回类型
def processData(data: List[String]): Map[String, Int] = {
  data.groupBy(identity).view.mapValues(_.length).toMap
}

方法参数

默认参数

scala
def greet(name: String, greeting: String = "Hello", punctuation: String = "!"): String = {
  s"$greeting, $name$punctuation"
}

// 调用方式
val msg1 = greet("Alice")                    // "Hello, Alice!"
val msg2 = greet("Bob", "Hi")               // "Hi, Bob!"
val msg3 = greet("Charlie", "Hey", ".")     // "Hey, Charlie."

命名参数

scala
def createUser(name: String, age: Int, email: String, active: Boolean = true): String = {
  s"User($name, $age, $email, $active)"
}

// 使用命名参数
val user1 = createUser(
  name = "Alice",
  age = 25,
  email = "alice@example.com"
)

val user2 = createUser(
  email = "bob@example.com",
  name = "Bob",
  age = 30,
  active = false
)

可变参数

scala
def sum(numbers: Int*): Int = {
  numbers.sum
}

def concatenate(separator: String, strings: String*): String = {
  strings.mkString(separator)
}

// 调用方式
val total1 = sum(1, 2, 3, 4, 5)
val total2 = sum()  // 空参数列表

val text1 = concatenate(", ", "apple", "banana", "cherry")
val text2 = concatenate(" - ")  // 只有分隔符

// 传递集合作为可变参数
val numbers = List(1, 2, 3, 4, 5)
val total3 = sum(numbers: _*)  // 展开列表

传名参数(Call-by-Name)

scala
// 传值参数(默认)
def callByValue(x: Int): Int = {
  println("Evaluating call-by-value")
  x + x
}

// 传名参数
def callByName(x: => Int): Int = {
  println("Evaluating call-by-name")
  x + x
}

// 测试差异
def expensiveComputation(): Int = {
  println("Computing...")
  Thread.sleep(1000)
  42
}

// 传值:计算一次,使用两次
callByValue(expensiveComputation())

// 传名:每次使用都重新计算
callByName(expensiveComputation())

函数字面量

匿名函数

scala
// 基本匿名函数语法
val add = (x: Int, y: Int) => x + y
val square = (x: Int) => x * x
val greet = (name: String) => s"Hello, $name!"

// 使用匿名函数
val result1 = add(3, 4)        // 7
val result2 = square(5)        // 25
val message = greet("Alice")   // "Hello, Alice!"

函数类型

scala
// 函数类型声明
val multiply: (Int, Int) => Int = (x, y) => x * y
val isEven: Int => Boolean = x => x % 2 == 0
val printer: String => Unit = s => println(s)

// 复杂函数类型
val processor: List[Int] => List[Int] = list => list.map(_ * 2)
val validator: String => Option[String] = s => 
  if (s.nonEmpty) Some(s) else None

简化语法

scala
val numbers = List(1, 2, 3, 4, 5)

// 完整语法
val doubled1 = numbers.map(x => x * 2)

// 简化语法(占位符)
val doubled2 = numbers.map(_ * 2)
val filtered = numbers.filter(_ > 2)
val sum = numbers.reduce(_ + _)

// 多个占位符
val pairs = numbers.zip(numbers).map { case (x, y) => x + y }
val pairsSimple = numbers.zip(numbers).map(_ + _)  // 错误:不明确

// 正确的多参数占位符
val combined = List((1, 2), (3, 4)).map { case (a, b) => a + b }

高阶函数

接受函数作为参数

scala
def applyOperation(x: Int, y: Int, operation: (Int, Int) => Int): Int = {
  operation(x, y)
}

def applyToList[T, R](list: List[T], f: T => R): List[R] = {
  list.map(f)
}

// 使用示例
val add = (x: Int, y: Int) => x + y
val multiply = (x: Int, y: Int) => x * y

val sum = applyOperation(3, 4, add)      // 7
val product = applyOperation(3, 4, multiply)  // 12

val numbers = List(1, 2, 3, 4, 5)
val squared = applyToList(numbers, (x: Int) => x * x)
val strings = applyToList(numbers, (x: Int) => s"Number: $x")

返回函数的方法

scala
def createMultiplier(factor: Int): Int => Int = {
  (x: Int) => x * factor
}

def createValidator(minLength: Int): String => Boolean = {
  (s: String) => s.length >= minLength
}

// 使用示例
val double = createMultiplier(2)
val triple = createMultiplier(3)

val result1 = double(5)   // 10
val result2 = triple(4)   // 12

val isValidPassword = createValidator(8)
val isValidName = createValidator(2)

val valid1 = isValidPassword("secret123")  // true
val valid2 = isValidName("Al")            // true

函数组合

scala
// 函数组合
def compose[A, B, C](f: B => C, g: A => B): A => C = {
  (x: A) => f(g(x))
}

def andThen[A, B, C](f: A => B, g: B => C): A => C = {
  (x: A) => g(f(x))
}

// 示例函数
val addOne = (x: Int) => x + 1
val double = (x: Int) => x * 2
val toString = (x: Int) => x.toString

// 组合函数
val addOneThenDouble = compose(double, addOne)  // 先加1,再乘2
val doubleAndAddOne = andThen(double, addOne)   // 先乘2,再加1

val result1 = addOneThenDouble(3)  // (3 + 1) * 2 = 8
val result2 = doubleAndAddOne(3)   // (3 * 2) + 1 = 7

// 链式组合
val pipeline = andThen(andThen(addOne, double), toString)
val result3 = pipeline(3)  // "8"

柯里化(Currying)

柯里化函数定义

scala
// 普通多参数函数
def add(x: Int, y: Int): Int = x + y

// 柯里化函数
def addCurried(x: Int)(y: Int): Int = x + y

// 手动柯里化
def addManual(x: Int): Int => Int = (y: Int) => x + y

// 使用示例
val result1 = add(3, 4)           // 7
val result2 = addCurried(3)(4)    // 7
val result3 = addManual(3)(4)     // 7

// 部分应用
val add5 = addCurried(5) _        // Int => Int
val result4 = add5(3)             // 8

val add10 = addManual(10)         // Int => Int
val result5 = add10(7)            // 17

柯里化的实际应用

scala
// 配置函数
def createLogger(level: String)(component: String)(message: String): Unit = {
  println(s"[$level] $component: $message")
}

// 创建特定级别的日志器
val infoLogger = createLogger("INFO") _
val errorLogger = createLogger("ERROR") _

// 创建特定组件的日志器
val dbInfoLogger = infoLogger("Database")
val apiErrorLogger = errorLogger("API")

// 使用
dbInfoLogger("Connection established")
apiErrorLogger("Request failed")

// 数据处理管道
def processData(validator: String => Boolean)
               (transformer: String => String)
               (data: List[String]): List[String] = {
  data.filter(validator).map(transformer)
}

val isNotEmpty = (s: String) => s.nonEmpty
val toUpperCase = (s: String) => s.toUpperCase

val processor = processData(isNotEmpty)(toUpperCase) _
val result = processor(List("hello", "", "world", "scala"))
// List("HELLO", "WORLD", "SCALA")

部分应用函数

基本部分应用

scala
def multiply(x: Int, y: Int, z: Int): Int = x * y * z

// 部分应用
val multiplyBy2 = multiply(2, _, _)      // (Int, Int) => Int
val multiplyBy2And3 = multiply(2, 3, _)  // Int => Int

val result1 = multiplyBy2(3, 4)          // 24
val result2 = multiplyBy2And3(5)         // 30

// 使用下划线进行部分应用
val numbers = List(1, 2, 3, 4, 5)
val doubled = numbers.map(multiply(2, _, 1))  // List(2, 4, 6, 8, 10)

复杂部分应用

scala
def createConnection(host: String, port: Int, timeout: Int, ssl: Boolean): String = {
  s"Connection to $host:$port (timeout: ${timeout}s, ssl: $ssl)"
}

// 创建特定环境的连接函数
val prodConnection = createConnection("prod.example.com", 443, _, true)
val devConnection = createConnection("localhost", 8080, _, false)

val prod = prodConnection(30)  // 生产环境,30秒超时
val dev = devConnection(5)     // 开发环境,5秒超时

递归函数

尾递归优化

scala
import scala.annotation.tailrec

// 非尾递归(可能栈溢出)
def factorial(n: Int): Int = {
  if (n <= 1) 1
  else n * factorial(n - 1)
}

// 尾递归版本
def factorialTailRec(n: Int): Int = {
  @tailrec
  def loop(n: Int, acc: Int): Int = {
    if (n <= 1) acc
    else loop(n - 1, n * acc)
  }
  loop(n, 1)
}

// 斐波那契数列
def fibonacci(n: Int): Int = {
  @tailrec
  def fib(n: Int, a: Int, b: Int): Int = {
    if (n == 0) a
    else fib(n - 1, b, a + b)
  }
  fib(n, 0, 1)
}

相互递归

scala
def isEven(n: Int): Boolean = {
  if (n == 0) true
  else isOdd(n - 1)
}

def isOdd(n: Int): Boolean = {
  if (n == 0) false
  else isEven(n - 1)
}

// 使用示例
val even = isEven(4)  // true
val odd = isOdd(5)    // true

函数式编程概念

纯函数

scala
// 纯函数:相同输入总是产生相同输出,无副作用
def add(x: Int, y: Int): Int = x + y
def multiply(x: Int, y: Int): Int = x * y

// 非纯函数:有副作用
var counter = 0
def impureIncrement(): Int = {
  counter += 1  // 副作用:修改外部状态
  counter
}

def impureRandom(): Int = {
  scala.util.Random.nextInt(100)  // 副作用:不确定性
}

// 纯函数版本
def pureIncrement(current: Int): Int = current + 1
def pureRandom(seed: Long): (Long, Int) = {
  val rng = new scala.util.Random(seed)
  (seed + 1, rng.nextInt(100))
}

不可变性

scala
// 不可变数据结构操作
val numbers = List(1, 2, 3, 4, 5)

// 所有操作都返回新的集合
val doubled = numbers.map(_ * 2)
val filtered = numbers.filter(_ > 2)
val sorted = numbers.sorted
val reversed = numbers.reverse

// 原始列表保持不变
println(numbers)  // List(1, 2, 3, 4, 5)

// 函数式更新
case class Person(name: String, age: Int)

def updateAge(person: Person, newAge: Int): Person = {
  person.copy(age = newAge)
}

val alice = Person("Alice", 25)
val olderAlice = updateAge(alice, 26)
// alice 保持不变,olderAlice 是新实例

实际应用示例

数据处理管道

scala
case class Employee(name: String, department: String, salary: Double, years: Int)

val employees = List(
  Employee("Alice", "Engineering", 75000, 3),
  Employee("Bob", "Sales", 65000, 5),
  Employee("Charlie", "Engineering", 85000, 7),
  Employee("Diana", "Marketing", 60000, 2),
  Employee("Eve", "Engineering", 95000, 10)
)

// 函数式数据处理
def processEmployees(employees: List[Employee]): Map[String, Double] = {
  employees
    .filter(_.years >= 3)                    // 过滤经验丰富的员工
    .groupBy(_.department)                   // 按部门分组
    .view.mapValues(_.map(_.salary).sum)     // 计算各部门总薪资
    .toMap
}

val departmentSalaries = processEmployees(employees)

// 使用高阶函数创建灵活的处理器
def createEmployeeProcessor(
  filter: Employee => Boolean,
  groupBy: Employee => String,
  aggregate: List[Employee] => Double
): List[Employee] => Map[String, Double] = {
  employees =>
    employees
      .filter(filter)
      .groupBy(groupBy)
      .view.mapValues(aggregate)
      .toMap
}

// 创建特定处理器
val seniorEmployeeAvgSalary = createEmployeeProcessor(
  _.years >= 5,                              // 过滤条件
  _.department,                              // 分组条件
  emps => emps.map(_.salary).sum / emps.length  // 聚合函数
)

函数组合示例

scala
// 字符串处理函数
val trim = (s: String) => s.trim
val toLowerCase = (s: String) => s.toLowerCase
val removeSpaces = (s: String) => s.replaceAll("\\s+", "")
val capitalize = (s: String) => s.capitalize

// 组合函数
def pipe[A, B, C, D](f1: A => B, f2: B => C, f3: C => D): A => D = {
  a => f3(f2(f1(a)))
}

val normalizeString = pipe(trim, toLowerCase, removeSpaces)
val formatTitle = pipe(trim, toLowerCase, capitalize)

// 使用示例
val input = "  Hello World  "
val normalized = normalizeString(input)  // "helloworld"
val title = formatTitle(input)           // "Hello world"

// 更复杂的管道
val processText: String => List[String] = { text =>
  text
    .split("\\.")
    .map(trim)
    .filter(_.nonEmpty)
    .map(capitalize)
    .toList
}

val sentences = processText("hello world. this is scala. functional programming is great.")
// List("Hello world", "This is scala", "Functional programming is great")

练习

练习 1:高阶函数实现

scala
object HigherOrderFunctionExercise {
  // 实现 map 函数
  def myMap[A, B](list: List[A], f: A => B): List[B] = {
    list match {
      case Nil => Nil
      case head :: tail => f(head) :: myMap(tail, f)
    }
  }
  
  // 实现 filter 函数
  def myFilter[A](list: List[A], predicate: A => Boolean): List[A] = {
    list match {
      case Nil => Nil
      case head :: tail =>
        if (predicate(head)) head :: myFilter(tail, predicate)
        else myFilter(tail, predicate)
    }
  }
  
  // 实现 reduce 函数
  def myReduce[A](list: List[A], f: (A, A) => A): A = {
    list match {
      case Nil => throw new IllegalArgumentException("Empty list")
      case head :: Nil => head
      case head :: tail => f(head, myReduce(tail, f))
    }
  }
  
  def main(args: Array[String]): Unit = {
    val numbers = List(1, 2, 3, 4, 5)
    
    val doubled = myMap(numbers, (x: Int) => x * 2)
    val evens = myFilter(numbers, (x: Int) => x % 2 == 0)
    val sum = myReduce(numbers, (x: Int, y: Int) => x + y)
    
    println(s"Original: $numbers")
    println(s"Doubled: $doubled")
    println(s"Evens: $evens")
    println(s"Sum: $sum")
  }
}

练习 2:函数式计算器

scala
object FunctionalCalculator {
  type Operation = (Double, Double) => Double
  
  val add: Operation = _ + _
  val subtract: Operation = _ - _
  val multiply: Operation = _ * _
  val divide: Operation = (x, y) => if (y != 0) x / y else throw new ArithmeticException("Division by zero")
  
  def calculate(x: Double, y: Double, op: Operation): Double = op(x, y)
  
  def createCalculator(op: Operation): (Double, Double) => Double = op
  
  // 链式计算
  def chain(initial: Double, operations: List[(Operation, Double)]): Double = {
    operations.foldLeft(initial) { case (acc, (op, value)) =>
      op(acc, value)
    }
  }
  
  def main(args: Array[String]): Unit = {
    // 基本计算
    println(calculate(10, 5, add))      // 15.0
    println(calculate(10, 5, multiply)) // 50.0
    
    // 创建专用计算器
    val adder = createCalculator(add)
    val multiplier = createCalculator(multiply)
    
    println(adder(3, 4))      // 7.0
    println(multiplier(3, 4)) // 12.0
    
    // 链式计算:(10 + 5) * 2 - 3 / 3
    val result = chain(10, List(
      (add, 5),
      (multiply, 2),
      (subtract, 3),
      (divide, 3)
    ))
    println(s"Chain result: $result") // 27.0
  }
}

总结

本章详细介绍了 Scala 中方法和函数的核心概念:

  • 方法定义:基本语法、参数处理、返回类型
  • 函数字面量:匿名函数、函数类型、简化语法
  • 高阶函数:接受和返回函数的方法
  • 柯里化:函数参数的分步应用
  • 递归:尾递归优化和相互递归
  • 函数式编程:纯函数、不可变性、函数组合

掌握这些概念是进行函数式编程的基础,也是理解 Scala 强大表达能力的关键。在下一章中,我们将学习 Scala 闭包,深入了解函数的作用域和变量捕获机制。

本站内容仅供学习和研究使用。