ratatouille/src/intent.py

70 lines
2.7 KiB
Python

import unittest
from transformers import AutoTokenizer, AutoModelForTokenClassification, TokenClassificationPipeline
from transformers import AutoModelForSequenceClassification, TextClassificationPipeline
DOMOTIQUE_OBJ_ON = ['iot_wemo_on']
DOMOTIQUE_OBJ_OFF = ['iot_wemo_off']
class AlexaIntent():
def __init__(self):
self.get_intents = self.init_intent_classification()
self.get_entities = self.init_entities_classification()
def init_intent_classification(self):
model_name = 'qanastek/XLMRoberta-Alexa-Intents-Classification'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
classifier = TextClassificationPipeline(model=model, tokenizer=tokenizer)
return classifier
def init_entities_classification(self):
tokenizer = AutoTokenizer.from_pretrained('qanastek/XLMRoberta-Alexa-Intents-NER-NLU')
model = AutoModelForTokenClassification.from_pretrained('qanastek/XLMRoberta-Alexa-Intents-NER-NLU')
predict = TokenClassificationPipeline(model=model, tokenizer=tokenizer)
return predict
def simple_sentence_corrector(self,sentence):
sentence = sentence.replace('étant','éteins')
sentence = sentence.replace('dépeint','éteins')
return sentence
def intent_corrector(self,intents):
for intent in intents:
if intent['label'] in DOMOTIQUE_OBJ_ON:
intent['label'] = 'iot_hue_lighton'
if intent['label'] in DOMOTIQUE_OBJ_OFF:
intent['label'] = 'iot_hue_lightoff'
return intents
def predict(self,sentence):
sentence = self.simple_sentence_corrector(sentence)
return {
'intents': self.intent_corrector(self.get_intents(sentence)),
'entities': self.get_entities(sentence)
}
class TestAlexa(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.alexa = AlexaIntent()
def test_lampe(self):
res = self.alexa.predict("éteins la cuisine")
self.assertEqual(res['intents'][0]['label'], 'iot_hue_lightoff')
self.assertEqual(res['entities'][0]['word'], '▁cuisine')
res = self.alexa.predict("allume la cuisine")
self.assertEqual(res['intents'][0]['label'], 'iot_hue_lighton')
self.assertEqual(res['entities'][0]['word'], '▁cuisine')
def test_bad_transcribe(self):
res = self.alexa.predict("dépeint la cuisine")
self.assertEqual(res['intents'][0]['label'], 'iot_hue_lightoff')
self.assertEqual(res['entities'][0]['word'], '▁cuisine')
if __name__ == '__main__':
unittest.main()