davidberenstein1957 HF staff commited on
Commit
8395748
·
1 Parent(s): ad7d65a

fix InferenceEndpointsLLM

Browse files
src/distilabel_dataset_generator/__init__.py CHANGED
@@ -1,12 +1,64 @@
 
1
  from typing import Optional
2
 
3
  import distilabel
4
  import distilabel.distiset
 
5
  from distilabel.utils.card.dataset_card import (
6
  DistilabelDatasetCard,
7
  size_categories_parser,
8
  )
9
  from huggingface_hub import DatasetCardData, HfApi
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
 
12
  class CustomDistisetWithAdditionalTag(distilabel.distiset.Distiset):
@@ -111,3 +163,4 @@ class CustomDistisetWithAdditionalTag(distilabel.distiset.Distiset):
111
 
112
 
113
  distilabel.distiset.Distiset = CustomDistisetWithAdditionalTag
 
 
1
+ import warnings
2
  from typing import Optional
3
 
4
  import distilabel
5
  import distilabel.distiset
6
+ from distilabel.llms import InferenceEndpointsLLM
7
  from distilabel.utils.card.dataset_card import (
8
  DistilabelDatasetCard,
9
  size_categories_parser,
10
  )
11
  from huggingface_hub import DatasetCardData, HfApi
12
+ from pydantic import (
13
+ ValidationError,
14
+ model_validator,
15
+ )
16
+
17
+
18
+ class CustomInferenceEndpointsLLM(InferenceEndpointsLLM):
19
+ @model_validator(mode="after") # type: ignore
20
+ def only_one_of_model_id_endpoint_name_or_base_url_provided(
21
+ self,
22
+ ) -> "InferenceEndpointsLLM":
23
+ """Validates that only one of `model_id` or `endpoint_name` is provided; and if `base_url` is also
24
+ provided, a warning will be shown informing the user that the provided `base_url` will be ignored in
25
+ favour of the dynamically calculated one.."""
26
+
27
+ if self.base_url and (self.model_id or self.endpoint_name):
28
+ warnings.warn( # type: ignore
29
+ f"Since the `base_url={self.base_url}` is available and either one of `model_id`"
30
+ " or `endpoint_name` is also provided, the `base_url` will either be ignored"
31
+ " or overwritten with the one generated from either of those args, for serverless"
32
+ " or dedicated inference endpoints, respectively."
33
+ )
34
+
35
+ if self.use_magpie_template and self.tokenizer_id is None:
36
+ raise ValueError(
37
+ "`use_magpie_template` cannot be `True` if `tokenizer_id` is `None`. Please,"
38
+ " set a `tokenizer_id` and try again."
39
+ )
40
+
41
+ if (
42
+ self.model_id
43
+ and self.tokenizer_id is None
44
+ and self.structured_output is not None
45
+ ):
46
+ self.tokenizer_id = self.model_id
47
+
48
+ if self.base_url and not (self.model_id or self.endpoint_name):
49
+ return self
50
+
51
+ if self.model_id and not self.endpoint_name:
52
+ return self
53
+
54
+ if self.endpoint_name and not self.model_id:
55
+ return self
56
+
57
+ raise ValidationError(
58
+ f"Only one of `model_id` or `endpoint_name` must be provided. If `base_url` is"
59
+ f" provided too, it will be overwritten instead. Found `model_id`={self.model_id},"
60
+ f" `endpoint_name`={self.endpoint_name}, and `base_url`={self.base_url}."
61
+ )
62
 
63
 
64
  class CustomDistisetWithAdditionalTag(distilabel.distiset.Distiset):
 
163
 
164
 
165
  distilabel.distiset.Distiset = CustomDistisetWithAdditionalTag
166
+ distilabel.llms.InferenceEndpointsLLM = CustomInferenceEndpointsLLM