扫码阅读
手机扫码阅读

【2022ModelBox客流分析活动打卡】客流未佩戴口罩识别作业3

258 2023-07-17

题目描述

课程4在判断是否佩戴口罩,只使用了过线那一帧的人脸做口罩判断,但是这种方式不够鲁棒,可能会有漏判或错判。比如过线那一刻没有检测到人脸,或者那一帧人脸质量不高导致判断口罩出错。

例如下图中课程4的测试视频,最后一位过线的人佩戴有口罩,应该跟前面三位一样标记蓝框,但因为过线那一刻未检测到人脸,导致口罩佩戴信息未知从而标记了黄框。

请修改口罩佩戴判断的逻辑,减少漏判和错判(比如可以改成使用过线后的多帧综合判断),贴出改进后的核心代码和最终视频效果。(400分)

完成要求:需要贴出改进后的核心代码和程序运行效果截图。

核心代码
object_tracker.py 文件
workspace\passenger_flow_mask_det\etc\flowunit\object_tracker\object_tracker.py

# Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import _flowunit as modelbox
import json
from easy_tracker import EasyTracker, match


class object_trackerFlowUnit(modelbox.FlowUnit):
def __init__(self):
super().__init__()

def open(self, config):
# 获取功能单元的配置参数
self.init_hits = config.get_int("init_hits", 3)
self.max_age = config.get_int("max_age", 5)
self.min_iou = config.get_float("min_iou", 0.5)
self.head_label = config.get_int('head_label')
self.face_label = config.get_int('face_label')
self.face_cover_ratio = config.get_float('face_cover_ratio')
self.line = config.get_float_list('line', [])

self.index = 0 # frame计数
self.flow_count = 0 # 客流计数
return modelbox.Status.StatusCode.STATUS_SUCCESS

def process(self, data_context):
# 从DataContext中获取输入输出BufferList对象
in_bbox = data_context.input("in_bbox")
out_track = data_context.output("out_track")

# 循环处理每一个输入Buffer数据
for buffer_bbox in in_bbox:
# 将输入Buffer转换为Python对象
bbox_str = buffer_bbox.as_object()
head_bboxes, face_bboxes = self.decode_bboxes(bbox_str)

# 业务处理:对图中的人使用头肩部检测框进行跟踪
self.tracker.update(head_bboxes, self.line[1])

# 获取所有的跟踪目标,转化为json数据
tracking_objects, face_info = self.get_tracking_objects(self.line[1], face_bboxes)
track_info = {'passenger_flow': self.flow_count,
'tracking_objects': json.dumps(tracking_objects)}
if face_info:
track_info['face_info'] = json.dumps(face_info)

modelbox.debug(f'track_info for {self.index}-th image is {track_info}')
self.index += 1

# 将业务处理返回的结果数据转换为Buffer
out_buffer = modelbox.Buffer(self.get_bind_device(), json.dumps(track_info))

# 将输出Buffer放入输出BufferList中
out_track.push_back(out_buffer)

# 返回成功标志,ModelBox框架会将数据发送到后续的功能单元
return modelbox.Status.StatusCode.STATUS_SUCCESS

def decode_bboxes(self, bbox_str):
'''从json数据中解码出检测框'''
try:
det_result = json.loads(bbox_str)['det_result']
if det_result == "None":
return [], []
bboxes = json.loads(det_result)
head_bboxes = list(filter(lambda x: int(x[5]) == self.head_label, bboxes))
face_bboxes = list(filter(lambda x: int(x[5]) == self.face_label, bboxes))
except Exception as ex:
modelbox.error(str(ex))
return [], []
else:
return head_bboxes, face_bboxes

def get_tracking_objects(self, line_y, face_bboxes):
'''从跟踪器中获取跟踪目标,保存到结构化数据中'''
def _is_pass_line(bbox, line_y):
'''根据检测框的中心点与线段的水平位置关系判断是否过线'''
center_y = (bbox[1] + bbox[3]) / 2
return center_y > line_y

track_bboxes = [track.det for track in self.tracker.tracks]
matches, _, _ = match(face_bboxes, track_bboxes, self.face_cover_ratio, True)

tracking_objects = [] # 所有跟踪目标
face_info = {} # 刚刚过线的人头目标对应的人脸框,track_id -> face_bbox
for ix, track in enumerate(self.tracker.tracks):
# 只记录CONFIRMED状态的跟踪目标
if track.state != EasyTracker.TrackingState.CONFIRMED:
continue
tracking_obj = {} # 使用字典保存跟踪目标
tracking_obj["id"] = track.track_id # 跟踪id
tracking_obj["bbox"] = track.det # 跟踪框
if not track.passline and _is_pass_line(track.det, line_y): # 刚好过线
track.passline = True
self.flow_count += 1
# if ix in matches: # 找到了匹配的人脸框
# face_info[track.track_id] = face_bboxes[matches[ix]]
if track.passline and ix in matches: #已过线且有配的人脸框
face_info[track.track_id]= face_bboxes[matches[ix]]
tracking_obj["passline"] = track.passline # 记录过线信息
tracking_objects.append(tracking_obj)
return tracking_objects, face_info

def close(self):
return modelbox.Status()

def data_pre(self, data_context):
# 视频流开始前的初始化动作,跟踪器的初始化放在此处
self.tracker = EasyTracker(self.init_hits, self.max_age, self.min_iou)
return modelbox.Status()

def data_post(self, data_context):
# After streaming data ends
return modelbox.Status()

draw_passenger_bbox.py 文件
workspace\passenger_flow_mask_det\etc\flowunit\draw_passenger_bbox\draw_passenger_bbox.py

# Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import _flowunit as modelbox
import cv2
import json
import numpy as np
from PIL import Image, ImageDraw, ImageFont


class draw_passenger_bboxFlowUnit(modelbox.FlowUnit):
def __init__(self):
super().__init__()

def open(self, config):
# 获取功能单元的配置参数
line = config.get_float_list('line', [])
resolution = config.get_int_list('resolution', [])
self.line = [int(line[ix] * resolution[ix % 2]) for ix in range(4)]
self.font_path = config.get_string("font_path")
self.mask_info = {}
return modelbox.Status.StatusCode.STATUS_SUCCESS

def process(self, data_context):
# 从DataContext中获取输入输出BufferList对象
in_image = data_context.input("in_image")
out_image = data_context.output("out_image")

# 循环处理每一个输入Buffer数据
for buffer_img in in_image:
# 输入图像Buffer获取宽、高、通道数等属性信息
width = buffer_img.get('width')
height = buffer_img.get('height')
channel = buffer_img.get('channel')

# 将输入Buffer转换为Python对象
img_data = np.array(buffer_img.as_object(), copy=False)
img_data = img_data.reshape((height, width, channel))

track_json = buffer_img.get('track_info')

# 业务处理:从前一个功能单元的json数据中获取检测框,画在图像上
flow_count, tracking_objects = self.decode_bboxes(
track_json, (height, width))
img_out = self.draw_tracking_object(
img_data, flow_count, tracking_objects)

# 将画框后的图像转换为Buffer
out_buffer = modelbox.Buffer(self.get_bind_device(), img_out)

# 设置输出Buffer的Meta信息,此处直接拷贝输入Buffer的Meta信息
out_buffer.copy_meta(buffer_img)

# 将输出Buffer放入输出BufferList中
out_image.push_back(out_buffer)

# 返回成功标志,ModelBox框架会将数据发送到后续的功能单元
return modelbox.Status.StatusCode.STATUS_SUCCESS

def decode_bboxes(self, track_json, input_shape):
'''从json数据中解码出检测框'''
try:
track_result = json.loads(track_json)
flow_count = track_result['passenger_flow']

face_info = {}
if 'face_info' in track_result:
face_info = json.loads(track_result['face_info'])

tracking_objects = []
tracking_list = json.loads(track_result['tracking_objects'])
for track_dict in tracking_list:
tracking_obj = {}
id = track_dict["id"]
tracking_obj["id"] = id
x1 = int(track_dict["bbox"][0] * input_shape[1])
y1 = int(track_dict["bbox"][1] * input_shape[0])
x2 = int(track_dict["bbox"][2] * input_shape[1])
y2 = int(track_dict["bbox"][3] * input_shape[0])
tracking_obj["bbox"] = [x1, y1, x2, y2]
tracking_obj["passline"] = track_dict["passline"]
if str(id) in face_info:
mask_info = face_info[str(track_dict["id"])]
if mask_info == 0: # 类别o表示戴口罩,将表示口罩佩戴信息的计数+1
self.mask_info[id] = self.mask_info[id] + 1 if id in self.mask_info else 1
else: # 如果未戴口罩,将表示口罩佩戴信息的计教-1
self.mask_info[id] = self.mask_info[id] - 1 if id in self.mask_info else -1
tracking_objects.append(tracking_obj)
except Exception as ex:
modelbox.error(str(ex))
return None, None
else:
return flow_count, tracking_objects

def draw_tracking_object(self, img_data, flow_count, tracking_objects):
'''在图中画出跟踪对象的检测框和过线的行人数据'''

GRAY = (117, 117, 117)
RED = (255, 0, 0)
BLUE = (0, 0, 255)
YELLO = (255, 255, 0)
thickness = 2
no_mask_count = 0
for track in tracking_objects:
color = GRAY
if track["passline"]:
color = YELLO
if track["id"] in self.mask_info:
if self.mask_info[track["id"]] 0: # 未戴口罩
no_mask_count += 1
color = RED
else:
color = BLUE
cv2.rectangle(img_data, (track["bbox"][0], track["bbox"][1]),
(track["bbox"][2], track["bbox"][3]), color, thickness)
cv2.line(img_data, (self.line[0], self.line[1]),
(self.line[2], self.line[3]), YELLO, 5)
img_data = self.put_chi_text(
img_data, 'JaneConan の 作业4 当前客流计数:%d' % flow_count, (50, 20), BLUE, 50)
if no_mask_count > 0:
img_data = self.put_chi_text(
img_data, '【警告】%d人未戴口罩' % no_mask_count, (900, 20), RED, 50)

return img_data

def put_chi_text(self, img, text, location, color=(0, 255, 0), size=50):
'''在图片中写汉字'''
if (isinstance(img, np.ndarray)): # 判断是否OpenCV图片类型
img = Image.fromarray(img)

draw = ImageDraw.Draw(img) # 创建一个可以在给定图像上绘图的对象
font_style = ImageFont.truetype(
self.font_path, size, encoding="utf-8") # 字体的格式
draw.text(location, text, color, font=font_style) # 绘制文本
return np.asarray(img) # 转换回OpenCV格式

def close(self):
return modelbox.Status()
原文链接: https://mp.weixin.qq.com/s?__biz=MzI0OTE5NzQxNw==&mid=2247485569&idx=1&sn=f59b5829e0e0ddb8ffaf4c8611851ac0