20 Jul 2016

Scala: Pattern matching & PartialFunction

Scala cung cấp tính năng pattern matching (và được sử dụng rất nhiều), tương ứng với switch của Java. Nhưng syntax có khác đi một chút.


Trong Java chúng ta dùng switch như  bên dưới:
String month;
switch (myNumber) {
  case 1:  month = "thang 1";
  break;
  case 2: month = "thang 2";
  break;
  case 3:
  case 4:
    default: month = "xxxx";
}

return month;
General syntax của swicth trong Java có dạng switch (selector) { alternatives } (1).
Đối với switch của Java, khi dừng câu lệnh switch, chúng ta phải dùng break, nếu không, khi gặp pattern được match, switch sẽ "fall through" đến các case tiếp theo cho đến khi nào gặp break, rất dễ xảy ra lỗi nếu như chúng ta không để ý, và IDE cũng không warning trong trường hợp có return type hợp lệ.


Dùng scala, khi sử dụng pattern matching, syntax sẽ là selector match { alternatives }

Pattern matching của scala cũng sử dụng từ khoá case như Java, tiếp theo là pattern muốn match, và dấu => để tách biệt pattern và biểu thức dùng để xử lí khi pattern được match. match trong scala khác biệt so với Java:
1. đầu tiên do pattern matching của Scala không "fall through" đến các case matching khác, nên không cần dùng break
2. Mỗi một biểu thức alternative trong scala đều phải sinh ra giá trị, cho dù rơi vào default case 
3. Đối với Java, nếu input không rơi vào case nào (không được match) và không có case default, sẽ không có lỗi xảy ra, nhưng có thể sẽ sinh ra lỗi về sau. Đối với Scala, nếu sử dụng pattern matching, mỗi khi dùng match chúng ta buộc phải return giá trị, dù input không match bất cứ case nào đi nữa, nếu không sẽ sinh ra exception MatchError.

Với ví dụ trên, khi dùng với scala sẽ như sau:

val monthName: String = month match {
  case 1 => "thang 1"  case 2 => "thang 2"  case _ => "xxx"}

Pattern matching cuả scala có thể dùng với nhiều dạng pattern: wildcard, constant (như ví dụ trên, 1, 2 là constant, _ là wildcard), case class, type, tuple, sequence, constructor...

Do vậy chúng ta hoàn toàn có thể sử dụng các cách thức sau đối với pattern matching:

case 1 => "thang 1"
hay là đối với list
case _ :: y :: _ if y > 1 => y
Với type
case s: String => s.length
v.v    

Một vấn đề khi sử dụng pattern matching là chúng ta thường không kiểm soát được input, và compiler không có cách nào để thông báo khi chúng ta dùng logic sai, một ví dụ đối với if else thông thường trong scala:

def myFunc(seq : List[Int]): Int = {
  if (seq.length > 1) seq(2)
  else 0}



Khi sử dụng myFunc với input là List Int có length nhỏ hơn 1 hoặc lớn hơn 3 sẽ cho kết quả đúng, nhưng nếu như input là một List với length bằng 2 chương trình sẽ crash:

scala> def myFunc(seq : List[Int]): Int = {
  |     if (seq.length > 1) seq(2)
  |     else 0  |   }
myFunc: (seq: List[Int])Int

scala> myFunc(List(2, 3, 4))
res6: Int = 4
scala> myFunc(List(2, 4))
java.lang.IndexOutOfBoundsException: 2at scala.collection.LinearSeqOptimized$class.apply(LinearSeqOptimized.scala:65)
at scala.collection.immutable.List.apply(List.scala:84)
at .myFunc(<console>:12)  ... 32 elided

Tương tự nếu sử dụng pattern matching mà chúng ta không cover hết tất cả các case, sẽ gặp phải Runtime exception. 

scala> val second: List[Int] => Int = { case _ :: y :: _  if y > 1 => y }
second: List[Int] => Int = <function1>
scala> second(List(1, 2))res12: Int = 2
scala> second(List(1))scala.MatchError: List(1) 
(of class scala.collection.immutable.$colon$colon)at 
$anonfun$1.apply(<console>:11)  
at $anonfun$1.apply(<console>:11)    ... 32 elided


Để khắc phục lỗi này, scala cung cấp PartialFunction.
val second: PartialFunction[List[Int],Int] = {
  case _ :: y :: _  if y > 1 => y
}

Hàm trên khi scala compile sẽ được dịch ra tương tự như 

val second = new PartialFunction[List[Int], Int]  {
  def apply(xs: List[Int]) = xs match {
    case _ :: y :: _  if y > 1 => y
  }

  def isDefinedAt(xs: List[Int]) = xs match {
    case _ :: y :: _ => true    
    case _ => false  
  }
}


Như vậy Partial function cung cấp một method dùng để kiểm tra function matcher của chúng ta có làm việc được với input cụ thể nào đó không. Nếu như không kiểm tra mà cứ dùng như ví dụ ban đầu, chúng ta vẫn sẽ gặp Runtime exception, do vậy khi define function matcher với Partial Function, thì lúc sử dụng, chúng ta dùng method isDefinedAt để kiểm tra giá trị input trước:

scala> val second: PartialFunction[List[Int],Int] = {
    case _ :: y :: _  if y > 1 => y
   }
second: PartialFunction[List[Int],Int] = <function1>

scala>   val l = 1 :: 2 :: 3 :: Nil
l: List[Int] = List(1, 2, 3)

scala>

scala>   val a = if (second.isDefinedAt(l)) second(l) else 0
a: Int = 2
scala>

scala>   val l1 = 1 :: Nil
l1: List[Int] = List(1)

scala>

scala>   val b = if (second.isDefinedAt(l1)) second(l) else 0
b: Int = 0

Ứng dụng:

Scala cung cấp các method để làm việc trên scala standard collection như Seq, List, Set... Đối với các method như filter, map chấp nhận input là các Anonymous function với parameters phù hợp để sinh ra các collection mới theo tiêu chí đặt ra trong Anonymous function, ví dụ khi làm việc với một List[Int] đơn giản như bên dưới:

val l = 1 :: 2 :: 3 :: Nil

Có thể dùng map và filter như sau:

scala> l.map({ case x => x*2 })
res7: List[Int] = List(2, 4, 6)

scala> l.filter(_ > 1)
res8: List[Int] = List(2, 3)

Để ý cấu trúc map method bên trên, ta sử dụng một Anonymous function dùng pattern matching là bất kì input nào, do đó tất cả các case đều được tính đến. Nhưng nếu như có nhu cầu chỉ lấy ra một List các số Int lớn hơn 1 trong List ban đầu, khi dùng map sẽ gặp lỗi:

scala> l.map({ case x if x > 1 => x })
scala.MatchError: 1 (of class java.lang.Integer)
at $anonfun$1.apply$mcII$sp(<console>:13)  
at $anonfun$1.apply(<console>:13)    
at $anonfun$1.apply(<console>:13)      
at scala.collection.immutable.List.map
(List.scala:273)      
... 32 elided

Vì sự khác nhau giữa switch của Java và match của Scala như đã nói từ đầu bài, tất cả các case đều cần được tính đến. Phần matching chúng ta dùng Pattern guard (if x > 1) để lọc giá trị lớn hơn 1, nhưng chưa match giá trị bằng hoặc nhỏ hơn 1, do vậy sẽ gặp MatchError Exception.

Để giải quyết vấn đề này, có thể dùng collect method với Partial Function:
scala> val moreThanOne: PartialFunction[Int,Int] = { case x if x > 1 => x }
moreThanOne: PartialFunction[Int,Int] = <function1>
scala> l.collect(moreThanOne)
res12: List[Int] = List(2, 3)

Như vậy với Partial Function chúng ta có thể giải quyết bài toán này theo một cách cực kì đơn giản.

 (1) Programming in Scala 2nd Edition

Disqus