Detecting Buildings in Satellite Images
Building footprint detection with fastai on the challenging SpaceNet7 dataset.
- Introduction
- Setup
- Preprocessing
- Data Loading Functions
- Visualizing the Data
- Validation Strategy
- Undersampling
- Creating Dataloaders
- Defining the Model
- Training
- Visualizing the Results
- Discussion
Introduction
In this notebook I implement a neural network based solution for building footprint detection on the SpaceNet7 dataset. I ignore the temporal aspect of the orginal challenge and focus on performing segmentation to detect buildings on single images. I use fastai, a deep learning library based on PyTorch. It provides functionality to train neural networks with modern best practices while reducing the amount of boilerplate code required.
The dataset is stored on AWS. Instructions how to install are here.
Defining training parameters:
BATCH_SIZE = 12 # 3 for xresnet50, 12 for xresnet34 with Tesla P100 (16GB)
TILES_PER_SCENE = 16
ARCHITECTURE = xresnet34
EPOCHS = 40
CLASS_WEIGHTS = [0.25,0.75]
LR_MAX = 3e-4
ENCODER_FACTOR = 10
CODES = ['Land','Building']
Exploring dataset structure, display sample scene directories:
scenes = path.ls().sorted()
print(f'Numer of scenes: {len(scenes)}')
pprint(list(scenes)[:5])
Which folders are in each scene (the last three have been added later during processing)
sample_scene = (path/'L15-0683E-1006N_2732_4164_13')
pprint(list(sample_scene.ls()))
How many images are in a specific scene:
images_masked = (sample_scene/'images_masked').ls().sorted()
labels = (sample_scene/'labels_match').ls().sorted()
print(f'Numer of images in scene: {len(images_masked)}')
pprint(list(images_masked[:5]))
There are 60 scenes of 4km x 4km in the dataset, each containing about 24 images over the span of two years.
Let's pick one example image and its polygons:
image, shapes = images_masked[0], labels[0]
We use the images that have UDM masks where clouds were in the original picture:
show_image(PILImage.create(image), figsize=(12,12));
This is a function to generate binary mask images from geojson vector files. Source
import rasterio
from rasterio.plot import reshape_as_image
import rasterio.mask
from rasterio.features import rasterize
import pandas as pd
import geopandas as gpd
from shapely.geometry import mapping, Point, Polygon
from shapely.ops import cascaded_union
# SOURCE: https://lpsmlgeo.github.io/2019-09-22-binary_mask/
def generate_mask(raster_path, shape_path, output_path=None, file_name=None):
"""Function that generates a binary mask from a vector file (shp or geojson)
raster_path = path to the .tif;
shape_path = path to the shapefile or GeoJson.
output_path = Path to save the binary mask.
file_name = Name of the file.
"""
#load raster
with rasterio.open(raster_path, "r") as src:
raster_img = src.read()
raster_meta = src.meta
#load o shapefile ou GeoJson
train_df = gpd.read_file(shape_path)
#Verify crs
if train_df.crs != src.crs:
print(" Raster crs : {}, Vector crs : {}.\n Convert vector and raster to the same CRS.".format(src.crs,train_df.crs))
#Function that generates the mask
def poly_from_utm(polygon, transform):
poly_pts = []
poly = cascaded_union(polygon)
for i in np.array(poly.exterior.coords):
poly_pts.append(~transform * tuple(i))
new_poly = Polygon(poly_pts)
return new_poly
poly_shp = []
im_size = (src.meta['height'], src.meta['width'])
for num, row in train_df.iterrows():
if row['geometry'].geom_type == 'Polygon':
poly = poly_from_utm(row['geometry'], src.meta['transform'])
poly_shp.append(poly)
else:
for p in row['geometry']:
poly = poly_from_utm(p, src.meta['transform'])
poly_shp.append(poly)
#set_trace()
if len(poly_shp) > 0:
mask = rasterize(shapes=poly_shp,
out_shape=im_size)
else:
mask = np.zeros(im_size)
# Save or show mask
mask = mask.astype("uint8")
bin_mask_meta = src.meta.copy()
bin_mask_meta.update({'count': 1})
if (output_path != None and file_name != None):
os.chdir(output_path)
with rasterio.open(file_name, 'w', **bin_mask_meta) as dst:
dst.write(mask * 255, 1)
else:
return mask
Show a mask:
mask = generate_mask(image, shapes)
plt.figure(figsize=(12,12))
plt.tight_layout()
plt.xticks([])
plt.yticks([])
plt.imshow(mask,cmap='cividis');