import os import time import json import requests import traceback import logging import logging.handlers import cv2 import pika import boto3 from botocore.client import Config import paramiko import torch import numpy as np import matplotlib.pyplot as plt from segment_anything import sam_model_registry, SamAutomaticMaskGenerator ############################################### # Config # ############################################### with open(os.path.join(os.getcwd(),'config/inference_config.json'), 'r') as f: info = json.load(f) ############################################### # Logger Setting # ############################################### logger = logging.getLogger() logger.setLevel(logging.INFO) formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') log_fileHandler = logging.handlers.RotatingFileHandler( filename=os.path.join(os.getcwd(), "log_inference.log"), maxBytes=1024000, backupCount=3, mode='a') log_fileHandler.setFormatter(formatter) logger.addHandler(log_fileHandler) logging.getLogger("pika").setLevel(logging.WARNING) ################################################################################ # S3 Set up # ################################################################################ s3 = boto3.resource('s3', endpoint_url = info['Minio_url'], aws_access_key_id=info['AccessKey'], aws_secret_access_key=info['SecretKey'], config=Config(signature_version=info['Boto3SignatureVersion']), region_name=info['Boto3RegionName'] ) def send_message_to_slack(message): data = {"text": message} req = requests.post( url=info['slack_url'], data=json.dumps(data) ) def image_upload_to_s3(path): try: s3 = boto3.client('s3') folder_name = path.split('/')[-1] for root, dirs, files in os.walk(path): for f in files: local_file_path = os.path.join(root, f) s3.upload_file(local_file_path, 'gseps-data', f'{folder_name}/{f}') except Exception as e: print(traceback.format_exc()) def image_copy_using_SCP(remote_path): try: # ssh connect ssh = paramiko.SSHClient() ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) ssh.connect(info['remote_server_ip'], username=info['remote_server_id'], password=info['remote_server_pw']) # copy image from edge sftp = ssh.open_sftp() image_name = remote_path.split('/')[-1] local_path = info['copied_image_path_from_remote_server'] if not os.path.exists(local_path): os.makedirs(local_path) sftp.get(remote_path, os.path.join(local_path, image_name)) sftp.remove(remote_path) sftp.close() return os.path.join(local_path, image_name) except Exception as e: print(e) return "" def get_info(body): payload = json.loads(body) image_path = payload['image_path'] return image_path class Consumer: def __init__(self): # with open("./config.json", "r") as f: # info = json.load(f.read()) # change hard coding to info_dictionary from config.json self.__url = info['amqp_url'] self.__port = info['amqp_port'] self.__vhost = info['amqp_vhost'] self.__cred = pika.PlainCredentials(info['amqp_id'], info['amqp_pw']) self.__TaskQ = info['amqp_TaskQ'] self.__ResultQ = info['amqp_ResultQ'] self.cfg = info['model_config'] self.mask_generator = None # self.cloud_vender = info['cloud_vender'] self.cloud_vender = 'S3' def initialize(self): start = time.time() sam = sam_model_registry['vit_h']( checkpoint=info['model_checkpoints']) sam = sam.to(self.cfg['device']) self.mask_generator = SamAutomaticMaskGenerator(model=sam, points_per_side=self.cfg['points_per_side'], pred_iou_thresh=self.cfg['pred_iou_thresh'], stability_score_thresh=self.cfg['stability_score_thresh'], crop_n_layers=self.cfg['crop_n_layers'], crop_n_points_downscale_factor=self.cfg[ 'crop_n_points_downscale_factor'], box_nms_thresh=self.cfg['box_nms_thresh'], min_mask_region_area=self.cfg['min_mask_region_area']) end = time.time() logger.info(f"Initialize {str(info['model_config'])}") print(f'Initialize time: {(end - start) // 60}m {(end - start) % 60:4f}s') def image_upload_to_ncp(self, path): pass def image_upload_to_cloud(self, path): if self.cloud_vender == 'S3': image_upload_to_s3(path) elif self.cloud_vender == 'NCP': self.image_upload_to_ncp(path) def inference(self, image): result = self.mask_generator.generate(image) shape = result[0]['segmentation'].shape result_image = np.zeros(shape) count = 0 sizes = [] for n, r in enumerate(result): if r['area'] < self.cfg['area_thresh'] * shape[0] * shape[1] and r['stability_score'] > self.cfg['stability_score_thresh']: if np.amax(result_image + r['segmentation'].astype(int)) < 2: result_image = result_image + r['segmentation'].astype(int) count += 1 x, y, w, h = r['bbox'] sizes.append(np.mean([w, h])) return result_image, count, sizes def result_publish(self, channel, properties, result): channel.basic_publish(exchange='', routing_key=self.__ResultQ, body=json.dumps(result), properties=properties ) print(f"Done!") def upload_to_database(self, results): """ Insert inference result data into DB(postgreSQL). Columns consist of file_name(or image_name), count. """ pass def main(self): try: conn = pika.BlockingConnection(pika.ConnectionParameters(self.__url, self.__port, self.__vhost, self.__cred)) chan = conn.channel() chan.queue_declare(queue=self.__TaskQ, durable=True) chan.queue_declare(queue=self.__ResultQ, durable=True) save_path = info['inference_result_path'] while True: try: method, properties, body = chan.basic_get(queue=self.__TaskQ, auto_ack=True) amqp_message_properties = pika.BasicProperties(expiration=info['amqp_message_expire_time']) # if Queue is empty if not method: send_message_to_slack("Empty Queue") logger.info(f"Empty Queue sleep for {info['amqp_Q_check_interval']}") time.sleep(info['amqp_Q_check_interval']) if method: logger.info(f" [x] Received {body}") Task_data = json.loads(body) if(Task_data.get("to")): download_path = os.path.join(os.getcwd(),info['download_data_path'],Task_data['to']['filename']) s3.Bucket(Task_data['to']['bucket']).download_file(Task_data['to']['filename'],download_path) else: logger.info("Check Message Data. key 'to' is missing") continue # read image file image = cv2.imread(download_path) # get file name image_name = download_path.split('/')[-1].split('.')[0] # run inference result_image, count, sizes = self.inference(image) # delete original file os.remove(download_path) logger.info(f" len(sizes) : {len(sizes)}") if(len(sizes) < 30): logger.info("PASS") continue # save reulst image result_filename = f'result_{image_name}.jpg' plt.imsave(os.path.join(save_path, result_filename), result_image) # calculation for histogram fig, axs = plt.subplots(1,1) n_bins = info['n_bins_for_histogram'] xy = axs.hist(sizes, bins=n_bins) print(xy) y = xy[0].astype(np.int32) x = xy[1] hist_y = y.tolist() hist_x = x.tolist() str_x = [] for single_x in hist_x: str_x.append(str(single_x)) hist_x = str_x.copy() # message contents set-up Task_data['type']="inference_result" Task_data['result']={ 'timestamp':int(time.time()*1000), 'count':count, 'sizes':sizes, 'filename':result_filename, 'hist':{ 'x':hist_x, 'y':hist_y } } Task_data['from']=Task_data['to'] Task_data.pop('to',None) # send message to AMQP Result Queue self.result_publish(chan, amqp_message_properties,Task_data) time.sleep(1) else: time.sleep(0.5) except Exception as e: print(traceback.format_exc()) send_message_to_slack(f"who : inference_server // error {str(e)}") logger.error(str(e)) continue except Exception as e: print(e) print(traceback.format_exc()) send_message_to_slack(f"who : inference_server // error {str(e)}") logger.error(str(e)) conn.close() if __name__ == "__main__": consumer = Consumer() consumer.initialize() consumer.main()