gseps_inference/inference.py

309 lines
11 KiB
Python

import os
import time
import json
import requests
import traceback
import logging
import logging.handlers
import cv2
import pika
import boto3
from botocore.client import Config
import paramiko
import torch
import numpy as np
import matplotlib.pyplot as plt
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
###############################################
# Config #
###############################################
with open(os.path.join(os.getcwd(),'config/inference_config.json'), 'r') as f:
info = json.load(f)
###############################################
# Logger Setting #
###############################################
logger = logging.getLogger()
logger.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
log_fileHandler = logging.handlers.RotatingFileHandler(
filename=os.path.join(os.getcwd(), "log_inference.log"),
maxBytes=1024000,
backupCount=3,
mode='a')
log_fileHandler.setFormatter(formatter)
logger.addHandler(log_fileHandler)
logging.getLogger("pika").setLevel(logging.WARNING)
################################################################################
# S3 Set up #
################################################################################
s3 = boto3.resource('s3',
endpoint_url = info['Minio_url'],
aws_access_key_id=info['AccessKey'],
aws_secret_access_key=info['SecretKey'],
config=Config(signature_version=info['Boto3SignatureVersion']),
region_name=info['Boto3RegionName']
)
def send_message_to_slack(message):
data = {"text": message}
req = requests.post(
url=info['slack_url'],
data=json.dumps(data)
)
def image_upload_to_s3(path):
try:
s3 = boto3.client('s3')
folder_name = path.split('/')[-1]
for root, dirs, files in os.walk(path):
for f in files:
local_file_path = os.path.join(root, f)
s3.upload_file(local_file_path, 'gseps-data', f'{folder_name}/{f}')
except Exception as e:
print(traceback.format_exc())
def image_copy_using_SCP(remote_path):
try:
# ssh connect
ssh = paramiko.SSHClient()
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
ssh.connect(info['remote_server_ip'], username=info['remote_server_id'], password=info['remote_server_pw'])
# copy image from edge
sftp = ssh.open_sftp()
image_name = remote_path.split('/')[-1]
local_path = info['copied_image_path_from_remote_server']
if not os.path.exists(local_path):
os.makedirs(local_path)
sftp.get(remote_path, os.path.join(local_path, image_name))
sftp.remove(remote_path)
sftp.close()
return os.path.join(local_path, image_name)
except Exception as e:
print(e)
return ""
def get_info(body):
payload = json.loads(body)
image_path = payload['image_path']
return image_path
class Consumer:
def __init__(self):
# with open("./config.json", "r") as f:
# info = json.load(f.read())
# change hard coding to info_dictionary from config.json
self.__url = info['amqp_url']
self.__port = info['amqp_port']
self.__vhost = info['amqp_vhost']
self.__cred = pika.PlainCredentials(info['amqp_id'], info['amqp_pw'])
self.__TaskQ = info['amqp_TaskQ']
self.__ResultQ = info['amqp_ResultQ']
self.cfg = info['model_config']
self.mask_generator = None
# self.cloud_vender = info['cloud_vender']
self.cloud_vender = 'S3'
def initialize(self):
start = time.time()
sam = sam_model_registry['vit_h'](
checkpoint=info['model_checkpoints'])
sam = sam.to(self.cfg['device'])
self.mask_generator = SamAutomaticMaskGenerator(model=sam,
points_per_side=self.cfg['points_per_side'],
pred_iou_thresh=self.cfg['pred_iou_thresh'],
stability_score_thresh=self.cfg['stability_score_thresh'],
crop_n_layers=self.cfg['crop_n_layers'],
crop_n_points_downscale_factor=self.cfg[
'crop_n_points_downscale_factor'],
box_nms_thresh=self.cfg['box_nms_thresh'],
min_mask_region_area=self.cfg['min_mask_region_area'])
end = time.time()
logger.info(f"Initialize {str(info['model_config'])}")
print(f'Initialize time: {(end - start) // 60}m {(end - start) % 60:4f}s')
def image_upload_to_ncp(self, path):
pass
def image_upload_to_cloud(self, path):
if self.cloud_vender == 'S3':
image_upload_to_s3(path)
elif self.cloud_vender == 'NCP':
self.image_upload_to_ncp(path)
def inference(self, image):
result = self.mask_generator.generate(image)
shape = result[0]['segmentation'].shape
result_image = np.zeros(shape)
count = 0
sizes = []
for n, r in enumerate(result):
if r['area'] < self.cfg['area_thresh'] * shape[0] * shape[1] and r['stability_score'] > self.cfg['stability_score_thresh']:
if np.amax(result_image + r['segmentation'].astype(int)) < 2:
result_image = result_image + r['segmentation'].astype(int)
count += 1
x, y, w, h = r['bbox']
sizes.append(np.mean([w, h]))
return result_image, count, sizes
def result_publish(self, channel, properties, result):
channel.basic_publish(exchange='',
routing_key=self.__ResultQ,
body=json.dumps(result),
properties=properties
)
print(f"Done!")
def upload_to_database(self, results):
"""
Insert inference result data into DB(postgreSQL).
Columns consist of file_name(or image_name), count.
"""
pass
def main(self):
try:
conn = pika.BlockingConnection(pika.ConnectionParameters(self.__url,
self.__port,
self.__vhost,
self.__cred))
chan = conn.channel()
chan.queue_declare(queue=self.__TaskQ, durable=True)
chan.queue_declare(queue=self.__ResultQ, durable=True)
save_path = info['inference_result_path']
while True:
try:
method, properties, body = chan.basic_get(queue=self.__TaskQ,
auto_ack=True)
amqp_message_properties = pika.BasicProperties(expiration=info['amqp_message_expire_time'])
# if Queue is empty
if not method:
send_message_to_slack("Empty Queue")
logger.info(f"Empty Queue sleep for {info['amqp_Q_check_interval']}")
time.sleep(info['amqp_Q_check_interval'])
if method:
logger.info(f" [x] Received {body}")
Task_data = json.loads(body)
if(Task_data.get("to")):
download_path = os.path.join(os.getcwd(),info['download_data_path'],Task_data['to']['filename'])
s3.Bucket(Task_data['to']['bucket']).download_file(Task_data['to']['filename'],download_path)
else:
logger.info("Check Message Data. key 'to' is missing")
continue
# read image file
image = cv2.imread(download_path)
# get file name
image_name = download_path.split('/')[-1].split('.')[0]
# run inference
result_image, count, sizes = self.inference(image)
# delete original file
os.remove(download_path)
logger.info(f" len(sizes) : {len(sizes)}")
if(len(sizes) < 30):
logger.info("PASS")
continue
# save reulst image
result_filename = f'result_{image_name}.jpg'
plt.imsave(os.path.join(save_path, result_filename), result_image)
# calculation for histogram
fig, axs = plt.subplots(1,1)
n_bins = info['n_bins_for_histogram']
xy = axs.hist(sizes, bins=n_bins)
print(xy)
y = xy[0].astype(np.int32)
x = xy[1]
hist_y = y.tolist()
hist_x = x.tolist()
str_x = []
for single_x in hist_x:
str_x.append(str(single_x))
hist_x = str_x.copy()
# message contents set-up
Task_data['type']="inference_result"
Task_data['result']={
'timestamp':int(time.time()*1000),
'count':count,
'sizes':sizes,
'filename':result_filename,
'hist':{
'x':hist_x,
'y':hist_y
}
}
Task_data['from']=Task_data['to']
Task_data.pop('to',None)
# send message to AMQP Result Queue
self.result_publish(chan, amqp_message_properties,Task_data)
time.sleep(1)
else:
time.sleep(0.5)
except Exception as e:
print(traceback.format_exc())
send_message_to_slack(f"who : inference_server // error {str(e)}")
logger.error(str(e))
continue
except Exception as e:
print(e)
print(traceback.format_exc())
send_message_to_slack(f"who : inference_server // error {str(e)}")
logger.error(str(e))
conn.close()
if __name__ == "__main__":
consumer = Consumer()
consumer.initialize()
consumer.main()