package magic_squares abstract class MagicSquareSolver(initialGrid: Grid) extends Displayable { protected var _initial: Grid = initialGrid protected val _size: Int = _initial.length protected val SUM: Int = _size * (_size * _size + 1) / 2 protected val DEBUG: Boolean = false protected def print(grid: Grid): Unit = { for (y: Int <- 0 until _size) { println(grid(y).mkString(",")) } } protected def copy(grid: Grid): Grid = { return grid.map(_.clone()) } def solve(): Unit = { val sol: Option[Grid] = solveFrom(_initial, 0, 0) if (sol.isEmpty) { println("No solution") } else { print(sol.get) } } private def solveFrom(grid: Grid, x: Int, y: Int): Option[Grid] = { if (DEBUG) println(s"Solving from $x, $y") if (DEBUG) print(grid) display(grid) if (!isValid(grid)) { if (DEBUG) println(" Grid is invalid") return None } if (y >= _size) { if (DEBUG) println(" Found solution") return Some(grid) } var values: Array[Int] = Array(_initial(y)(x)) if (values(0) == 0) values = (1 to _size * _size).toArray if (DEBUG) println(s" Values to test: " + values.mkString("[", ", ", "]")) val newGrid: Grid = copy(grid) var x2: Int = x + 1 var y2: Int = y if (x2 >= _size) { x2 -= _size y2 += 1 } for (i: Int <- values) { if (DEBUG) println(s" Testing $i") newGrid(y)(x) = i val sol: Option[Grid] = solveFrom(newGrid, x2, y2) if (sol.isDefined) { if (DEBUG) println(" Found solution, collapsing call stack") return sol } } if (DEBUG) println(s" No solution for this configuration") return None } private def isValid(grid: Grid): Boolean = { val values: Array[Int] = grid.reduce((a, b) => a.concat(b)).filter(_ != 0) if (values.distinct.length != values.length) return false val diag1: Array[Int] = new Array(_size) val diag2: Array[Int] = new Array(_size) for (i: Int <- 0 until _size) { val row: Array[Int] = grid(i) val col: Array[Int] = grid.map(_(i)) diag1(i) = row(i) diag2(i) = col(_size - i - 1) if (!isLineValid(row)) { if (DEBUG) println(s" -> row $i is invalid: " + row.mkString("[", ", ", "]")) return false } if (!isLineValid(col)) { if (DEBUG) println(s" -> column $i is invalid: " + col.mkString("[", ", ", "]")) return false } } if (!isLineValid(diag1)) { if (DEBUG) println(s" -> diag1 is invalid: " + diag1.mkString("[", ", ", "]")) return false } if (!isLineValid(diag2)) { if (DEBUG) println(s" -> diag2 is invalid: " + diag2.mkString("[", ", ", "]")) return false } return true } private def isLineValid(line: Array[Int]): Boolean = { val sum: Int = line.sum if (line.contains(0)) { if (sum > SUM) return false } else if (sum != SUM) return false return true } } object MagicSquareSolver { def main(args: Array[String]): Unit = { val solver1: MagicSquareSolver = new MagicSquareSolver(Array( Array(8, 0, 0), Array(0, 0, 7), Array(0, 9, 0) )) with TextDisplay solver1.solve() println() /* 8,1,6 3,5,7 4,9,2 */ val solver2: MagicSquareSolver = new MagicSquareSolver(Array( Array(9, 0, 0), Array(0, 0, 7), Array(0, 8, 0) )) with TextDisplay solver2.solve() println() /* no solution */ val solver3: MagicSquareSolver = new MagicSquareSolver(Array( Array(0, 0, 2), Array(0, 5, 7), Array(0, 0, 0) )) with GraphicsDisplay solver3.solve() /* 4,9,2 3,5,7 8,1,6 */ val solver4: MagicSquareSolver = new MagicSquareSolver(Array( Array(11, 0, 0, 20, 3), Array(4, 12, 0, 0, 16), Array(0, 5, 13, 21, 9), Array(10, 0, 0, 14, 22), Array(23, 6, 19, 2, 0), )) with GraphicsDisplay solver4.solve() } }