Although solving printed sudoku puzzle via a vision-based application by itself is not a valuable thing and may be seen as an easy problem, you will gain a lot of knowledge through solving this problem and it could take a time to produce a robust solution. In this post, I will explain how to solve a printed sudoku solver using a real-time feed from a camera, such that If the captured image contains a sudoku puzzle, the puzzle will be solved and the result will be printed on the clear cells.
Figure 1 shows the major steps of the algorithm. As shown in the figure, captured image passes through some preprocessing steps described in figure 2 and then the sudoku grid will be localized assuming that it is the largest rectangle in the image, then hough line transform is applied to localize vertical and horizontal lines which form the sudoku grid, then K-means clustering algorithm is applied to group the lines to 10 vertical lines and 10 horizontal lines, then the intersection points between these lines are computed, from these points, cells of the grid are localized. After that for each cell, we extract the largest blob constrained by a threshold, assuming that the largest blob is the digit. These digits are recognized by a convolution neural network. Finally, we form a matrix of sudoku puzzle and feed it to a sudoku solving algorithm.
The project contains five classes, they are listed in the next table with their function.
Class Name | Its Function |
SudokuSolver | Main class of the project |
NetworkTrainer | Train the model using the dataset (Mnist + Generated) |
Sudoku | Solve the sudoku puzzle, it returns the solution, given an array of the sudoku. |
GenerateDataset | Generate printed (not handwritten) digit dataset using Operating System fonts. |
LineComparator | Helper class to compare and sort lines. |
The model used in this project is based on Lenet CNN architecture for handwritten digit recognition, you can see the model class from NetworkTrainer.java. I tried to recognize sudoku digits with Mnist dataset but It gave me bad results, so I combined mnist dataset with a dataset I generated using different fonts included in OS (Mac, Windows, or Linux), you can generate this printed digit dataset with class GenerateDataset.java. Complete dataset (mnist with the generated one) is included in the project repository Here.
private static MultiLayerNetwork loadNetwork() {
MultiLayerNetwork network = null;
try {
String pathtoexe = System.getProperty("user.dir");
File net = new File(pathtoexe, "cnn-model.data");
network = ModelSerializer.restoreMultiLayerNetwork(net);
} catch (IOException ex) {
log.error("Error While Loading Pretrained Network: " + ex.getMessage());
}
return network;
}
I put the capture object in a AtomicReference in order to be able to stop and start cam streaming in a thread-safe way.
final AtomicReference capture = new AtomicReference<>(new VideoCapture());
capture.get().set(CV_CAP_PROP_FRAME_WIDTH, 1280);
capture.get().set(CV_CAP_PROP_FRAME_HEIGHT, 720);
if (!capture.get().open(0)) {
log.error("Can not open the cam !!!");
}
.
.
.
while (true) {
while (start.get() && capture.get().read(colorimg)) {
if (mainframe.isVisible()) {
These steps include converting to grayscale, applying gaussian filter, and binarizing image via adaptive thresholding.
/*Convert to grayscale mode*/
Mat sourceGrey = new Mat(colorimg.size(), CV_8UC1);
cvtColor(colorimg, sourceGrey, COLOR_BGR2GRAY);
//imwrite("gray.jpg", new Mat(image)); // Save gray version of image
/*Apply Gaussian Filter*/
Mat blurimg = new Mat(colorimg.size(), CV_8UC1);
GaussianBlur(sourceGrey, blurimg, new Size(5, 5), 0);
//imwrite("blur.jpg", binimg);
/*Binarising Image*/
Mat binimg = new Mat(colorimg.size());
adaptiveThreshold(blurimg, binimg, 255, ADAPTIVE_THRESH_GAUSSIAN_C, THRESH_BINARY_INV, 19, 3);
//imwrite("binarise.jpg", binimg);
In his step, we extract the largest Blob (Rectangle) assuming that it is the sudoku grid puzzle.
private static Rect getLargestRect(Mat img) {
MatVector countours = new MatVector();
List rects = new ArrayList<>();
List araes = new ArrayList<>();
findContours(img, countours, CV_RETR_TREE, CV_CHAIN_APPROX_SIMPLE, new Point(0, 0));
for (int i = 0; i < countours.size(); i++) {
Mat c = countours.get(i);
double area = contourArea(c);
Rect boundingRect = boundingRect(c);
araes.add(area);
rects.add(boundingRect);
}
if (araes.isEmpty() || Collections.max(araes) < 4000) {
return new Rect(0, 0, img.cols(), img.rows());
} else {
Double d = Collections.max(araes);
return rects.get(araes.indexOf(d));
}
}
This step includes deskewing extracted rectangle by computing the rotation angle of the rectangle and correcting it, this is done inside the deskewImage method, and then apply warp perspective using the four corners of the rectangle.
private static Mat warpPrespectivePuzzle(Mat image) {
image = deskewImage(image);
Rect rect = getLargestRect(image);
Point2f srcPts = new Point2f(4);
srcPts.position(0).x((float) rect.x()).y((float) rect.y());
srcPts.position(1).x((float) rect.x() + rect.width()).y((float) rect.y());
srcPts.position(2).x((float) rect.x() + rect.width()).y((float) rect.y() + rect.height());
srcPts.position(3).x((float) rect.x()).y((float) rect.y() + rect.height());
Point2f dstPts = new Point2f(4);
dstPts.position(0).x(0).y(0);
dstPts.position(1).x(600 - 2).y(0);
dstPts.position(2).x(600 - 2).y(600 - 2);
dstPts.position(3).x(0).y(600 - 2);
Mat p = getPerspectiveTransform(srcPts.position(0), dstPts.position(0));
Mat img = new Mat(new Size(600, 600), image.type());//image.size()
warpPerspective(image, img, p, img.size());
return img;
}
private static Mat deskewImage(Mat img) {
MatVector countours = new MatVector();
List araes = new ArrayList<>();
findContours(img, countours, CV_RETR_TREE, CV_CHAIN_APPROX_SIMPLE, new Point(0, 0));
for (int i = 0; i < countours.size(); i++) {
Mat c = countours.get(i);
double area = contourArea(c);
araes.add(area);
}
if (araes.isEmpty()) {
return img;
} else {
Double d = Collections.max(araes);
RotatedRect minAreaRect = minAreaRect(countours.get(araes.indexOf(d)));
float angle = minAreaRect.angle();
if (angle < -45) {
angle = -(90 + angle);
} else {
angle = -angle;
}
Mat rot = getRotationMatrix2D(minAreaRect.center(), angle, 1);
Mat dst = new Mat(img.size(), img.type());
warpAffine(img, dst, rot, dst.size(), WARP_INVERSE_MAP | INTER_LINEAR, BORDER_CONSTANT, new Scalar(0, 0, 0, 0));
return dst;
}
}
Mat canimg = new Mat(procimg.size());
Canny(procimg, canimg, 30, 90);
In this step, we apply hough line transform to the output image of the canny edge detector step, it returns an array (lines -in code-) of "rho" which is the perpendicular distance from origin to the line, and "theta" which is the angle formed by this perpendicular line and horizontal axis measured in counter-clockwise.
Mat lines = new Mat();//vector stores the parameters (rho,theta) of the detected lines
HoughLines(canimg, lines, 1, CV_PI / 180, 100);
In this step, we cluster output lines from the hough line transform using K-means into 10 vertical lines and 10 horizontal lines to remove redundant lines. The inputs to K-means are pairs of rhos and thetas of detected lines.
FloatRawIndexer srcIndexer = lines.createIndexer();
/*Horizontal lines and one for vertical lines*/
List hpoints = new ArrayList<>();
List vpoints = new ArrayList<>();
for (int i = 0; i < srcIndexer.rows(); i++) {
float[] data = new float[2]; //data[0] is rho and data[1] is theta
srcIndexer.get(0, i, data);
double d[] = {data[0], data[1]};
if (Math.sin(data[1]) > 0.8) {//horizontal lines have a sin value equals 1, I just considered >.8 is horizontal line.
hpoints.add(new org.deeplearning4j.clustering.cluster.Point("hrho" + Math.sin(data[1]), "hrho", d));
} else if (Math.cos(data[1]) > 0.8) {//vertical lines have a cos value equals 1,
vpoints.add(new org.deeplearning4j.clustering.cluster.Point("vrho" + Math.cos(data[1]), "vrho", d));
}
}
/*Cluster vertical and horizontal lines into 10 lines for each using k-means with 10 iterations*/
KMeansClustering kmeans = KMeansClustering.setup(10, 10, "euclidean");
log.info("Lines Number " + vpoints.size() + " " + hpoints.size());
if (vpoints.size() >= 10 && hpoints.size() >= 10) {
ClusterSet hcs = kmeans.applyTo(hpoints);
List hlines = hcs.getClusters();
Collections.sort(hlines, new LinesComparator());
ClusterSet vcs = kmeans.applyTo(vpoints);
List vlines = vcs.getClusters();
Collections.sort(vlines, new LinesComparator());
private static boolean checkLines(List vlines, List hlines) {
final int diff = 40;//this may vary if you change the image width and hieght in method warpPrespectivePuzzle (600)
if (!(vlines.size() == 10 && hlines.size() == 10)) {
return false;
}
for (int i = 0; i < hlines.size() - 1; i++) {
Cluster get = hlines.get(i);
double r1 = get.getCenter().getArray().getDouble(0);
Cluster get1 = hlines.get(i + 1);
double r2 = get1.getCenter().getArray().getDouble(0);
if (Math.abs(r1 - r2) < diff) {
return false;
}
}
for (int i = 0; i < vlines.size() - 1; i++) {
Cluster get = vlines.get(i);
double r1 = get.getCenter().getArray().getDouble(0);
Cluster get1 = vlines.get(i + 1);
double r2 = get1.getCenter().getArray().getDouble(0);
if (Math.abs(r1 - r2) < diff) {
return false;
}
}
return true;
}
private static List getPoint(List vlines, List hlines) {
List points = new ArrayList();
for (int i = 0; i < hlines.size(); i++) {
Cluster get = hlines.get(i);
double r1 = get.getCenter().getArray().getDouble(0);
double t1 = get.getCenter().getArray().getDouble(1);
for (int j = 0; j < vlines.size(); j++) {
Cluster get1 = vlines.get(j);
double r2 = get1.getCenter().getArray().getDouble(0);
double t2 = get1.getCenter().getArray().getDouble(1);
Point o = parametricIntersect(r1, t1, r2, t2);
if (o.y() != -1 & o.x() != -1) {
points.add(o);
}
}
}
for (int i = 0; i < points.size() - 1; i++) {
Point get = points.get(i);
Point get1 = points.get(i + 1);
if (getDistance(get, get1) < 20) {
points.remove(get);
}
}
//System.out.println("Points Size" + points.size());
return points;
}
/*get intersection points between two lines given their rhoes and thetas*/
private static Point parametricIntersect(Double r1, Double t1, Double r2, Double t2) {
double ct1 = Math.cos(t1); //matrix element a
double st1 = Math.sin(t1); //b
double ct2 = Math.cos(t2); //c
double st2 = Math.sin(t2); //d
double d = ct1 * st2 - st1 * ct2;//determinative (rearranged matrix for inverse)
if (d != 0.0f) {
int x = (int) ((st2 * r1 - st1 * r2) / d);
int y = (int) ((-ct2 * r1 + ct1 * r2) / d);
return new Point(x, y);
} else { //lines are parallel and will NEVER intersect!
return new Point(-1, -1);
}
}
static double getDistance(Point p1, Point p2) {
return Math.sqrt(Math.pow((p1.x() - p2.x()), 2) + Math.pow((p1.y() - p2.y()), 2));
}
private static Mat detectDigit(Mat img) {
Mat res = new Mat();
MatVector countours = new MatVector();
List rects = new ArrayList<>();
List araes = new ArrayList<>();
bitwise_not(img, img);
findContours(img, countours, opencv_imgproc.CV_RETR_TREE, CV_CHAIN_APPROX_SIMPLE, new Point(0, 0));
for (int i = 0; i < countours.size(); i++) {
Mat c = countours.get(i);
Rect boundbox = boundingRect(c);
if (boundbox.height() > 20 && boundbox.height() < 50 && boundbox.width() > 15 && boundbox.width() < 40) {
double aspectRatio = boundbox.height() / boundbox.width();
//System.out.println("Aspect ratio " + aspectRatio);
if (aspectRatio >= 1 && aspectRatio < 3) {
rects.add(boundbox);
double area = contourArea(c);
araes.add(area);
}
}
}
if (!araes.isEmpty()) {
bitwise_not(img, img);
Double d = Collections.max(araes);
res = img.apply(rects.get(araes.indexOf(d)));
copyMakeBorder(res, res, 10, 10, 10, 10, BORDER_CONSTANT, new Scalar(255, 255, 255, 255));
resize(res, res, new Size(28, 28));
return res;
} else {
return img;//org.bytedeco.javacpp.helper.AbstractMat.EMPTY
}
}
/*Recognise digit given its image*/
private static int recogniseDigit(Mat digit) {
int idx = 0;
try {
NativeImageLoader loader = new NativeImageLoader(28, 28, 1);
bitwise_not(digit, digit);//to make the digit white and the background black
INDArray dig = loader.asMatrix(digit);
INDArray flaten = dig.reshape(new int[]{1, 784});
flaten = flaten.div(255);
INDArray output = NETWORK.output(flaten);
idx = Nd4j.getExecutioner().execAndReturn(new IAMax(output)).getFinalResult();
//imwrite("di/" + i + ".jpg", digit);
digit.release();
} catch (IOException ex) {
log.error(ex.getMessage());
}
return DIGITS[idx];
}
in this step, we form the sudoku matrix from the recognized digits, for the empty cells we just put zero, after that, we send it to solving algorithm in the class Sudoku.java. I depended on a code from here and here.
In this step, we print the result to the image, the blue digits are the solution, and the red digits are the recognized digits of the sudoku puzzle.
private static void printResult(Mat img, INDArray result, INDArray puzzle, List rects) {
for (int i = 0; i < rects.size(); i++) {
Rect rect = rects.get(i);
int x = rect.x();
int y = rect.y();
int d = (int) result.getDouble(i / 9, i % 9);
int d1 = (int) puzzle.getDouble(i / 9, i % 9);
if (d != d1) {//Print Solution
putText(img, d + "", new Point(x + 20, y + 50),
FONT_HERSHEY_COMPLEX, 1.3, new Scalar(255, 0, 0, 0), 3, 2, false);
} else {//Print Recognised Puzzle
putText(img, d + "", new Point(x + 10, y + 40),
FONT_HERSHEY_COMPLEX, 1, new Scalar(0, 0, 255, 0), 2, 2, false);
}
}
}