1212import numpy as np
1313import requests
1414from branca .colormap import ColorMap , LinearColormap , StepColormap
15- from branca .element import Element , Figure , Html , IFrame , JavascriptLink , MacroElement
15+ from branca .element import (
16+ Div ,
17+ Element ,
18+ Figure ,
19+ Html ,
20+ IFrame ,
21+ JavascriptLink ,
22+ MacroElement ,
23+ )
1624from branca .utilities import color_brewer
1725
1826from folium .elements import JSCSSMixin
1927from folium .folium import Map
2028from folium .map import FeatureGroup , Icon , Layer , Marker , Popup , Tooltip
2129from folium .template import Template
2230from folium .utilities import (
31+ TypeBoundsReturn ,
32+ TypeContainer ,
2333 TypeJsonValue ,
2434 TypeLine ,
2535 TypePathOptions ,
@@ -165,7 +175,7 @@ def __init__(
165175 self .top = _parse_size (top )
166176 self .position = position
167177
168- def render (self , ** kwargs ) -> None :
178+ def render (self , ** kwargs ):
169179 """Renders the HTML representation of the element."""
170180 super ().render (** kwargs )
171181
@@ -284,9 +294,15 @@ def __init__(
284294 self .top = _parse_size (top )
285295 self .position = position
286296
287- def render (self , ** kwargs ) -> None :
297+ def render (self , ** kwargs ):
288298 """Renders the HTML representation of the element."""
289- self ._parent .html .add_child (
299+ parent = self ._parent
300+ if not isinstance (parent , (Figure , Div , Popup )):
301+ raise TypeError (
302+ "VegaLite elements can only be added to a Figure, Div, or Popup"
303+ )
304+
305+ parent .html .add_child (
290306 Element (
291307 Template (
292308 """
@@ -331,7 +347,7 @@ def render(self, **kwargs) -> None:
331347 embed_vegalite = embed_mapping .get (
332348 self .vegalite_major_version , self ._embed_vegalite_v2
333349 )
334- embed_vegalite (figure )
350+ embed_vegalite (figure = figure , parent = parent )
335351
336352 @property
337353 def vegalite_major_version (self ) -> Optional [int ]:
@@ -342,8 +358,8 @@ def vegalite_major_version(self) -> Optional[int]:
342358
343359 return int (schema .split ("/" )[- 1 ].split ("." )[0 ].lstrip ("v" ))
344360
345- def _embed_vegalite_v5 (self , figure : Figure ) -> None :
346- self ._vega_embed ()
361+ def _embed_vegalite_v5 (self , figure : Figure , parent : TypeContainer ) -> None :
362+ self ._vega_embed (parent = parent )
347363
348364 figure .header .add_child (
349365 JavascriptLink ("https://cdn.jsdelivr.net/npm//vega@5" ), name = "vega"
@@ -356,8 +372,8 @@ def _embed_vegalite_v5(self, figure: Figure) -> None:
356372 name = "vega-embed" ,
357373 )
358374
359- def _embed_vegalite_v4 (self , figure : Figure ) -> None :
360- self ._vega_embed ()
375+ def _embed_vegalite_v4 (self , figure : Figure , parent : TypeContainer ) -> None :
376+ self ._vega_embed (parent = parent )
361377
362378 figure .header .add_child (
363379 JavascriptLink ("https://cdn.jsdelivr.net/npm//vega@5" ), name = "vega"
@@ -370,8 +386,8 @@ def _embed_vegalite_v4(self, figure: Figure) -> None:
370386 name = "vega-embed" ,
371387 )
372388
373- def _embed_vegalite_v3 (self , figure : Figure ) -> None :
374- self ._vega_embed ()
389+ def _embed_vegalite_v3 (self , figure : Figure , parent : TypeContainer ) -> None :
390+ self ._vega_embed (parent = parent )
375391
376392 figure .header .add_child (
377393 JavascriptLink ("https://cdn.jsdelivr.net/npm/vega@4" ), name = "vega"
@@ -384,8 +400,8 @@ def _embed_vegalite_v3(self, figure: Figure) -> None:
384400 name = "vega-embed" ,
385401 )
386402
387- def _embed_vegalite_v2 (self , figure : Figure ) -> None :
388- self ._vega_embed ()
403+ def _embed_vegalite_v2 (self , figure : Figure , parent : TypeContainer ) -> None :
404+ self ._vega_embed (parent = parent )
389405
390406 figure .header .add_child (
391407 JavascriptLink ("https://cdn.jsdelivr.net/npm/vega@3" ), name = "vega"
@@ -398,8 +414,8 @@ def _embed_vegalite_v2(self, figure: Figure) -> None:
398414 name = "vega-embed" ,
399415 )
400416
401- def _vega_embed (self ) -> None :
402- self . _parent .script .add_child (
417+ def _vega_embed (self , parent : TypeContainer ) -> None :
418+ parent .script .add_child (
403419 Element (
404420 Template (
405421 """
@@ -412,8 +428,8 @@ def _vega_embed(self) -> None:
412428 name = self .get_name (),
413429 )
414430
415- def _embed_vegalite_v1 (self , figure : Figure ) -> None :
416- self . _parent .script .add_child (
431+ def _embed_vegalite_v1 (self , figure : Figure , parent : TypeContainer ) -> None :
432+ parent .script .add_child (
417433 Element (
418434 Template (
419435 """
@@ -436,19 +452,19 @@ def _embed_vegalite_v1(self, figure: Figure) -> None:
436452 figure .header .add_child (
437453 JavascriptLink ("https://cdnjs.cloudflare.com/ajax/libs/vega/2.6.5/vega.js" ),
438454 name = "vega" ,
439- ) # noqa
455+ )
440456 figure .header .add_child (
441457 JavascriptLink (
442458 "https://cdnjs.cloudflare.com/ajax/libs/vega-lite/1.3.1/vega-lite.js"
443459 ),
444460 name = "vega-lite" ,
445- ) # noqa
461+ )
446462 figure .header .add_child (
447463 JavascriptLink (
448464 "https://cdnjs.cloudflare.com/ajax/libs/vega-embed/2.2.0/vega-embed.js"
449465 ),
450466 name = "vega-embed" ,
451- ) # noqa
467+ )
452468
453469
454470class GeoJson (Layer ):
@@ -820,7 +836,7 @@ def _get_self_bounds(self) -> List[List[Optional[float]]]:
820836 """
821837 return get_bounds (self .data , lonlat = True )
822838
823- def render (self , ** kwargs ) -> None :
839+ def render (self , ** kwargs ):
824840 self .parent_map = get_obj_in_upper_tree (self , Map )
825841 # Need at least one feature, otherwise style mapping fails
826842 if (self .style or self .highlight ) and self .data ["features" ]:
@@ -1041,12 +1057,12 @@ def recursive_get(data, keys):
10411057 self .style_function (feature )
10421058 ) # noqa
10431059
1044- def render (self , ** kwargs ) -> None :
1060+ def render (self , ** kwargs ):
10451061 """Renders the HTML representation of the element."""
10461062 self .style_data ()
10471063 super ().render (** kwargs )
10481064
1049- def get_bounds (self ) -> List [ List [ float ]] :
1065+ def get_bounds (self ) -> TypeBoundsReturn :
10501066 """
10511067 Computes the bounds of the object itself (not including it's children)
10521068 in the form [[lat_min, lon_min], [lat_max, lon_max]]
@@ -1146,6 +1162,7 @@ def __init__(
11461162
11471163 def warn_for_geometry_collections (self ) -> None :
11481164 """Checks for GeoJson GeometryCollection features to warn user about incompatibility."""
1165+ assert isinstance (self ._parent , GeoJson )
11491166 geom_collections = [
11501167 feature .get ("properties" ) if feature .get ("properties" ) is not None else key
11511168 for key , feature in enumerate (self ._parent .data ["features" ])
@@ -1160,7 +1177,7 @@ def warn_for_geometry_collections(self) -> None:
11601177 UserWarning ,
11611178 )
11621179
1163- def render (self , ** kwargs ) -> None :
1180+ def render (self , ** kwargs ):
11641181 """Renders the HTML representation of the element."""
11651182 figure = self .get_root ()
11661183 if isinstance (self ._parent , GeoJson ):
@@ -1565,7 +1582,7 @@ def __init__(
15651582 color_range = color_brewer (fill_color , n = nb_bins )
15661583 self .color_scale = StepColormap (
15671584 color_range ,
1568- index = bin_edges ,
1585+ index = list ( bin_edges ) ,
15691586 vmin = bins_min ,
15701587 vmax = bins_max ,
15711588 caption = legend_name ,
@@ -1625,7 +1642,7 @@ def highlight_function(x):
16251642 return {"weight" : line_weight + 2 , "fillOpacity" : fill_opacity + 0.2 }
16261643
16271644 if topojson :
1628- self .geojson = TopoJson (
1645+ self .geojson : Union [ TopoJson , GeoJson ] = TopoJson (
16291646 geo_data ,
16301647 topojson ,
16311648 style_function = style_function ,
@@ -1657,7 +1674,7 @@ def _get_by_key(cls, obj: Union[dict, list], key: str) -> Union[float, str, None
16571674 else :
16581675 return value
16591676
1660- def render (self , ** kwargs ) -> None :
1677+ def render (self , ** kwargs ):
16611678 """Render the GeoJson/TopoJson and color scale objects."""
16621679 if self .color_scale :
16631680 # ColorMap needs Map as its parent
@@ -1963,8 +1980,13 @@ def __init__(
19631980 vmin = min (colors ),
19641981 vmax = max (colors ),
19651982 ).to_step (nb_steps )
1966- else :
1983+ elif isinstance ( colormap , StepColormap ) :
19671984 cm = colormap
1985+ else :
1986+ raise TypeError (
1987+ f"Unexpected type for argument `colormap`: { type (colormap )} "
1988+ )
1989+
19681990 out : Dict [str , List [List [List [float ]]]] = {}
19691991 for (lat1 , lng1 ), (lat2 , lng2 ), color in zip (coords [:- 1 ], coords [1 :], colors ):
19701992 out .setdefault (cm (color ), []).append ([[lat1 , lng1 ], [lat2 , lng2 ]])
0 commit comments