cvat/serverless/pytorch/mmpose/hrnet32/nuclio/main.py

75 lines
2.8 KiB
Python

# Copyright (C) CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
import json
import base64
import io
import yaml
import numpy as np
from PIL import Image
from mmpose.apis import MMPoseInferencer
def init_context(context):
context.logger.info("Init detector...")
det_config = "/opt/nuclio/mmpose/projects/rtmpose/rtmdet/person/rtmdet_nano_320-8xb32_coco-person.py"
det_checkpoint = "/opt/nuclio/rtmdet_nano_8xb32-100e_coco-obj365-person-05d8511e.pth"
pose_config = "/opt/nuclio/mmpose/configs/wholebody_2d_keypoint/topdown_heatmap/ubody2d/td-hm_hrnet-w32_8xb64-210e_ubody-256x192.py"
pose_checkpoint = "/opt/nuclio/td-hm_hrnet-w32_8xb64-210e_ubody-coco-256x192-7c227391_20230807.pth"
inferencer = MMPoseInferencer(
pose2d=pose_config,
pose2d_weights=pose_checkpoint,
det_model=det_config,
det_weights=det_checkpoint,
det_cat_ids=[0], # the category id of 'human' class
device='cpu'
)
context.logger.info("Init labels...")
with open("/opt/nuclio/function.yaml", "rb") as function_file:
functionconfig = yaml.safe_load(function_file)
labels_spec = functionconfig["metadata"]["annotations"]["spec"]
labels = json.loads(labels_spec)
context.user_data.labels = labels
context.user_data.inferencer = inferencer
context.logger.info("Function initialized")
def handler(context, event):
context.logger.info("Run mmpose ubody-2d model")
data = event.body
buf = io.BytesIO(base64.b64decode(data["image"]))
threshold = data.get('threshold', 0.55)
image = Image.open(buf).convert("RGB")
results = []
pred_instances = next(context.user_data.inferencer(np.array(image)[...,::-1]))["predictions"][0]
for pred_instance in pred_instances:
keypoints = pred_instance["keypoints"]
keypoint_scores = pred_instance["keypoint_scores"]
for label in context.user_data.labels:
skeleton = {
"confidence": str(pred_instance["bbox_score"]),
"label": label["name"],
"type": "skeleton",
"elements": [{
"label": element["name"],
"type": "points",
"outside": 0 if threshold < keypoint_scores[element["id"]] else 1,
"points": [
float(keypoints[element["id"]][0]),
float(keypoints[element["id"]][1])
],
"confidence": str(keypoint_scores[element["id"]]),
} for element in label["sublabels"]],
}
if not all([element['outside'] for element in skeleton["elements"]]):
results.append(skeleton)
return context.Response(body=json.dumps(results), headers={}, content_type="application/json", status_code=200)