-
Notifications
You must be signed in to change notification settings - Fork 0
replace DictTable with DictColumnTable #59
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,7 +1,7 @@ | ||
| name = "NearestNeighborModels" | ||
| uuid = "636a865e-7cf4-491e-846c-de09b730eb36" | ||
| authors = ["Anthony D. Blaom <[email protected]>", "Sebastian Vollmer <[email protected]>", "Thibaut Lienart <[email protected]>", "Okon Samuel <[email protected]>"] | ||
| version = "0.2.3" | ||
| version = "0.2.2" | ||
|
|
||
| [deps] | ||
| Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" | ||
|
|
@@ -13,15 +13,17 @@ NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce" | |
| Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" | ||
| StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" | ||
| Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" | ||
| OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" | ||
|
|
||
| [compat] | ||
| Distances = "^0.9, ^0.10" | ||
| FillArrays = "^0.9, ^0.10, ^0.11, 0.12, 0.13, 1.0" | ||
| FillArrays = "^0.9, ^0.10, ^0.11, 0.12, 0.13, 1" | ||
| MLJModelInterface = "1.4" | ||
| NearestNeighbors = "^0.4" | ||
| StatsBase = "0.33, 0.34" | ||
| OrderedCollections = "1.1" | ||
| StatsBase = "^0.33, 0.34" | ||
| Tables = "^1.2" | ||
| julia = "1.6" | ||
| julia = "1.3" | ||
|
|
||
| [extras] | ||
| MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,19 +19,21 @@ end | |
| function dict_preds(::Val{:columnaccess}, func, target_table, idxsvec, weights) | ||
| cols = Tables.columns(target_table) | ||
| colnames = Tables.columnnames(cols) | ||
| dict_table = Dict( | ||
| dict = OrderedDict{Symbol, AbstractVector}( | ||
| nm => func(weights, Tables.getcolumn(cols, nm), idxsvec) for nm in colnames | ||
| ) | ||
| dict_table = Tables.DictColumnTable(Tables.Schema(colnames, eltype.(values(dict))), dict) | ||
| return dict_table | ||
| end | ||
|
|
||
| function dict_preds(::Val{:noncolumnaccess}, func, target_table, idxsvec, weights) | ||
| cols = Tables.dictcolumntable(target_table) | ||
| colnames = Tables.columnnames(cols) | ||
| dict_table = Dict( | ||
| nm => func(weights, Tables.getcolumn(cols, nm), idxsvec) for nm in colnames | ||
| ) | ||
| return dict_table | ||
| cols = Tables.dictcolumntable(target_table) | ||
| colnames = Tables.columnnames(cols) | ||
| dict = OrderedDict{Symbol, AbstractVector}( | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Re |
||
| nm => func(weights, Tables.getcolumn(cols, nm), idxsvec) for nm in colnames | ||
| ) | ||
| dict_table = Tables.DictColumnTable(Tables.Schema(colnames, eltype.(values(dict))), dict) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The object There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it safest to construct your own table and use |
||
| return dict_table | ||
| end | ||
|
|
||
| function dict_preds(func::F, target_table, idxsvec, weights) where {F<:Function} | ||
|
|
@@ -71,7 +73,7 @@ function ntuple_preds(func::F, target_table, idxsvec, weights) where {F <: Funct | |
| end | ||
|
|
||
| function univariate_table(::Type{T}, weights, target_table, idxsvec) where {T} | ||
| table = if T <: DictTable | ||
| table = if T <: DictColumnTable | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is fine, but maybe type dispatch instead of |
||
| dict_preds(_predict_knnclassifier, target_table, idxsvec, weights) | ||
| else | ||
| ntuple_preds(_predict_knnclassifier, target_table, idxsvec, weights) | ||
|
|
@@ -80,7 +82,7 @@ function univariate_table(::Type{T}, weights, target_table, idxsvec) where {T} | |
| end | ||
|
|
||
| function categorical_table(::Type{T}, weights, target_table, idxsvec) where {T} | ||
| table = if T <: DictTable | ||
| table = if T <: DictColumnTable | ||
| dict_preds(_predictmode_knnclassifier, target_table, idxsvec, weights) | ||
| else | ||
| ntuple_preds(_predictmode_knnclassifier, target_table, idxsvec, weights) | ||
|
|
@@ -311,7 +313,7 @@ function MultitargetKNNClassifier(; | |
| leafsize::Int = (algorithm == :brutetree) ? 0 : 10, | ||
| reorder::Bool = algorithm != :brutetree, | ||
| weights::KNNKernel=Uniform(), | ||
| output_type::Type{<:MultiUnivariateFinite} = DictTable | ||
| output_type::Type{<:MultiUnivariateFinite} = DictColumnTable | ||
| ) | ||
| model = MultitargetKNNClassifier( | ||
| K, algorithm, metric, leafsize, reorder, weights, output_type | ||
|
|
@@ -623,7 +625,7 @@ Here: | |
| - `X` is any table of input features (eg, a `DataFrame`) whose columns are of scitype | ||
| `Continuous`; check column scitypes with `schema(X)`. | ||
|
|
||
| - y` is the target, which can be any table of responses whose element scitype is either | ||
| - `y` is the target, which can be any table of responses whose element scitype is either | ||
| `<:Finite` (`<:Multiclass` or `<:OrderedFactor` will do); check the columns scitypes with `schema(y)`. | ||
| Each column of `y` is assumed to belong to a common categorical pool. | ||
|
|
||
|
|
@@ -637,16 +639,16 @@ Train the machine using `fit!(mach, rows=...)`. | |
|
|
||
| $KNNFIELDS | ||
|
|
||
| * `output_type::Type{<:MultiUnivariateFinite}=DictTable` : One of | ||
| (`ColumnTable`, `DictTable`). The type of table type to use for predictions. | ||
| * `output_type::Type{<:MultiUnivariateFinite}=DictColumnTable` : One of | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In docs I recommend you qualify |
||
| (`ColumnTable`, `DictColumnTable`). The type of table type to use for predictions. | ||
| Setting to `ColumnTable` might improve performance for narrow tables while setting to | ||
| `DictTable` improves performance for wide tables. | ||
| `DictColumnTable` improves performance for wide tables. | ||
|
|
||
| # Operations | ||
|
|
||
| - `predict(mach, Xnew)`: Return predictions of the target given features `Xnew`, which | ||
| should have same scitype as `X` above. Predictions are either a `ColumnTable` or | ||
| `DictTable` of `UnivariateFiniteVector` columns depending on the value set for the | ||
| `DictColumnTable` of `UnivariateFiniteVector` columns depending on the value set for the | ||
| `output_type` parameter discussed above. The probabilistic predictions are uncalibrated. | ||
|
|
||
| - `predict_mode(mach, Xnew)`: Return the modes of each column of the table of probabilistic | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we reverting to Julia 1.3 support?