FIX: function rule
This commit is contained in:
parent
5536ec473d
commit
b09f485c90
|
@ -1,13 +1,21 @@
|
||||||
{
|
{
|
||||||
"slack_url": "https://hooks.slack.com/services/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",
|
"slack_url": "https://hooks.slack.com/services/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",
|
||||||
"slack_message": "Inference Done.",
|
"slack_message": "Queue is Empty! Restart after 100 second",
|
||||||
"amqp_url": "localhost",
|
"amqp_url": "13.209.39.139",
|
||||||
"amqp_port": 5672,
|
"amqp_port": 30747,
|
||||||
"amqp_vhost": "/",
|
"amqp_vhost": "/",
|
||||||
"amqp_id": "worker",
|
"amqp_id": "sdt",
|
||||||
"amqp_pw": "gseps1234",
|
"amqp_pw": "251327",
|
||||||
"amqp_taskq": "TaskQ",
|
"amqp_TaskQ": "gseps-mq",
|
||||||
"amqp_resultq": "ResultQ",
|
"amqp_ResultQ": "gseps-ResultQ",
|
||||||
|
"amqp_message_expire_time": "300000",
|
||||||
|
"Minio_url": "http://13.209.39.139:31191",
|
||||||
|
"AccessKey":"VV2gooVNevRAIg7HrXQr",
|
||||||
|
"SecretKey":"epJmFWxwfzUUgYeyDqLa8ouitHZaWTwAvPfPNUBL",
|
||||||
|
"Boto3SignatureVersion":"s3v4",
|
||||||
|
"Boto3RegionName":"us-east-1",
|
||||||
|
"BucketName":"gseps-test-a",
|
||||||
|
"download_data_path": "./data",
|
||||||
"model_config": {
|
"model_config": {
|
||||||
"points_per_side": 36,
|
"points_per_side": 36,
|
||||||
"pred_iou_thresh": 0.86,
|
"pred_iou_thresh": 0.86,
|
||||||
|
@ -19,10 +27,10 @@
|
||||||
"area_thresh": 0.1,
|
"area_thresh": 0.1,
|
||||||
"device": "cuda"
|
"device": "cuda"
|
||||||
},
|
},
|
||||||
"model_checkpoints": "/home/sdt/Workspace/gseps_inference/weights/sam_vit_h_4b8939.pth",
|
"model_checkpoints": "/home/sdt/Workspace/gseps/inference/weights/sam_vit_h_4b8939.pth",
|
||||||
"remote_server_ip": "25.15.14.31",
|
"remote_server_ip": "25.15.14.31",
|
||||||
"remote_server_id": "sdt",
|
"remote_server_id": "sdt",
|
||||||
"remote_server_pw": "251327",
|
"remote_server_pw": "251327",
|
||||||
"copied_image_path_from_remote_server": "/home/sdt/Workspace/gseps_inference/image_bucket",
|
"copied_image_path_from_remote_server": "/home/sdt/Workspace/gseps/inference/image_bucket",
|
||||||
"inference_result_path": "/home/sdt/Workspace/gseps_inference/inference_result/"
|
"inference_result_path": "/home/sdt/Workspace/gseps/inference/result/"
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
Description=FTP Server
|
Description=FTP Server
|
||||||
|
|
||||||
[Service]
|
[Service]
|
||||||
ExecStart=/home/sdt/miniconda3/bin/python /home/sdt/Workspace/gseps/ftp_server.py
|
ExecStart=/home/sdt/miniconda3/bin/python /home/sdt/Workspace/gseps_inference/ftp_server.py
|
||||||
Restart=on-failure
|
Restart=on-failure
|
||||||
Group=sdt
|
Group=sdt
|
||||||
User=sdt
|
User=sdt
|
||||||
|
|
124
inference.py
124
inference.py
|
@ -3,24 +3,57 @@ import time
|
||||||
import json
|
import json
|
||||||
import requests
|
import requests
|
||||||
import traceback
|
import traceback
|
||||||
|
import logging
|
||||||
|
import logging.handlers
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import pika
|
import pika
|
||||||
import boto3
|
import boto3
|
||||||
|
from botocore.client import Config
|
||||||
import paramiko
|
import paramiko
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
|
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
|
||||||
|
|
||||||
###############################################
|
###############################################
|
||||||
# Config #
|
# Config #
|
||||||
###############################################
|
###############################################
|
||||||
with open('/home/sdt/Workspace/gseps_inference/config/inference_config.json', 'r') as f:
|
with open(os.path.join(os.getcwd(),'config/inference_config.json'), 'r') as f:
|
||||||
info = json.load(f)
|
info = json.load(f)
|
||||||
|
|
||||||
|
###############################################
|
||||||
|
# Logger Setting #
|
||||||
|
###############################################
|
||||||
|
logger = logging.getLogger()
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||||
|
|
||||||
|
log_fileHandler = logging.handlers.RotatingFileHandler(
|
||||||
|
filename=os.path.join(os.getcwd(), "log_inference.log"),
|
||||||
|
maxBytes=1024000,
|
||||||
|
backupCount=3,
|
||||||
|
mode='a')
|
||||||
|
|
||||||
|
log_fileHandler.setFormatter(formatter)
|
||||||
|
logger.addHandler(log_fileHandler)
|
||||||
|
|
||||||
|
logging.getLogger("pika").setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
################################################################################
|
||||||
|
# S3 Set up #
|
||||||
|
################################################################################
|
||||||
|
s3 = boto3.resource('s3',
|
||||||
|
endpoint_url = info['Minio_url'],
|
||||||
|
aws_access_key_id=info['AccessKey'],
|
||||||
|
aws_secret_access_key=info['SecretKey'],
|
||||||
|
config=Config(signature_version=info['Boto3SignatureVersion']),
|
||||||
|
region_name=info['Boto3RegionName']
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def send_message_to_slack():
|
def send_message_to_slack():
|
||||||
data = {"text": info['slack_message']}
|
data = {"text": info['slack_message']}
|
||||||
|
@ -89,8 +122,8 @@ class Consumer:
|
||||||
self.__port = info['amqp_port']
|
self.__port = info['amqp_port']
|
||||||
self.__vhost = info['amqp_vhost']
|
self.__vhost = info['amqp_vhost']
|
||||||
self.__cred = pika.PlainCredentials(info['amqp_id'], info['amqp_pw'])
|
self.__cred = pika.PlainCredentials(info['amqp_id'], info['amqp_pw'])
|
||||||
self.__TaskQ = info['amqp_taskq']
|
self.__TaskQ = info['amqp_TaskQ']
|
||||||
self.__ResultQ = info['amqp_resultq']
|
self.__ResultQ = info['amqp_ResultQ']
|
||||||
|
|
||||||
self.cfg = info['model_config']
|
self.cfg = info['model_config']
|
||||||
|
|
||||||
|
@ -127,44 +160,31 @@ class Consumer:
|
||||||
elif self.cloud_vender == 'NCP':
|
elif self.cloud_vender == 'NCP':
|
||||||
self.image_upload_to_ncp(path)
|
self.image_upload_to_ncp(path)
|
||||||
|
|
||||||
def inference(self, image_path):
|
def inference(self, image):
|
||||||
image = cv2.imread(image_path)
|
|
||||||
image_name = image_path.split('/')[-1].split('.')[0]
|
|
||||||
|
|
||||||
result = self.mask_generator.generate(image)
|
result = self.mask_generator.generate(image)
|
||||||
shape = result[0]['segmentation'].shape
|
shape = result[0]['segmentation'].shape
|
||||||
cumulated = np.zeros(shape)
|
result_image = np.zeros(shape)
|
||||||
count = 0
|
count = 0
|
||||||
sizes = []
|
sizes = []
|
||||||
|
|
||||||
for n, r in enumerate(result):
|
for n, r in enumerate(result):
|
||||||
if r['area'] < self.cfg['area_thresh'] * shape[0] * shape[1] and r['stability_score'] > self.cfg['stability_score_thresh']:
|
if r['area'] < self.cfg['area_thresh'] * shape[0] * shape[1] and r['stability_score'] > self.cfg['stability_score_thresh']:
|
||||||
if np.amax(cumulated + r['segmentation'].astype(int)) < 2:
|
if np.amax(result_image + r['segmentation'].astype(int)) < 2:
|
||||||
cumulated = cumulated + r['segmentation'].astype(int)
|
result_image = result_image + r['segmentation'].astype(int)
|
||||||
count += 1
|
count += 1
|
||||||
x, y, w, h = r['bbox']
|
x, y, w, h = r['bbox']
|
||||||
sizes.append(np.mean([w, h]))
|
sizes.append(np.mean([w, h]))
|
||||||
|
|
||||||
save_path = os.path.join(info['inference_result_path'], image_name)
|
return result_image, count, sizes
|
||||||
if not os.path.exists(save_path):
|
|
||||||
os.makedirs(save_path)
|
|
||||||
|
|
||||||
cv2.imwrite(os.path.join(save_path, f'{image_name}.jpg'), image)
|
|
||||||
plt.imsave(os.path.join(save_path, f'result_{image_name}.jpg'), cumulated)
|
|
||||||
|
|
||||||
result_dict = {'image': os.path.join(save_path, f'{image_name}.jpg'), # 수정필요
|
def result_publish(self, channel, properties, result):
|
||||||
'count': count,
|
|
||||||
'sizes': sizes}
|
|
||||||
|
|
||||||
with open(os.path.join(save_path, f'result_{image_name}.json'), 'w') as f:
|
|
||||||
json.dump(result_dict, f, indent=4)
|
|
||||||
|
|
||||||
return result_dict, save_path
|
|
||||||
|
|
||||||
def result_publish(self, channel, result):
|
|
||||||
channel.basic_publish(exchange='',
|
channel.basic_publish(exchange='',
|
||||||
routing_key=self.__ResultQ,
|
routing_key=self.__ResultQ,
|
||||||
body=json.dumps(result))
|
body=json.dumps(result),
|
||||||
|
properties=properties
|
||||||
|
)
|
||||||
|
|
||||||
print(f"Done!")
|
print(f"Done!")
|
||||||
|
|
||||||
|
@ -185,32 +205,66 @@ class Consumer:
|
||||||
chan.queue_declare(queue=self.__TaskQ, durable=True)
|
chan.queue_declare(queue=self.__TaskQ, durable=True)
|
||||||
chan.queue_declare(queue=self.__ResultQ, durable=True)
|
chan.queue_declare(queue=self.__ResultQ, durable=True)
|
||||||
|
|
||||||
|
|
||||||
|
save_path = info['inference_result_path']
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
method, properties, body = chan.basic_get(queue=self.__TaskQ,
|
method, properties, body = chan.basic_get(queue=self.__TaskQ,
|
||||||
auto_ack=True)
|
auto_ack=True)
|
||||||
|
|
||||||
|
amqp_message_properties = pika.BasicProperties(expiration=info['amqp_message_expire_time'])
|
||||||
|
|
||||||
|
# if Queue is empty
|
||||||
if not method:
|
if not method:
|
||||||
send_message_to_slack()
|
send_message_to_slack()
|
||||||
break
|
logger.info("Empty Queue sleep for 100s")
|
||||||
|
time.sleep(100)
|
||||||
|
|
||||||
if method:
|
if method:
|
||||||
print(f" [x] Received {body}", end=' | ', flush=True)
|
logger.info(f" [x] Received {body}")
|
||||||
edge_image_path = get_info(body)
|
Task_data = json.loads(body)
|
||||||
image_path = image_copy_using_SCP(edge_image_path)
|
if(Task_data.get("to")):
|
||||||
|
download_path = os.path.join(os.getcwd(),info['download_data_path'],Task_data['to']['filename'])
|
||||||
if image_path == "":
|
s3.Bucket(Task_data['to']['Bucket']).download_file(Task_data['to']['filename'],download_path)
|
||||||
|
else:
|
||||||
|
logger.info("Check Message Data. key 'to' is missing")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
result_dict, save_path = self.inference(image_path)
|
# read image file
|
||||||
self.result_publish(chan, result_dict)
|
image = cv2.imread(download_path)
|
||||||
self.upload_to_database(result_dict)
|
|
||||||
self.image_upload_to_cloud(save_path)
|
|
||||||
|
|
||||||
|
# get file name
|
||||||
|
image_name = download_path.split('/')[-1].split('.')[0]
|
||||||
|
|
||||||
|
# run inference
|
||||||
|
result_image, count, sizes = self.inference(image)
|
||||||
|
|
||||||
|
# delete original file
|
||||||
|
os.remove(download_path)
|
||||||
|
|
||||||
|
# save reulst image
|
||||||
|
result_filename = f'result_{image_name}.jpg'
|
||||||
|
plt.imsave(os.path.join(save_path, result_filename), result_image)
|
||||||
|
|
||||||
|
# message contents set-up
|
||||||
|
Task_data['Type']="inference_result"
|
||||||
|
Task_data['result']={
|
||||||
|
'timestamp':int(time.time()*1000),
|
||||||
|
'count':count,
|
||||||
|
'sizes':sizes,
|
||||||
|
'filename':result_filename
|
||||||
|
}
|
||||||
|
Task_data['from']=Task_data['to']
|
||||||
|
Task_data.pop('to',None)
|
||||||
|
|
||||||
|
# send message to AMQP Result Queue
|
||||||
|
self.result_publish(chan, amqp_message_properties,Task_data)
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
else:
|
else:
|
||||||
time.sleep(0.5)
|
time.sleep(0.5)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,8 @@
|
||||||
Description=inference processor
|
Description=inference processor
|
||||||
|
|
||||||
[Service]
|
[Service]
|
||||||
ExecStart=/home/sdt/miniconda3/bin/python /home/sdt/Workspace/gseps/rabbitmq_test/inference.py
|
ExecStart=/home/sdt/miniconda3/bin/python /home/sdt/Workspace/gseps/inference/inference.py
|
||||||
|
WorkingDirectory=/home/sdt/Workspace/gseps/inference/
|
||||||
Restart=on-failure
|
Restart=on-failure
|
||||||
Group=sdt
|
Group=sdt
|
||||||
User=sdt
|
User=sdt
|
||||||
|
|
Loading…
Reference in New Issue