commit 46fe2a88f2153899eae3e946e4b56c2fd997b42c Author: ssung Date: Thu Aug 31 18:32:31 2023 +0900 inference server code diff --git a/config/ftp_config.json b/config/ftp_config.json new file mode 100644 index 0000000..af8c473 --- /dev/null +++ b/config/ftp_config.json @@ -0,0 +1,7 @@ +{ + "ftp_ip": "0.0.0.0", + "ftp_port": 21, + "ftp_id": "sdt", + "ftp_pw": "251327", + "ftp_root_dir": "/home/sdt/Workspace/gseps/rabbitmq_test/inference_result/" +} \ No newline at end of file diff --git a/config/inference_config.json b/config/inference_config.json new file mode 100644 index 0000000..2c0a7b1 --- /dev/null +++ b/config/inference_config.json @@ -0,0 +1,28 @@ +{ + "slack_url": "https://hooks.slack.com/services/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", + "slack_message": "Inference Done.", + "amqp_url": "localhost", + "amqp_port": 5672, + "amqp_vhost": "/", + "amqp_id": "worker", + "amqp_pw": "gseps1234", + "amqp_taskq": "TaskQ", + "amqp_resultq": "ResultQ", + "model_config": { + "points_per_side": 36, + "pred_iou_thresh": 0.86, + "stability_score_thresh": 0.9, + "crop_n_layers": 1, + "crop_n_points_downscale_factor": 1, + "box_nms_thresh": 0.8, + "min_mask_region_area": 10, + "area_thresh": 0.1, + "device": "cuda" + }, + "model_checkpoints": "/home/sdt/Workspace/gseps/rabbitmq_test/weights/sam_vit_h_4b8939.pth", + "remote_server_ip": "25.15.14.31", + "remote_server_id": "sdt", + "remote_server_pw": "251327", + "copied_image_path_from_remote_server": "/home/sdt/Workspace/gseps/rabbitmq_test/image_bucket", + "inference_result_path": "/home/sdt/Workspace/gseps/rabbitmq_test/inference_result/" +} \ No newline at end of file diff --git a/ftp_server.py b/ftp_server.py new file mode 100644 index 0000000..548e21d --- /dev/null +++ b/ftp_server.py @@ -0,0 +1,23 @@ +import json +from pyftpdlib.authorizers import DummyAuthorizer +from pyftpdlib.handlers import FTPHandler +from pyftpdlib.servers import FTPServer + +# config +with open('./config/ftp_config.json', 'r') as f: + info = json.load(f) + +# 서버 설정 +authorizer = DummyAuthorizer() +authorizer.add_user(info['ftp_id'], + info['ftp_pw'], + info['ftp_root_dir'], + perm="elradfmw") + +handler = FTPHandler +handler.authorizer = authorizer + +# 서버 시작 +address = (info['ftp_ip'], info['ftp_port']) # 서버 주소와 포트 +server = FTPServer(address, handler) +server.serve_forever() diff --git a/ftp_server.service b/ftp_server.service new file mode 100644 index 0000000..9c37bac --- /dev/null +++ b/ftp_server.service @@ -0,0 +1,9 @@ +[Unit] +Description=FTP Server + +[Service] +ExecStart=/home/sdt/miniconda3/bin/python /home/sdt/Workspace/gseps/ftp_server.py +Restart=on-failure + +[Install] +WantedBy=multi-user.target diff --git a/inference.py b/inference.py new file mode 100644 index 0000000..ba2590b --- /dev/null +++ b/inference.py @@ -0,0 +1,221 @@ +import os +import time +import json +import requests +import traceback + +import cv2 +import pika +import boto3 +import paramiko + +import torch +import numpy as np +import matplotlib.pyplot as plt + +from segment_anything import sam_model_registry, SamAutomaticMaskGenerator + +############################################### +# Config # +############################################### +with open('./config/inference_config.json', 'r') as f: + info = json.load(f) + + +def send_message_to_slack(): + data = {"text": info['slack_message']} + req = requests.post( + url=info['slack_url'], + data=json.dumps(data) + ) + + +def image_upload_to_s3(path): + try: + s3 = boto3.client('s3') + + folder_name = path.split('/')[-1] + for root, dirs, files in os.walk(path): + for f in files: + local_file_path = os.path.join(root, f) + s3.upload_file(local_file_path, 'gseps-data', f'{folder_name}/{f}') + except Exception as e: + print(traceback.format_exc()) + + +def image_copy_using_SCP(remote_path): + try: + # ssh connect + ssh = paramiko.SSHClient() + ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + ssh.connect(info['remote_server_ip'], username=info['remote_server_id'], password=info['remote_server_pw']) + + # copy image from edge + sftp = ssh.open_sftp() + + image_name = remote_path.split('/')[-1] + local_path = info['copied_image_path_from_remote_server'] + + if not os.path.exists(local_path): + os.makedirs(local_path) + + sftp.get(remote_path, os.path.join(local_path, image_name)) + sftp.remove(remote_path) + + sftp.close() + + return os.path.join(local_path, image_name) + except Exception as e: + print(e) + return "" + + +def get_info(body): + payload = json.loads(body) + image_path = payload['image_path'] + + return image_path + + +class Consumer: + def __init__(self): + + # with open("./config.json", "r") as f: + # info = json.load(f.read()) + + # change hard coding to info_dictionary from config.json + + self.__url = info['amqp_url'] + self.__port = info['amqp_port'] + self.__vhost = info['amqp_vhost'] + self.__cred = pika.PlainCredentials(info['amqp_id'], info['amqp_pw']) + self.__TaskQ = info['amqp_taskq'] + self.__ResultQ = info['amqp_resultq'] + + self.cfg = info['model_config'] + + self.mask_generator = None + + # self.cloud_vender = info['cloud_vender'] + self.cloud_vender = 'S3' + + def initialize(self): + start = time.time() + + sam = sam_model_registry['vit_h']( + checkpoint=info['model_checkpoints']) + sam = sam.to(self.cfg['device']) + + self.mask_generator = SamAutomaticMaskGenerator(model=sam, + points_per_side=self.cfg['points_per_side'], + pred_iou_thresh=self.cfg['pred_iou_thresh'], + stability_score_thresh=self.cfg['stability_score_thresh'], + crop_n_layers=self.cfg['crop_n_layers'], + crop_n_points_downscale_factor=self.cfg[ + 'crop_n_points_downscale_factor'], + box_nms_thresh=self.cfg['box_nms_thresh'], + min_mask_region_area=self.cfg['min_mask_region_area']) + end = time.time() + print(f'Initialize time: {(end - start) // 60}m {(end - start) % 60:4f}s') + + def image_upload_to_ncp(self, path): + pass + + def image_upload_to_cloud(self, path): + if self.cloud_vender == 'S3': + image_upload_to_s3(path) + elif self.cloud_vender == 'NCP': + self.image_upload_to_ncp(path) + + def inference(self, image_path): + image = cv2.imread(image_path) + image_name = image_path.split('/')[-1].split('.')[0] + + result = self.mask_generator.generate(image) + shape = result[0]['segmentation'].shape + cumulated = np.zeros(shape) + count = 0 + sizes = [] + + for n, r in enumerate(result): + if r['area'] < self.cfg['area_thresh'] * shape[0] * shape[1] and r['stability_score'] > self.cfg['stability_score_thresh']: + if np.amax(cumulated + r['segmentation'].astype(int)) < 2: + cumulated = cumulated + r['segmentation'].astype(int) + count += 1 + x, y, w, h = r['bbox'] + sizes.append(np.mean([w, h])) + + save_path = os.path.join(info['inference_result_path'], image_name) + if not os.path.exists(save_path): + os.makedirs(save_path) + + cv2.imwrite(os.path.join(save_path, f'{image_name}.jpg'), image) + plt.imsave(os.path.join(save_path, f'result_{image_name}.jpg'), cumulated) + + result_dict = {'image': os.path.join(save_path, f'{image_name}.jpg'), # 수정필요 + 'count': count, + 'sizes': sizes} + + with open(os.path.join(save_path, f'result_{image_name}.json'), 'w') as f: + json.dump(result_dict, f, indent=4) + + return result_dict, save_path + + def result_publish(self, channel, result): + channel.basic_publish(exchange='', + routing_key=self.__ResultQ, + body=json.dumps(result)) + + print(f"Done!") + + def upload_to_database(self, results): + """ + Insert inference result data into DB(postgreSQL). + Columns consist of file_name(or image_name), count. + """ + pass + + def main(self): + try: + conn = pika.BlockingConnection(pika.ConnectionParameters(self.__url, + self.__port, + self.__vhost, + self.__cred)) + chan = conn.channel() + chan.queue_declare(queue=self.__TaskQ, durable=True) + chan.queue_declare(queue=self.__ResultQ, durable=True) + + while True: + method, properties, body = chan.basic_get(queue=self.__TaskQ, + auto_ack=True) + + if not method: + send_message_to_slack() + break + + if method: + print(f" [x] Received {body}", end=' | ', flush=True) + edge_image_path = get_info(body) + image_path = image_copy_using_SCP(edge_image_path) + + if image_path == "": + continue + + result_dict, save_path = self.inference(image_path) + self.result_publish(chan, result_dict) + self.upload_to_database(result_dict) + self.image_upload_to_cloud(save_path) + + time.sleep(1) + else: + time.sleep(0.5) + + except Exception as e: + print(traceback.format_exc()) + conn.close() + + +if __name__ == "__main__": + consumer = Consumer() + consumer.initialize() + consumer.main() diff --git a/inference_processor.service b/inference_processor.service new file mode 100644 index 0000000..25f60af --- /dev/null +++ b/inference_processor.service @@ -0,0 +1,9 @@ +[Unit] +Description=inference processor + +[Service] +ExecStart=/home/sdt/miniconda3/bin/python /home/sdt/Workspace/gseps/rabbitmq_test/inference.py +Restart=on-failure + +[Install] +WantedBy=multi-user.target