Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/add more tests #239

Merged
merged 3 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 29 additions & 5 deletions Backend/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,25 +99,49 @@

## Run Test Model

1. Navigate into `./backend/app/models/t5` directory.
### T5

2. Train the model by running the following command:
1. Train the model by navigating to `./backend/app/models/train/text_generation_t5` and running the following command:

```bash
python train.py
```

3. Test the trained model by running the following command:
2. Test our trained model which is hosted on hugging face by navigating to `./backend/app/models/ai_ticket_service/t5` and running the following command:

```bash
python use_trained_t5_model.py
```

4. Test the untrained T5 model by running the following command:
3. Test the untrained T5 model by navigating to `./backend/app/models/ai_ticket_service/t5` and running the following command:

```bash
python train.py
python use_untrained_t5_model.py
```

### Text Classification (Roberta)
1. To train the model, you first need to navigate to the training script's directory:

```bash
cd ./backend/app/models/train/text_classification
```
2. Prepare a JSON-formatted string for both your classes and data paths. For example, if you have three classes "Class1", "Class2", "Class3" and your data is located at "path/to/data1.json" and "path/to/data2.json", your JSON strings would look like:
- Classes: '["Class1", "Class2", "Class3"]'
- Data Paths: '["path/to/data1.json", "path/to/data2.json"]'

3. Run the training command with the necessary arguments. Here's an example command that includes arguments for classes, ticket field, and data paths:

```bash
python train.py --batch_size 4 --epochs 4 --lr 2e-5 --no_cuda --save_model --classes '["Class1", "Class2", "Class3"]' --ticket_field "service" --data_paths '["path/to/data1.json", "path/to/data2.json"]'
```
- `--batch_size`: The number of training samples to work through before the model's internal parameters are updated.
- `--epochs`: The number of complete passes through the training dataset.
- `--lr`: The learning rate used by the optimizer.
- `--no_cuda`: Add this flag if you do not wish to use CUDA for training even if it's available.
- `--save_model`: Add this flag if you wish to save the model after training.
- `--classes`: The list of classes for the classifier in a JSON-formatted string.
- `--ticket_field`: The field name for ticket classification.
- `--data_paths`: The list of paths to your training data files in a JSON-formatted string.

## Run the Email Proxy

Expand Down
5 changes: 5 additions & 0 deletions Backend/app/model/ai_ticket_service/ai_ticket_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,11 @@ def generate_prediction(self, input_text, pipe, field, field_values, ticket_dict
field
)
)
if field is "priority":
prediction = "Medium"

if field is "customerPriority":
prediction = "Disruption but can work"
else:
prediction = self.map_label_to_class(
generated_output[0]["label"], field_values
Expand Down
251 changes: 45 additions & 206 deletions Backend/app/model/train/text_classification/train.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import json
import os
import time
import torch
Expand All @@ -16,26 +17,26 @@


def parse_args():
# input and output arguments
parser = argparse.ArgumentParser(
description="Train a neural network to diffuse images"
description="Train a neural network for text classification"
)
parser.add_argument(
"--batch_size",
type=int,
default=4,
help="input batch size for training (default: 64)",
help="input batch size for training (default: 4)",
)
parser.add_argument(
"--epochs", type=int, default=4, help="number of epochs to train (default: 5)"
"--epochs", type=int, default=4, help="number of epochs to train (default: 4)"
)
parser.add_argument(
"--lr", type=float, default=2e-5, help="learning rate (default: 0.003)"
"--lr", type=float, default=2e-5, help="learning rate (default: 0.00002)"
)
# parser.add_argument('--momentum', type=float, default=0.9, help='SGD momentum (default: 0.9)')

parser.add_argument(
"--no_cuda", action="store_true", default=False, help="disables CUDA training"
)
# parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)')
parser.add_argument(
"--log_interval",
type=int,
Expand All @@ -48,173 +49,36 @@ def parse_args():
default=True,
help="For Saving the current Model",
)
parser.add_argument(
"--classes",
type=json.loads, # Expect a JSON-formatted string for classes
required=True,
help='JSON-formatted string list of classes (e.g., \'["Class1", "Class2"]\')',
)
parser.add_argument(
"--ticket_field",
type=str,
required=True,
help="Field name for the ticket classification",
)
parser.add_argument(
"--data_paths",
type=json.loads, # Expect a JSON-formatted string for data paths
required=True,
help='JSON-formatted string list of paths to training data (e.g., \'["path1", "path2"]\')',
)
return parser.parse_args()


args = parse_args()

# classes = ["Incident", "Service Request"]
# ticket_field = "requestType"
# classes = ['SAP ERP', 'Atlassian', 'Adobe', 'Salesforce', 'Reporting', 'Microsoft Power Platform', 'Microsoft SharePoint', 'Snowflake', 'Microsoft Office']
# ticket_field = "service"
classes = [
"HANA -> Technical Issues",
"HANA -> Billing & Payment",
"HANA -> Product Inquiries",
"HANA -> Account Management",
"HANA -> Policy Questions",
"Business One -> Technical Issues",
"Business One -> Billing & Payment",
"Business One -> Product Inquiries",
"Business One -> Account Management",
"Business One -> Policy Questions",
"Jira -> Technical Issues",
"Jira -> Billing & Payment",
"Jira -> Product Inquiries",
"Jira -> Account Management",
"Jira -> Policy Questions",
"Sourcetree -> Technical Issues",
"Sourcetree -> Billing & Payment",
"Sourcetree -> Product Inquiries",
"Sourcetree -> Account Management",
"Sourcetree -> Policy Questions",
"Opsgenie -> Technical Issues",
"Opsgenie -> Billing & Payment",
"Opsgenie -> Product Inquiries",
"Opsgenie -> Account Management",
"Opsgenie -> Policy Questions",
"Trello -> Technical Issues",
"Trello -> Billing & Payment",
"Trello -> Product Inquiries",
"Trello -> Account Management",
"Trello -> Policy Questions",
"Illustrator -> Technical Issues",
"Illustrator -> Billing & Payment",
"Illustrator -> Product Inquiries",
"Illustrator -> Account Management",
"Illustrator -> Policy Questions",
"Photoshop -> Technical Issues",
"Photoshop -> Billing & Payment",
"Photoshop -> Product Inquiries",
"Photoshop -> Account Management",
"Photoshop -> Policy Questions",
"InDesign -> Technical Issues",
"InDesign -> Billing & Payment",
"InDesign -> Product Inquiries",
"InDesign -> Account Management",
"InDesign -> Policy Questions",
"Premiere -> Technical Issues",
"Premiere -> Billing & Payment",
"Premiere -> Product Inquiries",
"Premiere -> Account Management",
"Premiere -> Policy Questions",
"Apex -> Technical Issues",
"Apex -> Billing & Payment",
"Apex -> Product Inquiries",
"Apex -> Account Management",
"Apex -> Policy Questions",
"Trailhead -> Technical Issues",
"Trailhead -> Billing & Payment",
"Trailhead -> Product Inquiries",
"Trailhead -> Account Management",
"Trailhead -> Policy Questions",
"Visualforce -> Technical Issues",
"Visualforce -> Billing & Payment",
"Visualforce -> Product Inquiries",
"Visualforce -> Account Management",
"Visualforce -> Policy Questions",
"Sales Cloud -> Technical Issues",
"Sales Cloud -> Billing & Payment",
"Sales Cloud -> Product Inquiries",
"Sales Cloud -> Account Management",
"Sales Cloud -> Policy Questions",
"Tableau -> Technical Issues",
"Tableau -> Billing & Payment",
"Tableau -> Product Inquiries",
"Tableau -> Account Management",
"Tableau -> Policy Questions",
"Microsoft PowerBI -> Technical Issues",
"Microsoft PowerBI -> Billing & Payment",
"Microsoft PowerBI -> Product Inquiries",
"Microsoft PowerBI -> Account Management",
"Microsoft PowerBI -> Policy Questions",
"Datasource -> Technical Issues",
"Datasource -> Billing & Payment",
"Datasource -> Product Inquiries",
"Datasource -> Account Management",
"Datasource -> Policy Questions",
"DataFlow -> Technical Issues",
"DataFlow -> Billing & Payment",
"DataFlow -> Product Inquiries",
"DataFlow -> Account Management",
"DataFlow -> Policy Questions",
"Microsoft Power Apps -> Technical Issues",
"Microsoft Power App -> Billing & Payment",
"Microsoft Power App -> Product Inquiries",
"Microsoft Power App -> Account Management",
"Microsoft Power App -> Policy Questions",
"Microsoft Power BI -> Technical Issues",
"Microsoft Power BI -> Billing & Payment",
"Microsoft Power BI -> Product Inquiries",
"Microsoft Power BI -> Account Management",
"Microsoft Power BI -> Policy Questions",
"Microsoft Power Pages Automate -> Technical Issues",
"Microsoft Power Pages Automate -> Billing & Payment",
"Microsoft Power Pages Automate -> Product Inquiries",
"Microsoft Power Pages Automate -> Account Management",
"Microsoft Power Pages Automate -> Policy Questions",
"Microsoft SharePoint -> Technical Issues",
"Microsoft SharePoint -> Billing & Payment",
"Microsoft SharePoint -> Product Inquiries",
"Microsoft SharePoint -> Account Management",
"Microsoft SharePoint -> Policy Questions",
"SharePoint -> Technical Issues",
"SharePoint -> Billing & Payment",
"SharePoint -> Product Inquiries",
"SharePoint -> Account Management",
"SharePoint -> Policy Questions",
"SharePoint List -> Technical Issues",
"SharePoint List -> Billing & Payment",
"SharePoint List -> Product Inquiries",
"SharePoint List -> Account Management",
"SharePoint List -> Policy Questions",
"SharePoint Document Library -> Technical Issues",
"SharePoint Document Library -> Billing & Payment",
"SharePoint Document Library -> Product Inquiries",
"SharePoint Document Library -> Account Management",
"SharePoint Document Library -> Policy Questions",
"Snowflake -> Technical Issues",
"Snowflake -> Billing & Payment",
"Snowflake -> Product Inquiries",
"Snowflake -> Account Management",
"Snowflake -> Policy Question",
"SnowSQL -> Technical Issues",
"SnowSQL -> Billing & Payment",
"SnowSQL -> Product Inquiries",
"SnowSQL -> Account Management",
"SnowSQL -> Policy Question",
"Microsoft Office -> Technical Issues",
"Microsoft Office -> Billing & Payment",
"Microsoft Office -> Product Inquiries",
"Microsoft Office -> Account Management",
"Microsoft Office -> Policy Questions",
"Microsoft Word -> Technical Issues",
"Microsoft Word -> Billing & Payment",
"Microsoft Word -> Product Inquiries",
"Microsoft Word -> Account Management",
"Microsoft Word -> Policy Questions",
"Microsoft Excel -> Technical Issues",
"Microsoft Excel -> Billing & Payment",
"Microsoft Excel -> Product Inquiries",
"Microsoft Excel -> Account Management",
"Microsoft Excel -> Policy Questions",
"Microsoft PowerPoint -> Technical Issues",
"Microsoft PowerPoint -> Billing & Payment",
"Microsoft PowerPoint -> Product Inquiries",
"Microsoft PowerPoint -> Account Management",
"Microsoft PowerPoint -> Policy Questions",
]
ticket_field = "category"
###################
# <prepare the data>
###################

classes = args.classes
ticket_field = args.ticket_field
data_paths = args.data_paths

model_name = "roberta-base"
model = RobertaForSequenceClassification.from_pretrained(
Expand All @@ -224,22 +88,6 @@ def parse_args():

root_directory = os.path.dirname(__file__)

test_dir = os.path.join(root_directory, "..", "test_data")
data_paths = [
# os.path.join(test_dir, "test_data_garvin", "data_translated.json"),
# os.path.join(test_dir, "test_data_irild", "data_translated.json"),
# os.path.join(test_dir, "test_data_fabian", "data_translated.json"),
# os.path.join(test_dir, "test_data_sajjad", "data_translated.json"),
# os.path.join(test_dir, "test_data_marco", "data_translated.json"),
# os.path.join(test_dir, "test_data_tino", "data_translated.json"),
os.path.join(test_dir, "test_data_with_gpt", "data_1.json"),
# os.path.join(test_dir, "test_data_with_gpt", "data_2.json"),
# os.path.join(test_dir, "test_data_with_gpt", "data_3.json"),
# os.path.join(test_dir, "test_data_with_gpt", "data_4.json"),
# os.path.join(test_dir, "test_data_with_gpt", "data_5.json"),
# os.path.join(test_dir, "test_data_with_gpt", "test_data.json")
]

# Create datasets
custom_dataset = CustomDataset(tokenizer, data_paths, ticket_field, classes)
train_set, val_set = torch.utils.data.random_split(
Expand All @@ -250,6 +98,12 @@ def parse_args():
],
)

num_train_data = len(train_set)

##################
# <train the model>
##################

# Define data loaders
train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True)
valid_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False)
Expand Down Expand Up @@ -298,7 +152,6 @@ def parse_args():
total_valid_loss = 0
val_true_labels = []
val_predictions = []
val_confidence_scores = []

with torch.no_grad():
for batch in valid_loader:
Expand All @@ -316,11 +169,6 @@ def parse_args():
loss = torch.nn.functional.cross_entropy(logits, labels)
total_valid_loss += loss.item()

# Calculate confidence scores
softmax_scores = F.softmax(logits, dim=1)
max_confidence_scores = torch.max(softmax_scores, dim=1)
val_confidence_scores.extend(max_confidence_scores.values.cpu().numpy())

average_valid_loss = total_valid_loss / len(valid_loader)
valid_losses.append(average_valid_loss)

Expand Down Expand Up @@ -352,20 +200,11 @@ def parse_args():
plt.savefig(f"training_validation_loss_curve_{ticket_field}.png")
plt.close()

# Visualization of Confidence Scores
plt.figure(figsize=(10, 6))
plt.hist(val_confidence_scores, bins=50, alpha=0.7, color="blue")
plt.title(f"Confidence Score Distribution for {ticket_field}")
plt.xlabel("Confidence Score")
plt.ylabel("Number of Predictions")
plt.grid(True)

# Save Confidence Scores visualization as image
confidence_score_filename = f"confidence_scores_distribution_{ticket_field}.png"
plt.savefig(confidence_score_filename)
plt.close()

# Save the fine-tuned model
# if args.save_model:
model.save_pretrained("fine_tuned_roberta_model_" + ticket_field + "28k_data")
tokenizer.save_pretrained("fine_tuned_roberta_model_" + ticket_field + "28k_data")
if args.save_model:
model.save_pretrained(
"fine_tuned_roberta_model_" + ticket_field + str(num_train_data)
)
tokenizer.save_pretrained(
"fine_tuned_roberta_model_" + ticket_field + str(num_train_data)
)
Loading
Loading