92 lines
2.3 KiB
Python
92 lines
2.3 KiB
Python
import logging
|
|
import argparse
|
|
|
|
# import schedule
|
|
|
|
import config
|
|
from src.rhasspy import rhasspy_mqtt as yoda_listener
|
|
from src.rhasspy import Rhasspy
|
|
from src.ratatouille import Ratatouille
|
|
from src.mpd import Mpd
|
|
from src.hass import Hass
|
|
from src.httpServer import get_server
|
|
from src.fuzzy import fuzz_predict
|
|
from src.intent import BertIntent
|
|
from src.tools.simple_sentence_corrector import simple_sentence_corrector
|
|
|
|
|
|
# --------- setup args -------------
|
|
|
|
parser = argparse.ArgumentParser(
|
|
prog='Ratatouille',
|
|
description='Ratatouille le cerveau domotique !')
|
|
|
|
parser.add_argument('mode')
|
|
parser.add_argument('-i', '--ip', required=False)
|
|
parser.add_argument('-p', '--port', required=False, type=int)
|
|
parser.add_argument('-m', '--mpd', required=False)
|
|
|
|
args = parser.parse_args()
|
|
|
|
if args.mode == "server":
|
|
if args.ip is None or args.port is None:
|
|
logging.error(" --ip or --port argument missing")
|
|
exit()
|
|
|
|
|
|
# -------- setup logging ------------
|
|
|
|
logging.basicConfig(
|
|
level=10,
|
|
format="%(asctime)s %(filename)s:%(lineno)s %(levelname)s %(message)s"
|
|
)
|
|
|
|
logging.info("Loading ratatouilles modules")
|
|
|
|
|
|
# ---------- other ---------------
|
|
|
|
walle = Hass(config.hass_url, config.hass_token)
|
|
yoda = None # Rhasspy(config.rhasspy_url)
|
|
|
|
mopidy = None
|
|
|
|
if args.mpd is not None:
|
|
mopidy = Mpd(args.mpd)
|
|
else:
|
|
logging.warning('Starting without MPD connection')
|
|
|
|
ratatouille = Ratatouille(yoda, walle, mopidy, None)
|
|
# alexa = AlexaIntent() # we are not doing any request to the evil amazon but we are using one of its dataset
|
|
bert = BertIntent()
|
|
|
|
|
|
def answer(sentence):
|
|
# return ratatouille.parse_alexa(alexa.predict(sentence))
|
|
sentence_corrected = simple_sentence_corrector(sentence)
|
|
prediction = bert.predict(sentence_corrected)
|
|
return ratatouille.parse_fuzzy(prediction)
|
|
# return "42"
|
|
|
|
|
|
def run_server(ip, port):
|
|
server = get_server(ip, port, answer)
|
|
logging.info('Running server on '+ip+':'+str(port))
|
|
server.serve_forever()
|
|
|
|
|
|
def run_prompt():
|
|
question = "empty"
|
|
while question != "stop":
|
|
question = input("?")
|
|
if question != "stop":
|
|
print(answer(question))
|
|
|
|
|
|
logging.info("Ratatouille is ready !")
|
|
|
|
# run_server()
|
|
if args.mode == "server":
|
|
run_server(str(args.ip), args.port)
|
|
else:
|
|
run_prompt()
|