From b09f485c900af08047b0f524e5099e3ccba38161 Mon Sep 17 00:00:00 2001 From: ssung Date: Thu, 7 Sep 2023 15:45:52 +0900 Subject: [PATCH] FIX: function rule --- config/inference_config.json | 30 +++++--- ftp_server.service | 2 +- inference.py | 128 +++++++++++++++++++++++++---------- inference_processor.service | 3 +- 4 files changed, 113 insertions(+), 50 deletions(-) diff --git a/config/inference_config.json b/config/inference_config.json index a0e0374..89ca767 100644 --- a/config/inference_config.json +++ b/config/inference_config.json @@ -1,13 +1,21 @@ { - "slack_url": "https://hooks.slack.com/services/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", - "slack_message": "Inference Done.", - "amqp_url": "localhost", - "amqp_port": 5672, + "slack_url": "https://hooks.slack.com/services/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", + "slack_message": "Queue is Empty! Restart after 100 second", + "amqp_url": "13.209.39.139", + "amqp_port": 30747, "amqp_vhost": "/", - "amqp_id": "worker", - "amqp_pw": "gseps1234", - "amqp_taskq": "TaskQ", - "amqp_resultq": "ResultQ", + "amqp_id": "sdt", + "amqp_pw": "251327", + "amqp_TaskQ": "gseps-mq", + "amqp_ResultQ": "gseps-ResultQ", + "amqp_message_expire_time": "300000", + "Minio_url": "http://13.209.39.139:31191", + "AccessKey":"VV2gooVNevRAIg7HrXQr", + "SecretKey":"epJmFWxwfzUUgYeyDqLa8ouitHZaWTwAvPfPNUBL", + "Boto3SignatureVersion":"s3v4", + "Boto3RegionName":"us-east-1", + "BucketName":"gseps-test-a", + "download_data_path": "./data", "model_config": { "points_per_side": 36, "pred_iou_thresh": 0.86, @@ -19,10 +27,10 @@ "area_thresh": 0.1, "device": "cuda" }, - "model_checkpoints": "/home/sdt/Workspace/gseps_inference/weights/sam_vit_h_4b8939.pth", + "model_checkpoints": "/home/sdt/Workspace/gseps/inference/weights/sam_vit_h_4b8939.pth", "remote_server_ip": "25.15.14.31", "remote_server_id": "sdt", "remote_server_pw": "251327", - "copied_image_path_from_remote_server": "/home/sdt/Workspace/gseps_inference/image_bucket", - "inference_result_path": "/home/sdt/Workspace/gseps_inference/inference_result/" + "copied_image_path_from_remote_server": "/home/sdt/Workspace/gseps/inference/image_bucket", + "inference_result_path": "/home/sdt/Workspace/gseps/inference/result/" } diff --git a/ftp_server.service b/ftp_server.service index 217b0f7..2ac82c9 100644 --- a/ftp_server.service +++ b/ftp_server.service @@ -2,7 +2,7 @@ Description=FTP Server [Service] -ExecStart=/home/sdt/miniconda3/bin/python /home/sdt/Workspace/gseps/ftp_server.py +ExecStart=/home/sdt/miniconda3/bin/python /home/sdt/Workspace/gseps_inference/ftp_server.py Restart=on-failure Group=sdt User=sdt diff --git a/inference.py b/inference.py index 2a13fd8..2f4b7fe 100644 --- a/inference.py +++ b/inference.py @@ -3,24 +3,57 @@ import time import json import requests import traceback +import logging +import logging.handlers import cv2 import pika import boto3 +from botocore.client import Config import paramiko import torch import numpy as np import matplotlib.pyplot as plt + from segment_anything import sam_model_registry, SamAutomaticMaskGenerator ############################################### # Config # ############################################### -with open('/home/sdt/Workspace/gseps_inference/config/inference_config.json', 'r') as f: +with open(os.path.join(os.getcwd(),'config/inference_config.json'), 'r') as f: info = json.load(f) +############################################### +# Logger Setting # +############################################### +logger = logging.getLogger() +logger.setLevel(logging.INFO) +formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + +log_fileHandler = logging.handlers.RotatingFileHandler( + filename=os.path.join(os.getcwd(), "log_inference.log"), + maxBytes=1024000, + backupCount=3, + mode='a') + +log_fileHandler.setFormatter(formatter) +logger.addHandler(log_fileHandler) + +logging.getLogger("pika").setLevel(logging.WARNING) + +################################################################################ +# S3 Set up # +################################################################################ +s3 = boto3.resource('s3', + endpoint_url = info['Minio_url'], + aws_access_key_id=info['AccessKey'], + aws_secret_access_key=info['SecretKey'], + config=Config(signature_version=info['Boto3SignatureVersion']), + region_name=info['Boto3RegionName'] +) + def send_message_to_slack(): data = {"text": info['slack_message']} @@ -89,8 +122,8 @@ class Consumer: self.__port = info['amqp_port'] self.__vhost = info['amqp_vhost'] self.__cred = pika.PlainCredentials(info['amqp_id'], info['amqp_pw']) - self.__TaskQ = info['amqp_taskq'] - self.__ResultQ = info['amqp_resultq'] + self.__TaskQ = info['amqp_TaskQ'] + self.__ResultQ = info['amqp_ResultQ'] self.cfg = info['model_config'] @@ -127,44 +160,31 @@ class Consumer: elif self.cloud_vender == 'NCP': self.image_upload_to_ncp(path) - def inference(self, image_path): - image = cv2.imread(image_path) - image_name = image_path.split('/')[-1].split('.')[0] - + def inference(self, image): + result = self.mask_generator.generate(image) shape = result[0]['segmentation'].shape - cumulated = np.zeros(shape) + result_image = np.zeros(shape) count = 0 sizes = [] for n, r in enumerate(result): if r['area'] < self.cfg['area_thresh'] * shape[0] * shape[1] and r['stability_score'] > self.cfg['stability_score_thresh']: - if np.amax(cumulated + r['segmentation'].astype(int)) < 2: - cumulated = cumulated + r['segmentation'].astype(int) + if np.amax(result_image + r['segmentation'].astype(int)) < 2: + result_image = result_image + r['segmentation'].astype(int) count += 1 x, y, w, h = r['bbox'] sizes.append(np.mean([w, h])) - save_path = os.path.join(info['inference_result_path'], image_name) - if not os.path.exists(save_path): - os.makedirs(save_path) + return result_image, count, sizes - cv2.imwrite(os.path.join(save_path, f'{image_name}.jpg'), image) - plt.imsave(os.path.join(save_path, f'result_{image_name}.jpg'), cumulated) - result_dict = {'image': os.path.join(save_path, f'{image_name}.jpg'), # 수정필요 - 'count': count, - 'sizes': sizes} - - with open(os.path.join(save_path, f'result_{image_name}.json'), 'w') as f: - json.dump(result_dict, f, indent=4) - - return result_dict, save_path - - def result_publish(self, channel, result): + def result_publish(self, channel, properties, result): channel.basic_publish(exchange='', routing_key=self.__ResultQ, - body=json.dumps(result)) + body=json.dumps(result), + properties=properties + ) print(f"Done!") @@ -185,32 +205,66 @@ class Consumer: chan.queue_declare(queue=self.__TaskQ, durable=True) chan.queue_declare(queue=self.__ResultQ, durable=True) + + save_path = info['inference_result_path'] + while True: method, properties, body = chan.basic_get(queue=self.__TaskQ, auto_ack=True) + amqp_message_properties = pika.BasicProperties(expiration=info['amqp_message_expire_time']) + + # if Queue is empty if not method: send_message_to_slack() - break - + logger.info("Empty Queue sleep for 100s") + time.sleep(100) + if method: - print(f" [x] Received {body}", end=' | ', flush=True) - edge_image_path = get_info(body) - image_path = image_copy_using_SCP(edge_image_path) - - if image_path == "": + logger.info(f" [x] Received {body}") + Task_data = json.loads(body) + if(Task_data.get("to")): + download_path = os.path.join(os.getcwd(),info['download_data_path'],Task_data['to']['filename']) + s3.Bucket(Task_data['to']['Bucket']).download_file(Task_data['to']['filename'],download_path) + else: + logger.info("Check Message Data. key 'to' is missing") continue - result_dict, save_path = self.inference(image_path) - self.result_publish(chan, result_dict) - self.upload_to_database(result_dict) - self.image_upload_to_cloud(save_path) + # read image file + image = cv2.imread(download_path) + # get file name + image_name = download_path.split('/')[-1].split('.')[0] + + # run inference + result_image, count, sizes = self.inference(image) + + # delete original file + os.remove(download_path) + + # save reulst image + result_filename = f'result_{image_name}.jpg' + plt.imsave(os.path.join(save_path, result_filename), result_image) + + # message contents set-up + Task_data['Type']="inference_result" + Task_data['result']={ + 'timestamp':int(time.time()*1000), + 'count':count, + 'sizes':sizes, + 'filename':result_filename + } + Task_data['from']=Task_data['to'] + Task_data.pop('to',None) + + # send message to AMQP Result Queue + self.result_publish(chan, amqp_message_properties,Task_data) time.sleep(1) else: time.sleep(0.5) except Exception as e: + print(e) print(traceback.format_exc()) conn.close() diff --git a/inference_processor.service b/inference_processor.service index dad465b..0a6ece6 100644 --- a/inference_processor.service +++ b/inference_processor.service @@ -2,7 +2,8 @@ Description=inference processor [Service] -ExecStart=/home/sdt/miniconda3/bin/python /home/sdt/Workspace/gseps/rabbitmq_test/inference.py +ExecStart=/home/sdt/miniconda3/bin/python /home/sdt/Workspace/gseps/inference/inference.py +WorkingDirectory=/home/sdt/Workspace/gseps/inference/ Restart=on-failure Group=sdt User=sdt