FIX: function rule

This commit is contained in:
ssung 2023-09-07 15:45:52 +09:00
parent 5536ec473d
commit b09f485c90
4 changed files with 113 additions and 50 deletions

View File

@ -1,13 +1,21 @@
{
"slack_url": "https://hooks.slack.com/services/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",
"slack_message": "Inference Done.",
"amqp_url": "localhost",
"amqp_port": 5672,
"slack_url": "https://hooks.slack.com/services/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",
"slack_message": "Queue is Empty! Restart after 100 second",
"amqp_url": "13.209.39.139",
"amqp_port": 30747,
"amqp_vhost": "/",
"amqp_id": "worker",
"amqp_pw": "gseps1234",
"amqp_taskq": "TaskQ",
"amqp_resultq": "ResultQ",
"amqp_id": "sdt",
"amqp_pw": "251327",
"amqp_TaskQ": "gseps-mq",
"amqp_ResultQ": "gseps-ResultQ",
"amqp_message_expire_time": "300000",
"Minio_url": "http://13.209.39.139:31191",
"AccessKey":"VV2gooVNevRAIg7HrXQr",
"SecretKey":"epJmFWxwfzUUgYeyDqLa8ouitHZaWTwAvPfPNUBL",
"Boto3SignatureVersion":"s3v4",
"Boto3RegionName":"us-east-1",
"BucketName":"gseps-test-a",
"download_data_path": "./data",
"model_config": {
"points_per_side": 36,
"pred_iou_thresh": 0.86,
@ -19,10 +27,10 @@
"area_thresh": 0.1,
"device": "cuda"
},
"model_checkpoints": "/home/sdt/Workspace/gseps_inference/weights/sam_vit_h_4b8939.pth",
"model_checkpoints": "/home/sdt/Workspace/gseps/inference/weights/sam_vit_h_4b8939.pth",
"remote_server_ip": "25.15.14.31",
"remote_server_id": "sdt",
"remote_server_pw": "251327",
"copied_image_path_from_remote_server": "/home/sdt/Workspace/gseps_inference/image_bucket",
"inference_result_path": "/home/sdt/Workspace/gseps_inference/inference_result/"
"copied_image_path_from_remote_server": "/home/sdt/Workspace/gseps/inference/image_bucket",
"inference_result_path": "/home/sdt/Workspace/gseps/inference/result/"
}

View File

@ -2,7 +2,7 @@
Description=FTP Server
[Service]
ExecStart=/home/sdt/miniconda3/bin/python /home/sdt/Workspace/gseps/ftp_server.py
ExecStart=/home/sdt/miniconda3/bin/python /home/sdt/Workspace/gseps_inference/ftp_server.py
Restart=on-failure
Group=sdt
User=sdt

View File

@ -3,24 +3,57 @@ import time
import json
import requests
import traceback
import logging
import logging.handlers
import cv2
import pika
import boto3
from botocore.client import Config
import paramiko
import torch
import numpy as np
import matplotlib.pyplot as plt
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
###############################################
# Config #
###############################################
with open('/home/sdt/Workspace/gseps_inference/config/inference_config.json', 'r') as f:
with open(os.path.join(os.getcwd(),'config/inference_config.json'), 'r') as f:
info = json.load(f)
###############################################
# Logger Setting #
###############################################
logger = logging.getLogger()
logger.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
log_fileHandler = logging.handlers.RotatingFileHandler(
filename=os.path.join(os.getcwd(), "log_inference.log"),
maxBytes=1024000,
backupCount=3,
mode='a')
log_fileHandler.setFormatter(formatter)
logger.addHandler(log_fileHandler)
logging.getLogger("pika").setLevel(logging.WARNING)
################################################################################
# S3 Set up #
################################################################################
s3 = boto3.resource('s3',
endpoint_url = info['Minio_url'],
aws_access_key_id=info['AccessKey'],
aws_secret_access_key=info['SecretKey'],
config=Config(signature_version=info['Boto3SignatureVersion']),
region_name=info['Boto3RegionName']
)
def send_message_to_slack():
data = {"text": info['slack_message']}
@ -89,8 +122,8 @@ class Consumer:
self.__port = info['amqp_port']
self.__vhost = info['amqp_vhost']
self.__cred = pika.PlainCredentials(info['amqp_id'], info['amqp_pw'])
self.__TaskQ = info['amqp_taskq']
self.__ResultQ = info['amqp_resultq']
self.__TaskQ = info['amqp_TaskQ']
self.__ResultQ = info['amqp_ResultQ']
self.cfg = info['model_config']
@ -127,44 +160,31 @@ class Consumer:
elif self.cloud_vender == 'NCP':
self.image_upload_to_ncp(path)
def inference(self, image_path):
image = cv2.imread(image_path)
image_name = image_path.split('/')[-1].split('.')[0]
def inference(self, image):
result = self.mask_generator.generate(image)
shape = result[0]['segmentation'].shape
cumulated = np.zeros(shape)
result_image = np.zeros(shape)
count = 0
sizes = []
for n, r in enumerate(result):
if r['area'] < self.cfg['area_thresh'] * shape[0] * shape[1] and r['stability_score'] > self.cfg['stability_score_thresh']:
if np.amax(cumulated + r['segmentation'].astype(int)) < 2:
cumulated = cumulated + r['segmentation'].astype(int)
if np.amax(result_image + r['segmentation'].astype(int)) < 2:
result_image = result_image + r['segmentation'].astype(int)
count += 1
x, y, w, h = r['bbox']
sizes.append(np.mean([w, h]))
save_path = os.path.join(info['inference_result_path'], image_name)
if not os.path.exists(save_path):
os.makedirs(save_path)
return result_image, count, sizes
cv2.imwrite(os.path.join(save_path, f'{image_name}.jpg'), image)
plt.imsave(os.path.join(save_path, f'result_{image_name}.jpg'), cumulated)
result_dict = {'image': os.path.join(save_path, f'{image_name}.jpg'), # 수정필요
'count': count,
'sizes': sizes}
with open(os.path.join(save_path, f'result_{image_name}.json'), 'w') as f:
json.dump(result_dict, f, indent=4)
return result_dict, save_path
def result_publish(self, channel, result):
def result_publish(self, channel, properties, result):
channel.basic_publish(exchange='',
routing_key=self.__ResultQ,
body=json.dumps(result))
body=json.dumps(result),
properties=properties
)
print(f"Done!")
@ -185,32 +205,66 @@ class Consumer:
chan.queue_declare(queue=self.__TaskQ, durable=True)
chan.queue_declare(queue=self.__ResultQ, durable=True)
save_path = info['inference_result_path']
while True:
method, properties, body = chan.basic_get(queue=self.__TaskQ,
auto_ack=True)
amqp_message_properties = pika.BasicProperties(expiration=info['amqp_message_expire_time'])
# if Queue is empty
if not method:
send_message_to_slack()
break
logger.info("Empty Queue sleep for 100s")
time.sleep(100)
if method:
print(f" [x] Received {body}", end=' | ', flush=True)
edge_image_path = get_info(body)
image_path = image_copy_using_SCP(edge_image_path)
if image_path == "":
logger.info(f" [x] Received {body}")
Task_data = json.loads(body)
if(Task_data.get("to")):
download_path = os.path.join(os.getcwd(),info['download_data_path'],Task_data['to']['filename'])
s3.Bucket(Task_data['to']['Bucket']).download_file(Task_data['to']['filename'],download_path)
else:
logger.info("Check Message Data. key 'to' is missing")
continue
result_dict, save_path = self.inference(image_path)
self.result_publish(chan, result_dict)
self.upload_to_database(result_dict)
self.image_upload_to_cloud(save_path)
# read image file
image = cv2.imread(download_path)
# get file name
image_name = download_path.split('/')[-1].split('.')[0]
# run inference
result_image, count, sizes = self.inference(image)
# delete original file
os.remove(download_path)
# save reulst image
result_filename = f'result_{image_name}.jpg'
plt.imsave(os.path.join(save_path, result_filename), result_image)
# message contents set-up
Task_data['Type']="inference_result"
Task_data['result']={
'timestamp':int(time.time()*1000),
'count':count,
'sizes':sizes,
'filename':result_filename
}
Task_data['from']=Task_data['to']
Task_data.pop('to',None)
# send message to AMQP Result Queue
self.result_publish(chan, amqp_message_properties,Task_data)
time.sleep(1)
else:
time.sleep(0.5)
except Exception as e:
print(e)
print(traceback.format_exc())
conn.close()

View File

@ -2,7 +2,8 @@
Description=inference processor
[Service]
ExecStart=/home/sdt/miniconda3/bin/python /home/sdt/Workspace/gseps/rabbitmq_test/inference.py
ExecStart=/home/sdt/miniconda3/bin/python /home/sdt/Workspace/gseps/inference/inference.py
WorkingDirectory=/home/sdt/Workspace/gseps/inference/
Restart=on-failure
Group=sdt
User=sdt