Files
FunProg-Scala/src/Assignment3/IntSet.sc

90 lines
2.3 KiB
Scala

abstract class IntSet() {
def add(x: Int): IntSet
def contains(x: Int): Boolean
def foreach(f: Int => Unit): Unit
def union(other: IntSet): IntSet
def intersect(other: IntSet): IntSet
def excl(x: Int): IntSet
def +(x: Int): IntSet = this.add(x)
def -(x: Int): IntSet = this.excl(x)
}
class NonEmpty(elem: Int, left: IntSet, right: IntSet) extends IntSet() {
def add(x: Int): IntSet = {
if (x < elem) new NonEmpty(elem, left add x, right)
else if (x > elem) new NonEmpty(elem, left, right add x)
else this
}
def contains(x: Int): Boolean =
if (x < elem) left contains x
else if (x > elem) right contains x
else true
override def toString = "(" + left + "|" + elem + "|" + right + ")"
def foreach(f: Int => Unit): Unit = {
left.foreach(f)
f(elem)
right.foreach(f)
}
def union(other: IntSet): IntSet =
this.left.union(right)
.union(other)
.add(this.elem)
def intersect(other: IntSet): IntSet = {
val base = if (other.contains(this.elem)) Empty.add(this.elem)
else Empty
base.union(this.left.intersect(other))
.union(this.right.intersect(other))
}
def excl(x: Int): IntSet = {
if (x < elem) new NonEmpty(elem, this.left.excl(x), this.right)
else if (x > elem) new NonEmpty(elem, this.left, this.right.excl(x))
else this.left.union(this.right)
}
}
object Empty extends IntSet() {
def add(x: Int): IntSet = new NonEmpty(x, Empty, Empty)
def contains(x: Int): Boolean = false
override def toString = "-"
def foreach(f: Int => Unit): Unit = {}
def union(other: IntSet): IntSet = other
def intersect(other: IntSet): IntSet = Empty
def excl(x: Int): IntSet = this
}
val t1 = Empty
val t2 = t1 add 3
val t3 = t1 add 4 add 5 add 2 add 6
t3 contains 4
println(Empty) // prints -
println(Empty.add(3)) // prints (-|3|-)
println(Empty.add(3).add(2)) // prints ((-|2|-)|3|-)
val s = Empty.add(3).add(2).add(7).add(1)
s.foreach(println)
(Empty.add(3).add(2).add(6).add(1)) foreach (x => print(x+1 + ", "))
// 2, 3, 4, 7,
// Because a BST is always sorted
val s2 = Empty.add(3).add(4).add(6).add(2)
s.union(s2)
s.intersect(s2)
s2.excl(0)
s2.excl(6)
s2.excl(2)
s2.excl(4)
val o1 = Empty + 3 + 4 + 12 + 5
val o2 = (o1 - 3 - 4)
o2 // ((-|5|-)|12|-)