Type Class over Coupled Monads

Scala has support for built in language level Monads. They go a little something like this:

case class MyMonad[Type](item: Type){
  def map[Ret](f: Type => Ret): MyMonad[Ret] = ???
  def flatMap[Ret](f: Type => MyMonad[Ret]): MyMonad[Ret] = ??? 
}

The issue here is the Monadic flatMap (bind, >>=, whatever) is coupled to the data structure; there is a single operation which defines the Monadic functionality across all the code which uses the data structure. For instance, the Buffer flatMap flattens the Buffer[Buffer] into a Buffer by concatenating the first with the second with the third ... This flattening style is the only flattening style Buffer provides with Scala's built in, for comprehension Monadic chaining. This is very limiting.

If one wanted a different kind of flattening such as a round-robin: taking the first element from each Buffer, then the second from each ... One would need to implement their own Buffer structure and bake into that new Buffer a flatMap operation detailing the round-robin. Now, there are two separate Buffer data structures presenting two separate methods of Monadic chaining. Each function which accepts a Buffer needs to enforce which Buffer type it needs to ensure the flatMap provides the necessary functionality. Type Classes provide a solution to this issue.

Cats to the rescue

The Typelevel Cats library defines a Monad Type Class and implicits for syntax. Keeping to Buffer for our example: we would have two Monads over Buffer which look like:

import scala.collection.mutable.Buffer
import cats.Monad
import cats.implicits._

val mMonad = new Monad[Buffer]{
  override def pure[Type](items: Type): Buffer[Type] = pure(Seq(items))
  def pure[Type](items: Seq[Type]): Buffer[Type] = Buffer(items: _*)
  override def flatMap[Arg, Ret](c: Buffer[Arg])(f: Arg => Buffer[Ret]): Buffer[Ret] = {
    c.flatMap(f)
  }
}
val rrMonad = new Monad[Buffer]{
  override def pure[Type](items: Type): Buffer[Type] = pure(Seq(items))
  def pure[Type](items: Seq[Type]): Buffer[Type] = Buffer(items: _*)
  override def flatMap[Arg, Ret](c: Buffer[Arg])(f: Arg => Buffer[Ret]): Buffer[Ret] = {
    val mapped = c.map(f)
    val tupled = mapped.map(i => (i, i.size))
    val max = tupled.unzip._2.max
    val seq = (0 to max).flatMap{index =>
      tupled.collect{
        case (seq, size) if size > index => seq(index)
      }
    }
    seq.toBuffer
  }
}

Then to use them:

object simple{
  implicit val monad = mMonad
  def apply(values: Seq[Int]){
    val mMulti = mMonad.pure(values) >>={ int =>
      mMonad.pure((0 to int))
    }
    println(mMulti)
  }
}
object roundRobin{
  implicit val monad = rrMonad
  
  def apply(values: Seq[Int]){
    val rrMulti = rrMonad.pure(values) >>={int =>
      rrMonad.pure((0 to int))
    }
    println(rrMulti)
  }
}
val values = Seq(1,2,3,4,5)
simple(values)
roundRobin(values)

And that's it! Ad-hoc Monadic chaining any way you'd like without redefining classes or parameterized methods.