Strong Sort (strong deep learning for simple online and realtime tracking, 强化深度学习网络的简单在线与实时跟踪) 算子
Strong Sort (strong deep learning for simple online and realtime tracking, 强化深度学习网络的简单在线与实时跟踪) 使用深度学习 来唯一标识边界框以便通过图像流跟踪它们。
输入
- 图像: 高 x 宽 x BGR array。
- bbox: N_BBOX, X_MIN, X_MAX, Y_MIN, Y_MAX, CONDIDENCE, CLASS, array
输出
- obstacles_id: x1, x2, y1, y2 track_id, class_id, conf
示例绘制 (跟踪相应于蓝色 # id )
图描述
- id: yolov5
operator:
outputs:
- obstacles_id
inputs:
image: webcam/image
bbox: yolov5/bbox
python: ../../operators/strong_sort_op.py
图可视化
方法
__init__()
源码
def __init__(self):
model = StrongSORT(
"osnet_x0_25_msmt17.pt",
torch.device("cuda"),
False,
)
model.model.warmup()
self.model = model
self.frame = []
.on_event(...)
源码
def on_event(
self,
dora_event: dict,
send_output: Callable[[str, bytes], None],
) -> DoraStatus:
if dora_event["type"] == "INPUT":
return self.on_input(dora_event, send_output)
return DoraStatus.CONTINUE
.on_input(...)
源码
def on_input(
self,
dora_input: dict,
send_output: Callable[[str, bytes], None],
) -> DoraStatus:
if dora_input["id"] == "image":
frame = np.array(
dora_input["value"],
np.uint8,
).reshape((IMAGE_HEIGHT, IMAGE_WIDTH, 4))
self.frame = frame[:, :, :3]
elif dora_input["id"] == "obstacles_bbox" and len(self.frame) != 0:
obstacles = np.array(dora_input["value"]).reshape((-1, 6))
if obstacles.shape[0] == 0:
# self.model.increment_ages()
send_output(
"obstacles_id",
pa.array(np.array([]).ravel()),
dora_input["metadata"],
)
return DoraStatus.CONTINUE
# 后加工 yolov5
xywhs = xxyy2xywh(obstacles[:, 0:4])
confs = obstacles[:, 4]
clss = obstacles[:, 5]
with torch.no_grad():
outputs = np.array(
self.model.update(xywhs, confs, clss, self.frame)
).astype("int32")
if len(outputs) != 0:
outputs = outputs[
:, [0, 2, 1, 3, 4, 5, 6]
] # xyxy -> x1, x2, y1, y2 track_id, class_id, conf
send_output(
"obstacles_id",
pa.array(outputs.ravel()),
dora_input["metadata"],
)
return DoraStatus.CONTINUE