import os import time import json import requests import traceback import cv2 import pika import boto3 import paramiko import torch import numpy as np import matplotlib.pyplot as plt from segment_anything import sam_model_registry, SamAutomaticMaskGenerator ############################################### # Config # ############################################### with open('./config/inference_config.json', 'r') as f: info = json.load(f) def send_message_to_slack(): data = {"text": info['slack_message']} req = requests.post( url=info['slack_url'], data=json.dumps(data) ) def image_upload_to_s3(path): try: s3 = boto3.client('s3') folder_name = path.split('/')[-1] for root, dirs, files in os.walk(path): for f in files: local_file_path = os.path.join(root, f) s3.upload_file(local_file_path, 'gseps-data', f'{folder_name}/{f}') except Exception as e: print(traceback.format_exc()) def image_copy_using_SCP(remote_path): try: # ssh connect ssh = paramiko.SSHClient() ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) ssh.connect(info['remote_server_ip'], username=info['remote_server_id'], password=info['remote_server_pw']) # copy image from edge sftp = ssh.open_sftp() image_name = remote_path.split('/')[-1] local_path = info['copied_image_path_from_remote_server'] if not os.path.exists(local_path): os.makedirs(local_path) sftp.get(remote_path, os.path.join(local_path, image_name)) sftp.remove(remote_path) sftp.close() return os.path.join(local_path, image_name) except Exception as e: print(e) return "" def get_info(body): payload = json.loads(body) image_path = payload['image_path'] return image_path class Consumer: def __init__(self): # with open("./config.json", "r") as f: # info = json.load(f.read()) # change hard coding to info_dictionary from config.json self.__url = info['amqp_url'] self.__port = info['amqp_port'] self.__vhost = info['amqp_vhost'] self.__cred = pika.PlainCredentials(info['amqp_id'], info['amqp_pw']) self.__TaskQ = info['amqp_taskq'] self.__ResultQ = info['amqp_resultq'] self.cfg = info['model_config'] self.mask_generator = None # self.cloud_vender = info['cloud_vender'] self.cloud_vender = 'S3' def initialize(self): start = time.time() sam = sam_model_registry['vit_h']( checkpoint=info['model_checkpoints']) sam = sam.to(self.cfg['device']) self.mask_generator = SamAutomaticMaskGenerator(model=sam, points_per_side=self.cfg['points_per_side'], pred_iou_thresh=self.cfg['pred_iou_thresh'], stability_score_thresh=self.cfg['stability_score_thresh'], crop_n_layers=self.cfg['crop_n_layers'], crop_n_points_downscale_factor=self.cfg[ 'crop_n_points_downscale_factor'], box_nms_thresh=self.cfg['box_nms_thresh'], min_mask_region_area=self.cfg['min_mask_region_area']) end = time.time() print(f'Initialize time: {(end - start) // 60}m {(end - start) % 60:4f}s') def image_upload_to_ncp(self, path): pass def image_upload_to_cloud(self, path): if self.cloud_vender == 'S3': image_upload_to_s3(path) elif self.cloud_vender == 'NCP': self.image_upload_to_ncp(path) def inference(self, image_path): image = cv2.imread(image_path) image_name = image_path.split('/')[-1].split('.')[0] result = self.mask_generator.generate(image) shape = result[0]['segmentation'].shape cumulated = np.zeros(shape) count = 0 sizes = [] for n, r in enumerate(result): if r['area'] < self.cfg['area_thresh'] * shape[0] * shape[1] and r['stability_score'] > self.cfg['stability_score_thresh']: if np.amax(cumulated + r['segmentation'].astype(int)) < 2: cumulated = cumulated + r['segmentation'].astype(int) count += 1 x, y, w, h = r['bbox'] sizes.append(np.mean([w, h])) save_path = os.path.join(info['inference_result_path'], image_name) if not os.path.exists(save_path): os.makedirs(save_path) cv2.imwrite(os.path.join(save_path, f'{image_name}.jpg'), image) plt.imsave(os.path.join(save_path, f'result_{image_name}.jpg'), cumulated) result_dict = {'image': os.path.join(save_path, f'{image_name}.jpg'), # 수정필요 'count': count, 'sizes': sizes} with open(os.path.join(save_path, f'result_{image_name}.json'), 'w') as f: json.dump(result_dict, f, indent=4) return result_dict, save_path def result_publish(self, channel, result): channel.basic_publish(exchange='', routing_key=self.__ResultQ, body=json.dumps(result)) print(f"Done!") def upload_to_database(self, results): """ Insert inference result data into DB(postgreSQL). Columns consist of file_name(or image_name), count. """ pass def main(self): try: conn = pika.BlockingConnection(pika.ConnectionParameters(self.__url, self.__port, self.__vhost, self.__cred)) chan = conn.channel() chan.queue_declare(queue=self.__TaskQ, durable=True) chan.queue_declare(queue=self.__ResultQ, durable=True) while True: method, properties, body = chan.basic_get(queue=self.__TaskQ, auto_ack=True) if not method: send_message_to_slack() break if method: print(f" [x] Received {body}", end=' | ', flush=True) edge_image_path = get_info(body) image_path = image_copy_using_SCP(edge_image_path) if image_path == "": continue result_dict, save_path = self.inference(image_path) self.result_publish(chan, result_dict) self.upload_to_database(result_dict) self.image_upload_to_cloud(save_path) time.sleep(1) else: time.sleep(0.5) except Exception as e: print(traceback.format_exc()) conn.close() if __name__ == "__main__": consumer = Consumer() consumer.initialize() consumer.main()