OpenCV Module 12 : Phát hiện Đối tượng trong ảnh
Import thư viện cần dùng
import numpy as np import matplotlib.pyplot as plt import cv2 import os import urllib
Download Model files from Tensorflow model ZOO
Model files can be downloaded from the Tensorflow Object Detection Model Zoo https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf2_detection_zoo.md
The cell given below downloads a mobilenet model
1. Download mobilenet model file directly
The code below will run on Linux / MacOS systems. Please download the file http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v2_coco_2018_03_29.tar.gz
Uncompress it and put it in models folder.
ssd_mobilenet_v2_coco_2018_03_29
|─ checkpoint
|─ frozen_inference_graph.pb
|─ model.ckpt.data-00000-of-00001
|─ model.ckpt.index
|─ model.ckpt.meta
|─ pipeline.config
|─ saved_model
|─── saved_model.pb
|─── variables
Create constant filename
modelFile = "models/ssd_mobilenet_v2_coco_2018_03_29/frozen_inference_graph.pb" configFile = "models/ssd_mobilenet_v2_coco_2018_03_29.pbtxt" classFile = "coco_class_labels.txt"
2. Ngoài cách dowload trực tiếp như cách trên, bạn có thể sử dụng code dưới đây để dowload và giải nén file
# Kiểm tra xem có tồn tại folder models hay không, nếu có bỏ qua, nếu không thì tạo folder model if not os.path.isdir("models") : os.mkdir("models") # Kiểm tra xem có tồn tại file `modelFile` đã được định nghĩa trước đó hay không, nếu có bỏ qua, nếu không tiến hành tải file về if not os.path.isfile(modelFile) : # `os.chdir` dùng để thay đổi vị trí folder hiện tại os.chdir("models") # Tải file về với urllib urllib.request.urlretrieve("http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v2_coco_2018_03_29.tar.gz", "ssd_mobilenet_v2_coco_2018_03_29.tar.gz") # Giải nén zip !tar -xvf ssd_mobilenet_v2_coco_2018_03_29.tar.gz # Sau khi giải nén thì ta xóa luôn file zip os.remove("ssd_mobilenet_v2_coco_2018_03_29.tar.gz") os.chdir("..")
Tìm tên class labels trong coco_class_labels.txt
, Nếu bạn chưa có file coco_class_labels.txt
, có thể tải tại đây, kiểm tra các labels có trong file:
with open(classFile) as fp : labels = fp.read().split("\n") print(labels)
Nó sẽ xuất hiện một mảng gồm các labels :
['unlabeled', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'street sign', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'hat', 'backpack', 'umbrella', 'shoe', 'eye glasses', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'plate', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'mirror', 'dining table', 'window', 'desk', 'toilet', 'door', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'blender', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush', 'hair brush', '']
Tạo hàm phát hiện đối tượng
def detect_objects(img) : net = cv2.dnn.readNetFromTensorflow(modelFile, configFile) dim = 300 # Tạo một đốm màu từ hình ảnh blob = cv2.dnn.blobFromImage(img,1.0,size=(dim,dim),mean=(0,0,0),swapRB=True,crop=False) # Truyền đốm màu vào mạng net.setInput(blob) # Tiến hành dự đoán objects = net.forward() return objects
FONT_SCALE = 0.5 FONT_FACE = cv2.FONT_HERSHEY_SIMPLEX THICKNESS = 1 We will create `display_text` function in order to when detected object will be showed up the labels of ones def display_text(img, text, x, y) : textSize = cv2.getTextSize(text,FONT_FACE, FONT_FACE,THICKNESS) # textSize sẽ trả về một tuple có dạng ((x,y),baseline) dim, baseline = textSize # Sử dụng kích thước văn bản để tạo một hình chữ nhật màu đen cv2.rectangle(img, (x, y - dim[1] - baseline), (x+ dim[0], y + baseline), (0,0,0), cv2.FILLED) # Hiển thị văn bản bên trong hình chữ nhật đã được vẽ cv2.putText(img, text, (x, y-5), FONT_FACE, FONT_SCALE, (212, 129, 13), THICKNESS, cv2.LINE_AA)
Đây là hàm phát hiện đối tượng hiển thị, yêu cầu đưa bất kỳ hình ảnh nào cho tham số đầu tiên, bạn có thể đặt threshold thông số thứ hai như bạn muốn, không thì mặc định là 0.3 hoặc tùy ý bạn.
def display_objects(img, threshold = 0.3) : objects = detect_objects(img) # Due to image in open cv is width, height, channels, Rows, cols represents for vertical and horizontal rows = img.shape[0] cols = img.shape[1] for i in range(objects.shape[2]) : classId = int(objects[0,0,i,1]) score = float(objects[0,0,i,2]) x = int(objects[0,0,i,3] * cols) y = int(objects[0,0,i,4] * rows) w = int(objects[0,0,i,5] * cols -x) h = int(objects[0,0,i,6] * rows -y) if score > threshold : display_text(img, labels[classId], x,y) cv2.rectangle(img, (x,y), (x+w,y+h), (225,225,225), 2) img_rgb = img[:,:,::-1] plt.figure(figsize=(30,10)) plt.imshow(img_rgb)
Thử nghiệm hình ảnh với phát hiện đối tượng
messi = cv2.imread("DATA/messi-ball.jpg") display_objects(messi) plt.axis(False)
cantho = cv2.imread("DATA/cantho.jpg") display_objects(cantho) plt.axis(False)
Camera object detection ( tùy chọn )
Tạo hàm để xuất frame khi phát hiện đối tượng
def display_object_camera(frame, threshold=0.3) : objects = detect_objects(frame) rows = frame.shape[0] cols = frame.shape[1] for i in range(objects.shape[2]) : classId = int(objects[0,0,i,1]) score = float(objects[0,0,i,2]) x = int(objects[0,0,i,3] * cols) y = int(objects[0,0,i,4] * rows) w = int(objects[0,0,i,5] * cols - x ) h = int(objects[0,0,i,6] * rows - y ) if score > threshold : display_text(frame, labels[classId], x , y) cv2.rectangle(frame, (x,y), (x+w, y+h), (220,220,220), 2) return frame
Code để mở camera kết hợp object detection
video = cv2.VideoCapture(0) while video.isOpened() : ret, frame = video.read() frame = display_object_camera(frame) timer = cv2.getTickCount() FPS = cv2.getTickFrequency() / (cv2.getTickCount() - timer ) / 10000 cv2.putText(frame, f"FPS: {str(int(FPS))}",(50,50), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0,0,255),1, cv2.LINE_AA) cv2.imshow("Object Detection", frame) if cv2.waitKey(1) == 27 : break video.release() cv2.destroyAllWindows()