gseps_inference/inference.py

222 lines
7.4 KiB
Python

import os
import time
import json
import requests
import traceback
import cv2
import pika
import boto3
import paramiko
import torch
import numpy as np
import matplotlib.pyplot as plt
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
###############################################
# Config #
###############################################
with open('/home/sdt/Workspace/gseps_inference/config/inference_config.json', 'r') as f:
info = json.load(f)
def send_message_to_slack():
data = {"text": info['slack_message']}
req = requests.post(
url=info['slack_url'],
data=json.dumps(data)
)
def image_upload_to_s3(path):
try:
s3 = boto3.client('s3')
folder_name = path.split('/')[-1]
for root, dirs, files in os.walk(path):
for f in files:
local_file_path = os.path.join(root, f)
s3.upload_file(local_file_path, 'gseps-data', f'{folder_name}/{f}')
except Exception as e:
print(traceback.format_exc())
def image_copy_using_SCP(remote_path):
try:
# ssh connect
ssh = paramiko.SSHClient()
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
ssh.connect(info['remote_server_ip'], username=info['remote_server_id'], password=info['remote_server_pw'])
# copy image from edge
sftp = ssh.open_sftp()
image_name = remote_path.split('/')[-1]
local_path = info['copied_image_path_from_remote_server']
if not os.path.exists(local_path):
os.makedirs(local_path)
sftp.get(remote_path, os.path.join(local_path, image_name))
sftp.remove(remote_path)
sftp.close()
return os.path.join(local_path, image_name)
except Exception as e:
print(e)
return ""
def get_info(body):
payload = json.loads(body)
image_path = payload['image_path']
return image_path
class Consumer:
def __init__(self):
# with open("./config.json", "r") as f:
# info = json.load(f.read())
# change hard coding to info_dictionary from config.json
self.__url = info['amqp_url']
self.__port = info['amqp_port']
self.__vhost = info['amqp_vhost']
self.__cred = pika.PlainCredentials(info['amqp_id'], info['amqp_pw'])
self.__TaskQ = info['amqp_taskq']
self.__ResultQ = info['amqp_resultq']
self.cfg = info['model_config']
self.mask_generator = None
# self.cloud_vender = info['cloud_vender']
self.cloud_vender = 'S3'
def initialize(self):
start = time.time()
sam = sam_model_registry['vit_h'](
checkpoint=info['model_checkpoints'])
sam = sam.to(self.cfg['device'])
self.mask_generator = SamAutomaticMaskGenerator(model=sam,
points_per_side=self.cfg['points_per_side'],
pred_iou_thresh=self.cfg['pred_iou_thresh'],
stability_score_thresh=self.cfg['stability_score_thresh'],
crop_n_layers=self.cfg['crop_n_layers'],
crop_n_points_downscale_factor=self.cfg[
'crop_n_points_downscale_factor'],
box_nms_thresh=self.cfg['box_nms_thresh'],
min_mask_region_area=self.cfg['min_mask_region_area'])
end = time.time()
print(f'Initialize time: {(end - start) // 60}m {(end - start) % 60:4f}s')
def image_upload_to_ncp(self, path):
pass
def image_upload_to_cloud(self, path):
if self.cloud_vender == 'S3':
image_upload_to_s3(path)
elif self.cloud_vender == 'NCP':
self.image_upload_to_ncp(path)
def inference(self, image_path):
image = cv2.imread(image_path)
image_name = image_path.split('/')[-1].split('.')[0]
result = self.mask_generator.generate(image)
shape = result[0]['segmentation'].shape
cumulated = np.zeros(shape)
count = 0
sizes = []
for n, r in enumerate(result):
if r['area'] < self.cfg['area_thresh'] * shape[0] * shape[1] and r['stability_score'] > self.cfg['stability_score_thresh']:
if np.amax(cumulated + r['segmentation'].astype(int)) < 2:
cumulated = cumulated + r['segmentation'].astype(int)
count += 1
x, y, w, h = r['bbox']
sizes.append(np.mean([w, h]))
save_path = os.path.join(info['inference_result_path'], image_name)
if not os.path.exists(save_path):
os.makedirs(save_path)
cv2.imwrite(os.path.join(save_path, f'{image_name}.jpg'), image)
plt.imsave(os.path.join(save_path, f'result_{image_name}.jpg'), cumulated)
result_dict = {'image': os.path.join(save_path, f'{image_name}.jpg'), # 수정필요
'count': count,
'sizes': sizes}
with open(os.path.join(save_path, f'result_{image_name}.json'), 'w') as f:
json.dump(result_dict, f, indent=4)
return result_dict, save_path
def result_publish(self, channel, result):
channel.basic_publish(exchange='',
routing_key=self.__ResultQ,
body=json.dumps(result))
print(f"Done!")
def upload_to_database(self, results):
"""
Insert inference result data into DB(postgreSQL).
Columns consist of file_name(or image_name), count.
"""
pass
def main(self):
try:
conn = pika.BlockingConnection(pika.ConnectionParameters(self.__url,
self.__port,
self.__vhost,
self.__cred))
chan = conn.channel()
chan.queue_declare(queue=self.__TaskQ, durable=True)
chan.queue_declare(queue=self.__ResultQ, durable=True)
while True:
method, properties, body = chan.basic_get(queue=self.__TaskQ,
auto_ack=True)
if not method:
send_message_to_slack()
break
if method:
print(f" [x] Received {body}", end=' | ', flush=True)
edge_image_path = get_info(body)
image_path = image_copy_using_SCP(edge_image_path)
if image_path == "":
continue
result_dict, save_path = self.inference(image_path)
self.result_publish(chan, result_dict)
self.upload_to_database(result_dict)
self.image_upload_to_cloud(save_path)
time.sleep(1)
else:
time.sleep(0.5)
except Exception as e:
print(traceback.format_exc())
conn.close()
if __name__ == "__main__":
consumer = Consumer()
consumer.initialize()
consumer.main()