Project 'ulrike.pado/ASYST' was moved to 'knight/ASYST'. Please update any links and bookmarks that may still have the old path.
Commit bd414ae2 authored by Pado's avatar Pado
Browse files

Generate evaluation table as for German (if observed labels are given)

2 merge requests!2Deleted Source/dist/ASYST.exe, Source/build/main/ASYST.pkg,...,!1Deleted Source/dist/ASYST.exe, Source/build/main/ASYST.pkg,...
Showing with 58 additions and 4 deletions
+58 -4
......@@ -33,6 +33,7 @@ import torch
from torch.utils.data import DataLoader, SequentialSampler, TensorDataset
from sklearn.metrics import f1_score, accuracy_score
from sklearn.metrics import classification_report
from transformers import (
BertConfig,
......@@ -491,10 +492,53 @@ def print_predictions(args, preds):
processor.get_test_examples(args.data_dir)
)
# observed grade list created
obs_grade = [ex.label for ex in examples]
# suggested grade list created
sugg_grade = ['correct' if pred == 0 else 'incorrect' for pred in preds]
# flag: do observed grades exist?
count=0
# Check if obs_grade contains "NONE" values or is empty
if not obs_grade or all(grade == 'NONE' for grade in obs_grade):
count += 1
else:
# classification report
classification_rep = classification_report(obs_grade, sugg_grade)
report_string = classification_rep
report_lines = report_string.split('\n')
# print(report_lines)
# accuracy line
formatted_accuracy_line = "\t".join(report_lines[5].split())
formatted_acc_line_with_tabs = (formatted_accuracy_line[:formatted_accuracy_line.index('\t',
formatted_accuracy_line.index(
'\t'))] + '\t\t' +
formatted_accuracy_line[
formatted_accuracy_line.index('\t', formatted_accuracy_line.index('\t')):])
# #weighted avg printing
#
wt_avg_line = "\t".join(report_lines[7].split())
new_wt_avg_line = wt_avg_line.replace("\t", " ", 1)
# Join the entire newly formatted list into a single string
formatted_output = "\n".join([
"\t precision \t recall \t f1-score \t support",
"\t".join(report_lines[2].split()),
"\t".join(report_lines[3].split()),
formatted_acc_line_with_tabs,
new_wt_avg_line
])
with open(args.data_dir + "/" + dir_name + "/predictions.txt", "w", encoding="utf8") as writer:
# print("# examples: " + str(len(examples)))
# print("# labels: " + str(len(labels)))
# print("# preds: " + str(len(preds)))
writer.write(
"question\treferenceAnswer\tstudentAnswer\tsuggested grade\tobserved grade\n")
......@@ -519,7 +563,17 @@ def print_predictions(args, preds):
+ examples[i].label
+ "\n"
)
# else: print("Labels don't match! "+str(i)+": "+str(examples[i].label)+" "+str(labels[i]))
if count == 1:
writer.write("\nClassification Report cannot be printed as observed grade column is empty or filled "
"with 'NONE' or 'none' values\n")
else:
# Write the classification report to the file
writer.write(
"\nClassification Report - high Precision for classes correct or incorrect indicates that the class prediction is reliable:\n")
writer.write(formatted_output)
if __name__ == "__main__":
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment