gseps_inference/inference.py

222 lines
7.4 KiB
Python
Raw Normal View History

2023-08-31 09:32:31 +00:00
import os
import time
import json
import requests
import traceback
import cv2
import pika
import boto3
import paramiko
import torch
import numpy as np
import matplotlib.pyplot as plt
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
###############################################
# Config #
###############################################
2023-09-01 05:54:57 +00:00
with open('/home/sdt/Workspace/gseps_inference/config/inference_config.json', 'r') as f:
2023-08-31 09:32:31 +00:00
info = json.load(f)
def send_message_to_slack():
data = {"text": info['slack_message']}
req = requests.post(
url=info['slack_url'],
data=json.dumps(data)
)
def image_upload_to_s3(path):
try:
s3 = boto3.client('s3')
folder_name = path.split('/')[-1]
for root, dirs, files in os.walk(path):
for f in files:
local_file_path = os.path.join(root, f)
s3.upload_file(local_file_path, 'gseps-data', f'{folder_name}/{f}')
except Exception as e:
print(traceback.format_exc())
def image_copy_using_SCP(remote_path):
try:
# ssh connect
ssh = paramiko.SSHClient()
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
ssh.connect(info['remote_server_ip'], username=info['remote_server_id'], password=info['remote_server_pw'])
# copy image from edge
sftp = ssh.open_sftp()
image_name = remote_path.split('/')[-1]
local_path = info['copied_image_path_from_remote_server']
if not os.path.exists(local_path):
os.makedirs(local_path)
sftp.get(remote_path, os.path.join(local_path, image_name))
sftp.remove(remote_path)
sftp.close()
return os.path.join(local_path, image_name)
except Exception as e:
print(e)
return ""
def get_info(body):
payload = json.loads(body)
image_path = payload['image_path']
return image_path
class Consumer:
def __init__(self):
# with open("./config.json", "r") as f:
# info = json.load(f.read())
# change hard coding to info_dictionary from config.json
self.__url = info['amqp_url']
self.__port = info['amqp_port']
self.__vhost = info['amqp_vhost']
self.__cred = pika.PlainCredentials(info['amqp_id'], info['amqp_pw'])
self.__TaskQ = info['amqp_taskq']
self.__ResultQ = info['amqp_resultq']
self.cfg = info['model_config']
self.mask_generator = None
# self.cloud_vender = info['cloud_vender']
self.cloud_vender = 'S3'
def initialize(self):
start = time.time()
sam = sam_model_registry['vit_h'](
checkpoint=info['model_checkpoints'])
sam = sam.to(self.cfg['device'])
self.mask_generator = SamAutomaticMaskGenerator(model=sam,
points_per_side=self.cfg['points_per_side'],
pred_iou_thresh=self.cfg['pred_iou_thresh'],
stability_score_thresh=self.cfg['stability_score_thresh'],
crop_n_layers=self.cfg['crop_n_layers'],
crop_n_points_downscale_factor=self.cfg[
'crop_n_points_downscale_factor'],
box_nms_thresh=self.cfg['box_nms_thresh'],
min_mask_region_area=self.cfg['min_mask_region_area'])
end = time.time()
print(f'Initialize time: {(end - start) // 60}m {(end - start) % 60:4f}s')
def image_upload_to_ncp(self, path):
pass
def image_upload_to_cloud(self, path):
if self.cloud_vender == 'S3':
image_upload_to_s3(path)
elif self.cloud_vender == 'NCP':
self.image_upload_to_ncp(path)
def inference(self, image_path):
image = cv2.imread(image_path)
image_name = image_path.split('/')[-1].split('.')[0]
result = self.mask_generator.generate(image)
shape = result[0]['segmentation'].shape
cumulated = np.zeros(shape)
count = 0
sizes = []
for n, r in enumerate(result):
if r['area'] < self.cfg['area_thresh'] * shape[0] * shape[1] and r['stability_score'] > self.cfg['stability_score_thresh']:
if np.amax(cumulated + r['segmentation'].astype(int)) < 2:
cumulated = cumulated + r['segmentation'].astype(int)
count += 1
x, y, w, h = r['bbox']
sizes.append(np.mean([w, h]))
save_path = os.path.join(info['inference_result_path'], image_name)
if not os.path.exists(save_path):
os.makedirs(save_path)
cv2.imwrite(os.path.join(save_path, f'{image_name}.jpg'), image)
plt.imsave(os.path.join(save_path, f'result_{image_name}.jpg'), cumulated)
result_dict = {'image': os.path.join(save_path, f'{image_name}.jpg'), # 수정필요
'count': count,
'sizes': sizes}
with open(os.path.join(save_path, f'result_{image_name}.json'), 'w') as f:
json.dump(result_dict, f, indent=4)
return result_dict, save_path
def result_publish(self, channel, result):
channel.basic_publish(exchange='',
routing_key=self.__ResultQ,
body=json.dumps(result))
print(f"Done!")
def upload_to_database(self, results):
"""
Insert inference result data into DB(postgreSQL).
Columns consist of file_name(or image_name), count.
"""
pass
def main(self):
try:
conn = pika.BlockingConnection(pika.ConnectionParameters(self.__url,
self.__port,
self.__vhost,
self.__cred))
chan = conn.channel()
chan.queue_declare(queue=self.__TaskQ, durable=True)
chan.queue_declare(queue=self.__ResultQ, durable=True)
while True:
method, properties, body = chan.basic_get(queue=self.__TaskQ,
auto_ack=True)
if not method:
send_message_to_slack()
break
if method:
print(f" [x] Received {body}", end=' | ', flush=True)
edge_image_path = get_info(body)
image_path = image_copy_using_SCP(edge_image_path)
if image_path == "":
continue
result_dict, save_path = self.inference(image_path)
self.result_publish(chan, result_dict)
self.upload_to_database(result_dict)
self.image_upload_to_cloud(save_path)
time.sleep(1)
else:
time.sleep(0.5)
except Exception as e:
print(traceback.format_exc())
conn.close()
if __name__ == "__main__":
consumer = Consumer()
consumer.initialize()
consumer.main()