90 lines
2.3 KiB
Scala
90 lines
2.3 KiB
Scala
abstract class IntSet() {
|
|
def add(x: Int): IntSet
|
|
def contains(x: Int): Boolean
|
|
def foreach(f: Int => Unit): Unit
|
|
def union(other: IntSet): IntSet
|
|
def intersect(other: IntSet): IntSet
|
|
def excl(x: Int): IntSet
|
|
def +(x: Int): IntSet = this.add(x)
|
|
def -(x: Int): IntSet = this.excl(x)
|
|
}
|
|
|
|
class NonEmpty(elem: Int, left: IntSet, right: IntSet) extends IntSet() {
|
|
def add(x: Int): IntSet = {
|
|
if (x < elem) new NonEmpty(elem, left add x, right)
|
|
else if (x > elem) new NonEmpty(elem, left, right add x)
|
|
else this
|
|
}
|
|
|
|
def contains(x: Int): Boolean =
|
|
if (x < elem) left contains x
|
|
else if (x > elem) right contains x
|
|
else true
|
|
|
|
override def toString = "(" + left + "|" + elem + "|" + right + ")"
|
|
|
|
def foreach(f: Int => Unit): Unit = {
|
|
left.foreach(f)
|
|
f(elem)
|
|
right.foreach(f)
|
|
}
|
|
|
|
def union(other: IntSet): IntSet =
|
|
this.left.union(right)
|
|
.union(other)
|
|
.add(this.elem)
|
|
|
|
def intersect(other: IntSet): IntSet = {
|
|
val base = if (other.contains(this.elem)) Empty.add(this.elem)
|
|
else Empty
|
|
base.union(this.left.intersect(other))
|
|
.union(this.right.intersect(other))
|
|
}
|
|
|
|
def excl(x: Int): IntSet = {
|
|
if (x < elem) new NonEmpty(elem, this.left.excl(x), this.right)
|
|
else if (x > elem) new NonEmpty(elem, this.left, this.right.excl(x))
|
|
else this.left.union(this.right)
|
|
}
|
|
}
|
|
|
|
object Empty extends IntSet() {
|
|
def add(x: Int): IntSet = new NonEmpty(x, Empty, Empty)
|
|
def contains(x: Int): Boolean = false
|
|
override def toString = "-"
|
|
|
|
def foreach(f: Int => Unit): Unit = {}
|
|
def union(other: IntSet): IntSet = other
|
|
def intersect(other: IntSet): IntSet = Empty
|
|
def excl(x: Int): IntSet = this
|
|
}
|
|
|
|
val t1 = Empty
|
|
val t2 = t1 add 3
|
|
val t3 = t1 add 4 add 5 add 2 add 6
|
|
t3 contains 4
|
|
|
|
println(Empty) // prints -
|
|
println(Empty.add(3)) // prints (-|3|-)
|
|
println(Empty.add(3).add(2)) // prints ((-|2|-)|3|-)
|
|
|
|
val s = Empty.add(3).add(2).add(7).add(1)
|
|
s.foreach(println)
|
|
|
|
(Empty.add(3).add(2).add(6).add(1)) foreach (x => print(x+1 + ", "))
|
|
// 2, 3, 4, 7,
|
|
// Because a BST is always sorted
|
|
|
|
val s2 = Empty.add(3).add(4).add(6).add(2)
|
|
|
|
s.union(s2)
|
|
s.intersect(s2)
|
|
|
|
s2.excl(0)
|
|
s2.excl(6)
|
|
s2.excl(2)
|
|
s2.excl(4)
|
|
|
|
val o1 = Empty + 3 + 4 + 12 + 5
|
|
val o2 = (o1 - 3 - 4)
|
|
o2 // ((-|5|-)|12|-) |