Haskellでsmallptを作ってみた

こんにちはtatsyです。

お盆休みなので最近勉強しているHaskellで[smallpt][1]を実装してみました(←暇か!)。

作成した手順に従って、ちょっとずつ解説してみたいと思います。

Vecクラス(のようなもの)を作る


「クラスのようなもの」と書いたのはHaskellにはオブジェクト指向型言語で言うところのクラスはないからです。Haskell的には独自定義のオブジェクトは代数的データ型といいます。

ちなみに、Haskellでclassというのは型クラスと呼ばれる他の言語でいうところのインターフェースのようなものを指します。まぎらわしいですね。

それはともかく、コードを見てみます。

-- Vec
data Vec = Vec (Double, Double, Double) deriving (Show)
instance (Num Vec) where
    (Vec (x, y, z)) + (Vec (a, b, c)) = Vec (x + a, y + b, z + c)
    (Vec (x, y, z)) - (Vec (a, b, c)) = Vec (x - a, y - b, z - c)
    (Vec (x, y, z)) * (Vec (a, b, c)) = Vec (x * a, y * b, z * c)
    abs = undefined
    signum = undefined
    fromInteger x = Vec (dx, dx, dx) where dx = fromIntegral x

x :: Vec -> Double
x (Vec (x, _, _)) = x

y :: Vec -> Double
y (Vec (_, y, _)) = y

z :: Vec -> Double
z (Vec (_, _, z)) = z

mul :: Vec -> Double -> Vec
mul (Vec (x, y, z)) s = Vec (x * s, y * s, z * s)

dot :: Vec -> Vec -> Double
dot (Vec (x, y, z)) (Vec (a, b, c))  = x * a + y * b + z * c

norm :: Vec -> Vec
norm (Vec (x, y, z)) = Vec (x * invnrm, y * invnrm, z * invnrm)
    where invnrm = 1 / sqrt (x * x + y * y + z * z)

cross :: Vec -> Vec -> Vec
cross (Vec (x, y, z)) (Vec (a, b, c)) = Vec (y * c - b * z, z * a - c * x, x * b - a * y)

まず、最初の数行でVec型の雛形となる定義を書いています。Vec型は足し算とか引き算とか出来てほしいのでNum型クラスを実装してあげます。

ちなみにabsとかsignumは使わないので適当にundefinedにしておきます。

その下は普通の関数定義です。ちなみに、ここではdataを使って新しい型を作っていますが、tupleに対してnewtypeを使ったりしても基本は同じです。

RayとSphereの定義


続いてはRayとSphereを定義します。こちらは単純な代数的データ型の定義なので特に難しいことはなしです。

-- Ray
data Ray = Ray (Vec, Vec) deriving (Show)
type Camera = Ray

org :: Ray -> Vec
org (Ray (org, _)) = org

dir :: Ray -> Vec
dir (Ray (_, dir)) = dir

-- Material
data Refl = Diff
          | Spec
          | Refr
          deriving Show

-- Sphere
data Sphere = Sphere (Double, Vec, Vec, Vec, Refl) deriving (Show)

rad :: Sphere -> Double
rad  (Sphere (rad, _, _, _, _   )) = rad

pos :: Sphere -> Vec
pos  (Sphere (_  , p, _, _, _   )) = p

emit :: Sphere -> Vec
emit (Sphere (_  , _, e, _, _   )) = e

col :: Sphere -> Vec
col  (Sphere (_  , _, _, c, _   )) = c

refl :: Sphere -> Refl
refl (Sphere (_  , _, _, _, refl)) = refl
```

続いてはintersection関数の定義です。これはSphereとRayの交点を求める関数ですが、smallptの関数をナイーブに実装すると次のようになります。

```haskell
 Double
intersect sp ray =
    let op  = (pos sp) - (org ray)
        b   = op `dot` (dir ray)
        det = b * b - (op `dot` op) + ((rad sp) ** 2)
    in
        if det < 0.0
            then inf
            else
                let sqdet = sqrt det
                    t1    = b - sqdet
                    t2    = b + sqdet
                in
                    if t1 > eps
                        then t1
                        else if t2 > eps
                            then t2
                            else inf
```

まあ、交わらないというのをMaybeモナドのNothingとかで表現してもいいんですが、今回はそこまでしなくてもいいかなという感じで、ナイーブに書いてあります。  


## radiance関数での乱数の取り扱い

* * *

今回の実装で一番ハマったポイントといっても過言ではないのがradiance関数における乱数の取り扱いです。

Haskell使いの方々には常識かと思うのですが、Haskellで普通に乱数を使おうとするとIOモナドが出てきます。

一応、解説しておくと、IOモナドというのはHaskellにおいて入出力を行うためのもので、IOを行うと参照透過性が損なわれることから、関数からの戻り値が普通の値ではなく`IO Vec`のようになります。

で、何が問題かというと乱数を取り扱う関数もIOを返してくるんですね。まぁ、乱数には参照透過性はないわけなので当たり前といえば当たり前なのかもですが。

別にIOのまま扱っても良かったのですが、レンダリング結果の画素値を`IO Vec`で持つようにして関数を書くと、IOの中身が積極評価されてメモリをたくさん食ってしまうみたいなのでした。

具体的には

```haskell
pixels <- sequence [ (画素値を計算する関数、IO Vecを返す) | r <- rays ]
```

みたいなリスト内包表記を使っていたのですが、こいつがメモリをバカ食いしてしまたのです。で、評価を遅らせるために、とりあえず

```haskell
let pixels = [ unsafePerformIO (画素値を計算する関数) | r <- rays ]
```

としてやったわけです。これでめでたくメモリのバカ食いは阻止できたのですが`unsafePerformIO`は何ともいえない感じだなぁと。

で、困ってStackOverflowで聞いてみたところ`MonadRandom`という乱数を扱うためのモナドがあるらしい!

こいつを使って実装したradiance関数は次のような感じになります(長いです)。

```haskell
 Rand g Vec
radiance scene ray depth = do
    let (t, i) = (intersects scene ray)
    if inf <= t
        then return (Vec (0, 0, 0))
        else do
            r0 <- nextDouble
            r1 <- nextDouble
            r2 <- nextDouble
            let obj = (scene !! i)
            let c = col obj
            let prob = (max (x c) (max (y c) (z c)))
            if depth >= 5 && r0 >= prob
                then return (emit obj)
                else do
                    let rlt = if depth < 5 then 1 else prob
                    let f = (col obj)
                    let d = (dir ray)
                    let p = (org ray) + (d `mul` t)
                    let n = norm $ p - (pos obj)
                    let nl = if (d `dot` n) < 0.0  then n else (-n)
                    nextRad <- case (refl obj) of
                            Diff -> (radiance scene (Ray (p, ndir)) (succ depth))
                                where th  = 2.0 * pi * r1
                                      r2s = sqrt r2
                                      w = nl
                                      u = norm $ (if (abs (x w)) > eps then Vec (0, 1, 0) else Vec (1, 0, 0)) `cross` w
                                      v = w `cross` u
                                      uu = u `mul` ((cos th) * r2s)
                                      vv = v `mul` ((sin th) * r2s)
                                      ww = w `mul` (sqrt (1.0 - r2))
                                      ndir = norm (uu + vv + ww)

                            Spec -> (radiance scene (Ray (p, ndir)) (succ depth))
                                where ndir = d - (nl `mul` (2.0 * nl `dot` d))

                            Refr -> let rdir = d - (nl `mul` (2.0 * nl `dot` d))
                                        into = (n `dot` nl) > 0
                                        nnt  = if into then (nc / nt) else (nt / nc)
                                        ddn  = d `dot` nl
                                        cos2t = 1.0 - nnt * nnt * (1.0 - ddn * ddn)
                                    in
                                        if cos2t < 0.0
                                            then (radiance scene (Ray (p, rdir)) (succ depth))
                                            else
                                                let tdir = norm $ ((d `mul` nnt) -) $ n `mul` ((if into then 1 else -1) * (ddn * nnt + (sqrt cos2t)))
                                                    a = nt - nc
                                                    b = nt + nc
                                                    r0 = (a * a) / (b * b)
                                                    c = 1.0 - (if into then -ddn else (tdir `dot` n))
                                                    re = r0 + (1 - r0) * (c ** 5)
                                                    tr = 1.0 - re
                                                    pp = 0.25 + 0.5 * re
                                                in
                                                    if depth > 2
                                                        then
                                                            if r1 < pp
                                                               then fmap (`mul` (re / pp)) (radiance scene (Ray (p, rdir)) (succ depth))
                                                               else fmap (`mul` (tr / (1.0 - pp))) (radiance scene (Ray (p, tdir)) (succ depth))
                                                        else
                                                            (\r t -> (r `mul` re) + (r `mul` tr)) <$> (radiance scene (Ray (p, rdir)) (succ depth)) <*> (radiance scene (Ray (p, tdir)) (succ depth))
                    return $ (emit obj) + ((f * nextRad) `mul` (1/rlt))
```

はい、もう`if`と`else`がネストしまくって死にそうですね。時間があったらどうにかします。  


## 実行結果

* * *

さてさて、これで主要な部分は説明できたはずですので、結果を見てみます。

![haskell-smallptの結果][2] 

はい、確かにレイトレできてますね。素晴らしい!

## まとめ

Haskellでレイトレするとかいう変なものを作って、誰が読むんだ?という気もしなくはないですが、一応目標達成です。

全体のコードはいつも通り自分のGitHubに上げておきますので、宜しければご覧ください。

https://github.com/tatsy/haskell-smallpt.git

最後までお読み頂きありがとうございました!

 [1]: http://www.kevinbeason.com/smallpt/
 [2]: https://raw.githubusercontent.com/tatsy/haskell-smallpt/master/result.png