Skip to content

Commit e147078

Browse files
Fix #41, add tests around evaluate/tuning in MLJ (#42)
* Fix #41, add tests * fix formatting * fix bounds for Test, Statistics * run formatter * Change formatting default to new lts, 1.10 * bump lts to 1.10, increase StatisticalMeasures lowerbound * Revert docs version bump. Leaving it for a different PR * add arm mac to test matrix * Fix arch for arm mac tests * bump Catboost.jl to v0.3.6 * Make sure `y_first` is a CategoricalValue * Pass the pool object to `UnivariateFinite` * Fix `MMI.predict`, ensure `fitresult` is a named tuple * reformat
1 parent 13b1919 commit e147078

File tree

11 files changed

+149
-158
lines changed

11 files changed

+149
-158
lines changed

.github/workflows/CI.yml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ jobs:
2626
fail-fast: false
2727
matrix:
2828
version:
29-
- '1.6'
29+
- '1.10'
3030
- '1'
3131
- 'nightly'
3232
os:
@@ -40,6 +40,9 @@ jobs:
4040
- os: macos-latest
4141
version: '1'
4242
arch: x64
43+
- os: macos-latest
44+
version: '1'
45+
arch: arm64
4346
steps:
4447
- uses: actions/checkout@v4
4548
with:

.github/workflows/docs.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ jobs:
2222
- uses: julia-actions/setup-julia@latest
2323
with:
2424
version: 1.6 # earliest supported version
25-
- uses: julia-actions/cache@v1 # https://github.com/julia-actions/cache
25+
- uses: julia-actions/cache@v2 # https://github.com/julia-actions/cache
2626
- uses: julia-actions/julia-docdeploy@releases/v1
2727
env:
2828
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # For authentication with GitHub Actions token

.github/workflows/format_check.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ jobs:
1515
steps:
1616
- uses: julia-actions/setup-julia@latest
1717
with:
18-
version: 1.6.0
18+
version: 1.10.7
1919
- uses: actions/checkout@v4
2020
- name: Instantiate `format` environment and format
2121
run: |

CondaPkg.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,2 @@
11
[deps]
2-
numpy = ">=1,<2"
32
catboost = ">=1.1"

Project.toml

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,40 @@
11
name = "CatBoost"
22
uuid = "e2e10f9a-a85d-4fa9-b6b2-639a32100a12"
33
authors = ["Beacon Biosignals, Inc."]
4-
version = "0.3.5"
4+
version = "0.3.6"
55

66
[deps]
7+
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
78
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
89
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
910
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
1011
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
1112

1213
[compat]
1314
Aqua = "0.8.4"
15+
CategoricalArrays = "0.10.8"
1416
DataFrames = "1.6"
1517
MLJBase = "1"
1618
MLJModelInterface = "1.7"
1719
MLJTestInterface = "0.2.6"
20+
MLJTuning = "0.8"
1821
OrderedCollections = "1.6"
1922
PythonCall = "0.9"
23+
StatisticalMeasures = "0.1.7"
24+
Statistics = "<0.0.1, 1"
2025
Tables = "1.10"
21-
Test = "1.6"
22-
julia = "1.6"
26+
Test = "<0.0.1, 1"
27+
julia = "1.10"
2328

2429
[extras]
2530
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
2631
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
2732
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
2833
MLJTestInterface = "72560011-54dd-4dc2-94f3-c5de45b75ecd"
34+
MLJTuning = "03970b2e-30c4-11ea-3135-d1576263f10f"
35+
StatisticalMeasures = "a19d573c-0a75-4610-95b3-7071388c7541"
36+
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2937
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3038

3139
[targets]
32-
test = ["Aqua", "DataFrames", "MLJBase", "MLJTestInterface", "Test"]
40+
test = ["Aqua", "DataFrames", "MLJBase", "MLJTestInterface", "MLJTuning", "Statistics", "StatisticalMeasures", "Test"]

format/Manifest.toml

Lines changed: 55 additions & 138 deletions
Original file line numberDiff line numberDiff line change
@@ -1,196 +1,113 @@
11
# This file is machine-generated - editing it directly is not advised
22

3-
[[ArgTools]]
4-
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
3+
julia_version = "1.10.7"
4+
manifest_format = "2.0"
5+
project_hash = "30b405be1c677184b7703a9bfb3d2100029ccad0"
56

6-
[[Artifacts]]
7-
uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
8-
9-
[[Base64]]
7+
[[deps.Base64]]
108
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
119

12-
[[CSTParser]]
10+
[[deps.CSTParser]]
1311
deps = ["Tokenize"]
14-
git-tree-sha1 = "3ddd48d200eb8ddf9cb3e0189fc059fd49b97c1f"
12+
git-tree-sha1 = "0157e592151e39fa570645e2b2debcdfb8a0f112"
1513
uuid = "00ebfdb7-1f24-5e51-bd34-a7502290713f"
16-
version = "3.3.6"
14+
version = "3.4.3"
1715

18-
[[CommonMark]]
19-
deps = ["Crayons", "JSON", "PrecompileTools", "URIs"]
20-
git-tree-sha1 = "532c4185d3c9037c0237546d817858b23cf9e071"
16+
[[deps.CommonMark]]
17+
deps = ["Crayons", "PrecompileTools"]
18+
git-tree-sha1 = "3faae67b8899797592335832fccf4b3c80bb04fa"
2119
uuid = "a80b9123-70ca-4bc0-993e-6e3bcb318db6"
22-
version = "0.8.12"
20+
version = "0.8.15"
2321

24-
[[Compat]]
25-
deps = ["Dates", "LinearAlgebra", "UUIDs"]
26-
git-tree-sha1 = "886826d76ea9e72b35fcd000e535588f7b60f21d"
22+
[[deps.Compat]]
23+
deps = ["TOML", "UUIDs"]
24+
git-tree-sha1 = "8ae8d32e09f0dcf42a36b90d4e17f5dd2e4c4215"
2725
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
28-
version = "4.10.1"
26+
version = "4.16.0"
27+
28+
[deps.Compat.extensions]
29+
CompatLinearAlgebraExt = "LinearAlgebra"
2930

30-
[[Crayons]]
31+
[deps.Compat.weakdeps]
32+
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
33+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
34+
35+
[[deps.Crayons]]
3136
git-tree-sha1 = "249fe38abf76d48563e2f4556bebd215aa317e15"
3237
uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
3338
version = "4.1.1"
3439

35-
[[DataStructures]]
40+
[[deps.DataStructures]]
3641
deps = ["Compat", "InteractiveUtils", "OrderedCollections"]
37-
git-tree-sha1 = "ac67408d9ddf207de5cfa9a97e114352430f01ed"
42+
git-tree-sha1 = "1d0a14036acb104d9e89698bd408f63ab58cdc82"
3843
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
39-
version = "0.18.16"
44+
version = "0.18.20"
4045

41-
[[Dates]]
46+
[[deps.Dates]]
4247
deps = ["Printf"]
4348
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
4449

45-
[[Downloads]]
46-
deps = ["ArgTools", "LibCURL", "NetworkOptions"]
47-
uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
48-
49-
[[Glob]]
50+
[[deps.Glob]]
5051
git-tree-sha1 = "97285bbd5230dd766e9ef6749b80fc617126d496"
5152
uuid = "c27321d9-0574-5035-807b-f59d2c89b15c"
5253
version = "1.3.1"
5354

54-
[[InteractiveUtils]]
55+
[[deps.InteractiveUtils]]
5556
deps = ["Markdown"]
5657
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
5758

58-
[[JSON]]
59-
deps = ["Dates", "Mmap", "Parsers", "Unicode"]
60-
git-tree-sha1 = "31e996f0a15c7b280ba9f76636b3ff9e2ae58c9a"
61-
uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
62-
version = "0.21.4"
63-
64-
[[JuliaFormatter]]
65-
deps = ["CSTParser", "CommonMark", "DataStructures", "Glob", "Pkg", "PrecompileTools", "Tokenize"]
66-
git-tree-sha1 = "8f5295e46f594ad2d8652f1098488a77460080cd"
59+
[[deps.JuliaFormatter]]
60+
deps = ["CSTParser", "CommonMark", "DataStructures", "Glob", "PrecompileTools", "TOML", "Tokenize"]
61+
git-tree-sha1 = "59cf7ad64f1b0708a4fa4369879d33bad3239b56"
6762
uuid = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
68-
version = "1.0.45"
69-
70-
[[LibCURL]]
71-
deps = ["LibCURL_jll", "MozillaCACerts_jll"]
72-
uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21"
73-
74-
[[LibCURL_jll]]
75-
deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"]
76-
uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0"
63+
version = "1.0.62"
7764

78-
[[LibGit2]]
79-
deps = ["Base64", "NetworkOptions", "Printf", "SHA"]
80-
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
81-
82-
[[LibSSH2_jll]]
83-
deps = ["Artifacts", "Libdl", "MbedTLS_jll"]
84-
uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8"
85-
86-
[[Libdl]]
87-
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
88-
89-
[[LinearAlgebra]]
90-
deps = ["Libdl"]
91-
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
92-
93-
[[Logging]]
94-
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
95-
96-
[[Markdown]]
65+
[[deps.Markdown]]
9766
deps = ["Base64"]
9867
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
9968

100-
[[MbedTLS_jll]]
101-
deps = ["Artifacts", "Libdl"]
102-
uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1"
103-
104-
[[Mmap]]
105-
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
106-
107-
[[MozillaCACerts_jll]]
108-
uuid = "14a3606d-f60d-562e-9121-12d972cd8159"
109-
110-
[[NetworkOptions]]
111-
uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"
112-
113-
[[OrderedCollections]]
114-
git-tree-sha1 = "dfdf5519f235516220579f949664f1bf44e741c5"
69+
[[deps.OrderedCollections]]
70+
git-tree-sha1 = "12f1439c4f986bb868acda6ea33ebc78e19b95ad"
11571
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
116-
version = "1.6.3"
117-
118-
[[Parsers]]
119-
deps = ["Dates", "PrecompileTools", "UUIDs"]
120-
git-tree-sha1 = "8489905bcdbcfac64d1daa51ca07c0d8f0283821"
121-
uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
122-
version = "2.8.1"
72+
version = "1.7.0"
12373

124-
[[Pkg]]
125-
deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"]
126-
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
127-
128-
[[PrecompileTools]]
74+
[[deps.PrecompileTools]]
12975
deps = ["Preferences"]
130-
git-tree-sha1 = "03b4c25b43cb84cee5c90aa9b5ea0a78fd848d2f"
76+
git-tree-sha1 = "5aa36f7049a63a1528fe8f7c3f2113413ffd4e1f"
13177
uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
132-
version = "1.2.0"
78+
version = "1.2.1"
13379

134-
[[Preferences]]
80+
[[deps.Preferences]]
13581
deps = ["TOML"]
136-
git-tree-sha1 = "00805cd429dcb4870060ff49ef443486c262e38e"
82+
git-tree-sha1 = "9306f6085165d270f7e3db02af26a400d580f5c6"
13783
uuid = "21216c6a-2e73-6563-6e65-726566657250"
138-
version = "1.4.1"
84+
version = "1.4.3"
13985

140-
[[Printf]]
86+
[[deps.Printf]]
14187
deps = ["Unicode"]
14288
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
14389

144-
[[REPL]]
145-
deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"]
146-
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
147-
148-
[[Random]]
149-
deps = ["Serialization"]
90+
[[deps.Random]]
91+
deps = ["SHA"]
15092
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
15193

152-
[[SHA]]
94+
[[deps.SHA]]
15395
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
96+
version = "0.7.0"
15497

155-
[[Serialization]]
156-
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
157-
158-
[[Sockets]]
159-
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
160-
161-
[[TOML]]
98+
[[deps.TOML]]
16299
deps = ["Dates"]
163100
uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
101+
version = "1.0.3"
164102

165-
[[Tar]]
166-
deps = ["ArgTools", "SHA"]
167-
uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e"
168-
169-
[[Tokenize]]
170-
git-tree-sha1 = "3ac1ac11b09e8033ec93a7993acdb9b68252be6d"
103+
[[deps.Tokenize]]
104+
git-tree-sha1 = "468b4685af4abe0e9fd4d7bf495a6554a6276e75"
171105
uuid = "0796e94c-ce3b-5d07-9a54-7f471281c624"
172-
version = "0.5.27"
173-
174-
[[URIs]]
175-
git-tree-sha1 = "67db6cc7b3821e19ebe75791a9dd19c9b1188f2b"
176-
uuid = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4"
177-
version = "1.5.1"
106+
version = "0.5.29"
178107

179-
[[UUIDs]]
108+
[[deps.UUIDs]]
180109
deps = ["Random", "SHA"]
181110
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
182111

183-
[[Unicode]]
112+
[[deps.Unicode]]
184113
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
185-
186-
[[Zlib_jll]]
187-
deps = ["Libdl"]
188-
uuid = "83775a58-1f1d-513f-b197-d71354ab007a"
189-
190-
[[nghttp2_jll]]
191-
deps = ["Artifacts", "Libdl"]
192-
uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d"
193-
194-
[[p7zip_jll]]
195-
deps = ["Artifacts", "Libdl"]
196-
uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0"

format/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# format
22

3-
Run `julia --project=format format/run.jl` with Julia 1.6 to run JuliaFormatter.
3+
Run `julia --project=format format/run.jl` with Julia 1.10 to run JuliaFormatter.
44

55
If you update the version of Julia used to generate the `Manifest.toml` make sure to also
6-
update the version in `.github/workflows/format_check.yml` to match.
6+
update the version in `.github/workflows/format_check.yml` to match.

src/MLJCatBoostInterface.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ using Tables
99
using MLJModelInterface: MLJModelInterface
1010
const MMI = MLJModelInterface
1111
using MLJModelInterface: Table, Continuous, Count, Finite, OrderedFactor, Multiclass
12+
using CategoricalArrays: CategoricalArray, CategoricalValue
1213
const PKG = "CatBoost"
1314

1415
"""
@@ -150,7 +151,7 @@ function MMI.update(mlj_model::CatBoostModels, verbosity::Integer, fitresult, ca
150151
report = (feature_importances=feature_importance(new_model),)
151152
cache = (; mlj_model=mlj_model)
152153
else
153-
new_model, cache, report = fit(mlj_model, verbosity, data_pool)
154+
new_model, cache, report = MMI.fit(mlj_model, verbosity, data_pool)
154155
end
155156

156157
return new_model, cache, report

0 commit comments

Comments
 (0)