23 lines
766 B
Python
23 lines
766 B
Python
# Copyright (C) CVAT.ai Corporation
|
|
#
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
import numpy as np
|
|
import torch
|
|
from segment_anything import sam_model_registry, SamPredictor
|
|
|
|
class ModelHandler:
|
|
def __init__(self):
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
self.sam_checkpoint = "/opt/nuclio/sam/sam_vit_h_4b8939.pth"
|
|
self.model_type = "vit_h"
|
|
self.latest_image = None
|
|
sam_model = sam_model_registry[self.model_type](checkpoint=self.sam_checkpoint)
|
|
sam_model.to(device=self.device)
|
|
self.predictor = SamPredictor(sam_model)
|
|
|
|
def handle(self, image):
|
|
self.predictor.set_image(np.array(image))
|
|
features = self.predictor.get_image_embedding()
|
|
return features
|