From 0334f67eb2acb4815a8cc9666f00075f365c6164 Mon Sep 17 00:00:00 2001 From: Koffi Tino Gnagniko Date: Mon, 5 Feb 2024 18:07:06 +0100 Subject: [PATCH 1/3] Added more tests for AI service MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Marco Martin Härtl Co-authored-by: garvinkon Signed-off-by: Koffi Tino Gnagniko --- Backend/test/model/ai_ticket_service_test.py | 84 ++++++++++++++++++++ 1 file changed, 84 insertions(+) diff --git a/Backend/test/model/ai_ticket_service_test.py b/Backend/test/model/ai_ticket_service_test.py index a380cb3c..e4090cad 100644 --- a/Backend/test/model/ai_ticket_service_test.py +++ b/Backend/test/model/ai_ticket_service_test.py @@ -8,6 +8,7 @@ class TestAITicketService(unittest.TestCase): def setUp(self): self.ai_ticket_service = AITicketService() + # Existing mocks self.mock_title_pipeline = MagicMock() self.mock_title_pipeline.return_value = [{"generated_text": "Mocked Title"}] self.ai_ticket_service.title_generator_pipe = self.mock_title_pipeline @@ -26,6 +27,18 @@ def setUp(self): ] self.ai_ticket_service.generate_keywords = self.mock_keywords_pipeline + # Mocks for category + self.mock_category_pipeline = MagicMock() + self.ai_ticket_service.category_generator_pipe = self.mock_category_pipeline + + # Mock for service + self.mock_service_pipeline = MagicMock() + self.ai_ticket_service.service_generator_pipe = self.mock_service_pipeline + + # Mock for priority + self.mock_priority_pipeline = MagicMock() + self.ai_ticket_service.priority_generator_pipe = self.mock_priority_pipeline + def test_create_ticket(self): # Arrange input_text = "Sample input text" @@ -61,3 +74,74 @@ def test_generate_affected_person(self): # Assert self.assertEqual(ticket_dict["affectedPerson"], "John") + + def test_generate_prediction_success(self): + # Arrange + input_text = "Sample input text for a successful prediction" + ticket_dict = {} + field = "category" + field_values = self.ai_ticket_service.category_values + + # Mocking the pipeline to return a high confidence prediction + self.mock_category_pipeline.return_value = [{"label": "LABEL_0", "score": 0.9}] + self.ai_ticket_service.category_generator_pipe = self.mock_category_pipeline + + # Act + self.ai_ticket_service.generate_prediction( + input_text, + self.ai_ticket_service.category_generator_pipe, + field, + field_values, + ticket_dict, + ) + + # Assert + self.assertEqual( + ticket_dict[field], field_values[0] + ) # Assuming LABEL_0 maps to the first category + + def test_generate_prediction_low_confidence(self): + # Arrange + input_text = "Sample input text for a low confidence prediction" + ticket_dict = {} + field = "service" + field_values = self.ai_ticket_service.service_values + + # Mocking the pipeline to return a low confidence prediction + self.mock_service_pipeline.return_value = [{"label": "LABEL_1", "score": 0.4}] + self.ai_ticket_service.service_generator_pipe = self.mock_service_pipeline + + # Act + self.ai_ticket_service.generate_prediction( + input_text, + self.ai_ticket_service.service_generator_pipe, + field, + field_values, + ticket_dict, + ) + + # Assert + self.assertIsNone(ticket_dict[field]) + + def test_generate_prediction_empty_output(self): + # Arrange + input_text = "Sample input text for an empty prediction" + ticket_dict = {} + field = "priority" + field_values = self.ai_ticket_service.priority_values + + # Mocking the pipeline to return an empty output + self.mock_priority_pipeline.return_value = [] + self.ai_ticket_service.priority_generator_pipe = self.mock_priority_pipeline + + # Act + self.ai_ticket_service.generate_prediction( + input_text, + self.ai_ticket_service.priority_generator_pipe, + field, + field_values, + ticket_dict, + ) + + # Assert + self.assertIsNone(ticket_dict[field]) From 9bb508fddfeb3e8b4725a1968261a6c0938ed3c1 Mon Sep 17 00:00:00 2001 From: Koffi Tino Gnagniko Date: Mon, 5 Feb 2024 21:06:19 +0100 Subject: [PATCH 2/3] Added more tests ticket api MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Marco Martin Härtl Co-authored-by: garvinkon Signed-off-by: Koffi Tino Gnagniko --- Backend/test/api/v1/ticket_api_test.py | 244 ++++++++++++++++++++++++- 1 file changed, 240 insertions(+), 4 deletions(-) diff --git a/Backend/test/api/v1/ticket_api_test.py b/Backend/test/api/v1/ticket_api_test.py index 96418e88..c8ff15f0 100644 --- a/Backend/test/api/v1/ticket_api_test.py +++ b/Backend/test/api/v1/ticket_api_test.py @@ -4,7 +4,9 @@ from unittest.mock import MagicMock from app.api.dto.ticket import Ticket +from app.api.dto.user import User from app.dependency.collection import get_ticket_collection, get_user_collection +from app.dependency.db_service import get_user_db_service from app.dependency.email_service import get_email_service from app.enum.customer_prio import CustomerPrio from app.enum.prio import Prio @@ -21,8 +23,8 @@ class TicketAPIIntegrationTest(TestCase): def override_get_ticket_collection(self): return self.ticket_collection_mock - def override_get_user_collection(self): - return self.user_collection_mock + def override_get_user_db_service(self): + return self.user_db_service_mock def override_get_email_service(self): return self.email_service_mock @@ -30,11 +32,11 @@ def override_get_email_service(self): def setUp(self): self.client = TestClient(app) self.ticket_collection_mock = MagicMock() - self.user_collection_mock = MagicMock() + self.user_db_service_mock = MagicMock() self.email_service_mock = MagicMock() app.dependency_overrides = { get_ticket_collection: self.override_get_ticket_collection, - get_user_collection: self.override_get_user_collection, + get_user_db_service: self.override_get_user_db_service, get_email_service: self.override_get_email_service, } self.ticket_id = ObjectId("6554b34d82161e93bff08df6") @@ -305,6 +307,240 @@ def test_update_ticket_attributes_invalid_ticket_id(self): self.assertEqual(status.HTTP_400_BAD_REQUEST, response.status_code) self.assertEqual(exp_json, response.json()) + def test_process_text_with_email_and_user_data_success(self): + # Define + email_input = "test@example.com" + user_mock = User( + email_address=email_input, + location="Test Location", + first_name="John", + family_name="Doe", + password="", + ) + ticket_entity = TicketEntity( + _id=self.ticket_id, + title="Test Ticket", + service="Test Location", + category="", + keywords=[], + customerPriority=CustomerPrio.can_work, + affectedPerson="John Doe", + description="", + priority=Prio.low, + attachments=[], + requestType="", + state=State.draft, + ) + insert_one_result = InsertOneResult( + inserted_id=self.ticket_id, acknowledged=True + ) + text_input = {"text": "Hello from the test!", "email": email_input} + exp_ticket = Ticket( + id=str(self.ticket_id), + title="Test Ticket", + service="Test Location", + category="", + keywords=[], + customerPriority=CustomerPrio.can_work, + affectedPerson="John Doe", + description="", + priority=Prio.low, + attachmentNames=[], + requestType="", + state=State.draft, + ) + + # Define mocking behavior + self.user_db_service_mock.get_user_by_email.return_value = user_mock + self.ticket_collection_mock.insert_one.return_value = insert_one_result + self.ticket_collection_mock.find.return_value = [ticket_entity] + + # Act + response = self._run_process_text_endpoint(text_input) + + # Assert + # Mocks + self.user_db_service_mock.get_user_by_email.assert_called_with(email_input) + self.ticket_collection_mock.insert_one.assert_called_once() + self.ticket_collection_mock.find.assert_called_once() + # Response + self.assertEqual(status.HTTP_201_CREATED, response.status_code) + self.assertEqual(exp_ticket, Ticket.parse_obj(response.json())) + + def test_process_text_with_email_and_only_first_name(self): + # Define + email_input = "test@example.com" + user_mock = User( + email_address=email_input, + location="Test Location", + first_name="John", + family_name="", + password="", + ) + ticket_entity = TicketEntity( + _id=self.ticket_id, + title="Test Ticket", + service="Test Location", + category="", + keywords=[], + customerPriority=CustomerPrio.can_work, + affectedPerson="John", + description="", + priority=Prio.low, + attachments=[], + requestType="", + state=State.draft, + ) + insert_one_result = InsertOneResult( + inserted_id=self.ticket_id, acknowledged=True + ) + text_input = {"text": "Hello from the test!", "email": email_input} + exp_ticket = Ticket( + id=str(self.ticket_id), + title="Test Ticket", + service="Test Location", + category="", + keywords=[], + customerPriority=CustomerPrio.can_work, + affectedPerson="John", + description="", + priority=Prio.low, + attachmentNames=[], + requestType="", + state=State.draft, + ) + + # Define mocking behavior + self.user_db_service_mock.get_user_by_email.return_value = user_mock + self.ticket_collection_mock.insert_one.return_value = insert_one_result + self.ticket_collection_mock.find.return_value = [ticket_entity] + + # Act + response = self._run_process_text_endpoint(text_input) + + # Assert + self.user_db_service_mock.get_user_by_email.assert_called_with(email_input) + self.ticket_collection_mock.insert_one.assert_called_once() + self.ticket_collection_mock.find.assert_called_once() + self.assertEqual(status.HTTP_201_CREATED, response.status_code) + self.assertEqual(exp_ticket, Ticket.parse_obj(response.json())) + + def test_process_text_with_email_and_only_family_name(self): + # Define + email_input = "test@example.com" + user_mock = User( + email_address=email_input, + location="Test Location", + first_name="", + family_name="Doe", + password="", + ) + ticket_entity = TicketEntity( + _id=self.ticket_id, + title="Test Ticket", + service="", + category="", + keywords=[], + customerPriority=CustomerPrio.can_work, + affectedPerson="Doe", + description="", + priority=Prio.low, + attachments=[], + requestType="", + state=State.draft, + ) + insert_one_result = InsertOneResult( + inserted_id=self.ticket_id, acknowledged=True + ) + text_input = {"text": "Hello from the test!", "email": email_input} + exp_ticket = Ticket( + id=str(self.ticket_id), + title="Test Ticket", + service="", + category="", + keywords=[], + customerPriority=CustomerPrio.can_work, + affectedPerson="Doe", + description="", + priority=Prio.low, + attachmentNames=[], + requestType="", + state=State.draft, + ) + + # Define mocking behavior + self.user_db_service_mock.get_user_by_email.return_value = user_mock + self.ticket_collection_mock.insert_one.return_value = insert_one_result + self.ticket_collection_mock.find.return_value = [ticket_entity] + + # Act + response = self._run_process_text_endpoint(text_input) + + # Assert + self.user_db_service_mock.get_user_by_email.assert_called_with(email_input) + self.ticket_collection_mock.insert_one.assert_called_once() + self.ticket_collection_mock.find.assert_called_once() + self.assertEqual(status.HTTP_201_CREATED, response.status_code) + self.assertEqual(exp_ticket, Ticket.parse_obj(response.json())) + + def test_process_text_with_email_and_no_name(self): + # Define + email_input = "test@example.com" + user_mock = User( + email_address=email_input, + location="Test Location", + first_name="", + family_name="", # No name is provided + password="", + ) + ticket_entity = TicketEntity( + _id=self.ticket_id, + title="Test Ticket", + service="", + category="", + keywords=[], + customerPriority=CustomerPrio.can_work, + affectedPerson="", + description="", + priority=Prio.low, + attachments=[], + requestType="", + state=State.draft, + ) + insert_one_result = InsertOneResult( + inserted_id=self.ticket_id, acknowledged=True + ) + text_input = {"text": "Hello from the test!", "email": email_input} + exp_ticket = Ticket( + id=str(self.ticket_id), + title="Test Ticket", + service="", + category="", + keywords=[], + customerPriority=CustomerPrio.can_work, + affectedPerson="", # Expected to remain unset + description="", + priority=Prio.low, + attachmentNames=[], + requestType="", + state=State.draft, + ) + + # Define mocking behavior + self.user_db_service_mock.get_user_by_email.return_value = user_mock + self.ticket_collection_mock.insert_one.return_value = insert_one_result + self.ticket_collection_mock.find.return_value = [ticket_entity] + + # Act + response = self._run_process_text_endpoint(text_input) + + # Assert + self.user_db_service_mock.get_user_by_email.assert_called_with(email_input) + self.ticket_collection_mock.insert_one.assert_called_once() + self.ticket_collection_mock.find.assert_called_once() + self.assertEqual(status.HTTP_201_CREATED, response.status_code) + self.assertEqual(exp_ticket, Ticket.parse_obj(response.json())) + def _run_process_text_endpoint(self, text_input: dict): return self.client.post( "/api/v1/ticket/text", From b6b06fdf5d68333a3921e1d87ad8c7edf3e39a13 Mon Sep 17 00:00:00 2001 From: Koffi Tino Gnagniko Date: Mon, 5 Feb 2024 22:09:50 +0100 Subject: [PATCH 3/3] Adjust ai ticket service and add instruction for using training scripts --- Backend/README.md | 34 ++- .../ai_ticket_service/ai_ticket_service.py | 5 + .../model/train/text_classification/train.py | 251 ++++-------------- 3 files changed, 79 insertions(+), 211 deletions(-) diff --git a/Backend/README.md b/Backend/README.md index 0ecc85d9..a4cef01c 100644 --- a/Backend/README.md +++ b/Backend/README.md @@ -99,25 +99,49 @@ ## Run Test Model -1. Navigate into `./backend/app/models/t5` directory. +### T5 -2. Train the model by running the following command: +1. Train the model by navigating to `./backend/app/models/train/text_generation_t5` and running the following command: ```bash python train.py ``` -3. Test the trained model by running the following command: +2. Test our trained model which is hosted on hugging face by navigating to `./backend/app/models/ai_ticket_service/t5` and running the following command: ```bash python use_trained_t5_model.py ``` -4. Test the untrained T5 model by running the following command: +3. Test the untrained T5 model by navigating to `./backend/app/models/ai_ticket_service/t5` and running the following command: ```bash - python train.py + python use_untrained_t5_model.py + ``` + +### Text Classification (Roberta) +1. To train the model, you first need to navigate to the training script's directory: + + ```bash + cd ./backend/app/models/train/text_classification + ``` +2. Prepare a JSON-formatted string for both your classes and data paths. For example, if you have three classes "Class1", "Class2", "Class3" and your data is located at "path/to/data1.json" and "path/to/data2.json", your JSON strings would look like: + - Classes: '["Class1", "Class2", "Class3"]' + - Data Paths: '["path/to/data1.json", "path/to/data2.json"]' + +3. Run the training command with the necessary arguments. Here's an example command that includes arguments for classes, ticket field, and data paths: + + ```bash + python train.py --batch_size 4 --epochs 4 --lr 2e-5 --no_cuda --save_model --classes '["Class1", "Class2", "Class3"]' --ticket_field "service" --data_paths '["path/to/data1.json", "path/to/data2.json"]' ``` + - `--batch_size`: The number of training samples to work through before the model's internal parameters are updated. + - `--epochs`: The number of complete passes through the training dataset. + - `--lr`: The learning rate used by the optimizer. + - `--no_cuda`: Add this flag if you do not wish to use CUDA for training even if it's available. + - `--save_model`: Add this flag if you wish to save the model after training. + - `--classes`: The list of classes for the classifier in a JSON-formatted string. + - `--ticket_field`: The field name for ticket classification. + - `--data_paths`: The list of paths to your training data files in a JSON-formatted string. ## Run the Email Proxy diff --git a/Backend/app/model/ai_ticket_service/ai_ticket_service.py b/Backend/app/model/ai_ticket_service/ai_ticket_service.py index 3d34bd69..d0d3b2e8 100644 --- a/Backend/app/model/ai_ticket_service/ai_ticket_service.py +++ b/Backend/app/model/ai_ticket_service/ai_ticket_service.py @@ -242,6 +242,11 @@ def generate_prediction(self, input_text, pipe, field, field_values, ticket_dict field ) ) + if field is "priority": + prediction = "Medium" + + if field is "customerPriority": + prediction = "Disruption but can work" else: prediction = self.map_label_to_class( generated_output[0]["label"], field_values diff --git a/Backend/app/model/train/text_classification/train.py b/Backend/app/model/train/text_classification/train.py index ee672f38..ed90fd44 100644 --- a/Backend/app/model/train/text_classification/train.py +++ b/Backend/app/model/train/text_classification/train.py @@ -1,4 +1,5 @@ import argparse +import json import os import time import torch @@ -16,26 +17,26 @@ def parse_args(): + # input and output arguments parser = argparse.ArgumentParser( - description="Train a neural network to diffuse images" + description="Train a neural network for text classification" ) parser.add_argument( "--batch_size", type=int, default=4, - help="input batch size for training (default: 64)", + help="input batch size for training (default: 4)", ) parser.add_argument( - "--epochs", type=int, default=4, help="number of epochs to train (default: 5)" + "--epochs", type=int, default=4, help="number of epochs to train (default: 4)" ) parser.add_argument( - "--lr", type=float, default=2e-5, help="learning rate (default: 0.003)" + "--lr", type=float, default=2e-5, help="learning rate (default: 0.00002)" ) - # parser.add_argument('--momentum', type=float, default=0.9, help='SGD momentum (default: 0.9)') + parser.add_argument( "--no_cuda", action="store_true", default=False, help="disables CUDA training" ) - # parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)') parser.add_argument( "--log_interval", type=int, @@ -48,173 +49,36 @@ def parse_args(): default=True, help="For Saving the current Model", ) + parser.add_argument( + "--classes", + type=json.loads, # Expect a JSON-formatted string for classes + required=True, + help='JSON-formatted string list of classes (e.g., \'["Class1", "Class2"]\')', + ) + parser.add_argument( + "--ticket_field", + type=str, + required=True, + help="Field name for the ticket classification", + ) + parser.add_argument( + "--data_paths", + type=json.loads, # Expect a JSON-formatted string for data paths + required=True, + help='JSON-formatted string list of paths to training data (e.g., \'["path1", "path2"]\')', + ) return parser.parse_args() args = parse_args() -# classes = ["Incident", "Service Request"] -# ticket_field = "requestType" -# classes = ['SAP ERP', 'Atlassian', 'Adobe', 'Salesforce', 'Reporting', 'Microsoft Power Platform', 'Microsoft SharePoint', 'Snowflake', 'Microsoft Office'] -# ticket_field = "service" -classes = [ - "HANA -> Technical Issues", - "HANA -> Billing & Payment", - "HANA -> Product Inquiries", - "HANA -> Account Management", - "HANA -> Policy Questions", - "Business One -> Technical Issues", - "Business One -> Billing & Payment", - "Business One -> Product Inquiries", - "Business One -> Account Management", - "Business One -> Policy Questions", - "Jira -> Technical Issues", - "Jira -> Billing & Payment", - "Jira -> Product Inquiries", - "Jira -> Account Management", - "Jira -> Policy Questions", - "Sourcetree -> Technical Issues", - "Sourcetree -> Billing & Payment", - "Sourcetree -> Product Inquiries", - "Sourcetree -> Account Management", - "Sourcetree -> Policy Questions", - "Opsgenie -> Technical Issues", - "Opsgenie -> Billing & Payment", - "Opsgenie -> Product Inquiries", - "Opsgenie -> Account Management", - "Opsgenie -> Policy Questions", - "Trello -> Technical Issues", - "Trello -> Billing & Payment", - "Trello -> Product Inquiries", - "Trello -> Account Management", - "Trello -> Policy Questions", - "Illustrator -> Technical Issues", - "Illustrator -> Billing & Payment", - "Illustrator -> Product Inquiries", - "Illustrator -> Account Management", - "Illustrator -> Policy Questions", - "Photoshop -> Technical Issues", - "Photoshop -> Billing & Payment", - "Photoshop -> Product Inquiries", - "Photoshop -> Account Management", - "Photoshop -> Policy Questions", - "InDesign -> Technical Issues", - "InDesign -> Billing & Payment", - "InDesign -> Product Inquiries", - "InDesign -> Account Management", - "InDesign -> Policy Questions", - "Premiere -> Technical Issues", - "Premiere -> Billing & Payment", - "Premiere -> Product Inquiries", - "Premiere -> Account Management", - "Premiere -> Policy Questions", - "Apex -> Technical Issues", - "Apex -> Billing & Payment", - "Apex -> Product Inquiries", - "Apex -> Account Management", - "Apex -> Policy Questions", - "Trailhead -> Technical Issues", - "Trailhead -> Billing & Payment", - "Trailhead -> Product Inquiries", - "Trailhead -> Account Management", - "Trailhead -> Policy Questions", - "Visualforce -> Technical Issues", - "Visualforce -> Billing & Payment", - "Visualforce -> Product Inquiries", - "Visualforce -> Account Management", - "Visualforce -> Policy Questions", - "Sales Cloud -> Technical Issues", - "Sales Cloud -> Billing & Payment", - "Sales Cloud -> Product Inquiries", - "Sales Cloud -> Account Management", - "Sales Cloud -> Policy Questions", - "Tableau -> Technical Issues", - "Tableau -> Billing & Payment", - "Tableau -> Product Inquiries", - "Tableau -> Account Management", - "Tableau -> Policy Questions", - "Microsoft PowerBI -> Technical Issues", - "Microsoft PowerBI -> Billing & Payment", - "Microsoft PowerBI -> Product Inquiries", - "Microsoft PowerBI -> Account Management", - "Microsoft PowerBI -> Policy Questions", - "Datasource -> Technical Issues", - "Datasource -> Billing & Payment", - "Datasource -> Product Inquiries", - "Datasource -> Account Management", - "Datasource -> Policy Questions", - "DataFlow -> Technical Issues", - "DataFlow -> Billing & Payment", - "DataFlow -> Product Inquiries", - "DataFlow -> Account Management", - "DataFlow -> Policy Questions", - "Microsoft Power Apps -> Technical Issues", - "Microsoft Power App -> Billing & Payment", - "Microsoft Power App -> Product Inquiries", - "Microsoft Power App -> Account Management", - "Microsoft Power App -> Policy Questions", - "Microsoft Power BI -> Technical Issues", - "Microsoft Power BI -> Billing & Payment", - "Microsoft Power BI -> Product Inquiries", - "Microsoft Power BI -> Account Management", - "Microsoft Power BI -> Policy Questions", - "Microsoft Power Pages Automate -> Technical Issues", - "Microsoft Power Pages Automate -> Billing & Payment", - "Microsoft Power Pages Automate -> Product Inquiries", - "Microsoft Power Pages Automate -> Account Management", - "Microsoft Power Pages Automate -> Policy Questions", - "Microsoft SharePoint -> Technical Issues", - "Microsoft SharePoint -> Billing & Payment", - "Microsoft SharePoint -> Product Inquiries", - "Microsoft SharePoint -> Account Management", - "Microsoft SharePoint -> Policy Questions", - "SharePoint -> Technical Issues", - "SharePoint -> Billing & Payment", - "SharePoint -> Product Inquiries", - "SharePoint -> Account Management", - "SharePoint -> Policy Questions", - "SharePoint List -> Technical Issues", - "SharePoint List -> Billing & Payment", - "SharePoint List -> Product Inquiries", - "SharePoint List -> Account Management", - "SharePoint List -> Policy Questions", - "SharePoint Document Library -> Technical Issues", - "SharePoint Document Library -> Billing & Payment", - "SharePoint Document Library -> Product Inquiries", - "SharePoint Document Library -> Account Management", - "SharePoint Document Library -> Policy Questions", - "Snowflake -> Technical Issues", - "Snowflake -> Billing & Payment", - "Snowflake -> Product Inquiries", - "Snowflake -> Account Management", - "Snowflake -> Policy Question", - "SnowSQL -> Technical Issues", - "SnowSQL -> Billing & Payment", - "SnowSQL -> Product Inquiries", - "SnowSQL -> Account Management", - "SnowSQL -> Policy Question", - "Microsoft Office -> Technical Issues", - "Microsoft Office -> Billing & Payment", - "Microsoft Office -> Product Inquiries", - "Microsoft Office -> Account Management", - "Microsoft Office -> Policy Questions", - "Microsoft Word -> Technical Issues", - "Microsoft Word -> Billing & Payment", - "Microsoft Word -> Product Inquiries", - "Microsoft Word -> Account Management", - "Microsoft Word -> Policy Questions", - "Microsoft Excel -> Technical Issues", - "Microsoft Excel -> Billing & Payment", - "Microsoft Excel -> Product Inquiries", - "Microsoft Excel -> Account Management", - "Microsoft Excel -> Policy Questions", - "Microsoft PowerPoint -> Technical Issues", - "Microsoft PowerPoint -> Billing & Payment", - "Microsoft PowerPoint -> Product Inquiries", - "Microsoft PowerPoint -> Account Management", - "Microsoft PowerPoint -> Policy Questions", -] -ticket_field = "category" +################### +# +################### + +classes = args.classes +ticket_field = args.ticket_field +data_paths = args.data_paths model_name = "roberta-base" model = RobertaForSequenceClassification.from_pretrained( @@ -224,22 +88,6 @@ def parse_args(): root_directory = os.path.dirname(__file__) -test_dir = os.path.join(root_directory, "..", "test_data") -data_paths = [ - # os.path.join(test_dir, "test_data_garvin", "data_translated.json"), - # os.path.join(test_dir, "test_data_irild", "data_translated.json"), - # os.path.join(test_dir, "test_data_fabian", "data_translated.json"), - # os.path.join(test_dir, "test_data_sajjad", "data_translated.json"), - # os.path.join(test_dir, "test_data_marco", "data_translated.json"), - # os.path.join(test_dir, "test_data_tino", "data_translated.json"), - os.path.join(test_dir, "test_data_with_gpt", "data_1.json"), - # os.path.join(test_dir, "test_data_with_gpt", "data_2.json"), - # os.path.join(test_dir, "test_data_with_gpt", "data_3.json"), - # os.path.join(test_dir, "test_data_with_gpt", "data_4.json"), - # os.path.join(test_dir, "test_data_with_gpt", "data_5.json"), - # os.path.join(test_dir, "test_data_with_gpt", "test_data.json") -] - # Create datasets custom_dataset = CustomDataset(tokenizer, data_paths, ticket_field, classes) train_set, val_set = torch.utils.data.random_split( @@ -250,6 +98,12 @@ def parse_args(): ], ) +num_train_data = len(train_set) + +################## +# +################## + # Define data loaders train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True) valid_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False) @@ -298,7 +152,6 @@ def parse_args(): total_valid_loss = 0 val_true_labels = [] val_predictions = [] - val_confidence_scores = [] with torch.no_grad(): for batch in valid_loader: @@ -316,11 +169,6 @@ def parse_args(): loss = torch.nn.functional.cross_entropy(logits, labels) total_valid_loss += loss.item() - # Calculate confidence scores - softmax_scores = F.softmax(logits, dim=1) - max_confidence_scores = torch.max(softmax_scores, dim=1) - val_confidence_scores.extend(max_confidence_scores.values.cpu().numpy()) - average_valid_loss = total_valid_loss / len(valid_loader) valid_losses.append(average_valid_loss) @@ -352,20 +200,11 @@ def parse_args(): plt.savefig(f"training_validation_loss_curve_{ticket_field}.png") plt.close() -# Visualization of Confidence Scores -plt.figure(figsize=(10, 6)) -plt.hist(val_confidence_scores, bins=50, alpha=0.7, color="blue") -plt.title(f"Confidence Score Distribution for {ticket_field}") -plt.xlabel("Confidence Score") -plt.ylabel("Number of Predictions") -plt.grid(True) - -# Save Confidence Scores visualization as image -confidence_score_filename = f"confidence_scores_distribution_{ticket_field}.png" -plt.savefig(confidence_score_filename) -plt.close() - # Save the fine-tuned model -# if args.save_model: -model.save_pretrained("fine_tuned_roberta_model_" + ticket_field + "28k_data") -tokenizer.save_pretrained("fine_tuned_roberta_model_" + ticket_field + "28k_data") +if args.save_model: + model.save_pretrained( + "fine_tuned_roberta_model_" + ticket_field + str(num_train_data) + ) + tokenizer.save_pretrained( + "fine_tuned_roberta_model_" + ticket_field + str(num_train_data) + )