k-Nearest Neighbors Classification (k-NN)¶
\(k\)-NN classification and search algorithms are based on finding the \(k\) nearest observations to the training set. For classification, the problem is to infer the class of a new feature vector by computing the majority vote of its \(k\) nearest observations from the training set. For search, the problem is to infer \(k\) nearest observations from the training set to a new feature vector. The nearest observations are computed based on the chosen distance metric.
Operation |
Computational methods |
Programming Interface |
|||
Mathematical formulation¶
Refer to Developer Guide: k-Nearest Neighbors Classification.
Programming Interface¶
All types and functions in this section are declared in the
oneapi::dal::knn
namespace and be available via inclusion of the
oneapi/dal/algo/knn.hpp
header file.
Enum classes¶
-
enum class
voting_mode
¶ - voting_mode::uniform
Uniform weights for neighbors for prediction voting.
- voting_mode::distance
Weight neighbors by the inverse of their distance.
Result options¶
Descriptor¶
-
template<typename
Float
= float, typenameMethod
= method::by_default, typenameTask
= task::by_default, typenameDistance
= oneapi::dal::minkowski_distance::descriptor<Float>>
classdescriptor
¶ - Template Parameters
Float – The floating-point type that the algorithm uses for intermediate computations. Can be
float
ordouble
.Method – Tag-type that specifies an implementation of algorithm. Can be
method::brute_force
ormethod::kd_tree
.Task – Tag-type that specifies type of the problem to solve. Can be
task::classification
,task::regression
, ortask::search
.Distance – The descriptor of the distance used for computations. Can be
minkowski_distance::descriptor
orchebyshev_distance::descriptor
.
Constructors
-
descriptor
(std::int64_t class_count, std::int64_t neighbor_count)¶ Creates a new instance of the class with the given
class_count
andneighbor_count
property values.
-
template<typename
M
= Method, typenameNone
= detail::enable_if_brute_force_t<M>>descriptor
(std::int64_t class_count, std::int64_t neighbor_count, const distance_t &distance)¶ Creates a new instance of the class with the given
class_count
,neighbor_count
anddistance
property values. Used withmethod::brute_force
only.
-
template<typename
T
= Task, typenameNone
= detail::enable_if_not_classification_t<T>>descriptor
(std::int64_t neighbor_count)¶ Creates a new instance of the class with the given
neighbor_count
property value. Used withtask::search
only.
-
template<typename
T
= Task, typenameNone
= detail::enable_if_not_classification_t<T>>descriptor
(std::int64_t neighbor_count, const distance_t &distance)¶ Creates a new instance of the class with the given
neighbor_count
anddistance
property values. Used withtask::search
only.
Properties
-
std::int64_t
class_count
¶ The number of classes c.
- Getter & Setter
std::int64_t get_class_count() const
auto & set_class_count(std::int64_t value)
- Invariants
class_count > 1
-
voting_mode
voting_mode
¶ The voting mode.
- Getter & Setter
voting_mode get_voting_mode() const
auto & set_voting_mode(voting_mode value)
-
std::int64_t
neighbor_count
¶ The number of neighbors k.
- Getter & Setter
std::int64_t get_neighbor_count() const
auto & set_neighbor_count(std::int64_t value)
- Invariants
neighbor_count > 0
-
const distance_t &
distance
¶ Choose distance type for calculations. Used with
method::brute_force
only.- Getter & Setter
template <typename M = Method, typename None = detail::enable_if_brute_force_t<M>> const distance_t & get_distance() const
template <typename M = Method, typename None = detail::enable_if_brute_force_t<M>> auto & set_distance(const distance_t &dist)
-
result_option_id
result_options
¶ Choose which results should be computed and returned.
- Getter & Setter
result_option_id get_result_options() const
auto & set_result_options(const result_option_id &value)
Method tags¶
-
struct
brute_force
¶ Tag-type that denotes brute-force computational method.
-
using
by_default
= brute_force¶ Alias tag-type for brute-force computational method.
Task tags¶
-
struct
classification
¶ Tag-type that parameterizes entities used for solving classification problem.
-
struct
regression
¶ Tag-type that parameterizes entities used for solving the regression problem.
-
struct
search
¶ Tag-type that parameterizes entities used for solving the search problem.
-
using
by_default
= classification¶ Alias tag-type for classification task.
Model¶
-
template<typename
Task
= task::by_default>
classmodel
¶ - Template Parameters
Task – Tag-type that specifies type of the problem to solve. Can be
task::classification
.
Constructors
-
model
()¶ Creates a new instance of the class with the default property values.
Training train(...)
¶
Input¶
-
template<typename
Task
= task::by_default>
classtrain_input
¶ - Template Parameters
Task – Tag-type that specifies type of the problem to solve. Can be
task::classification
ortask::search
.
Constructors
-
train_input
(const table &data, const table &responses)¶ Creates a new instance of the class with the given
data
andresponses
property values.
Properties
-
const table &
labels
¶ Vector of labels y for the training set X. Default value: table{}.
- Getter & Setter
const table & get_labels() const
template <typename T = Task, typename None = detail::enable_if_classification_t<T>> auto & set_labels(const table &value)
Result¶
-
template<typename
Task
= task::by_default>
classtrain_result
¶ - Template Parameters
Task – Tag-type that specifies type of the problem to solve. Can be
task::classification
ortask::search
.
Constructors
-
train_result
()¶ Creates a new instance of the class with the default property values.
Properties
Operation¶
-
template<typename
Descriptor
>
knn::train_resulttrain
(const Descriptor &desc, const knn::train_input &input)¶ - Parameters
desc – k-NN algorithm descriptor
knn::descriptor
input – Input data for the training operation
- Preconditions
Inference infer(...)
¶
Input¶
-
template<typename
Task
= task::by_default>
classinfer_input
¶ - Template Parameters
Task – Tag-type that specifies type of the problem to solve. Can be
task::classification
ortask::search
.
Constructors
-
infer_input
(const table &data, const model<Task> &model)¶ Creates a new instance of the class with the given
model
anddata
property values.
Properties
Result¶
-
template<typename
Task
= task::by_default>
classinfer_result
¶ - Template Parameters
Task – Tag-type that specifies type of the problem to solve. Can be
task::classification
ortask::search
.
Constructors
-
infer_result
()¶ Creates a new instance of the class with the default property values.
Properties
-
const table &
labels
¶ The predicted labels. Default value: table{}.
- Getter & Setter
const table & get_labels() const
template <typename T = Task, typename None = detail::enable_if_classification_t<T>> auto & set_labels(const table &value)
-
const table &
indices
¶ Indices of nearest neighbors. Default value: table{}.
- Getter & Setter
const table & get_indices() const
auto & set_indices(const table &value)
-
const result_option_id &
result_options
¶ Result options that indicates availability of the properties.
- Getter & Setter
const result_option_id & get_result_options() const
auto & set_result_options(const result_option_id &value)
Operation¶
-
template<typename
Descriptor
>
knn::infer_resultinfer
(const Descriptor &desc, const knn::infer_input &input)¶ - Parameters
desc – k-NN algorithm descriptor
knn::descriptor
input – Input data for the inference operation
- Preconditions
input.data.has_data == true
- Postconditions
Usage example¶
Training¶
knn::model<> run_training(const table& data,
const table& labels) {
const std::int64_t class_count = 10;
const std::int64_t neighbor_count = 5;
const auto knn_desc = knn::descriptor<float>{class_count, neighbor_count};
const auto result = train(knn_desc, data, labels);
return result.get_model();
}
Inference¶
table run_inference(const knn::model<>& model,
const table& new_data) {
const std::int64_t class_count = 10;
const std::int64_t neighbor_count = 5;
const auto knn_desc = knn::descriptor<float>{class_count, neighbor_count};
const auto result = infer(knn_desc, model, new_data);
print_table("labels", result.get_labels());
}
Examples¶
Batch Processing:
Batch Processing:
Batch Processing: