FIX: function rule

This commit is contained in:
ssung 2023-09-07 15:45:52 +09:00
parent 5536ec473d
commit b09f485c90
4 changed files with 113 additions and 50 deletions

View File

@ -1,13 +1,21 @@
{ {
"slack_url": "https://hooks.slack.com/services/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", "slack_url": "https://hooks.slack.com/services/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",
"slack_message": "Inference Done.", "slack_message": "Queue is Empty! Restart after 100 second",
"amqp_url": "localhost", "amqp_url": "13.209.39.139",
"amqp_port": 5672, "amqp_port": 30747,
"amqp_vhost": "/", "amqp_vhost": "/",
"amqp_id": "worker", "amqp_id": "sdt",
"amqp_pw": "gseps1234", "amqp_pw": "251327",
"amqp_taskq": "TaskQ", "amqp_TaskQ": "gseps-mq",
"amqp_resultq": "ResultQ", "amqp_ResultQ": "gseps-ResultQ",
"amqp_message_expire_time": "300000",
"Minio_url": "http://13.209.39.139:31191",
"AccessKey":"VV2gooVNevRAIg7HrXQr",
"SecretKey":"epJmFWxwfzUUgYeyDqLa8ouitHZaWTwAvPfPNUBL",
"Boto3SignatureVersion":"s3v4",
"Boto3RegionName":"us-east-1",
"BucketName":"gseps-test-a",
"download_data_path": "./data",
"model_config": { "model_config": {
"points_per_side": 36, "points_per_side": 36,
"pred_iou_thresh": 0.86, "pred_iou_thresh": 0.86,
@ -19,10 +27,10 @@
"area_thresh": 0.1, "area_thresh": 0.1,
"device": "cuda" "device": "cuda"
}, },
"model_checkpoints": "/home/sdt/Workspace/gseps_inference/weights/sam_vit_h_4b8939.pth", "model_checkpoints": "/home/sdt/Workspace/gseps/inference/weights/sam_vit_h_4b8939.pth",
"remote_server_ip": "25.15.14.31", "remote_server_ip": "25.15.14.31",
"remote_server_id": "sdt", "remote_server_id": "sdt",
"remote_server_pw": "251327", "remote_server_pw": "251327",
"copied_image_path_from_remote_server": "/home/sdt/Workspace/gseps_inference/image_bucket", "copied_image_path_from_remote_server": "/home/sdt/Workspace/gseps/inference/image_bucket",
"inference_result_path": "/home/sdt/Workspace/gseps_inference/inference_result/" "inference_result_path": "/home/sdt/Workspace/gseps/inference/result/"
} }

View File

@ -2,7 +2,7 @@
Description=FTP Server Description=FTP Server
[Service] [Service]
ExecStart=/home/sdt/miniconda3/bin/python /home/sdt/Workspace/gseps/ftp_server.py ExecStart=/home/sdt/miniconda3/bin/python /home/sdt/Workspace/gseps_inference/ftp_server.py
Restart=on-failure Restart=on-failure
Group=sdt Group=sdt
User=sdt User=sdt

View File

@ -3,24 +3,57 @@ import time
import json import json
import requests import requests
import traceback import traceback
import logging
import logging.handlers
import cv2 import cv2
import pika import pika
import boto3 import boto3
from botocore.client import Config
import paramiko import paramiko
import torch import torch
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
############################################### ###############################################
# Config # # Config #
############################################### ###############################################
with open('/home/sdt/Workspace/gseps_inference/config/inference_config.json', 'r') as f: with open(os.path.join(os.getcwd(),'config/inference_config.json'), 'r') as f:
info = json.load(f) info = json.load(f)
###############################################
# Logger Setting #
###############################################
logger = logging.getLogger()
logger.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
log_fileHandler = logging.handlers.RotatingFileHandler(
filename=os.path.join(os.getcwd(), "log_inference.log"),
maxBytes=1024000,
backupCount=3,
mode='a')
log_fileHandler.setFormatter(formatter)
logger.addHandler(log_fileHandler)
logging.getLogger("pika").setLevel(logging.WARNING)
################################################################################
# S3 Set up #
################################################################################
s3 = boto3.resource('s3',
endpoint_url = info['Minio_url'],
aws_access_key_id=info['AccessKey'],
aws_secret_access_key=info['SecretKey'],
config=Config(signature_version=info['Boto3SignatureVersion']),
region_name=info['Boto3RegionName']
)
def send_message_to_slack(): def send_message_to_slack():
data = {"text": info['slack_message']} data = {"text": info['slack_message']}
@ -89,8 +122,8 @@ class Consumer:
self.__port = info['amqp_port'] self.__port = info['amqp_port']
self.__vhost = info['amqp_vhost'] self.__vhost = info['amqp_vhost']
self.__cred = pika.PlainCredentials(info['amqp_id'], info['amqp_pw']) self.__cred = pika.PlainCredentials(info['amqp_id'], info['amqp_pw'])
self.__TaskQ = info['amqp_taskq'] self.__TaskQ = info['amqp_TaskQ']
self.__ResultQ = info['amqp_resultq'] self.__ResultQ = info['amqp_ResultQ']
self.cfg = info['model_config'] self.cfg = info['model_config']
@ -127,44 +160,31 @@ class Consumer:
elif self.cloud_vender == 'NCP': elif self.cloud_vender == 'NCP':
self.image_upload_to_ncp(path) self.image_upload_to_ncp(path)
def inference(self, image_path): def inference(self, image):
image = cv2.imread(image_path)
image_name = image_path.split('/')[-1].split('.')[0]
result = self.mask_generator.generate(image) result = self.mask_generator.generate(image)
shape = result[0]['segmentation'].shape shape = result[0]['segmentation'].shape
cumulated = np.zeros(shape) result_image = np.zeros(shape)
count = 0 count = 0
sizes = [] sizes = []
for n, r in enumerate(result): for n, r in enumerate(result):
if r['area'] < self.cfg['area_thresh'] * shape[0] * shape[1] and r['stability_score'] > self.cfg['stability_score_thresh']: if r['area'] < self.cfg['area_thresh'] * shape[0] * shape[1] and r['stability_score'] > self.cfg['stability_score_thresh']:
if np.amax(cumulated + r['segmentation'].astype(int)) < 2: if np.amax(result_image + r['segmentation'].astype(int)) < 2:
cumulated = cumulated + r['segmentation'].astype(int) result_image = result_image + r['segmentation'].astype(int)
count += 1 count += 1
x, y, w, h = r['bbox'] x, y, w, h = r['bbox']
sizes.append(np.mean([w, h])) sizes.append(np.mean([w, h]))
save_path = os.path.join(info['inference_result_path'], image_name) return result_image, count, sizes
if not os.path.exists(save_path):
os.makedirs(save_path)
cv2.imwrite(os.path.join(save_path, f'{image_name}.jpg'), image)
plt.imsave(os.path.join(save_path, f'result_{image_name}.jpg'), cumulated)
result_dict = {'image': os.path.join(save_path, f'{image_name}.jpg'), # 수정필요 def result_publish(self, channel, properties, result):
'count': count,
'sizes': sizes}
with open(os.path.join(save_path, f'result_{image_name}.json'), 'w') as f:
json.dump(result_dict, f, indent=4)
return result_dict, save_path
def result_publish(self, channel, result):
channel.basic_publish(exchange='', channel.basic_publish(exchange='',
routing_key=self.__ResultQ, routing_key=self.__ResultQ,
body=json.dumps(result)) body=json.dumps(result),
properties=properties
)
print(f"Done!") print(f"Done!")
@ -185,32 +205,66 @@ class Consumer:
chan.queue_declare(queue=self.__TaskQ, durable=True) chan.queue_declare(queue=self.__TaskQ, durable=True)
chan.queue_declare(queue=self.__ResultQ, durable=True) chan.queue_declare(queue=self.__ResultQ, durable=True)
save_path = info['inference_result_path']
while True: while True:
method, properties, body = chan.basic_get(queue=self.__TaskQ, method, properties, body = chan.basic_get(queue=self.__TaskQ,
auto_ack=True) auto_ack=True)
amqp_message_properties = pika.BasicProperties(expiration=info['amqp_message_expire_time'])
# if Queue is empty
if not method: if not method:
send_message_to_slack() send_message_to_slack()
break logger.info("Empty Queue sleep for 100s")
time.sleep(100)
if method: if method:
print(f" [x] Received {body}", end=' | ', flush=True) logger.info(f" [x] Received {body}")
edge_image_path = get_info(body) Task_data = json.loads(body)
image_path = image_copy_using_SCP(edge_image_path) if(Task_data.get("to")):
download_path = os.path.join(os.getcwd(),info['download_data_path'],Task_data['to']['filename'])
if image_path == "": s3.Bucket(Task_data['to']['Bucket']).download_file(Task_data['to']['filename'],download_path)
else:
logger.info("Check Message Data. key 'to' is missing")
continue continue
result_dict, save_path = self.inference(image_path) # read image file
self.result_publish(chan, result_dict) image = cv2.imread(download_path)
self.upload_to_database(result_dict)
self.image_upload_to_cloud(save_path)
# get file name
image_name = download_path.split('/')[-1].split('.')[0]
# run inference
result_image, count, sizes = self.inference(image)
# delete original file
os.remove(download_path)
# save reulst image
result_filename = f'result_{image_name}.jpg'
plt.imsave(os.path.join(save_path, result_filename), result_image)
# message contents set-up
Task_data['Type']="inference_result"
Task_data['result']={
'timestamp':int(time.time()*1000),
'count':count,
'sizes':sizes,
'filename':result_filename
}
Task_data['from']=Task_data['to']
Task_data.pop('to',None)
# send message to AMQP Result Queue
self.result_publish(chan, amqp_message_properties,Task_data)
time.sleep(1) time.sleep(1)
else: else:
time.sleep(0.5) time.sleep(0.5)
except Exception as e: except Exception as e:
print(e)
print(traceback.format_exc()) print(traceback.format_exc())
conn.close() conn.close()

View File

@ -2,7 +2,8 @@
Description=inference processor Description=inference processor
[Service] [Service]
ExecStart=/home/sdt/miniconda3/bin/python /home/sdt/Workspace/gseps/rabbitmq_test/inference.py ExecStart=/home/sdt/miniconda3/bin/python /home/sdt/Workspace/gseps/inference/inference.py
WorkingDirectory=/home/sdt/Workspace/gseps/inference/
Restart=on-failure Restart=on-failure
Group=sdt Group=sdt
User=sdt User=sdt