Skip to content

K-means in Python 3 on Sentinel 2 data

  • blog

18 months ago I wrote about unsupervised classification of randomly extracted point data from satellite data. I have been meaning to follow it up with showing how straightforward it is to use the cluster algorithms in Sklearn to classify Sentinel 2 data. I have made this blog into a Juypter Notebook which is available here.

Set up

To run an unsupervised classification on satellite data using Python you need GDAL, Numpy and Sklearn. If you wish to see the data you will also need Matplotlib. Assuming you have the libraries installed, import them at the start.

import numpy as np
from sklearn import cluster
from osgeo import gdal, gdal_array
import matplotlib.pyplot as plt

# Tell GDAL to throw Python exceptions, and register all drivers
gdal.UseExceptions()
gdal.AllRegister()

Running a classification on a single band

Let’s start by looking at a single band of Sentinel 2. In this example I am looking at band 2, which is the blue band and has a spatial resolution of 10m. Sentinel 2 is currently a pair of multispectral satellites with a variety of spatial resolutions. You can find out more about Sentinel 2 here.

In this example I am using a DOS atmospherically corrected Sentinel 2a dataset over the south of England. To read it into gdal and then to convert to an array use the following code:

# Read in raster image
img_ds = gdal.Open('../S2_may_South_coast.tif', gdal.GA_ReadOnly)

band = img_ds.GetRasterBand(2)

img = band.ReadAsArray()

Remember I am working with band 2 (hence the second line). The variable img is now a numpy array. The shape of the array in this example is (519, 751). You can find this out by printing img.shape() to the command line. To use the classifier we need to reshape this array. For the classifier to accept the data we need to use the command .reshape(-1,1) and that means we want to flatten the data into the rows (unknown length) and keep the coloumns as 1. So our resulting array takes the form (rows, 1). In our case this is (389769, 1). The code to do this is shown below:

X = img.reshape((-1,1))

Now we can run the k-means classifier on our data. First of all we need to choose how many clusters; in this case I am selecting eight classes. Then I am fitting it to my data (X) which we defined above. Finally I am assigning a new variable to the results of this fitting called X_Cluster. After the labels have been assigned I then need to reshape this result back to the dimensions of my original image.

k_means = cluster.KMeans(n_clusters=8)
k_means.fit(X)

X_cluster = k_means.labels_
X_cluster = X_cluster.reshape(img.shape)

Finally to visualise this data use the final three lines of code:

plt.figure(figsize=(20,20))
plt.imshow(X_cluster, cmap="hsv")
plt.show()

This is setting the figsize, assigning the data to plot and setting a colour map; in this case ‘hsv’, but you can choose any from here. The resulting image looks like this:

Running a classification on all the bands

Sentinel 2 has 13 spectral bands, giving us another 12 inputs into our classifer. It would seem a waste to not utilise all of these. They do have a variety of spatial resolutions but numpy can handle that. Let’s reload our image and this time read all the bands into an array:

# Read in raster image
img_ds = gdal.Open('../S2_may_South_coast.tif', gdal.GA_ReadOnly)


img = np.zeros((img_ds.RasterYSize, img_ds.RasterXSize, img_ds.RasterCount),
               gdal_array.GDALTypeCodeToNumericTypeCode(img_ds.GetRasterBand(1).DataType))

for b in range(img.shape[2]):
    img[:, :, b] = img_ds.GetRasterBand(b + 1).ReadAsArray()

The first line is the same as before. The second line is the quickest way of loading a multi-band image into a numpy. This is done by intitalising an array of zeros (or it could be ones with the np,ones command). It takes parameters of the RasterYSize, RasterXSize and number of bands of the input satellite image and the data type. You can print these if unsure of the values.

Then I loop over the number of bands in the image (img.shape[2]), inserting the values into the numpy array. Next we need to reshape our array. Similar to before but we need to keep the columns as 13 this time (img.shape[2]). So, we use the following code:

new_shape = (img.shape[0] * img.shape[1], img.shape[2])

Again, if in doubt, print the variables to the command line. Then, based on this shape, we can build the input value X:

X = img[:, :, :13].reshape(new_shape)

Now we can reuse the clustering code to run the classifer as before. The only difference is the final line that reshapes the result to a single band image, as opposed to a 13 band image:

k_means = cluster.KMeans(n_clusters=8)
k_means.fit(X)

X_cluster = k_means.labels_
X_cluster = X_cluster.reshape(img[:, :, 0].shape)

Use the same code to plot the data as before, shown below:

plt.figure(figsize=(20,20))
plt.imshow(X_cluster, cmap="hsv")

plt.show()

The result looks like the image below:

This seems like an improved result compared to the result of the classification using the single band image. One of the really nice things about this code is how you can change the classifier. So, for example, if you wanted to use the Mini-Batch K-Means clustering algorithm once the data was loaded, you would only need to change one line (the first one shown below):

MB_KMeans = cluster.MiniBatchKMeans(n_clusters=8)
MB_KMeans.fit(X)

X_cluster = MB_KMeans.labels_


X_cluster = X_cluster.reshape(img[:, :, 0].shape)

The resulting image is shown here:

This is a much faster implementation of K-Means and has produced a comparable result, though perhaps with a little more noise.

Saving your classification

Finally let’s save the result as a geotiff. We do this by first opening the input image to get at its properties. We only need a single band as before (band 2). We convert it to an array of data (line 3) and then extract the columns and rows to a list (line 4).

Set the output as a geotiff (line 5 and 6).

Line 7 creates the output raster with the dimensions of the input raster; lines 8 and 9 set the projection and extent of the data.

The last step is to write the classification result to a single band raster image (line 10) before calling FlushCache() to remove from memory and delete the data.

ds = gdal.Open("../S2_may_South_coast.tif")
band = ds.GetRasterBand(2)
arr = band.ReadAsArray()
[cols, rows] = arr.shape

format = "GTiff"
driver = gdal.GetDriverByName(format)


outDataRaster = driver.Create(".../k_means.gtif", rows, cols, 1, gdal.GDT_Byte)
outDataRaster.SetGeoTransform(ds.GetGeoTransform())##sets same geotransform as input
outDataRaster.SetProjection(ds.GetProjection())##sets same projection as input


outDataRaster.GetRasterBand(1).WriteArray(X_cluster)

outDataRaster.FlushCache() ## remove from memory
del outDataRaster ## delete the data (not the actual geotiff)

That’s it. You should be able to load this classification raster into a GIS package now and compare directly with your input satellite image.

Recap

  • We have loaded a single and a multi-band raster to a numpy array
  • We have reshaped our data to fit the classifier
  • We have run the classifer and then changed the classifer and run again
  • We have displayed the results in matplotlib
  • We have saved the results as a geotiff to use in a GIS package of choice.

The code is available here as a jupyter notebook.

This code shows how to use an unsupervised classification on satellite data with Python. I teach how to adapt this code to build a supervised classifications on satellite imagery with Python. If you want to find out more please drop me a line or look on my website.

 


I am a freelancer able to help you with your projects. I offer consultancy, training and writing. I’d be delighted to hear from you. Please check out the books I have written on QGIS 3.4

https://www.packtpub.com/application-development/learn-qgis-fourth-edition

https://www.packtpub.com/application-development/qgis-quick-start-guide


 

I have grouped all my previous blogs (technical stuff / tutorials / opinions / ideas) at http://gis.acgeospatial.co.uk.

Feel free to connect or follow me; I am always keen to talk about Earth Observation.

I am @map_andrew on twitter