Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NearestNeighborModels"
uuid = "636a865e-7cf4-491e-846c-de09b730eb36"
authors = ["Anthony D. Blaom <[email protected]>", "Sebastian Vollmer <[email protected]>", "Thibaut Lienart <[email protected]>", "Okon Samuel <[email protected]>"]
version = "0.2.3"
version = "0.2.2"

[deps]
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
Expand All @@ -13,15 +13,17 @@ NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"

[compat]
Distances = "^0.9, ^0.10"
FillArrays = "^0.9, ^0.10, ^0.11, 0.12, 0.13, 1.0"
FillArrays = "^0.9, ^0.10, ^0.11, 0.12, 0.13, 1"
MLJModelInterface = "1.4"
NearestNeighbors = "^0.4"
StatsBase = "0.33, 0.34"
OrderedCollections = "1.1"
StatsBase = "^0.33, 0.34"
Tables = "^1.2"
julia = "1.6"
julia = "1.3"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we reverting to Julia 1.3 support?


[extras]
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
Expand Down
5 changes: 3 additions & 2 deletions src/NearestNeighborModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ using Distances
using FillArrays
using LinearAlgebra
using Statistics
using OrderedCollections

# ==============================================================================================
## EXPORTS
Expand All @@ -36,8 +37,8 @@ const Vec{T} = AbstractVector{T}
const Mat{T} = AbstractMatrix{T}
const Arr{T, N} = AbstractArray{T, N}
const ColumnTable = Tables.ColumnTable
const DictTable = Tables.DictColumns
const MultiUnivariateFinite = Union{DictTable, ColumnTable}
const DictColumnTable = Tables.DictColumnTable
const MultiUnivariateFinite = Union{DictColumnTable, ColumnTable}

# Define constants for easy referencing of packages
const MMI = MLJModelInterface
Expand Down
32 changes: 17 additions & 15 deletions src/models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,21 @@ end
function dict_preds(::Val{:columnaccess}, func, target_table, idxsvec, weights)
cols = Tables.columns(target_table)
colnames = Tables.columnnames(cols)
dict_table = Dict(
dict = OrderedDict{Symbol, AbstractVector}(
nm => func(weights, Tables.getcolumn(cols, nm), idxsvec) for nm in colnames
)
dict_table = Tables.DictColumnTable(Tables.Schema(colnames, eltype.(values(dict))), dict)
return dict_table
end

function dict_preds(::Val{:noncolumnaccess}, func, target_table, idxsvec, weights)
cols = Tables.dictcolumntable(target_table)
colnames = Tables.columnnames(cols)
dict_table = Dict(
nm => func(weights, Tables.getcolumn(cols, nm), idxsvec) for nm in colnames
)
return dict_table
cols = Tables.dictcolumntable(target_table)
colnames = Tables.columnnames(cols)
dict = OrderedDict{Symbol, AbstractVector}(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re AbstractVector, remind me, why cannot we not have a concrete type here?

nm => func(weights, Tables.getcolumn(cols, nm), idxsvec) for nm in colnames
)
dict_table = Tables.DictColumnTable(Tables.Schema(colnames, eltype.(values(dict))), dict)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure DictColumnTable constructor is strictly public. Maybe we should be using Tables.dictcolumntable instead?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The object DictColumnTable is public. So I don't see any harm in using the constructor.
The other alternative is to create my own table as Tables.dictcolumntable isn't a constructor, it a utility method that converts other tables into a DictColumnTable.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it safest to construct your own table and use dictcolumntable.

return dict_table
end

function dict_preds(func::F, target_table, idxsvec, weights) where {F<:Function}
Expand Down Expand Up @@ -71,7 +73,7 @@ function ntuple_preds(func::F, target_table, idxsvec, weights) where {F <: Funct
end

function univariate_table(::Type{T}, weights, target_table, idxsvec) where {T}
table = if T <: DictTable
table = if T <: DictColumnTable
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine, but maybe type dispatch instead of if...else...end is more usual?

dict_preds(_predict_knnclassifier, target_table, idxsvec, weights)
else
ntuple_preds(_predict_knnclassifier, target_table, idxsvec, weights)
Expand All @@ -80,7 +82,7 @@ function univariate_table(::Type{T}, weights, target_table, idxsvec) where {T}
end

function categorical_table(::Type{T}, weights, target_table, idxsvec) where {T}
table = if T <: DictTable
table = if T <: DictColumnTable
dict_preds(_predictmode_knnclassifier, target_table, idxsvec, weights)
else
ntuple_preds(_predictmode_knnclassifier, target_table, idxsvec, weights)
Expand Down Expand Up @@ -311,7 +313,7 @@ function MultitargetKNNClassifier(;
leafsize::Int = (algorithm == :brutetree) ? 0 : 10,
reorder::Bool = algorithm != :brutetree,
weights::KNNKernel=Uniform(),
output_type::Type{<:MultiUnivariateFinite} = DictTable
output_type::Type{<:MultiUnivariateFinite} = DictColumnTable
)
model = MultitargetKNNClassifier(
K, algorithm, metric, leafsize, reorder, weights, output_type
Expand Down Expand Up @@ -623,7 +625,7 @@ Here:
- `X` is any table of input features (eg, a `DataFrame`) whose columns are of scitype
`Continuous`; check column scitypes with `schema(X)`.

- y` is the target, which can be any table of responses whose element scitype is either
- `y` is the target, which can be any table of responses whose element scitype is either
`<:Finite` (`<:Multiclass` or `<:OrderedFactor` will do); check the columns scitypes with `schema(y)`.
Each column of `y` is assumed to belong to a common categorical pool.

Expand All @@ -637,16 +639,16 @@ Train the machine using `fit!(mach, rows=...)`.

$KNNFIELDS

* `output_type::Type{<:MultiUnivariateFinite}=DictTable` : One of
(`ColumnTable`, `DictTable`). The type of table type to use for predictions.
* `output_type::Type{<:MultiUnivariateFinite}=DictColumnTable` : One of
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In docs I recommend you qualify DictColumnTable as Tables.DictColumnTable or otherwise explain that this type is owned by 3rd party package.

(`ColumnTable`, `DictColumnTable`). The type of table type to use for predictions.
Setting to `ColumnTable` might improve performance for narrow tables while setting to
`DictTable` improves performance for wide tables.
`DictColumnTable` improves performance for wide tables.

# Operations

- `predict(mach, Xnew)`: Return predictions of the target given features `Xnew`, which
should have same scitype as `X` above. Predictions are either a `ColumnTable` or
`DictTable` of `UnivariateFiniteVector` columns depending on the value set for the
`DictColumnTable` of `UnivariateFiniteVector` columns depending on the value set for the
`output_type` parameter discussed above. The probabilistic predictions are uncalibrated.

- `predict_mode(mach, Xnew)`: Return the modes of each column of the table of probabilistic
Expand Down