Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Menu
Open sidebar
Pado
ASYST
Commits
bd414ae2
Commit
bd414ae2
authored
Oct 22, 2024
by
Pado
Browse files
Generate evaluation table as for German (if observed labels are given)
parent
78131ec0
Changes
1
Show whitespace changes
Inline
Side-by-side
Source/Skript/english/run_SAG_mnli.py
View file @
bd414ae2
...
...
@@ -33,6 +33,7 @@ import torch
from
torch.utils.data
import
DataLoader
,
SequentialSampler
,
TensorDataset
from
sklearn.metrics
import
f1_score
,
accuracy_score
from
sklearn.metrics
import
classification_report
from
transformers
import
(
BertConfig
,
...
...
@@ -491,10 +492,53 @@ def print_predictions(args, preds):
processor
.
get_test_examples
(
args
.
data_dir
)
)
# observed grade list created
obs_grade
=
[
ex
.
label
for
ex
in
examples
]
# suggested grade list created
sugg_grade
=
[
'correct'
if
pred
==
0
else
'incorrect'
for
pred
in
preds
]
# flag: do observed grades exist?
count
=
0
# Check if obs_grade contains "NONE" values or is empty
if
not
obs_grade
or
all
(
grade
==
'NONE'
for
grade
in
obs_grade
):
count
+=
1
else
:
# classification report
classification_rep
=
classification_report
(
obs_grade
,
sugg_grade
)
report_string
=
classification_rep
report_lines
=
report_string
.
split
(
'
\n
'
)
# print(report_lines)
# accuracy line
formatted_accuracy_line
=
"
\t
"
.
join
(
report_lines
[
5
].
split
())
formatted_acc_line_with_tabs
=
(
formatted_accuracy_line
[:
formatted_accuracy_line
.
index
(
'
\t
'
,
formatted_accuracy_line
.
index
(
'
\t
'
))]
+
'
\t\t
'
+
formatted_accuracy_line
[
formatted_accuracy_line
.
index
(
'
\t
'
,
formatted_accuracy_line
.
index
(
'
\t
'
)):])
# #weighted avg printing
#
wt_avg_line
=
"
\t
"
.
join
(
report_lines
[
7
].
split
())
new_wt_avg_line
=
wt_avg_line
.
replace
(
"
\t
"
,
" "
,
1
)
# Join the entire newly formatted list into a single string
formatted_output
=
"
\n
"
.
join
([
"
\t
precision
\t
recall
\t
f1-score
\t
support"
,
"
\t
"
.
join
(
report_lines
[
2
].
split
()),
"
\t
"
.
join
(
report_lines
[
3
].
split
()),
formatted_acc_line_with_tabs
,
new_wt_avg_line
])
with
open
(
args
.
data_dir
+
"/"
+
dir_name
+
"/predictions.txt"
,
"w"
,
encoding
=
"utf8"
)
as
writer
:
# print("# examples: " + str(len(examples)))
# print("# labels: " + str(len(labels)))
# print("# preds: " + str(len(preds)))
writer
.
write
(
"question
\t
referenceAnswer
\t
studentAnswer
\t
suggested grade
\t
observed grade
\n
"
)
...
...
@@ -519,7 +563,17 @@ def print_predictions(args, preds):
+
examples
[
i
].
label
+
"
\n
"
)
# else: print("Labels don't match! "+str(i)+": "+str(examples[i].label)+" "+str(labels[i]))
if
count
==
1
:
writer
.
write
(
"
\n
Classification Report cannot be printed as observed grade column is empty or filled "
"with 'NONE' or 'none' values
\n
"
)
else
:
# Write the classification report to the file
writer
.
write
(
"
\n
Classification Report - high Precision for classes correct or incorrect indicates that the class prediction is reliable:
\n
"
)
writer
.
write
(
formatted_output
)
if
__name__
==
"__main__"
:
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment