abstract class IntSet() { def add(x: Int): IntSet def contains(x: Int): Boolean def foreach(f: Int => Unit): Unit def union(other: IntSet): IntSet def intersect(other: IntSet): IntSet def excl(x: Int): IntSet def +(x: Int): IntSet = this.add(x) def -(x: Int): IntSet = this.excl(x) } class NonEmpty(elem: Int, left: IntSet, right: IntSet) extends IntSet() { def add(x: Int): IntSet = { if (x < elem) new NonEmpty(elem, left add x, right) else if (x > elem) new NonEmpty(elem, left, right add x) else this } def contains(x: Int): Boolean = if (x < elem) left contains x else if (x > elem) right contains x else true override def toString = "(" + left + "|" + elem + "|" + right + ")" def foreach(f: Int => Unit): Unit = { left.foreach(f) f(elem) right.foreach(f) } def union(other: IntSet): IntSet = this.left.union(right) .union(other) .add(this.elem) def intersect(other: IntSet): IntSet = { val base = if (other.contains(this.elem)) Empty.add(this.elem) else Empty base.union(this.left.intersect(other)) .union(this.right.intersect(other)) } def excl(x: Int): IntSet = { if (x < elem) new NonEmpty(elem, this.left.excl(x), this.right) else if (x > elem) new NonEmpty(elem, this.left, this.right.excl(x)) else this.left.union(this.right) } } object Empty extends IntSet() { def add(x: Int): IntSet = new NonEmpty(x, Empty, Empty) def contains(x: Int): Boolean = false override def toString = "-" def foreach(f: Int => Unit): Unit = {} def union(other: IntSet): IntSet = other def intersect(other: IntSet): IntSet = Empty def excl(x: Int): IntSet = this } val t1 = Empty val t2 = t1 add 3 val t3 = t1 add 4 add 5 add 2 add 6 t3 contains 4 println(Empty) // prints - println(Empty.add(3)) // prints (-|3|-) println(Empty.add(3).add(2)) // prints ((-|2|-)|3|-) val s = Empty.add(3).add(2).add(7).add(1) s.foreach(println) (Empty.add(3).add(2).add(6).add(1)) foreach (x => print(x+1 + ", ")) // 2, 3, 4, 7, // Because a BST is always sorted val s2 = Empty.add(3).add(4).add(6).add(2) s.union(s2) s.intersect(s2) s2.excl(0) s2.excl(6) s2.excl(2) s2.excl(4) val o1 = Empty + 3 + 4 + 12 + 5 val o2 = (o1 - 3 - 4) o2 // ((-|5|-)|12|-)