alozowski HF staff commited on
Commit
c13b5be
·
1 Parent(s): 9213aed

Add check_official_provider_status function

Browse files
backend/app/utils/model_validation.py CHANGED
@@ -1,15 +1,14 @@
1
  import json
2
  import logging
3
  import asyncio
4
- import re
5
  from typing import Tuple, Optional, Dict, Any
6
- import aiohttp
7
  from huggingface_hub import HfApi, ModelCard, hf_hub_download
8
  from huggingface_hub import hf_api
9
  from transformers import AutoConfig, AutoTokenizer
10
- from app.config.base import HF_TOKEN, API
11
- from app.utils.logging import LogFormatter
12
-
13
 
14
  logger = logging.getLogger(__name__)
15
 
@@ -207,4 +206,61 @@ class ModelValidator:
207
  except Exception as e:
208
  if "You are trying to access a gated repo." in str(e):
209
  return True, "The model is gated and requires special access permissions.", None
210
- return False, f"The model was not found or is misconfigured on the Hub. Error: {e.args[0]}", None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import json
2
  import logging
3
  import asyncio
 
4
  from typing import Tuple, Optional, Dict, Any
5
+ from datasets import load_dataset
6
  from huggingface_hub import HfApi, ModelCard, hf_hub_download
7
  from huggingface_hub import hf_api
8
  from transformers import AutoConfig, AutoTokenizer
9
+ from app.config.base import HF_TOKEN
10
+ from app.config.hf_config import OFFICIAL_PROVIDERS_REPO
11
+ from app.core.formatting import LogFormatter
12
 
13
  logger = logging.getLogger(__name__)
14
 
 
206
  except Exception as e:
207
  if "You are trying to access a gated repo." in str(e):
208
  return True, "The model is gated and requires special access permissions.", None
209
+ return False, f"The model was not found or is misconfigured on the Hub. Error: {e.args[0]}", None
210
+
211
+ async def check_official_provider_status(
212
+ self,
213
+ model_id: str,
214
+ existing_models: Dict[str, list]
215
+ ) -> Tuple[bool, Optional[str]]:
216
+ """
217
+ Check if model is from official provider and has finished submission.
218
+
219
+ Args:
220
+ model_id: The model identifier (org/model-name)
221
+ existing_models: Dictionary of models by status from get_models()
222
+
223
+ Returns:
224
+ Tuple[bool, Optional[str]]: (is_valid, error_message)
225
+ """
226
+ try:
227
+ logger.info(LogFormatter.info(f"Checking official provider status for {model_id}"))
228
+
229
+ # Get model organization
230
+ model_org = model_id.split('/')[0] if '/' in model_id else None
231
+
232
+ if not model_org:
233
+ return True, None
234
+
235
+ # Load official providers dataset
236
+ dataset = load_dataset(OFFICIAL_PROVIDERS_REPO)
237
+ official_providers = dataset["train"][0]["CURATED_SET"]
238
+
239
+ # Check if model org is in official providers
240
+ is_official = model_org in official_providers
241
+
242
+ if is_official:
243
+ logger.info(LogFormatter.info(f"Model organization '{model_org}' is an official provider"))
244
+
245
+ # Check for finished submissions
246
+ if "finished" in existing_models:
247
+ for model in existing_models["finished"]:
248
+ if model["name"] == model_id:
249
+ error_msg = (
250
+ f"Model {model_id} is an official provider model "
251
+ f"with a completed evaluation. "
252
+ f"To re-evaluate, please open a discussion."
253
+ )
254
+ logger.error(LogFormatter.error("Validation failed", error_msg))
255
+ return False, error_msg
256
+
257
+ logger.info(LogFormatter.success("No finished submission found for this official provider model"))
258
+ else:
259
+ logger.info(LogFormatter.info(f"Model organization '{model_org}' is not an official provider"))
260
+
261
+ return True, None
262
+
263
+ except Exception as e:
264
+ error_msg = f"Failed to check official provider status: {str(e)}"
265
+ logger.error(LogFormatter.error(error_msg))
266
+ return False, error_msg