diff --git a/Source/Skript/english/run_SAG_mnli.py b/Source/Skript/english/run_SAG_mnli.py index 20c5f7bd3b076d0f097bf0f36d5a82fb55afa624..37e203a4983fb83d4ba9edf1b8108f231abf37ba 100644 --- a/Source/Skript/english/run_SAG_mnli.py +++ b/Source/Skript/english/run_SAG_mnli.py @@ -33,6 +33,7 @@ import torch from torch.utils.data import DataLoader, SequentialSampler, TensorDataset from sklearn.metrics import f1_score, accuracy_score +from sklearn.metrics import classification_report from transformers import ( BertConfig, @@ -491,10 +492,53 @@ def print_predictions(args, preds): processor.get_test_examples(args.data_dir) ) + # observed grade list created + obs_grade = [ex.label for ex in examples] + + # suggested grade list created + sugg_grade = ['correct' if pred == 0 else 'incorrect' for pred in preds] + + # flag: do observed grades exist? + count=0 + + # Check if obs_grade contains "NONE" values or is empty + if not obs_grade or all(grade == 'NONE' for grade in obs_grade): + count += 1 + + else: + # classification report + classification_rep = classification_report(obs_grade, sugg_grade) + + report_string = classification_rep + + report_lines = report_string.split('\n') + + # print(report_lines) + + # accuracy line + formatted_accuracy_line = "\t".join(report_lines[5].split()) + formatted_acc_line_with_tabs = (formatted_accuracy_line[:formatted_accuracy_line.index('\t', + formatted_accuracy_line.index( + '\t'))] + '\t\t' + + formatted_accuracy_line[ + formatted_accuracy_line.index('\t', formatted_accuracy_line.index('\t')):]) + + # #weighted avg printing + # + wt_avg_line = "\t".join(report_lines[7].split()) + + new_wt_avg_line = wt_avg_line.replace("\t", " ", 1) + + # Join the entire newly formatted list into a single string + formatted_output = "\n".join([ + "\t precision \t recall \t f1-score \t support", + "\t".join(report_lines[2].split()), + "\t".join(report_lines[3].split()), + formatted_acc_line_with_tabs, + new_wt_avg_line + ]) + with open(args.data_dir + "/" + dir_name + "/predictions.txt", "w", encoding="utf8") as writer: - # print("# examples: " + str(len(examples))) - # print("# labels: " + str(len(labels))) - # print("# preds: " + str(len(preds))) writer.write( "question\treferenceAnswer\tstudentAnswer\tsuggested grade\tobserved grade\n") @@ -519,7 +563,17 @@ def print_predictions(args, preds): + examples[i].label + "\n" ) - # else: print("Labels don't match! "+str(i)+": "+str(examples[i].label)+" "+str(labels[i])) + + if count == 1: + writer.write("\nClassification Report cannot be printed as observed grade column is empty or filled " + "with 'NONE' or 'none' values\n") + else: + + # Write the classification report to the file + + writer.write( + "\nClassification Report - high Precision for classes correct or incorrect indicates that the class prediction is reliable:\n") + writer.write(formatted_output) if __name__ == "__main__":