Detectron2 fine tuning using Custom COCO dataset
Detectron is Facebook AI Research’s (FAIR) software system that implements state-of-the-art object detection algorithms, including Mask R-CNN. It is written in Python and powered by the Caffe2 deep learning framework.
We will be using google Colab Notebook so you don’t need to worry about setting up the development environment on your own machine
Installing Detectron2 and torch
!pip install pyyaml==5.1
!pip install torch==1.8.0+cu101 torchvision==0.9.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html
We’re installing old version of pytorch since detectron2 hasn’t released packages for pytorch 1.9
Now lets install detectron2
!pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu101/torch1.8/index.html
Once the detectron2 is installed we’ve to restart the runtime. Just click on the Restart prompt shown in the cell. It will automatically restart the runtime for us.
Now lets check if everything is installed properly.
!nvidia-smi
The output will display the GPU details and the CUDA version. We’re good to go now.
To train a detection model, we need images, labels and bounding box annotations. The COCO (Common Objects in Context) dataset is a popular choice and benchmark since it covers a variety of different objects in different settings. We can use the Label Studio to a COCO dataset. Setting up the Label studio and creating dataset can be covered in a different session. Once we have a COCO dataset we need to register it in order to use it.
Lets give the dataset a name and define the paths to json file and image folder path of the COCO dataset
dataset_name= name_of_the_dataset
json_path = path/to/result.json
image_folder_path = path/to/cocofolder
Now we can register the dataset
from detectron2.data.datasets import register_coco_instances
register_coco_instances(dataset_name, {}, json_path, image_folder_path)
We can verify the dataset by
MetadataCatalog.get(dataset_name)
This will display the parameters of the dataset and we can verify it.
The next step is to configure the main model file. We can choose the pretrained models from the detectron2 library. Here we are using ‘faster_rcnn_R_50_FPN_3x‘ model. It is a faster and lighter version suiting our requirement.
from detectron2.config import get_cfg
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml"))
cfg.DATASETS.TRAIN = (dataset_name,)
cfg.DATALOADER.NUM_WORKERS = 2
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml")
cfg.SOLVER.IMS_PER_BATCH = 2
cfg.SOLVER.BASE_LR = 0.00025
cfg.SOLVER.MAX_ITER = 1000
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 6 #number of tags in dataset
We can define the output folder path and once the model configuration is setup we can start training the custom model.
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
trainer = DefaultTrainer(cfg)
trainer.resume_or_load(resume=False)
trainer.train()
This may take upto 10 minutes depending upon the size of the dataset, iteration and the GPU. You can also manually save the model by
torch.save(trainer.model.state_dict(), path/to/directory/model_name.pth)
Once the training is finished we can check the new model predictions. Lets define the new model weights
cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_final.pth")
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
predictor = DefaultPredictor(cfg)
We will be using cv2 package to process the image
microcontroller_metadata = MetadataCatalog.get(dataset_name)
img = cv2.imread(filename)
outputs = predictor(img)
vis = Visualizer(img[:, :, ::-1],
metadata=microcontroller_metadata,
scale=0.8,
instance_mode=ColorMode.IMAGE_BW
)
We will be using Matplotlib for displaying the output
v_draw = vis.draw_instance_predictions(outputs["instances"].to("cpu"))
plt.figure(figsize = (14, 10))
plt.imshow(cv2.cvtColor(v_draw.get_image()[:, :, ::-1], cv2.COLOR_BGR2RGB))
plt.show()
The sample output will looks like this. Different color bounding boxes indicates different tags.
Link to full code in colab notebook
Thank you for reading.