function majority_label(label_values::Vector{String})
counts = Dict{String, Int}()
for value in label_values
counts[value] = get(counts, value, 0) + 1
end
max_count = maximum(values(counts))
sort([label for (label, count) in counts if count == max_count])[1]
end
function vote_by_distance(neighbors; weighted = false, fallback = "")
isempty(neighbors) && return fallback
if weighted
weights = Dict{String, Float64}()
for (distance, label) in neighbors
weights[label] = get(weights, label, 0.0) + 1.0 / (distance + 1e-8)
end
best_weight = maximum(values(weights))
return sort([label for (label, weight) in weights if weight == best_weight])[1]
end
counts = Dict{String, Int}()
for (_, label) in neighbors
counts[label] = get(counts, label, 0) + 1
end
max_count = maximum(values(counts))
tied = [label for (label, count) in counts if count == max_count]
length(tied) == 1 && return only(tied)
average_distances = Dict(
label => mean([distance for (distance, neighbor_label) in neighbors if neighbor_label == label])
for label in tied
)
best_distance = minimum(values(average_distances))
sort([label for (label, distance) in average_distances if distance == best_distance])[1]
end
function predict_distance_knn_fold(D, labels, train_idx, test_idx; k = 1, weighted = false)
fallback = majority_label(labels[train_idx])
predictions = Vector{String}(undef, length(test_idx))
for (out_idx, sample_idx) in enumerate(test_idx)
neighbors = [
(D[sample_idx, train_idx_i], labels[train_idx_i])
for train_idx_i in train_idx
if isfinite(D[sample_idx, train_idx_i])
]
sort!(neighbors, by = first)
selected = neighbors[1:min(k, length(neighbors))]
predictions[out_idx] = vote_by_distance(selected; weighted = weighted, fallback = fallback)
end
predictions
end
function predict_distance_centroid_fold(D, labels, train_idx, test_idx)
classes = sort(unique(labels[train_idx]))
predictions = Vector{String}(undef, length(test_idx))
for (out_idx, sample_idx) in enumerate(test_idx)
class_distances = Dict{String, Float64}()
for class in classes
class_idx = [i for i in train_idx if labels[i] == class]
class_distances[class] = mean(D[sample_idx, class_idx])
end
best_distance = minimum(values(class_distances))
predictions[out_idx] = sort([
class for (class, distance) in class_distances if distance == best_distance
])[1]
end
predictions
end
function repeated_stratified_distance_cv(
D::Matrix,
labels::Vector{String};
classifier::Symbol,
k::Int = 1,
weighted::Bool = false,
folds::Int,
repeats::Int,
rng_seed::Int,
)
Dclean = sanitize_distance_matrix(D)
n = length(labels)
repeat_metrics = DataFrame(
repeat = Int[],
accuracy = Float64[],
macro_f1 = Float64[],
macro_recall = Float64[],
)
pooled_true = String[]
pooled_pred = String[]
for repeat_id in 1:repeats
rng = MersenneTwister(rng_seed + repeat_id)
cv_folds = stratified_kfolds(labels; k = folds, rng = rng)
repeat_pred = Vector{String}(undef, n)
for fold_id in 1:folds
test_idx = sort(cv_folds[fold_id])
train_idx = setdiff(1:n, test_idx)
fold_pred = if classifier == :centroid
predict_distance_centroid_fold(Dclean, labels, train_idx, test_idx)
else
predict_distance_knn_fold(
Dclean,
labels,
train_idx,
test_idx;
k = k,
weighted = weighted,
)
end
repeat_pred[test_idx] = fold_pred
end
metrics = classification_metrics(labels, repeat_pred)
push!(repeat_metrics, (
repeat_id,
metrics.accuracy,
metrics.macro_f1,
metrics.balanced_accuracy,
))
append!(pooled_true, labels)
append!(pooled_pred, repeat_pred)
end
pooled_metrics = classification_metrics(pooled_true, pooled_pred)
(repeat_metrics = repeat_metrics, pooled_metrics = pooled_metrics)
end
distance_classifier_specs = [
(classifier = :knn, label = "1-NN", k = 1, weighted = false),
(classifier = :knn, label = "3-NN", k = 3, weighted = false),
(classifier = :knn, label = "5-NN", k = 5, weighted = false),
(classifier = :knn, label = "3-NN weighted", k = 3, weighted = true),
(classifier = :knn, label = "5-NN weighted", k = 5, weighted = true),
(classifier = :centroid, label = "Nearest family average", k = 1, weighted = false),
]
wasserstein_cv_rows = NamedTuple[]
for spec in wasserstein_specs
D = rips_wasserstein_distances[spec.distance]
for classifier_spec in distance_classifier_specs
result = repeated_stratified_distance_cv(
D,
labels;
classifier = classifier_spec.classifier,
k = classifier_spec.k,
weighted = classifier_spec.weighted,
folds = CV_K,
repeats = CV_REPEATS,
rng_seed = RNG_SEED + 200_000,
)
push!(wasserstein_cv_rows, (
distance = spec.distance,
classifier = classifier_spec.label,
accuracy = result.pooled_metrics.accuracy,
macro_f1 = result.pooled_metrics.macro_f1,
macro_recall = result.pooled_metrics.balanced_accuracy,
))
end
end
wasserstein_cv_results = DataFrame(wasserstein_cv_rows)
sort!(wasserstein_cv_results, [:macro_f1, :macro_recall, :accuracy], rev = true)
wasserstein_cv_results.accuracy_percent = round.(wasserstein_cv_results.accuracy .* 100, digits = 1)
wasserstein_cv_results.macro_f1_percent = round.(wasserstein_cv_results.macro_f1 .* 100, digits = 1)
wasserstein_cv_results.macro_recall_percent = round.(wasserstein_cv_results.macro_recall .* 100, digits = 1)
select(
first(wasserstein_cv_results, 12),
:distance,
:classifier,
:accuracy_percent,
:macro_f1_percent,
:macro_recall_percent,
)