Add fuzzy and text-classification light model #1
3 changed files with 115 additions and 22 deletions
65
main.py
65
main.py
|
@ -1,9 +1,9 @@
|
||||||
import logging
|
import logging
|
||||||
import time
|
import argparse
|
||||||
import schedule
|
|
||||||
|
# import schedule
|
||||||
|
|
||||||
import config
|
import config
|
||||||
|
|
||||||
from src.rhasspy import rhasspy_mqtt as yoda_listener
|
from src.rhasspy import rhasspy_mqtt as yoda_listener
|
||||||
from src.rhasspy import Rhasspy
|
from src.rhasspy import Rhasspy
|
||||||
from src.ratatouille import Ratatouille
|
from src.ratatouille import Ratatouille
|
||||||
|
@ -11,30 +11,73 @@ from src.mpd import Mpd
|
||||||
from src.hass import Hass
|
from src.hass import Hass
|
||||||
from src.httpServer import get_server
|
from src.httpServer import get_server
|
||||||
from src.fuzzy import fuzz_predict
|
from src.fuzzy import fuzz_predict
|
||||||
# from src.intent import AlexaIntent
|
from src.intent import BertIntent
|
||||||
|
from src.tools.simple_sentence_corrector import simple_sentence_corrector
|
||||||
|
|
||||||
|
|
||||||
|
# --------- setup args -------------
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
prog='Ratatouille',
|
||||||
|
description='Ratatouille le cerveau domotique !')
|
||||||
|
|
||||||
|
parser.add_argument('mode')
|
||||||
|
parser.add_argument('-i', '--ip', required=False)
|
||||||
|
parser.add_argument('-p', '--port', required=False, type=int)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.mode == "server":
|
||||||
|
if args.ip is None or args.port is None:
|
||||||
|
logging.error(" --ip or --port argument missing")
|
||||||
|
exit()
|
||||||
|
|
||||||
|
|
||||||
|
# -------- setup logging ------------
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=10,
|
level=10,
|
||||||
format="%(asctime)s %(filename)s:%(lineno)s %(levelname)s %(message)s"
|
format="%(asctime)s %(filename)s:%(lineno)s %(levelname)s %(message)s"
|
||||||
)
|
)
|
||||||
|
|
||||||
IP = "10.10.10.8"
|
logging.info("Loading ratatouilles modules")
|
||||||
PORT = 5555
|
|
||||||
|
|
||||||
|
# ---------- other ---------------
|
||||||
|
|
||||||
walle = Hass(config.hass_url, config.hass_token)
|
walle = Hass(config.hass_url, config.hass_token)
|
||||||
yoda = None # Rhasspy(config.rhasspy_url)
|
yoda = None # Rhasspy(config.rhasspy_url)
|
||||||
mopidy = Mpd('10.10.10.8')
|
mopidy = Mpd('10.10.10.8')
|
||||||
ratatouille = Ratatouille(yoda, walle, mopidy, schedule)
|
ratatouille = Ratatouille(yoda, walle, mopidy, None)
|
||||||
# alexa = AlexaIntent() # we are not doing any request to the evil amazon but we are using one of its dataset
|
# alexa = AlexaIntent() # we are not doing any request to the evil amazon but we are using one of its dataset
|
||||||
|
bert = BertIntent()
|
||||||
|
|
||||||
|
|
||||||
def answer(sentence):
|
def answer(sentence):
|
||||||
# return ratatouille.parse_alexa(alexa.predict(sentence))
|
# return ratatouille.parse_alexa(alexa.predict(sentence))
|
||||||
return ratatouille.parse_fuzzy(fuzz_predict(sentence))
|
sentence_corrected = simple_sentence_corrector(sentence)
|
||||||
|
prediction = bert.predict(sentence_corrected)
|
||||||
|
return ratatouille.parse_fuzzy(prediction)
|
||||||
# return "42"
|
# return "42"
|
||||||
|
|
||||||
|
|
||||||
server = get_server(IP, PORT, answer)
|
def run_server(ip, port):
|
||||||
|
server = get_server(ip, port, answer)
|
||||||
logging.info('Running server on '+IP+':'+str(PORT))
|
logging.info('Running server on '+ip+':'+str(port))
|
||||||
server.serve_forever()
|
server.serve_forever()
|
||||||
|
|
||||||
|
|
||||||
|
def run_prompt():
|
||||||
|
question = "empty"
|
||||||
|
while question != "stop":
|
||||||
|
question = input("?")
|
||||||
|
print(answer(question))
|
||||||
|
|
||||||
|
|
||||||
|
logging.info("Ratatouille is ready !")
|
||||||
|
|
||||||
|
# run_server()
|
||||||
|
if args.mode == "server":
|
||||||
|
run_server(str(args.ip), args.port)
|
||||||
|
else:
|
||||||
|
run_prompt()
|
||||||
|
|
|
@ -1,10 +1,12 @@
|
||||||
import unittest
|
import unittest
|
||||||
from transformers import AutoTokenizer, AutoModelForTokenClassification, TokenClassificationPipeline
|
from transformers import AutoTokenizer, AutoModelForTokenClassification, TokenClassificationPipeline
|
||||||
from transformers import AutoModelForSequenceClassification, TextClassificationPipeline
|
from transformers import AutoModelForSequenceClassification, TextClassificationPipeline
|
||||||
|
from transformers import pipeline
|
||||||
|
|
||||||
DOMOTIQUE_OBJ_ON = ['iot_wemo_on']
|
DOMOTIQUE_OBJ_ON = ['iot_wemo_on']
|
||||||
DOMOTIQUE_OBJ_OFF = ['iot_wemo_off']
|
DOMOTIQUE_OBJ_OFF = ['iot_wemo_off']
|
||||||
|
|
||||||
|
|
||||||
class AlexaIntent():
|
class AlexaIntent():
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -15,12 +17,15 @@ class AlexaIntent():
|
||||||
model_name = 'qanastek/XLMRoberta-Alexa-Intents-Classification'
|
model_name = 'qanastek/XLMRoberta-Alexa-Intents-Classification'
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
||||||
classifier = TextClassificationPipeline(model=model, tokenizer=tokenizer)
|
classifier = TextClassificationPipeline(
|
||||||
|
model=model, tokenizer=tokenizer)
|
||||||
return classifier
|
return classifier
|
||||||
|
|
||||||
def init_entities_classification(self):
|
def init_entities_classification(self):
|
||||||
tokenizer = AutoTokenizer.from_pretrained('qanastek/XLMRoberta-Alexa-Intents-NER-NLU')
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model = AutoModelForTokenClassification.from_pretrained('qanastek/XLMRoberta-Alexa-Intents-NER-NLU')
|
'qanastek/XLMRoberta-Alexa-Intents-NER-NLU')
|
||||||
|
model = AutoModelForTokenClassification.from_pretrained(
|
||||||
|
'qanastek/XLMRoberta-Alexa-Intents-NER-NLU')
|
||||||
predict = TokenClassificationPipeline(model=model, tokenizer=tokenizer)
|
predict = TokenClassificationPipeline(model=model, tokenizer=tokenizer)
|
||||||
return predict
|
return predict
|
||||||
|
|
||||||
|
@ -45,6 +50,47 @@ class AlexaIntent():
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class BertIntent():
|
||||||
|
def __init__(self):
|
||||||
|
self.classifier = pipeline("text-classification",
|
||||||
|
model="Tjiho/french-intents-classificaton")
|
||||||
|
|
||||||
|
def predict(self, sentence):
|
||||||
|
# sentence = self.simple_sentence_corrector(sentence)
|
||||||
|
classification = self.classifier(sentence)
|
||||||
|
|
||||||
|
if classification[0]["score"] < 0.7: # score too low
|
||||||
|
return
|
||||||
|
|
||||||
|
label = classification[0]["label"]
|
||||||
|
|
||||||
|
if label == "HEURE":
|
||||||
|
return {'intentName': 'GetTime'}
|
||||||
|
elif label == "DATE":
|
||||||
|
return {'intentName': 'GetDate'}
|
||||||
|
elif label == "ETEINDRE_CUISINE":
|
||||||
|
return {'intentName': 'GetTime', 'intentArg': ['cuisine']}
|
||||||
|
elif label == "ETEINDRE_BUREAU":
|
||||||
|
return {'intentName': 'LightOff', 'intentArg': ['bureau']}
|
||||||
|
elif label == "ETEINDRE_SALON":
|
||||||
|
return {'intentName': 'LightOff', 'intentArg': ['salon']}
|
||||||
|
elif label == "ETEINDRE_CHAMBRE":
|
||||||
|
return {'intentName': 'LightOff', 'intentArg': ['chambre']}
|
||||||
|
elif label == "ALLUMER_CUISINE":
|
||||||
|
return {'intentName': 'LightOn', 'intentArg': ['cuisine']}
|
||||||
|
elif label == "ALLUMER_SALON":
|
||||||
|
return {'intentName': 'LightOn', 'intentArg': ['salon']}
|
||||||
|
elif label == "ALLUMER_BUREAU":
|
||||||
|
return {'intentName': 'LightOn', 'intentArg': ['bureau']}
|
||||||
|
elif label == "ALLUMER_CHAMBRE":
|
||||||
|
return {'intentName': 'LightOn', 'intentArg': ['chambre']}
|
||||||
|
elif label == "METEO":
|
||||||
|
return {'intentName': 'Meteo'}
|
||||||
|
elif label == "TEMPERATURE_EXTERIEUR":
|
||||||
|
return {'intentName': 'Temperature_ext'}
|
||||||
|
elif label == "TEMPERATURE_INTERIEUR":
|
||||||
|
return {'intentName': 'Temperature_int'}
|
||||||
|
|
||||||
|
|
||||||
class TestAlexa(unittest.TestCase):
|
class TestAlexa(unittest.TestCase):
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -60,11 +106,11 @@ class TestAlexa(unittest.TestCase):
|
||||||
self.assertEqual(res['intents'][0]['label'], 'iot_hue_lighton')
|
self.assertEqual(res['intents'][0]['label'], 'iot_hue_lighton')
|
||||||
self.assertEqual(res['entities'][0]['word'], '▁cuisine')
|
self.assertEqual(res['entities'][0]['word'], '▁cuisine')
|
||||||
|
|
||||||
|
|
||||||
def test_bad_transcribe(self):
|
def test_bad_transcribe(self):
|
||||||
res = self.alexa.predict("dépeint la cuisine")
|
res = self.alexa.predict("dépeint la cuisine")
|
||||||
self.assertEqual(res['intents'][0]['label'], 'iot_hue_lightoff')
|
self.assertEqual(res['intents'][0]['label'], 'iot_hue_lightoff')
|
||||||
self.assertEqual(res['entities'][0]['word'], '▁cuisine')
|
self.assertEqual(res['entities'][0]['word'], '▁cuisine')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
4
src/tools/simple_sentence_corrector.py
Normal file
4
src/tools/simple_sentence_corrector.py
Normal file
|
@ -0,0 +1,4 @@
|
||||||
|
def simple_sentence_corrector(sentence):
|
||||||
|
sentence = sentence.replace('étant', 'éteins')
|
||||||
|
sentence = sentence.replace('dépeint', 'éteins')
|
||||||
|
return sentence
|
Loading…
Add table
Reference in a new issue