Add fuzzy and text-classification light model #1
3 changed files with 115 additions and 22 deletions
65
main.py
65
main.py
|
@ -1,9 +1,9 @@
|
|||
import logging
|
||||
import time
|
||||
import schedule
|
||||
import argparse
|
||||
|
||||
# import schedule
|
||||
|
||||
import config
|
||||
|
||||
from src.rhasspy import rhasspy_mqtt as yoda_listener
|
||||
from src.rhasspy import Rhasspy
|
||||
from src.ratatouille import Ratatouille
|
||||
|
@ -11,30 +11,73 @@ from src.mpd import Mpd
|
|||
from src.hass import Hass
|
||||
from src.httpServer import get_server
|
||||
from src.fuzzy import fuzz_predict
|
||||
# from src.intent import AlexaIntent
|
||||
from src.intent import BertIntent
|
||||
from src.tools.simple_sentence_corrector import simple_sentence_corrector
|
||||
|
||||
|
||||
# --------- setup args -------------
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
prog='Ratatouille',
|
||||
description='Ratatouille le cerveau domotique !')
|
||||
|
||||
parser.add_argument('mode')
|
||||
parser.add_argument('-i', '--ip', required=False)
|
||||
parser.add_argument('-p', '--port', required=False, type=int)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.mode == "server":
|
||||
if args.ip is None or args.port is None:
|
||||
logging.error(" --ip or --port argument missing")
|
||||
exit()
|
||||
|
||||
|
||||
# -------- setup logging ------------
|
||||
|
||||
logging.basicConfig(
|
||||
level=10,
|
||||
format="%(asctime)s %(filename)s:%(lineno)s %(levelname)s %(message)s"
|
||||
)
|
||||
|
||||
IP = "10.10.10.8"
|
||||
PORT = 5555
|
||||
logging.info("Loading ratatouilles modules")
|
||||
|
||||
|
||||
# ---------- other ---------------
|
||||
|
||||
walle = Hass(config.hass_url, config.hass_token)
|
||||
yoda = None # Rhasspy(config.rhasspy_url)
|
||||
mopidy = Mpd('10.10.10.8')
|
||||
ratatouille = Ratatouille(yoda, walle, mopidy, schedule)
|
||||
ratatouille = Ratatouille(yoda, walle, mopidy, None)
|
||||
# alexa = AlexaIntent() # we are not doing any request to the evil amazon but we are using one of its dataset
|
||||
bert = BertIntent()
|
||||
|
||||
|
||||
def answer(sentence):
|
||||
# return ratatouille.parse_alexa(alexa.predict(sentence))
|
||||
return ratatouille.parse_fuzzy(fuzz_predict(sentence))
|
||||
sentence_corrected = simple_sentence_corrector(sentence)
|
||||
prediction = bert.predict(sentence_corrected)
|
||||
return ratatouille.parse_fuzzy(prediction)
|
||||
# return "42"
|
||||
|
||||
|
||||
server = get_server(IP, PORT, answer)
|
||||
def run_server(ip, port):
|
||||
server = get_server(ip, port, answer)
|
||||
logging.info('Running server on '+ip+':'+str(port))
|
||||
server.serve_forever()
|
||||
|
||||
logging.info('Running server on '+IP+':'+str(PORT))
|
||||
server.serve_forever()
|
||||
|
||||
def run_prompt():
|
||||
question = "empty"
|
||||
while question != "stop":
|
||||
question = input("?")
|
||||
print(answer(question))
|
||||
|
||||
|
||||
logging.info("Ratatouille is ready !")
|
||||
|
||||
# run_server()
|
||||
if args.mode == "server":
|
||||
run_server(str(args.ip), args.port)
|
||||
else:
|
||||
run_prompt()
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
import unittest
|
||||
from transformers import AutoTokenizer, AutoModelForTokenClassification, TokenClassificationPipeline
|
||||
from transformers import AutoModelForSequenceClassification, TextClassificationPipeline
|
||||
from transformers import pipeline
|
||||
|
||||
DOMOTIQUE_OBJ_ON = ['iot_wemo_on']
|
||||
DOMOTIQUE_OBJ_OFF = ['iot_wemo_off']
|
||||
|
||||
|
||||
class AlexaIntent():
|
||||
|
||||
def __init__(self):
|
||||
|
@ -15,21 +17,24 @@ class AlexaIntent():
|
|||
model_name = 'qanastek/XLMRoberta-Alexa-Intents-Classification'
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
||||
classifier = TextClassificationPipeline(model=model, tokenizer=tokenizer)
|
||||
classifier = TextClassificationPipeline(
|
||||
model=model, tokenizer=tokenizer)
|
||||
return classifier
|
||||
|
||||
def init_entities_classification(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained('qanastek/XLMRoberta-Alexa-Intents-NER-NLU')
|
||||
model = AutoModelForTokenClassification.from_pretrained('qanastek/XLMRoberta-Alexa-Intents-NER-NLU')
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
'qanastek/XLMRoberta-Alexa-Intents-NER-NLU')
|
||||
model = AutoModelForTokenClassification.from_pretrained(
|
||||
'qanastek/XLMRoberta-Alexa-Intents-NER-NLU')
|
||||
predict = TokenClassificationPipeline(model=model, tokenizer=tokenizer)
|
||||
return predict
|
||||
|
||||
def simple_sentence_corrector(self,sentence):
|
||||
sentence = sentence.replace('étant','éteins')
|
||||
sentence = sentence.replace('dépeint','éteins')
|
||||
def simple_sentence_corrector(self, sentence):
|
||||
sentence = sentence.replace('étant', 'éteins')
|
||||
sentence = sentence.replace('dépeint', 'éteins')
|
||||
return sentence
|
||||
|
||||
def intent_corrector(self,intents):
|
||||
def intent_corrector(self, intents):
|
||||
for intent in intents:
|
||||
if intent['label'] in DOMOTIQUE_OBJ_ON:
|
||||
intent['label'] = 'iot_hue_lighton'
|
||||
|
@ -37,7 +42,7 @@ class AlexaIntent():
|
|||
intent['label'] = 'iot_hue_lightoff'
|
||||
return intents
|
||||
|
||||
def predict(self,sentence):
|
||||
def predict(self, sentence):
|
||||
sentence = self.simple_sentence_corrector(sentence)
|
||||
return {
|
||||
'intents': self.intent_corrector(self.get_intents(sentence)),
|
||||
|
@ -45,6 +50,47 @@ class AlexaIntent():
|
|||
}
|
||||
|
||||
|
||||
class BertIntent():
|
||||
def __init__(self):
|
||||
self.classifier = pipeline("text-classification",
|
||||
model="Tjiho/french-intents-classificaton")
|
||||
|
||||
def predict(self, sentence):
|
||||
# sentence = self.simple_sentence_corrector(sentence)
|
||||
classification = self.classifier(sentence)
|
||||
|
||||
if classification[0]["score"] < 0.7: # score too low
|
||||
return
|
||||
|
||||
label = classification[0]["label"]
|
||||
|
||||
if label == "HEURE":
|
||||
return {'intentName': 'GetTime'}
|
||||
elif label == "DATE":
|
||||
return {'intentName': 'GetDate'}
|
||||
elif label == "ETEINDRE_CUISINE":
|
||||
return {'intentName': 'GetTime', 'intentArg': ['cuisine']}
|
||||
elif label == "ETEINDRE_BUREAU":
|
||||
return {'intentName': 'LightOff', 'intentArg': ['bureau']}
|
||||
elif label == "ETEINDRE_SALON":
|
||||
return {'intentName': 'LightOff', 'intentArg': ['salon']}
|
||||
elif label == "ETEINDRE_CHAMBRE":
|
||||
return {'intentName': 'LightOff', 'intentArg': ['chambre']}
|
||||
elif label == "ALLUMER_CUISINE":
|
||||
return {'intentName': 'LightOn', 'intentArg': ['cuisine']}
|
||||
elif label == "ALLUMER_SALON":
|
||||
return {'intentName': 'LightOn', 'intentArg': ['salon']}
|
||||
elif label == "ALLUMER_BUREAU":
|
||||
return {'intentName': 'LightOn', 'intentArg': ['bureau']}
|
||||
elif label == "ALLUMER_CHAMBRE":
|
||||
return {'intentName': 'LightOn', 'intentArg': ['chambre']}
|
||||
elif label == "METEO":
|
||||
return {'intentName': 'Meteo'}
|
||||
elif label == "TEMPERATURE_EXTERIEUR":
|
||||
return {'intentName': 'Temperature_ext'}
|
||||
elif label == "TEMPERATURE_INTERIEUR":
|
||||
return {'intentName': 'Temperature_int'}
|
||||
|
||||
|
||||
class TestAlexa(unittest.TestCase):
|
||||
@classmethod
|
||||
|
@ -60,11 +106,11 @@ class TestAlexa(unittest.TestCase):
|
|||
self.assertEqual(res['intents'][0]['label'], 'iot_hue_lighton')
|
||||
self.assertEqual(res['entities'][0]['word'], '▁cuisine')
|
||||
|
||||
|
||||
def test_bad_transcribe(self):
|
||||
res = self.alexa.predict("dépeint la cuisine")
|
||||
self.assertEqual(res['intents'][0]['label'], 'iot_hue_lightoff')
|
||||
self.assertEqual(res['entities'][0]['word'], '▁cuisine')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
4
src/tools/simple_sentence_corrector.py
Normal file
4
src/tools/simple_sentence_corrector.py
Normal file
|
@ -0,0 +1,4 @@
|
|||
def simple_sentence_corrector(sentence):
|
||||
sentence = sentence.replace('étant', 'éteins')
|
||||
sentence = sentence.replace('dépeint', 'éteins')
|
||||
return sentence
|
Loading…
Add table
Reference in a new issue