ScalaでJDBCのレコードを取り出す話

JDBC_Scalaメモ

ScalaからJDBC経由でMySQLにつなげた時の健忘録

まずはconnを確立

chatacterEncodingとTimezoneを伝えないと怒られる。SSLは使わないことを明示しないと怒られる。

// OperateJDBC.scala

import java.sql._  // これだとArray使えないので注意
object OperateJDBC {

  val driver = "com.mysql.cj.jdbc.Driver"
  Class.forName(driver)
  val url = "jdbc:mysql://localhost:3306/sakila" +
    "?characterEncoding=UTF-8&" +
    "serverTimezone=JST&" +
    "useSSL=false&" +
    "requireSSL=false"
  val conn = DriverManager.getConnection(url, ___, ___)
}

SELECTからとりだす

データは以下の通り

actor_ID first_name
1 PENELOPE
2 NICK
3 ED
4 JENNIFER
5 JOHNNY

まずはこんな感じでベタに実装してみる。

// OperateJDBC.scala

  // mapper。とりあえずSMALLINTだけ。
  def patternMatchObj(obj: Object, cName: String, cTypeName: String)
: (String, Any) = {
    (obj, cName, cTypeName) match {
      case (Some(o), cname, "SMALLINT UNSIGNED") =>
        cname -> o.toString.toInt
      case (Some(o), cname, _) => 
        cname -> o.toString
      case (None, cname, _) => 
        cname -> None
    }
  }

  def select(sql: String): Seq[Map[String, Any]] = {
    val stmt = conn.createStatement()
    stmt.execute(sql)
    val rs: ResultSet = stmt.getResultSet
    val md: ResultSetMetaData = rs.getMetaData
    val colCount: Int = md.getColumnCount
    val colName: Seq[String] = (1 to colCount)
      .map(md.getColumnName)
    val colTypeName: Seq[String] = (1 to colCount)
      .map(md.getColumnTypeName)

    var ret: Seq[Map[String, Any]] = Seq()
    while (rs.next()) {
      ret = ret :+ (1 to colCount)
        .map(n => Option(rs.getObject(n)))
        .zipWithIndex
        .map(t => patternMatchObj(t._1, colName(t._2), colTypeName(t._2)))
        .toMap
    }
    rs.close()
    stmt.close()

    ret
  }

テスト用コードで動かす

// TestJDBC.scala
object TestJDBC {
  def main(args: Array[String]): Unit = {
    val ret = OperateJDBC
      .select("SELECT actor_id, first_name FROM actor WHERE actor_id BETWEEN 1 AND 5")
    ret.toString.replaceAll("\\),", "),\r\n").tap(println)
  }

  implicit class RichObj[T](obj:T){
      def tap[U](f: T => U):T = {f(obj);obj;}
    }
}

出力

List(Map(actor_id -> 1, first_name -> PENELOPE),
 Map(actor_id -> 2, first_name -> NICK),
 Map(actor_id -> 3, first_name -> ED),
 Map(actor_id -> 4, first_name -> JENNIFER),
 Map(actor_id -> 5, first_name -> JOHNNY))

takeWhileを使ったパターン

varとか使っていてScala警察におこられそうなので、以下を参考にして実装

http://kmizu.hatenablog.com/entry/20121128/1354115941

  def select(sql: String):Seq[Map[String, Any]] = {
    val stmt = conn.createStatement()
    stmt.execute(sql)
    val rs: ResultSet = stmt.getResultSet
    val md: ResultSetMetaData = rs.getMetaData
    val colCount: Int = md.getColumnCount
    val colName: Seq[String] = (1 to colCount).map(md.getColumnName)
    val colTypeName: Seq[String] = (1 to colCount).map(md.getColumnTypeName)

    Iterator
      .continually{
        (1 to colCount)
          .map(n => Option(rs.getObject(n)))
          .zipWithIndex
          .map(t => patternMatchObj(t._1, colName(t._2), colTypeName(t._2)))
          .toMap
      }
      .takeWhile(x => rs.next())
      .toSeq
  }

出力。

Exception in thread "main" java.sql.SQLException: Before start of result set

多分takeWhilers.next()する前に判定しているので例外が出る。最初に判定するwhileにしたいけど方法が分からない。

Iterator.continuallyってなんだ

APIDocumentから引用。無限長Iteratorを返すらしい。

def continuallyA: Iterator[A] Creates an infinite-length iterator returning the results of evaluating an > expression.

ぐぐる

無能なので有能な先例をしらべる。

https://stackoverflow.com/questions/9636545/treating-an-sql-resultset-like-a-scala-stream

無名クラスを使ってIteratorのhasNextとnextをoverrideする方法があるらしい。今回だと以下のようになる。

new Iterator[Map[String, Any]] {
    override def hasNext: Boolean = rs.next()
    override def next(): Map[String, Any] = {
    (1 to colCount)
        .map(n => Option(rs.getObject(n)))
        .zipWithIndex
        .map(t => patternMatchObj(t._1, colName(t._2), colTypeName(t._2)))
        .toMap
    }
}.toStream

IteratorToListなんかでimmutableにすると、hasNext=>next()=>hasNext=>next()=>...と呼びだされ、hasNext == falseの時に止まる。

Iterator継承パターン

最終的にこうなった。

val ret = new Iterator[Map[String, Any]] {
    override def hasNext: Boolean = rs.next()
    override def next(): Map[String, Any] = {
    (1 to colCount)
        .map(n => Option(rs.getObject(n)))
        .zipWithIndex
        .map(t => patternMatchObj(t._1, colName(t._2), colTypeName(t._2)))
        .toMap
    }
}.toList
rs.close()
stmt.close()
ret

めでたしめでたし。

List(Map(actor_id -> 1, first_name -> PENELOPE),
 Map(actor_id -> 2, first_name -> NICK),
 Map(actor_id -> 3, first_name -> ED),
 Map(actor_id -> 4, first_name -> JENNIFER),
 Map(actor_id -> 5, first_name -> JOHNNY))