70 lines
2.7 KiB
Python
70 lines
2.7 KiB
Python
import unittest
|
|
from transformers import AutoTokenizer, AutoModelForTokenClassification, TokenClassificationPipeline
|
|
from transformers import AutoModelForSequenceClassification, TextClassificationPipeline
|
|
|
|
DOMOTIQUE_OBJ_ON = ['iot_wemo_on']
|
|
DOMOTIQUE_OBJ_OFF = ['iot_wemo_off']
|
|
|
|
class AlexaIntent():
|
|
|
|
def __init__(self):
|
|
self.get_intents = self.init_intent_classification()
|
|
self.get_entities = self.init_entities_classification()
|
|
|
|
def init_intent_classification(self):
|
|
model_name = 'qanastek/XLMRoberta-Alexa-Intents-Classification'
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
|
classifier = TextClassificationPipeline(model=model, tokenizer=tokenizer)
|
|
return classifier
|
|
|
|
def init_entities_classification(self):
|
|
tokenizer = AutoTokenizer.from_pretrained('qanastek/XLMRoberta-Alexa-Intents-NER-NLU')
|
|
model = AutoModelForTokenClassification.from_pretrained('qanastek/XLMRoberta-Alexa-Intents-NER-NLU')
|
|
predict = TokenClassificationPipeline(model=model, tokenizer=tokenizer)
|
|
return predict
|
|
|
|
def simple_sentence_corrector(self,sentence):
|
|
sentence = sentence.replace('étant','éteins')
|
|
sentence = sentence.replace('dépeint','éteins')
|
|
return sentence
|
|
|
|
def intent_corrector(self,intents):
|
|
for intent in intents:
|
|
if intent['label'] in DOMOTIQUE_OBJ_ON:
|
|
intent['label'] = 'iot_hue_lighton'
|
|
if intent['label'] in DOMOTIQUE_OBJ_OFF:
|
|
intent['label'] = 'iot_hue_lightoff'
|
|
return intents
|
|
|
|
def predict(self,sentence):
|
|
sentence = self.simple_sentence_corrector(sentence)
|
|
return {
|
|
'intents': self.intent_corrector(self.get_intents(sentence)),
|
|
'entities': self.get_entities(sentence)
|
|
}
|
|
|
|
|
|
|
|
class TestAlexa(unittest.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
cls.alexa = AlexaIntent()
|
|
|
|
def test_lampe(self):
|
|
res = self.alexa.predict("éteins la cuisine")
|
|
self.assertEqual(res['intents'][0]['label'], 'iot_hue_lightoff')
|
|
self.assertEqual(res['entities'][0]['word'], '▁cuisine')
|
|
|
|
res = self.alexa.predict("allume la cuisine")
|
|
self.assertEqual(res['intents'][0]['label'], 'iot_hue_lighton')
|
|
self.assertEqual(res['entities'][0]['word'], '▁cuisine')
|
|
|
|
|
|
def test_bad_transcribe(self):
|
|
res = self.alexa.predict("dépeint la cuisine")
|
|
self.assertEqual(res['intents'][0]['label'], 'iot_hue_lightoff')
|
|
self.assertEqual(res['entities'][0]['word'], '▁cuisine')
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main() |