cvat/serverless/pytorch/facebookresearch/sam/nuclio/model_handler.py

23 lines
766 B
Python

# Copyright (C) CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
import numpy as np
import torch
from segment_anything import sam_model_registry, SamPredictor
class ModelHandler:
def __init__(self):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.sam_checkpoint = "/opt/nuclio/sam/sam_vit_h_4b8939.pth"
self.model_type = "vit_h"
self.latest_image = None
sam_model = sam_model_registry[self.model_type](checkpoint=self.sam_checkpoint)
sam_model.to(device=self.device)
self.predictor = SamPredictor(sam_model)
def handle(self, image):
self.predictor.set_image(np.array(image))
features = self.predictor.get_image_embedding()
return features