Add fuzzy and text-classification light model #1

Merged
Tjiho merged 18 commits from fuzzy into master 2025-02-01 01:35:11 +01:00
11 changed files with 353 additions and 84 deletions

View file

@ -1,3 +1,19 @@
# ratatouille # ratatouille
Projet ratatouille: interface tampon entre Rhasspy, Hass, Mpd, Bookstack, Systemd et autre Projet ratatouille: interface tampon entre Rhasspy, Hass, Mpd, Bookstack, Systemd et autre.
To work with models (intent.py), transformers and pytorch is needed.
See https://pytorch.org/get-started/locally/
## Run
```
python main.py server -i 127.0.0.1 -p 7777 --mpd 10.10.10.10
```
ou
```
python main.py prompt
```

12
await.sh Executable file
View file

@ -0,0 +1,12 @@
while [[ 1 == 1 ]]; do
/bin/sleep 10
code=$(curl -o /dev/null -s -w "%{http_code}\n" 10.10.10.8)
echo "10.10.10.8 return code $code"
if [ $code = '200' ]; then
exit 0;
else
echo "retrying in 10s";
fi
done

79
main.py
View file

@ -1,37 +1,92 @@
import logging import logging
import time import argparse
import schedule
# import schedule
import config import config
from src.rhasspy import rhasspy_mqtt as yoda_listener from src.rhasspy import rhasspy_mqtt as yoda_listener
from src.rhasspy import Rhasspy from src.rhasspy import Rhasspy
from src.ratatouille import Ratatouille from src.ratatouille import Ratatouille
from src.mpd import Mpd from src.mpd import Mpd
from src.hass import Hass from src.hass import Hass
from src.httpServer import get_server from src.httpServer import get_server
from src.intent import AlexaIntent from src.fuzzy import fuzz_predict
from src.intent import BertIntent
from src.tools.simple_sentence_corrector import simple_sentence_corrector
# --------- setup args -------------
parser = argparse.ArgumentParser(
prog='Ratatouille',
description='Ratatouille le cerveau domotique !')
parser.add_argument('mode')
parser.add_argument('-i', '--ip', required=False)
parser.add_argument('-p', '--port', required=False, type=int)
parser.add_argument('-m', '--mpd', required=False)
args = parser.parse_args()
if args.mode == "server":
if args.ip is None or args.port is None:
logging.error(" --ip or --port argument missing")
exit()
# -------- setup logging ------------
logging.basicConfig( logging.basicConfig(
level=10, level=10,
format="%(asctime)s %(filename)s:%(lineno)s %(levelname)s %(message)s" format="%(asctime)s %(filename)s:%(lineno)s %(levelname)s %(message)s"
) )
IP = "10.10.10.11" logging.info("Loading ratatouilles modules")
PORT = 5555
# ---------- other ---------------
walle = Hass(config.hass_url, config.hass_token) walle = Hass(config.hass_url, config.hass_token)
yoda = None # Rhasspy(config.rhasspy_url) yoda = None # Rhasspy(config.rhasspy_url)
mopidy = Mpd('10.10.10.10', yoda)
ratatouille = Ratatouille(yoda, walle, mopidy, schedule) mopidy = None
alexa = AlexaIntent() # we are not doing any request to the evil amazon but we are using one of its dataset
if args.mpd is not None:
mopidy = Mpd(args.mpd)
else:
logging.warning('Starting without MPD connection')
ratatouille = Ratatouille(yoda, walle, mopidy, None)
# alexa = AlexaIntent() # we are not doing any request to the evil amazon but we are using one of its dataset
bert = BertIntent()
def answer(sentence): def answer(sentence):
return ratatouille.parseAlexa(alexa.predict(sentence)) # return ratatouille.parse_alexa(alexa.predict(sentence))
sentence_corrected = simple_sentence_corrector(sentence)
prediction = bert.predict(sentence_corrected)
return ratatouille.parse_fuzzy(prediction)
# return "42" # return "42"
server = get_server(IP,PORT,answer)
logging.info('Running server on '+IP+':'+str(PORT)) def run_server(ip, port):
server = get_server(ip, port, answer)
logging.info('Running server on '+ip+':'+str(port))
server.serve_forever() server.serve_forever()
def run_prompt():
question = "empty"
while question != "stop":
question = input("?")
if question != "stop":
print(answer(question))
logging.info("Ratatouille is ready !")
# run_server()
if args.mode == "server":
run_server(str(args.ip), args.port)
else:
run_prompt()

View file

@ -2,5 +2,6 @@ paho-mqtt<1.6
requests<2.26 requests<2.26
schedule<1.1.0 schedule<1.1.0
python-mpd2<3.1 python-mpd2<3.1
transformers<4.28.0 #transformers<4.28.0
torch<2.1.0 #torch<2.1.0
rapidfuzz<4.0.0

69
src/fuzzy.py Normal file
View file

@ -0,0 +1,69 @@
from rapidfuzz import process, fuzz, utils
ROOMS = ["cuisine", "salon", "chambre", "bureau"]
MUSIQUE_GENRE_SHORT = ["synthpop", "jazz classique", "jazz manouche", "latine", "classique", "rock", "jazz",
"blues", "film", "francaise", "pop", "reggae", "folk",
"électro", "punk", "corse", "arabe", "persane", "piano", "rap", "slam"]
MUSIQUE_GENRE_LONG = ["synthpop", "jazz classique", "jazz manouche", "chanson latine", "classique", "rock", "jazz",
"blues", "musique de film", "chanson francaise", "pop", "reggae", "folk",
"électro", "punk", "chanson corse", "chanson arabe", "chanson persane", "piano", "rap", "slam", "chanson classique"]
def compute_sentences_lamp_on():
return ["allume la lumière de la" + room for room in ROOMS] + ["allume la" + room for room in ROOMS]
def compute_sentences_lamp_off():
return ["éteins la lumière de la" + room for room in ROOMS] + ["éteins la" + room for room in ROOMS]
def compute_sentences_musique_genre_long():
return ["met du" + genre for genre in MUSIQUE_GENRE_LONG]
def compute_sentences_musique_genre_short():
return ["met de la musique" + genre for genre in MUSIQUE_GENRE_SHORT]
SENTENCES_HOUR = ["quel heure est-il ?"]
SENTENCES_LAMP_ON = compute_sentences_lamp_on()
SENTENCES_LAMP_OFF = compute_sentences_lamp_off()
SENTENCES_MUSIQUE_GENRE_LONG = compute_sentences_musique_genre_long()
SENTENCES_MUSIQUE_GENRE_SHORT = compute_sentences_musique_genre_short()
def fuzz_predict(text):
choices = SENTENCES_HOUR + SENTENCES_LAMP_ON + SENTENCES_LAMP_OFF + \
SENTENCES_MUSIQUE_GENRE_LONG + SENTENCES_MUSIQUE_GENRE_SHORT
result = process.extractOne(
text, choices, scorer=fuzz.WRatio, processor=utils.default_process)
choosen_sentence = result[0]
if choosen_sentence in SENTENCES_HOUR:
return {'intentName': 'GetTime'}
if choosen_sentence in SENTENCES_LAMP_ON:
return {'intentName': 'LightOn', 'intentArg': [find_matching(ROOMS, text)]}
if choosen_sentence in SENTENCES_LAMP_OFF:
return {'intentName': 'LightOff', 'intentArg': [find_matching(ROOMS, text)]}
if choosen_sentence in SENTENCES_MUSIQUE_GENRE_LONG:
return {'intentName': 'PlayMusicGenre', 'intentArg': [find_matching(MUSIQUE_GENRE_SHORT, text)]}
if choosen_sentence in SENTENCES_MUSIQUE_GENRE_SHORT:
return {'intentName': 'PlayMusicGenre', 'intentArg': [find_matching(MUSIQUE_GENRE_SHORT, text)]}
return {'intentName': 'Unknown'}
def find_matching(list_str, text):
for search in list_str:
if search in text:
return search
return None

View file

@ -2,6 +2,7 @@ import http.server
import json import json
import logging import logging
def get_server(ip, port, answer_function): def get_server(ip, port, answer_function):
class Server(http.server.BaseHTTPRequestHandler): class Server(http.server.BaseHTTPRequestHandler):
@ -16,10 +17,9 @@ def get_server(ip, port, answer_function):
logging.info('Get request:' + text) logging.info('Get request:' + text)
print(text) print(text)
self.send_response(200) self.send_response(200)
self.send_header('Content-type','text/plain') self.send_header('Content-type', 'text/plain; charset=utf-8')
self.end_headers() self.end_headers()
self.wfile.write(res.encode()) self.wfile.write(res.encode())
return http.server.HTTPServer((ip, port), Server) return http.server.HTTPServer((ip, port), Server)

View file

@ -1,10 +1,12 @@
import unittest import unittest
from transformers import AutoTokenizer, AutoModelForTokenClassification, TokenClassificationPipeline from transformers import AutoTokenizer, AutoModelForTokenClassification, TokenClassificationPipeline
from transformers import AutoModelForSequenceClassification, TextClassificationPipeline from transformers import AutoModelForSequenceClassification, TextClassificationPipeline
from transformers import pipeline
DOMOTIQUE_OBJ_ON = ['iot_wemo_on'] DOMOTIQUE_OBJ_ON = ['iot_wemo_on']
DOMOTIQUE_OBJ_OFF = ['iot_wemo_off'] DOMOTIQUE_OBJ_OFF = ['iot_wemo_off']
class AlexaIntent(): class AlexaIntent():
def __init__(self): def __init__(self):
@ -15,12 +17,15 @@ class AlexaIntent():
model_name = 'qanastek/XLMRoberta-Alexa-Intents-Classification' model_name = 'qanastek/XLMRoberta-Alexa-Intents-Classification'
tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name) model = AutoModelForSequenceClassification.from_pretrained(model_name)
classifier = TextClassificationPipeline(model=model, tokenizer=tokenizer) classifier = TextClassificationPipeline(
model=model, tokenizer=tokenizer)
return classifier return classifier
def init_entities_classification(self): def init_entities_classification(self):
tokenizer = AutoTokenizer.from_pretrained('qanastek/XLMRoberta-Alexa-Intents-NER-NLU') tokenizer = AutoTokenizer.from_pretrained(
model = AutoModelForTokenClassification.from_pretrained('qanastek/XLMRoberta-Alexa-Intents-NER-NLU') 'qanastek/XLMRoberta-Alexa-Intents-NER-NLU')
model = AutoModelForTokenClassification.from_pretrained(
'qanastek/XLMRoberta-Alexa-Intents-NER-NLU')
predict = TokenClassificationPipeline(model=model, tokenizer=tokenizer) predict = TokenClassificationPipeline(model=model, tokenizer=tokenizer)
return predict return predict
@ -45,6 +50,47 @@ class AlexaIntent():
} }
class BertIntent():
def __init__(self):
self.classifier = pipeline("text-classification",
model="Tjiho/french-intents-classificaton")
def predict(self, sentence):
# sentence = self.simple_sentence_corrector(sentence)
classification = self.classifier(sentence)
if classification[0]["score"] < 0.7: # score too low
return {'intentName': ''}
label = classification[0]["label"]
if label == "HEURE":
return {'intentName': 'GetTime'}
elif label == "DATE":
return {'intentName': 'GetDate'}
elif label == "ETEINDRE_CUISINE":
return {'intentName': 'LightOff', 'intentArg': ['cuisine']}
elif label == "ETEINDRE_BUREAU":
return {'intentName': 'LightOff', 'intentArg': ['bureau']}
elif label == "ETEINDRE_SALON":
return {'intentName': 'LightOff', 'intentArg': ['salon']}
elif label == "ETEINDRE_CHAMBRE":
return {'intentName': 'LightOff', 'intentArg': ['chambre']}
elif label == "ALLUMER_CUISINE":
return {'intentName': 'LightOn', 'intentArg': ['cuisine']}
elif label == "ALLUMER_SALON":
return {'intentName': 'LightOn', 'intentArg': ['salon']}
elif label == "ALLUMER_BUREAU":
return {'intentName': 'LightOn', 'intentArg': ['bureau']}
elif label == "ALLUMER_CHAMBRE":
return {'intentName': 'LightOn', 'intentArg': ['chambre']}
elif label == "METEO":
return {'intentName': 'Meteo'}
elif label == "TEMPERATURE_EXTERIEUR":
return {'intentName': 'Temperature_ext'}
elif label == "TEMPERATURE_INTERIEUR":
return {'intentName': 'Temperature_int'}
class TestAlexa(unittest.TestCase): class TestAlexa(unittest.TestCase):
@classmethod @classmethod
@ -60,11 +106,11 @@ class TestAlexa(unittest.TestCase):
self.assertEqual(res['intents'][0]['label'], 'iot_hue_lighton') self.assertEqual(res['intents'][0]['label'], 'iot_hue_lighton')
self.assertEqual(res['entities'][0]['word'], '▁cuisine') self.assertEqual(res['entities'][0]['word'], '▁cuisine')
def test_bad_transcribe(self): def test_bad_transcribe(self):
res = self.alexa.predict("dépeint la cuisine") res = self.alexa.predict("dépeint la cuisine")
self.assertEqual(res['intents'][0]['label'], 'iot_hue_lightoff') self.assertEqual(res['intents'][0]['label'], 'iot_hue_lightoff')
self.assertEqual(res['entities'][0]['word'], '▁cuisine') self.assertEqual(res['entities'][0]['word'], '▁cuisine')
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View file

@ -2,16 +2,16 @@ import logging
import random import random
from mpd import MPDClient from mpd import MPDClient
class Mpd(): class Mpd():
def __init__(self,ip, yoda): def __init__(self, ip):
self.ip = ip self.ip = ip
self.port = 6600 self.port = 6600
self.client = MPDClient() # create client object self.client = MPDClient() # create client object
self.client.timeout = 10 # network timeout in seconds (floats allowed), default: None # network timeout in seconds (floats allowed), default: None
self.client.timeout = 10
self.client.idletimeout = None self.client.idletimeout = None
self.yoda = yoda
self.client.connect(ip, self.port) self.client.connect(ip, self.port)
logging.debug(self.client.mpd_version) logging.debug(self.client.mpd_version)
@ -32,26 +32,32 @@ class Mpd():
self.client.disconnect() self.client.disconnect()
def play_genre(self, genre): def play_genre(self, genre):
res = ""
self.client.connect(self.ip, self.port) self.client.connect(self.ip, self.port)
logging.debug('Listing '+self.normalize_genre(genre)) logging.debug('Listing '+self.normalize_genre(genre))
list_album = []
try: try:
# todo: select only one album instead of the artist # todo: select only one album instead of the artist
list_album = self.client.lsinfo("Files/Genres/"+self.normalize_genre(genre)) list_album = self.client.lsinfo(
except: "Subsonic/Genre/"+self.normalize_genre(genre))
list_album = []
if (list_album): if (list_album):
random_album = random.choice(list_album)["directory"] random_album = random.choice(list_album)["directory"]
self.play_album(random_album) self.play_album(random_album)
res = "C'est parti !"
else: else:
self.yoda.say("Je n\'ai rien trouvé") res = "Je n\'ai rien trouvé."
except:
res = "Il y a eu une erreur durant le lancement de la musique"
finally:
self.client.close() self.client.close()
self.client.disconnect() self.client.disconnect()
return res
def play_album(self, directory): def play_album(self, directory):
self.client.stop() self.client.stop()
self.yoda.say("Lancement de "+directory.split('/')[-1]) # self.yoda.say("Lancement de "+directory.split('/')[-1])
logging.debug('Playing '+directory) logging.debug('Playing '+directory)
self.client.clear() self.client.clear()
self.client.add(directory) self.client.add(directory)
@ -61,17 +67,31 @@ class Mpd():
return NORMALIZED_GENRE[genre.lower()] return NORMALIZED_GENRE[genre.lower()]
NORMALIZED_GENRE = { NORMALIZED_GENRE = {
'classique': 'classique', 'classique': 'Classique',
'musique classique': 'classique', 'musique classique': 'Classique',
'jazz': 'jazz', 'jazz': 'Jazz',
'chanson française': 'chanson francaise', 'chanson française': 'Chanson Francaise',
'chanson anglaise': 'chanson anglaise', 'francaise': 'Chanson Francaise',
'musique de film': 'musique de film', 'chanson anglaise': 'Chanson anglaise',
'électro': 'electro', 'reggae': 'Reggae',
'musique électronique': 'electro', 'folk': 'Folk',
'rock': 'rock', 'électro': 'Electro',
'pop': 'pop', 'musique électronique': 'Electro',
'chanson latine': 'chanson latine', 'punk': 'Punk',
'rock': 'Rock',
'pop': 'Pop',
'latine': 'Latin',
'arabe': 'Arabic',
'corse': 'Chants Corse',
'persane': 'Chanson Persane',
'piano': 'Piano',
'rap': 'Rap',
'slam': 'Slam',
'synthpop': 'Synthpop',
'jazz classique': 'Classique Jazz',
'jazz manouche': 'Jazz Manouche',
'blues': 'Blues',
'film': 'Soundtrack',
'musique de film': 'Soundtrack',
} }

View file

@ -11,6 +11,7 @@ from src.tools.parse_entities import parse_entities
from src.const.temperature_keyword import TEMPERATURE_KEYWORD from src.const.temperature_keyword import TEMPERATURE_KEYWORD
class Ratatouille(): class Ratatouille():
def __init__(self, yoda, walle, mopidy, schedule): def __init__(self, yoda, walle, mopidy, schedule):
@ -22,7 +23,7 @@ class Ratatouille():
# schedule.every().day.at(config.hibernate_time).do(self.hibernate) # schedule.every().day.at(config.hibernate_time).do(self.hibernate)
# schedule.every().day.at(config.wakeup_time).do(self.clear_hibernate) # schedule.every().day.at(config.wakeup_time).do(self.clear_hibernate)
# yoda.say('Ratatouille a bien démmaré') # yoda.say('Ratatouille a bien démmaré')
logging.info('loaded') # logging.info('loaded')
def parse_rhasspy_command(self, payload): def parse_rhasspy_command(self, payload):
command = payload['intent']['intentName'] command = payload['intent']['intentName']
@ -44,7 +45,7 @@ class Ratatouille():
self.yoda.say(response) self.yoda.say(response)
return return
def parseAlexa(self,payload): def parse_alexa(self, payload):
print(payload) print(payload)
intent = payload['intents'][0]['label'] intent = payload['intents'][0]['label']
entities = payload['entities'] entities = payload['entities']
@ -59,6 +60,20 @@ class Ratatouille():
return self.weather_query(parse_entities(entities)) return self.weather_query(parse_entities(entities))
return '42' return '42'
def parse_fuzzy(self, payload):
command = payload['intentName']
if command == 'GetTime':
return self.send_hour()
elif command == 'GetDate':
return self.send_date()
elif command == "LightOn":
return self.light_on_single(payload['intentArg'][0])
elif command == "LightOff":
return self.light_off_single(payload['intentArg'][0])
elif command == "PlayMusicGenre":
return self.mopidy.play_genre(payload['intentArg'][0])
return '42'
def weather_query(self, entities): def weather_query(self, entities):
if any(a in entities['B-weather_descriptor'] for a in TEMPERATURE_KEYWORD): if any(a in entities['B-weather_descriptor'] for a in TEMPERATURE_KEYWORD):
res = self.send_temperature() res = self.send_temperature()
@ -99,9 +114,30 @@ class Ratatouille():
def send_temperature(self): def send_temperature(self):
logging.info('Send temperature') logging.info('Send temperature')
data = self.walle.get('weather.toulouse') data = self.walle.get('weather.toulouse')
temperature = str(data['attributes']['temperature']).replace('.',' virgule ') temperature = str(data['attributes']['temperature']
).replace('.', ' virgule ')
return 'il fait '+temperature+' degrés' return 'il fait '+temperature+' degrés'
def light_off_single(self, lamp):
try:
self.walle.light_off(lamp)
except Exception as e:
logging.warning("Error light off:")
logging.warning(e)
return "J'ai pas pu éteindre la lampe '" + lamp + "'."
return "J'ai éteint la lampe '" + lamp + "'."
def light_on_single(self, lamp):
try:
self.walle.light_on(lamp)
except Exception as e:
logging.warning("Error light on:")
logging.warning(e)
return "J'ai pas pu allumer la lampe '" + lamp + "'."
return "J'ai allumé la lampe '" + lamp + "'."
def light_off(self, entities): def light_off(self, entities):
number_error = 0 number_error = 0
for lamp in entities: for lamp in entities:
@ -132,8 +168,8 @@ class Ratatouille():
else: else:
return 'J\'ai pas pu allumer ' + int_to_str(number_error, 'f') + ' lampes' return 'J\'ai pas pu allumer ' + int_to_str(number_error, 'f') + ' lampes'
# -- -- hibernate -- -- # -- -- hibernate -- --
def can_hibernate(self): def can_hibernate(self):
lamp_desk = self.walle.get('switch.prise_bureau_switch') lamp_desk = self.walle.get('switch.prise_bureau_switch')
if lamp_desk: if lamp_desk:
@ -153,11 +189,13 @@ class Ratatouille():
time.sleep(5) time.sleep(5)
else: else:
self.schedule.clear('hourly-hibernate') self.schedule.clear('hourly-hibernate')
self.schedule.every(3).minutes.do(self.hibernate).tag('hourly-hibernate') self.schedule.every(3).minutes.do(
self.hibernate).tag('hourly-hibernate')
logging.info('retry to hibernate in 3 minutes') logging.info('retry to hibernate in 3 minutes')
return return
def clear_hibernate(self): def clear_hibernate(self):
self.schedule.clear('hourly-hibernate') self.schedule.clear('hourly-hibernate')

View file

@ -1,7 +1,10 @@
from src.tools.str import int_to_str from src.tools.str import int_to_str
MONTHS = ['janvier','février','mars','avril','mai','juin','juillet','aout','septembre','octobre','novembre','decembre'] MONTHS = ['janvier', 'février', 'mars', 'avril', 'mai', 'juin',
WEEKDAY = ['lundi','mardi','mercredi','jeudi','vendredi','samedi','dimanche'] 'juillet', 'aout', 'septembre', 'octobre', 'novembre', 'decembre']
WEEKDAY = ['lundi', 'mardi', 'mercredi',
'jeudi', 'vendredi', 'samedi', 'dimanche']
def get_time(date): def get_time(date):
hour = date.hour hour = date.hour
@ -13,11 +16,12 @@ def get_time(date):
elif minute == 45: elif minute == 45:
return 'il est '+format_hour(hour+1)+' moins le quart' return 'il est '+format_hour(hour+1)+' moins le quart'
elif minute == 15: elif minute == 15:
return 'il est '+format_hour(hour+1)+' et quart' return 'il est '+format_hour(hour)+' et quart'
else: else:
return 'il est '+format_hour(hour)+' '+str(minute) return 'il est '+format_hour(hour)+' '+str(minute)
return return
def get_date(date): def get_date(date):
week_day = date.weekday() week_day = date.weekday()
day = date.day day = date.day
@ -26,6 +30,7 @@ def get_date(date):
return 'Nous somme le '+format_weekday(week_day)+' '+str(day)+' '+format_month(month)+' '+str(year) return 'Nous somme le '+format_weekday(week_day)+' '+str(day)+' '+format_month(month)+' '+str(year)
# self.yoda.say('Nous somme le '+str(week_day)+' '+str(day)+' '+str(month)+' '+str(year)) # self.yoda.say('Nous somme le '+str(week_day)+' '+str(day)+' '+str(month)+' '+str(year))
def format_hour(hour): def format_hour(hour):
if hour == 12: if hour == 12:
return 'midi' return 'midi'
@ -34,8 +39,10 @@ def format_hour(hour):
return int_to_str(hour, 'f') + ' heure' return int_to_str(hour, 'f') + ' heure'
def format_month(month): def format_month(month):
return MONTHS[month - 1] return MONTHS[month - 1]
def format_weekday(week_day): def format_weekday(week_day):
return WEEKDAY[week_day] return WEEKDAY[week_day]

View file

@ -0,0 +1,5 @@
def simple_sentence_corrector(sentence):
sentence = sentence.replace('étant', 'éteins')
sentence = sentence.replace('dépeint', 'éteins')
sentence = sentence.replace('mais', 'mets')
return sentence