FIX: function rule
This commit is contained in:
parent
5536ec473d
commit
b09f485c90
|
@ -1,13 +1,21 @@
|
|||
{
|
||||
"slack_url": "https://hooks.slack.com/services/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",
|
||||
"slack_message": "Inference Done.",
|
||||
"amqp_url": "localhost",
|
||||
"amqp_port": 5672,
|
||||
"slack_url": "https://hooks.slack.com/services/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",
|
||||
"slack_message": "Queue is Empty! Restart after 100 second",
|
||||
"amqp_url": "13.209.39.139",
|
||||
"amqp_port": 30747,
|
||||
"amqp_vhost": "/",
|
||||
"amqp_id": "worker",
|
||||
"amqp_pw": "gseps1234",
|
||||
"amqp_taskq": "TaskQ",
|
||||
"amqp_resultq": "ResultQ",
|
||||
"amqp_id": "sdt",
|
||||
"amqp_pw": "251327",
|
||||
"amqp_TaskQ": "gseps-mq",
|
||||
"amqp_ResultQ": "gseps-ResultQ",
|
||||
"amqp_message_expire_time": "300000",
|
||||
"Minio_url": "http://13.209.39.139:31191",
|
||||
"AccessKey":"VV2gooVNevRAIg7HrXQr",
|
||||
"SecretKey":"epJmFWxwfzUUgYeyDqLa8ouitHZaWTwAvPfPNUBL",
|
||||
"Boto3SignatureVersion":"s3v4",
|
||||
"Boto3RegionName":"us-east-1",
|
||||
"BucketName":"gseps-test-a",
|
||||
"download_data_path": "./data",
|
||||
"model_config": {
|
||||
"points_per_side": 36,
|
||||
"pred_iou_thresh": 0.86,
|
||||
|
@ -19,10 +27,10 @@
|
|||
"area_thresh": 0.1,
|
||||
"device": "cuda"
|
||||
},
|
||||
"model_checkpoints": "/home/sdt/Workspace/gseps_inference/weights/sam_vit_h_4b8939.pth",
|
||||
"model_checkpoints": "/home/sdt/Workspace/gseps/inference/weights/sam_vit_h_4b8939.pth",
|
||||
"remote_server_ip": "25.15.14.31",
|
||||
"remote_server_id": "sdt",
|
||||
"remote_server_pw": "251327",
|
||||
"copied_image_path_from_remote_server": "/home/sdt/Workspace/gseps_inference/image_bucket",
|
||||
"inference_result_path": "/home/sdt/Workspace/gseps_inference/inference_result/"
|
||||
"copied_image_path_from_remote_server": "/home/sdt/Workspace/gseps/inference/image_bucket",
|
||||
"inference_result_path": "/home/sdt/Workspace/gseps/inference/result/"
|
||||
}
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
Description=FTP Server
|
||||
|
||||
[Service]
|
||||
ExecStart=/home/sdt/miniconda3/bin/python /home/sdt/Workspace/gseps/ftp_server.py
|
||||
ExecStart=/home/sdt/miniconda3/bin/python /home/sdt/Workspace/gseps_inference/ftp_server.py
|
||||
Restart=on-failure
|
||||
Group=sdt
|
||||
User=sdt
|
||||
|
|
124
inference.py
124
inference.py
|
@ -3,24 +3,57 @@ import time
|
|||
import json
|
||||
import requests
|
||||
import traceback
|
||||
import logging
|
||||
import logging.handlers
|
||||
|
||||
import cv2
|
||||
import pika
|
||||
import boto3
|
||||
from botocore.client import Config
|
||||
import paramiko
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
|
||||
|
||||
###############################################
|
||||
# Config #
|
||||
###############################################
|
||||
with open('/home/sdt/Workspace/gseps_inference/config/inference_config.json', 'r') as f:
|
||||
with open(os.path.join(os.getcwd(),'config/inference_config.json'), 'r') as f:
|
||||
info = json.load(f)
|
||||
|
||||
###############################################
|
||||
# Logger Setting #
|
||||
###############################################
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
|
||||
log_fileHandler = logging.handlers.RotatingFileHandler(
|
||||
filename=os.path.join(os.getcwd(), "log_inference.log"),
|
||||
maxBytes=1024000,
|
||||
backupCount=3,
|
||||
mode='a')
|
||||
|
||||
log_fileHandler.setFormatter(formatter)
|
||||
logger.addHandler(log_fileHandler)
|
||||
|
||||
logging.getLogger("pika").setLevel(logging.WARNING)
|
||||
|
||||
################################################################################
|
||||
# S3 Set up #
|
||||
################################################################################
|
||||
s3 = boto3.resource('s3',
|
||||
endpoint_url = info['Minio_url'],
|
||||
aws_access_key_id=info['AccessKey'],
|
||||
aws_secret_access_key=info['SecretKey'],
|
||||
config=Config(signature_version=info['Boto3SignatureVersion']),
|
||||
region_name=info['Boto3RegionName']
|
||||
)
|
||||
|
||||
|
||||
def send_message_to_slack():
|
||||
data = {"text": info['slack_message']}
|
||||
|
@ -89,8 +122,8 @@ class Consumer:
|
|||
self.__port = info['amqp_port']
|
||||
self.__vhost = info['amqp_vhost']
|
||||
self.__cred = pika.PlainCredentials(info['amqp_id'], info['amqp_pw'])
|
||||
self.__TaskQ = info['amqp_taskq']
|
||||
self.__ResultQ = info['amqp_resultq']
|
||||
self.__TaskQ = info['amqp_TaskQ']
|
||||
self.__ResultQ = info['amqp_ResultQ']
|
||||
|
||||
self.cfg = info['model_config']
|
||||
|
||||
|
@ -127,44 +160,31 @@ class Consumer:
|
|||
elif self.cloud_vender == 'NCP':
|
||||
self.image_upload_to_ncp(path)
|
||||
|
||||
def inference(self, image_path):
|
||||
image = cv2.imread(image_path)
|
||||
image_name = image_path.split('/')[-1].split('.')[0]
|
||||
def inference(self, image):
|
||||
|
||||
result = self.mask_generator.generate(image)
|
||||
shape = result[0]['segmentation'].shape
|
||||
cumulated = np.zeros(shape)
|
||||
result_image = np.zeros(shape)
|
||||
count = 0
|
||||
sizes = []
|
||||
|
||||
for n, r in enumerate(result):
|
||||
if r['area'] < self.cfg['area_thresh'] * shape[0] * shape[1] and r['stability_score'] > self.cfg['stability_score_thresh']:
|
||||
if np.amax(cumulated + r['segmentation'].astype(int)) < 2:
|
||||
cumulated = cumulated + r['segmentation'].astype(int)
|
||||
if np.amax(result_image + r['segmentation'].astype(int)) < 2:
|
||||
result_image = result_image + r['segmentation'].astype(int)
|
||||
count += 1
|
||||
x, y, w, h = r['bbox']
|
||||
sizes.append(np.mean([w, h]))
|
||||
|
||||
save_path = os.path.join(info['inference_result_path'], image_name)
|
||||
if not os.path.exists(save_path):
|
||||
os.makedirs(save_path)
|
||||
return result_image, count, sizes
|
||||
|
||||
cv2.imwrite(os.path.join(save_path, f'{image_name}.jpg'), image)
|
||||
plt.imsave(os.path.join(save_path, f'result_{image_name}.jpg'), cumulated)
|
||||
|
||||
result_dict = {'image': os.path.join(save_path, f'{image_name}.jpg'), # 수정필요
|
||||
'count': count,
|
||||
'sizes': sizes}
|
||||
|
||||
with open(os.path.join(save_path, f'result_{image_name}.json'), 'w') as f:
|
||||
json.dump(result_dict, f, indent=4)
|
||||
|
||||
return result_dict, save_path
|
||||
|
||||
def result_publish(self, channel, result):
|
||||
def result_publish(self, channel, properties, result):
|
||||
channel.basic_publish(exchange='',
|
||||
routing_key=self.__ResultQ,
|
||||
body=json.dumps(result))
|
||||
body=json.dumps(result),
|
||||
properties=properties
|
||||
)
|
||||
|
||||
print(f"Done!")
|
||||
|
||||
|
@ -185,32 +205,66 @@ class Consumer:
|
|||
chan.queue_declare(queue=self.__TaskQ, durable=True)
|
||||
chan.queue_declare(queue=self.__ResultQ, durable=True)
|
||||
|
||||
|
||||
save_path = info['inference_result_path']
|
||||
|
||||
while True:
|
||||
method, properties, body = chan.basic_get(queue=self.__TaskQ,
|
||||
auto_ack=True)
|
||||
|
||||
amqp_message_properties = pika.BasicProperties(expiration=info['amqp_message_expire_time'])
|
||||
|
||||
# if Queue is empty
|
||||
if not method:
|
||||
send_message_to_slack()
|
||||
break
|
||||
logger.info("Empty Queue sleep for 100s")
|
||||
time.sleep(100)
|
||||
|
||||
if method:
|
||||
print(f" [x] Received {body}", end=' | ', flush=True)
|
||||
edge_image_path = get_info(body)
|
||||
image_path = image_copy_using_SCP(edge_image_path)
|
||||
|
||||
if image_path == "":
|
||||
logger.info(f" [x] Received {body}")
|
||||
Task_data = json.loads(body)
|
||||
if(Task_data.get("to")):
|
||||
download_path = os.path.join(os.getcwd(),info['download_data_path'],Task_data['to']['filename'])
|
||||
s3.Bucket(Task_data['to']['Bucket']).download_file(Task_data['to']['filename'],download_path)
|
||||
else:
|
||||
logger.info("Check Message Data. key 'to' is missing")
|
||||
continue
|
||||
|
||||
result_dict, save_path = self.inference(image_path)
|
||||
self.result_publish(chan, result_dict)
|
||||
self.upload_to_database(result_dict)
|
||||
self.image_upload_to_cloud(save_path)
|
||||
# read image file
|
||||
image = cv2.imread(download_path)
|
||||
|
||||
# get file name
|
||||
image_name = download_path.split('/')[-1].split('.')[0]
|
||||
|
||||
# run inference
|
||||
result_image, count, sizes = self.inference(image)
|
||||
|
||||
# delete original file
|
||||
os.remove(download_path)
|
||||
|
||||
# save reulst image
|
||||
result_filename = f'result_{image_name}.jpg'
|
||||
plt.imsave(os.path.join(save_path, result_filename), result_image)
|
||||
|
||||
# message contents set-up
|
||||
Task_data['Type']="inference_result"
|
||||
Task_data['result']={
|
||||
'timestamp':int(time.time()*1000),
|
||||
'count':count,
|
||||
'sizes':sizes,
|
||||
'filename':result_filename
|
||||
}
|
||||
Task_data['from']=Task_data['to']
|
||||
Task_data.pop('to',None)
|
||||
|
||||
# send message to AMQP Result Queue
|
||||
self.result_publish(chan, amqp_message_properties,Task_data)
|
||||
time.sleep(1)
|
||||
else:
|
||||
time.sleep(0.5)
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print(traceback.format_exc())
|
||||
conn.close()
|
||||
|
||||
|
|
|
@ -2,7 +2,8 @@
|
|||
Description=inference processor
|
||||
|
||||
[Service]
|
||||
ExecStart=/home/sdt/miniconda3/bin/python /home/sdt/Workspace/gseps/rabbitmq_test/inference.py
|
||||
ExecStart=/home/sdt/miniconda3/bin/python /home/sdt/Workspace/gseps/inference/inference.py
|
||||
WorkingDirectory=/home/sdt/Workspace/gseps/inference/
|
||||
Restart=on-failure
|
||||
Group=sdt
|
||||
User=sdt
|
||||
|
|
Loading…
Reference in New Issue