Haskellでsmallptを作ってみた
こんにちはtatsyです。
お盆休みなので最近勉強しているHaskellで[smallpt][1]を実装してみました(←暇か!)。
作成した手順に従って、ちょっとずつ解説してみたいと思います。
Vecクラス(のようなもの)を作る
「クラスのようなもの」と書いたのはHaskellにはオブジェクト指向型言語で言うところのクラスはないからです。Haskell的には独自定義のオブジェクトは代数的データ型といいます。
ちなみに、Haskellでclassというのは型クラスと呼ばれる他の言語でいうところのインターフェースのようなものを指します。まぎらわしいですね。
それはともかく、コードを見てみます。
-- Vec
data Vec = Vec (Double, Double, Double) deriving (Show)
instance (Num Vec) where
(Vec (x, y, z)) + (Vec (a, b, c)) = Vec (x + a, y + b, z + c)
(Vec (x, y, z)) - (Vec (a, b, c)) = Vec (x - a, y - b, z - c)
(Vec (x, y, z)) * (Vec (a, b, c)) = Vec (x * a, y * b, z * c)
abs = undefined
signum = undefined
fromInteger x = Vec (dx, dx, dx) where dx = fromIntegral x
x :: Vec -> Double
x (Vec (x, _, _)) = x
y :: Vec -> Double
y (Vec (_, y, _)) = y
z :: Vec -> Double
z (Vec (_, _, z)) = z
mul :: Vec -> Double -> Vec
mul (Vec (x, y, z)) s = Vec (x * s, y * s, z * s)
dot :: Vec -> Vec -> Double
dot (Vec (x, y, z)) (Vec (a, b, c)) = x * a + y * b + z * c
norm :: Vec -> Vec
norm (Vec (x, y, z)) = Vec (x * invnrm, y * invnrm, z * invnrm)
where invnrm = 1 / sqrt (x * x + y * y + z * z)
cross :: Vec -> Vec -> Vec
cross (Vec (x, y, z)) (Vec (a, b, c)) = Vec (y * c - b * z, z * a - c * x, x * b - a * y)
まず、最初の数行でVec型の雛形となる定義を書いています。Vec型は足し算とか引き算とか出来てほしいのでNum型クラスを実装してあげます。
ちなみにabsとかsignumは使わないので適当にundefinedにしておきます。
その下は普通の関数定義です。ちなみに、ここではdataを使って新しい型を作っていますが、tupleに対してnewtypeを使ったりしても基本は同じです。
RayとSphereの定義
続いてはRayとSphereを定義します。こちらは単純な代数的データ型の定義なので特に難しいことはなしです。
-- Ray
data Ray = Ray (Vec, Vec) deriving (Show)
type Camera = Ray
org :: Ray -> Vec
org (Ray (org, _)) = org
dir :: Ray -> Vec
dir (Ray (_, dir)) = dir
-- Material
data Refl = Diff
| Spec
| Refr
deriving Show
-- Sphere
data Sphere = Sphere (Double, Vec, Vec, Vec, Refl) deriving (Show)
rad :: Sphere -> Double
rad (Sphere (rad, _, _, _, _ )) = rad
pos :: Sphere -> Vec
pos (Sphere (_ , p, _, _, _ )) = p
emit :: Sphere -> Vec
emit (Sphere (_ , _, e, _, _ )) = e
col :: Sphere -> Vec
col (Sphere (_ , _, _, c, _ )) = c
refl :: Sphere -> Refl
refl (Sphere (_ , _, _, _, refl)) = refl
```
続いてはintersection関数の定義です。これはSphereとRayの交点を求める関数ですが、smallptの関数をナイーブに実装すると次のようになります。
```haskell
Double
intersect sp ray =
let op = (pos sp) - (org ray)
b = op `dot` (dir ray)
det = b * b - (op `dot` op) + ((rad sp) ** 2)
in
if det < 0.0
then inf
else
let sqdet = sqrt det
t1 = b - sqdet
t2 = b + sqdet
in
if t1 > eps
then t1
else if t2 > eps
then t2
else inf
```
まあ、交わらないというのをMaybeモナドのNothingとかで表現してもいいんですが、今回はそこまでしなくてもいいかなという感じで、ナイーブに書いてあります。
## radiance関数での乱数の取り扱い
* * *
今回の実装で一番ハマったポイントといっても過言ではないのがradiance関数における乱数の取り扱いです。
Haskell使いの方々には常識かと思うのですが、Haskellで普通に乱数を使おうとするとIOモナドが出てきます。
一応、解説しておくと、IOモナドというのはHaskellにおいて入出力を行うためのもので、IOを行うと参照透過性が損なわれることから、関数からの戻り値が普通の値ではなく`IO Vec`のようになります。
で、何が問題かというと乱数を取り扱う関数もIOを返してくるんですね。まぁ、乱数には参照透過性はないわけなので当たり前といえば当たり前なのかもですが。
別にIOのまま扱っても良かったのですが、レンダリング結果の画素値を`IO Vec`で持つようにして関数を書くと、IOの中身が積極評価されてメモリをたくさん食ってしまうみたいなのでした。
具体的には
```haskell
pixels <- sequence [ (画素値を計算する関数、IO Vecを返す) | r <- rays ]
```
みたいなリスト内包表記を使っていたのですが、こいつがメモリをバカ食いしてしまたのです。で、評価を遅らせるために、とりあえず
```haskell
let pixels = [ unsafePerformIO (画素値を計算する関数) | r <- rays ]
```
としてやったわけです。これでめでたくメモリのバカ食いは阻止できたのですが`unsafePerformIO`は何ともいえない感じだなぁと。
で、困ってStackOverflowで聞いてみたところ`MonadRandom`という乱数を扱うためのモナドがあるらしい!
こいつを使って実装したradiance関数は次のような感じになります(長いです)。
```haskell
Rand g Vec
radiance scene ray depth = do
let (t, i) = (intersects scene ray)
if inf <= t
then return (Vec (0, 0, 0))
else do
r0 <- nextDouble
r1 <- nextDouble
r2 <- nextDouble
let obj = (scene !! i)
let c = col obj
let prob = (max (x c) (max (y c) (z c)))
if depth >= 5 && r0 >= prob
then return (emit obj)
else do
let rlt = if depth < 5 then 1 else prob
let f = (col obj)
let d = (dir ray)
let p = (org ray) + (d `mul` t)
let n = norm $ p - (pos obj)
let nl = if (d `dot` n) < 0.0 then n else (-n)
nextRad <- case (refl obj) of
Diff -> (radiance scene (Ray (p, ndir)) (succ depth))
where th = 2.0 * pi * r1
r2s = sqrt r2
w = nl
u = norm $ (if (abs (x w)) > eps then Vec (0, 1, 0) else Vec (1, 0, 0)) `cross` w
v = w `cross` u
uu = u `mul` ((cos th) * r2s)
vv = v `mul` ((sin th) * r2s)
ww = w `mul` (sqrt (1.0 - r2))
ndir = norm (uu + vv + ww)
Spec -> (radiance scene (Ray (p, ndir)) (succ depth))
where ndir = d - (nl `mul` (2.0 * nl `dot` d))
Refr -> let rdir = d - (nl `mul` (2.0 * nl `dot` d))
into = (n `dot` nl) > 0
nnt = if into then (nc / nt) else (nt / nc)
ddn = d `dot` nl
cos2t = 1.0 - nnt * nnt * (1.0 - ddn * ddn)
in
if cos2t < 0.0
then (radiance scene (Ray (p, rdir)) (succ depth))
else
let tdir = norm $ ((d `mul` nnt) -) $ n `mul` ((if into then 1 else -1) * (ddn * nnt + (sqrt cos2t)))
a = nt - nc
b = nt + nc
r0 = (a * a) / (b * b)
c = 1.0 - (if into then -ddn else (tdir `dot` n))
re = r0 + (1 - r0) * (c ** 5)
tr = 1.0 - re
pp = 0.25 + 0.5 * re
in
if depth > 2
then
if r1 < pp
then fmap (`mul` (re / pp)) (radiance scene (Ray (p, rdir)) (succ depth))
else fmap (`mul` (tr / (1.0 - pp))) (radiance scene (Ray (p, tdir)) (succ depth))
else
(\r t -> (r `mul` re) + (r `mul` tr)) <$> (radiance scene (Ray (p, rdir)) (succ depth)) <*> (radiance scene (Ray (p, tdir)) (succ depth))
return $ (emit obj) + ((f * nextRad) `mul` (1/rlt))
```
はい、もう`if`と`else`がネストしまくって死にそうですね。時間があったらどうにかします。
## 実行結果
* * *
さてさて、これで主要な部分は説明できたはずですので、結果を見てみます。
![haskell-smallptの結果][2]
はい、確かにレイトレできてますね。素晴らしい!
## まとめ
Haskellでレイトレするとかいう変なものを作って、誰が読むんだ?という気もしなくはないですが、一応目標達成です。
全体のコードはいつも通り自分のGitHubに上げておきますので、宜しければご覧ください。
https://github.com/tatsy/haskell-smallpt.git
最後までお読み頂きありがとうございました!
[1]: http://www.kevinbeason.com/smallpt/
[2]: https://raw.githubusercontent.com/tatsy/haskell-smallpt/master/result.png