inference server code
This commit is contained in:
commit
46fe2a88f2
|
@ -0,0 +1,7 @@
|
||||||
|
{
|
||||||
|
"ftp_ip": "0.0.0.0",
|
||||||
|
"ftp_port": 21,
|
||||||
|
"ftp_id": "sdt",
|
||||||
|
"ftp_pw": "251327",
|
||||||
|
"ftp_root_dir": "/home/sdt/Workspace/gseps/rabbitmq_test/inference_result/"
|
||||||
|
}
|
|
@ -0,0 +1,28 @@
|
||||||
|
{
|
||||||
|
"slack_url": "https://hooks.slack.com/services/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",
|
||||||
|
"slack_message": "Inference Done.",
|
||||||
|
"amqp_url": "localhost",
|
||||||
|
"amqp_port": 5672,
|
||||||
|
"amqp_vhost": "/",
|
||||||
|
"amqp_id": "worker",
|
||||||
|
"amqp_pw": "gseps1234",
|
||||||
|
"amqp_taskq": "TaskQ",
|
||||||
|
"amqp_resultq": "ResultQ",
|
||||||
|
"model_config": {
|
||||||
|
"points_per_side": 36,
|
||||||
|
"pred_iou_thresh": 0.86,
|
||||||
|
"stability_score_thresh": 0.9,
|
||||||
|
"crop_n_layers": 1,
|
||||||
|
"crop_n_points_downscale_factor": 1,
|
||||||
|
"box_nms_thresh": 0.8,
|
||||||
|
"min_mask_region_area": 10,
|
||||||
|
"area_thresh": 0.1,
|
||||||
|
"device": "cuda"
|
||||||
|
},
|
||||||
|
"model_checkpoints": "/home/sdt/Workspace/gseps/rabbitmq_test/weights/sam_vit_h_4b8939.pth",
|
||||||
|
"remote_server_ip": "25.15.14.31",
|
||||||
|
"remote_server_id": "sdt",
|
||||||
|
"remote_server_pw": "251327",
|
||||||
|
"copied_image_path_from_remote_server": "/home/sdt/Workspace/gseps/rabbitmq_test/image_bucket",
|
||||||
|
"inference_result_path": "/home/sdt/Workspace/gseps/rabbitmq_test/inference_result/"
|
||||||
|
}
|
|
@ -0,0 +1,23 @@
|
||||||
|
import json
|
||||||
|
from pyftpdlib.authorizers import DummyAuthorizer
|
||||||
|
from pyftpdlib.handlers import FTPHandler
|
||||||
|
from pyftpdlib.servers import FTPServer
|
||||||
|
|
||||||
|
# config
|
||||||
|
with open('./config/ftp_config.json', 'r') as f:
|
||||||
|
info = json.load(f)
|
||||||
|
|
||||||
|
# 서버 설정
|
||||||
|
authorizer = DummyAuthorizer()
|
||||||
|
authorizer.add_user(info['ftp_id'],
|
||||||
|
info['ftp_pw'],
|
||||||
|
info['ftp_root_dir'],
|
||||||
|
perm="elradfmw")
|
||||||
|
|
||||||
|
handler = FTPHandler
|
||||||
|
handler.authorizer = authorizer
|
||||||
|
|
||||||
|
# 서버 시작
|
||||||
|
address = (info['ftp_ip'], info['ftp_port']) # 서버 주소와 포트
|
||||||
|
server = FTPServer(address, handler)
|
||||||
|
server.serve_forever()
|
|
@ -0,0 +1,9 @@
|
||||||
|
[Unit]
|
||||||
|
Description=FTP Server
|
||||||
|
|
||||||
|
[Service]
|
||||||
|
ExecStart=/home/sdt/miniconda3/bin/python /home/sdt/Workspace/gseps/ftp_server.py
|
||||||
|
Restart=on-failure
|
||||||
|
|
||||||
|
[Install]
|
||||||
|
WantedBy=multi-user.target
|
|
@ -0,0 +1,221 @@
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import json
|
||||||
|
import requests
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import pika
|
||||||
|
import boto3
|
||||||
|
import paramiko
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
|
||||||
|
|
||||||
|
###############################################
|
||||||
|
# Config #
|
||||||
|
###############################################
|
||||||
|
with open('./config/inference_config.json', 'r') as f:
|
||||||
|
info = json.load(f)
|
||||||
|
|
||||||
|
|
||||||
|
def send_message_to_slack():
|
||||||
|
data = {"text": info['slack_message']}
|
||||||
|
req = requests.post(
|
||||||
|
url=info['slack_url'],
|
||||||
|
data=json.dumps(data)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def image_upload_to_s3(path):
|
||||||
|
try:
|
||||||
|
s3 = boto3.client('s3')
|
||||||
|
|
||||||
|
folder_name = path.split('/')[-1]
|
||||||
|
for root, dirs, files in os.walk(path):
|
||||||
|
for f in files:
|
||||||
|
local_file_path = os.path.join(root, f)
|
||||||
|
s3.upload_file(local_file_path, 'gseps-data', f'{folder_name}/{f}')
|
||||||
|
except Exception as e:
|
||||||
|
print(traceback.format_exc())
|
||||||
|
|
||||||
|
|
||||||
|
def image_copy_using_SCP(remote_path):
|
||||||
|
try:
|
||||||
|
# ssh connect
|
||||||
|
ssh = paramiko.SSHClient()
|
||||||
|
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
||||||
|
ssh.connect(info['remote_server_ip'], username=info['remote_server_id'], password=info['remote_server_pw'])
|
||||||
|
|
||||||
|
# copy image from edge
|
||||||
|
sftp = ssh.open_sftp()
|
||||||
|
|
||||||
|
image_name = remote_path.split('/')[-1]
|
||||||
|
local_path = info['copied_image_path_from_remote_server']
|
||||||
|
|
||||||
|
if not os.path.exists(local_path):
|
||||||
|
os.makedirs(local_path)
|
||||||
|
|
||||||
|
sftp.get(remote_path, os.path.join(local_path, image_name))
|
||||||
|
sftp.remove(remote_path)
|
||||||
|
|
||||||
|
sftp.close()
|
||||||
|
|
||||||
|
return os.path.join(local_path, image_name)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def get_info(body):
|
||||||
|
payload = json.loads(body)
|
||||||
|
image_path = payload['image_path']
|
||||||
|
|
||||||
|
return image_path
|
||||||
|
|
||||||
|
|
||||||
|
class Consumer:
|
||||||
|
def __init__(self):
|
||||||
|
|
||||||
|
# with open("./config.json", "r") as f:
|
||||||
|
# info = json.load(f.read())
|
||||||
|
|
||||||
|
# change hard coding to info_dictionary from config.json
|
||||||
|
|
||||||
|
self.__url = info['amqp_url']
|
||||||
|
self.__port = info['amqp_port']
|
||||||
|
self.__vhost = info['amqp_vhost']
|
||||||
|
self.__cred = pika.PlainCredentials(info['amqp_id'], info['amqp_pw'])
|
||||||
|
self.__TaskQ = info['amqp_taskq']
|
||||||
|
self.__ResultQ = info['amqp_resultq']
|
||||||
|
|
||||||
|
self.cfg = info['model_config']
|
||||||
|
|
||||||
|
self.mask_generator = None
|
||||||
|
|
||||||
|
# self.cloud_vender = info['cloud_vender']
|
||||||
|
self.cloud_vender = 'S3'
|
||||||
|
|
||||||
|
def initialize(self):
|
||||||
|
start = time.time()
|
||||||
|
|
||||||
|
sam = sam_model_registry['vit_h'](
|
||||||
|
checkpoint=info['model_checkpoints'])
|
||||||
|
sam = sam.to(self.cfg['device'])
|
||||||
|
|
||||||
|
self.mask_generator = SamAutomaticMaskGenerator(model=sam,
|
||||||
|
points_per_side=self.cfg['points_per_side'],
|
||||||
|
pred_iou_thresh=self.cfg['pred_iou_thresh'],
|
||||||
|
stability_score_thresh=self.cfg['stability_score_thresh'],
|
||||||
|
crop_n_layers=self.cfg['crop_n_layers'],
|
||||||
|
crop_n_points_downscale_factor=self.cfg[
|
||||||
|
'crop_n_points_downscale_factor'],
|
||||||
|
box_nms_thresh=self.cfg['box_nms_thresh'],
|
||||||
|
min_mask_region_area=self.cfg['min_mask_region_area'])
|
||||||
|
end = time.time()
|
||||||
|
print(f'Initialize time: {(end - start) // 60}m {(end - start) % 60:4f}s')
|
||||||
|
|
||||||
|
def image_upload_to_ncp(self, path):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def image_upload_to_cloud(self, path):
|
||||||
|
if self.cloud_vender == 'S3':
|
||||||
|
image_upload_to_s3(path)
|
||||||
|
elif self.cloud_vender == 'NCP':
|
||||||
|
self.image_upload_to_ncp(path)
|
||||||
|
|
||||||
|
def inference(self, image_path):
|
||||||
|
image = cv2.imread(image_path)
|
||||||
|
image_name = image_path.split('/')[-1].split('.')[0]
|
||||||
|
|
||||||
|
result = self.mask_generator.generate(image)
|
||||||
|
shape = result[0]['segmentation'].shape
|
||||||
|
cumulated = np.zeros(shape)
|
||||||
|
count = 0
|
||||||
|
sizes = []
|
||||||
|
|
||||||
|
for n, r in enumerate(result):
|
||||||
|
if r['area'] < self.cfg['area_thresh'] * shape[0] * shape[1] and r['stability_score'] > self.cfg['stability_score_thresh']:
|
||||||
|
if np.amax(cumulated + r['segmentation'].astype(int)) < 2:
|
||||||
|
cumulated = cumulated + r['segmentation'].astype(int)
|
||||||
|
count += 1
|
||||||
|
x, y, w, h = r['bbox']
|
||||||
|
sizes.append(np.mean([w, h]))
|
||||||
|
|
||||||
|
save_path = os.path.join(info['inference_result_path'], image_name)
|
||||||
|
if not os.path.exists(save_path):
|
||||||
|
os.makedirs(save_path)
|
||||||
|
|
||||||
|
cv2.imwrite(os.path.join(save_path, f'{image_name}.jpg'), image)
|
||||||
|
plt.imsave(os.path.join(save_path, f'result_{image_name}.jpg'), cumulated)
|
||||||
|
|
||||||
|
result_dict = {'image': os.path.join(save_path, f'{image_name}.jpg'), # 수정필요
|
||||||
|
'count': count,
|
||||||
|
'sizes': sizes}
|
||||||
|
|
||||||
|
with open(os.path.join(save_path, f'result_{image_name}.json'), 'w') as f:
|
||||||
|
json.dump(result_dict, f, indent=4)
|
||||||
|
|
||||||
|
return result_dict, save_path
|
||||||
|
|
||||||
|
def result_publish(self, channel, result):
|
||||||
|
channel.basic_publish(exchange='',
|
||||||
|
routing_key=self.__ResultQ,
|
||||||
|
body=json.dumps(result))
|
||||||
|
|
||||||
|
print(f"Done!")
|
||||||
|
|
||||||
|
def upload_to_database(self, results):
|
||||||
|
"""
|
||||||
|
Insert inference result data into DB(postgreSQL).
|
||||||
|
Columns consist of file_name(or image_name), count.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def main(self):
|
||||||
|
try:
|
||||||
|
conn = pika.BlockingConnection(pika.ConnectionParameters(self.__url,
|
||||||
|
self.__port,
|
||||||
|
self.__vhost,
|
||||||
|
self.__cred))
|
||||||
|
chan = conn.channel()
|
||||||
|
chan.queue_declare(queue=self.__TaskQ, durable=True)
|
||||||
|
chan.queue_declare(queue=self.__ResultQ, durable=True)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
method, properties, body = chan.basic_get(queue=self.__TaskQ,
|
||||||
|
auto_ack=True)
|
||||||
|
|
||||||
|
if not method:
|
||||||
|
send_message_to_slack()
|
||||||
|
break
|
||||||
|
|
||||||
|
if method:
|
||||||
|
print(f" [x] Received {body}", end=' | ', flush=True)
|
||||||
|
edge_image_path = get_info(body)
|
||||||
|
image_path = image_copy_using_SCP(edge_image_path)
|
||||||
|
|
||||||
|
if image_path == "":
|
||||||
|
continue
|
||||||
|
|
||||||
|
result_dict, save_path = self.inference(image_path)
|
||||||
|
self.result_publish(chan, result_dict)
|
||||||
|
self.upload_to_database(result_dict)
|
||||||
|
self.image_upload_to_cloud(save_path)
|
||||||
|
|
||||||
|
time.sleep(1)
|
||||||
|
else:
|
||||||
|
time.sleep(0.5)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(traceback.format_exc())
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
consumer = Consumer()
|
||||||
|
consumer.initialize()
|
||||||
|
consumer.main()
|
|
@ -0,0 +1,9 @@
|
||||||
|
[Unit]
|
||||||
|
Description=inference processor
|
||||||
|
|
||||||
|
[Service]
|
||||||
|
ExecStart=/home/sdt/miniconda3/bin/python /home/sdt/Workspace/gseps/rabbitmq_test/inference.py
|
||||||
|
Restart=on-failure
|
||||||
|
|
||||||
|
[Install]
|
||||||
|
WantedBy=multi-user.target
|
Loading…
Reference in New Issue