309 lines
11 KiB
Python
309 lines
11 KiB
Python
import os
|
|
import time
|
|
import json
|
|
import requests
|
|
import traceback
|
|
import logging
|
|
import logging.handlers
|
|
|
|
import cv2
|
|
import pika
|
|
import boto3
|
|
from botocore.client import Config
|
|
import paramiko
|
|
|
|
import torch
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
|
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
|
|
|
|
###############################################
|
|
# Config #
|
|
###############################################
|
|
with open(os.path.join(os.getcwd(),'config/inference_config.json'), 'r') as f:
|
|
info = json.load(f)
|
|
|
|
###############################################
|
|
# Logger Setting #
|
|
###############################################
|
|
logger = logging.getLogger()
|
|
logger.setLevel(logging.INFO)
|
|
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
|
|
log_fileHandler = logging.handlers.RotatingFileHandler(
|
|
filename=os.path.join(os.getcwd(), "log_inference.log"),
|
|
maxBytes=1024000,
|
|
backupCount=3,
|
|
mode='a')
|
|
|
|
log_fileHandler.setFormatter(formatter)
|
|
logger.addHandler(log_fileHandler)
|
|
|
|
logging.getLogger("pika").setLevel(logging.WARNING)
|
|
|
|
################################################################################
|
|
# S3 Set up #
|
|
################################################################################
|
|
s3 = boto3.resource('s3',
|
|
endpoint_url = info['Minio_url'],
|
|
aws_access_key_id=info['AccessKey'],
|
|
aws_secret_access_key=info['SecretKey'],
|
|
config=Config(signature_version=info['Boto3SignatureVersion']),
|
|
region_name=info['Boto3RegionName']
|
|
)
|
|
|
|
|
|
def send_message_to_slack(message):
|
|
data = {"text": message}
|
|
req = requests.post(
|
|
url=info['slack_url'],
|
|
data=json.dumps(data)
|
|
)
|
|
|
|
|
|
def image_upload_to_s3(path):
|
|
try:
|
|
s3 = boto3.client('s3')
|
|
|
|
folder_name = path.split('/')[-1]
|
|
for root, dirs, files in os.walk(path):
|
|
for f in files:
|
|
local_file_path = os.path.join(root, f)
|
|
s3.upload_file(local_file_path, 'gseps-data', f'{folder_name}/{f}')
|
|
except Exception as e:
|
|
print(traceback.format_exc())
|
|
|
|
|
|
def image_copy_using_SCP(remote_path):
|
|
try:
|
|
# ssh connect
|
|
ssh = paramiko.SSHClient()
|
|
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
|
ssh.connect(info['remote_server_ip'], username=info['remote_server_id'], password=info['remote_server_pw'])
|
|
|
|
# copy image from edge
|
|
sftp = ssh.open_sftp()
|
|
|
|
image_name = remote_path.split('/')[-1]
|
|
local_path = info['copied_image_path_from_remote_server']
|
|
|
|
if not os.path.exists(local_path):
|
|
os.makedirs(local_path)
|
|
|
|
sftp.get(remote_path, os.path.join(local_path, image_name))
|
|
sftp.remove(remote_path)
|
|
|
|
sftp.close()
|
|
|
|
return os.path.join(local_path, image_name)
|
|
except Exception as e:
|
|
print(e)
|
|
return ""
|
|
|
|
|
|
def get_info(body):
|
|
payload = json.loads(body)
|
|
image_path = payload['image_path']
|
|
|
|
return image_path
|
|
|
|
|
|
class Consumer:
|
|
def __init__(self):
|
|
|
|
# with open("./config.json", "r") as f:
|
|
# info = json.load(f.read())
|
|
|
|
# change hard coding to info_dictionary from config.json
|
|
|
|
self.__url = info['amqp_url']
|
|
self.__port = info['amqp_port']
|
|
self.__vhost = info['amqp_vhost']
|
|
self.__cred = pika.PlainCredentials(info['amqp_id'], info['amqp_pw'])
|
|
self.__TaskQ = info['amqp_TaskQ']
|
|
self.__ResultQ = info['amqp_ResultQ']
|
|
|
|
self.cfg = info['model_config']
|
|
|
|
self.mask_generator = None
|
|
|
|
# self.cloud_vender = info['cloud_vender']
|
|
self.cloud_vender = 'S3'
|
|
|
|
def initialize(self):
|
|
start = time.time()
|
|
|
|
sam = sam_model_registry['vit_h'](
|
|
checkpoint=info['model_checkpoints'])
|
|
sam = sam.to(self.cfg['device'])
|
|
|
|
self.mask_generator = SamAutomaticMaskGenerator(model=sam,
|
|
points_per_side=self.cfg['points_per_side'],
|
|
pred_iou_thresh=self.cfg['pred_iou_thresh'],
|
|
stability_score_thresh=self.cfg['stability_score_thresh'],
|
|
crop_n_layers=self.cfg['crop_n_layers'],
|
|
crop_n_points_downscale_factor=self.cfg[
|
|
'crop_n_points_downscale_factor'],
|
|
box_nms_thresh=self.cfg['box_nms_thresh'],
|
|
min_mask_region_area=self.cfg['min_mask_region_area'])
|
|
end = time.time()
|
|
logger.info(f"Initialize {str(info['model_config'])}")
|
|
print(f'Initialize time: {(end - start) // 60}m {(end - start) % 60:4f}s')
|
|
|
|
def image_upload_to_ncp(self, path):
|
|
pass
|
|
|
|
def image_upload_to_cloud(self, path):
|
|
if self.cloud_vender == 'S3':
|
|
image_upload_to_s3(path)
|
|
elif self.cloud_vender == 'NCP':
|
|
self.image_upload_to_ncp(path)
|
|
|
|
def inference(self, image):
|
|
|
|
result = self.mask_generator.generate(image)
|
|
shape = result[0]['segmentation'].shape
|
|
result_image = np.zeros(shape)
|
|
count = 0
|
|
sizes = []
|
|
|
|
for n, r in enumerate(result):
|
|
if r['area'] < self.cfg['area_thresh'] * shape[0] * shape[1] and r['stability_score'] > self.cfg['stability_score_thresh']:
|
|
if np.amax(result_image + r['segmentation'].astype(int)) < 2:
|
|
result_image = result_image + r['segmentation'].astype(int)
|
|
count += 1
|
|
x, y, w, h = r['bbox']
|
|
sizes.append(np.mean([w, h]))
|
|
|
|
return result_image, count, sizes
|
|
|
|
|
|
def result_publish(self, channel, properties, result):
|
|
channel.basic_publish(exchange='',
|
|
routing_key=self.__ResultQ,
|
|
body=json.dumps(result),
|
|
properties=properties
|
|
)
|
|
|
|
print(f"Done!")
|
|
|
|
def upload_to_database(self, results):
|
|
"""
|
|
Insert inference result data into DB(postgreSQL).
|
|
Columns consist of file_name(or image_name), count.
|
|
"""
|
|
pass
|
|
|
|
def main(self):
|
|
try:
|
|
conn = pika.BlockingConnection(pika.ConnectionParameters(self.__url,
|
|
self.__port,
|
|
self.__vhost,
|
|
self.__cred))
|
|
chan = conn.channel()
|
|
chan.queue_declare(queue=self.__TaskQ, durable=True)
|
|
chan.queue_declare(queue=self.__ResultQ, durable=True)
|
|
|
|
|
|
save_path = info['inference_result_path']
|
|
|
|
while True:
|
|
try:
|
|
method, properties, body = chan.basic_get(queue=self.__TaskQ,
|
|
auto_ack=True)
|
|
|
|
amqp_message_properties = pika.BasicProperties(expiration=info['amqp_message_expire_time'])
|
|
|
|
# if Queue is empty
|
|
if not method:
|
|
send_message_to_slack("Empty Queue")
|
|
logger.info(f"Empty Queue sleep for {info['amqp_Q_check_interval']}")
|
|
time.sleep(info['amqp_Q_check_interval'])
|
|
|
|
if method:
|
|
logger.info(f" [x] Received {body}")
|
|
Task_data = json.loads(body)
|
|
if(Task_data.get("to")):
|
|
download_path = os.path.join(os.getcwd(),info['download_data_path'],Task_data['to']['filename'])
|
|
s3.Bucket(Task_data['to']['bucket']).download_file(Task_data['to']['filename'],download_path)
|
|
else:
|
|
logger.info("Check Message Data. key 'to' is missing")
|
|
continue
|
|
|
|
# read image file
|
|
image = cv2.imread(download_path)
|
|
|
|
# get file name
|
|
image_name = download_path.split('/')[-1].split('.')[0]
|
|
|
|
# run inference
|
|
result_image, count, sizes = self.inference(image)
|
|
|
|
# delete original file
|
|
os.remove(download_path)
|
|
logger.info(f" len(sizes) : {len(sizes)}")
|
|
|
|
if(len(sizes) < 30):
|
|
logger.info("PASS")
|
|
continue
|
|
|
|
|
|
# save reulst image
|
|
result_filename = f'result_{image_name}.jpg'
|
|
plt.imsave(os.path.join(save_path, result_filename), result_image)
|
|
|
|
# calculation for histogram
|
|
fig, axs = plt.subplots(1,1)
|
|
n_bins = info['n_bins_for_histogram']
|
|
xy = axs.hist(sizes, bins=n_bins)
|
|
print(xy)
|
|
y = xy[0].astype(np.int32)
|
|
x = xy[1]
|
|
hist_y = y.tolist()
|
|
hist_x = x.tolist()
|
|
str_x = []
|
|
for single_x in hist_x:
|
|
str_x.append(str(single_x))
|
|
hist_x = str_x.copy()
|
|
# message contents set-up
|
|
Task_data['type']="inference_result"
|
|
Task_data['result']={
|
|
'timestamp':int(time.time()*1000),
|
|
'count':count,
|
|
'sizes':sizes,
|
|
'filename':result_filename,
|
|
'hist':{
|
|
'x':hist_x,
|
|
'y':hist_y
|
|
}
|
|
}
|
|
Task_data['from']=Task_data['to']
|
|
Task_data.pop('to',None)
|
|
|
|
# send message to AMQP Result Queue
|
|
self.result_publish(chan, amqp_message_properties,Task_data)
|
|
time.sleep(1)
|
|
else:
|
|
time.sleep(0.5)
|
|
except Exception as e:
|
|
print(traceback.format_exc())
|
|
send_message_to_slack(f"who : inference_server // error {str(e)}")
|
|
logger.error(str(e))
|
|
continue
|
|
|
|
|
|
except Exception as e:
|
|
print(e)
|
|
print(traceback.format_exc())
|
|
send_message_to_slack(f"who : inference_server // error {str(e)}")
|
|
logger.error(str(e))
|
|
conn.close()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
consumer = Consumer()
|
|
consumer.initialize()
|
|
consumer.main()
|