Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create StateT type #25

Merged
merged 13 commits into from
Jun 13, 2017
6 changes: 6 additions & 0 deletions katz/src/main/kotlin/katz/data/State.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package katz

object State {
operator fun <S, A> invoke(run: (S) -> Tuple2<S, A>, MF: Monad<Id.F> = Id): StateT<Id.F, S, A> =
StateT(MF, Id(run.andThen { Id(it) }))
}
89 changes: 89 additions & 0 deletions katz/src/main/kotlin/katz/data/StateT.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package katz

typealias StateTKind<F, S, A> = HK3<StateT.F, F, S, A>
typealias StateTF<F, S> = HK2<StateT.F, F, S>

typealias StateTFun<F, S, A> = (S) -> HK<F, Tuple2<S, A>>
typealias StateTFunKind<F, S, A> = HK<F, StateTFun<F, S, A>>

fun <F, S, A> StateTKind<F, S, A>.ev(): StateT<F, S, A> =
this as StateT<F, S, A>

class StateT<F, S, A>(
val MF: Monad<F>,
val runF: StateTFunKind<F, S, A>
) : StateTKind<F, S, A> {
class F private constructor()

companion object {
inline operator fun <reified F, S, A> invoke(run: StateTFunKind<F, S, A>, MF: Monad<F> = monad<F>()): StateT<F, S, A> =
StateT(MF, run)
}

fun <B> map(f: (A) -> B): StateT<F, S, B> =
transform { (s, a) -> Tuple2(s, f(a)) }

fun <B, Z> map2(sb: StateT<F, S, B>, fn: (A, B) -> Z): StateT<F, S, Z> =
applyF(MF.map2(runF, sb.runF) { (ssa, ssb) ->
ssa.andThen { fsa ->
MF.flatMap(fsa) { (s, a) ->
MF.map(ssb(s)) { (s, b) -> Tuple2(s, fn(a, b)) }
}
}
}, MF)

fun <B, Z> map2Eval(sb: Eval<StateT<F, S, B>>, fn: (A, B) -> Z): Eval<StateT<F, S, Z>> =
MF.map2Eval(runF, sb.map { it.runF }) { (ssa, ssb) ->
ssa.andThen { fsa ->
MF.flatMap(fsa) { (s, a) ->
MF.map(ssb((s))) { (s, b) -> Tuple2(s, fn(a, b)) }
}
}
}.map { applyF(it, MF) }

fun <B> product(sb: StateT<F, S, B>): StateT<F, S, Tuple2<A, B>> =
map2(sb) { a, b -> Tuple2(a, b) }

fun <B> flatMap(fas: (A) -> StateTKind<F, S, B>): StateT<F, S, B> =
applyF(
MF.map(runF) { sfsa ->
sfsa.andThen { fsa ->
MF.flatMap(fsa) {
fas(it.b).ev().run(it.a)
}
}
}
, MF)

fun <B> flatMapF(faf: (A) -> HK<F, B>): StateT<F, S, B> =
applyF(
MF.map(runF) { sfsa ->
sfsa.andThen { fsa ->
MF.flatMap(fsa) { (s, a) ->
MF.map(faf(a)) { b -> Tuple2(s, b) }
}
}
}
, MF)

fun <B> transform(f: (Tuple2<S, A>) -> Tuple2<S, B>): StateT<F, S, B> =
applyF(
MF.map(runF) { sfsa ->
sfsa.andThen { fsa ->
MF.map(fsa, f)
}
}, MF)

fun <F, S, A> applyF(runF: StateTFunKind<F, S, A>, MF: Monad<F>): StateT<F, S, A> =
StateT(MF, runF)

fun run(initial: S): HK<F, Tuple2<S, A>> =
MF.flatMap(runF) { f -> f(initial) }

fun runA(s: S): HK<F, A> =
MF.map(run(s)) { it.b }

fun runS(s: S): HK<F, S> =
MF.map(run(s)) { it.a }
}

35 changes: 35 additions & 0 deletions katz/src/main/kotlin/katz/instances/StateTMonad.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package katz

data class StateTMonad<F, S>(val MF: Monad<F>) : Monad<StateTF<F, S>>, Typeclass {

override fun <A, B> flatMap(fa: HK<StateTF<F, S>, A>, f: (A) -> HK<StateTF<F, S>, B>): StateT<F, S, B> =
fa.ev().flatMap(f)

override fun <A, B> map(fa: HK<StateTF<F, S>, A>, f: (A) -> B): StateT<F, S, B> =
fa.ev().map(f)

override fun <A> pure(a: A): StateT<F, S, A> =
StateT(MF, MF.pure({ s: S -> MF.pure(Tuple2(s, a)) }))

override fun <A, B> ap(fa: HK<StateTF<F, S>, A>, ff: HK<StateTF<F, S>, (A) -> B>): StateT<F, S, B> =
ff.ev().map2(fa.ev()) { f, a -> f(a) }

override fun <A, B, Z> map2(fa: HK<StateTF<F, S>, A>, fb: HK<StateTF<F, S>, B>, f: (Tuple2<A, B>) -> Z): StateT<F, S, Z> =
fa.ev().map2(fb.ev(), { a, b -> f(Tuple2(a, b)) })

@Suppress("UNCHECKED_CAST")
override fun <A, B, Z> map2Eval(fa: HK<StateTF<F, S>, A>, fb: Eval<HK<StateTF<F, S>, B>>, f: (Tuple2<A, B>) -> Z): Eval<StateT<F, S, Z>> =
fa.ev().map2Eval(fb as Eval<StateT<F, S, B>>) { a, b -> f(Tuple2(a, b)) }

override fun <A, B> product(fa: HK<StateTF<F, S>, A>, fb: HK<StateTF<F, S>, B>): HK<StateTF<F, S>, Tuple2<A, B>> =
fa.ev().product(fb.ev())

override fun <A, B> tailRecM(a: A, f: (A) -> HK<StateTF<F, S>, Either<A, B>>): StateT<F, S, B> =
StateT(MF, MF.pure({ s: S ->
MF.tailRecM<Tuple2<S, A>, Tuple2<S, B>>(Tuple2(s, a), { (s, a) ->
MF.map(f(a).ev().run(s)) { (s, ab) ->
ab.bimap({ a -> Tuple2(s, a) }, { b -> Tuple2(s, b) })
}
})
}))
}
49 changes: 49 additions & 0 deletions katz/src/test/kotlin/katz/data/StateTests.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Copyright (C) 2017 The Katz Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package katz

import io.kotlintest.KTestJUnitRunner
import io.kotlintest.matchers.shouldBe
import org.junit.runner.RunWith

@RunWith(KTestJUnitRunner::class)
class StateTests : UnitSpec() {

private val addOne = State<Int, Int>({ n -> Tuple2(n * 2, n) })

init {
"addOne.run(1) should return Pair(2, 1)" {
addOne.run(1).ev().value shouldBe Tuple2(2, 1)
}

"addOne.map(n -> n).run(1) should return same Pair(2, 1)" {
addOne.map { n -> n }.run(1).ev().value shouldBe Tuple2(2, 1)
}

"addOne.map(n -> n.toString).run(1) should return same Pair(2, \"1\")" {
addOne.map(Int::toString).run(1).ev().value shouldBe Tuple2(2, "1")
}

"addOne.runS(1) should return 2" {
addOne.runS(1).ev().value shouldBe 2
}

"addOne.runA(1) should return 1" {
addOne.runA(1).ev().value shouldBe 1
}
}
}