davidberenstein1957 HF staff commited on
Commit
b4ac9ca
·
1 Parent(s): a5b5003

fix: Add user creation during validation

Browse files
src/distilabel_dataset_generator/apps/sft.py CHANGED
@@ -247,18 +247,6 @@ def push_to_argilla(
247
  progress(0.1, desc="Setting up user and workspace")
248
  client = get_argilla_client()
249
  hf_user = HfApi().whoami(token=oauth_token.token)["name"]
250
-
251
- # Create user if it doesn't exist
252
- rg_user = client.users(username=hf_user)
253
- if rg_user is None:
254
- rg_user = client.users.add(rg.User(username=hf_user, role="admin"))
255
-
256
- # Create workspace if it doesn't exist
257
- workspace = client.workspaces(name=rg_user.username)
258
- if workspace is None:
259
- workspace = client.workspaces.add(rg.Workspace(name=rg_user.username))
260
- workspace.add_user(rg_user)
261
-
262
  if "messages" in dataframe.columns:
263
  settings = rg.Settings(
264
  fields=[
@@ -356,11 +344,11 @@ def push_to_argilla(
356
  dataframe["prompt_embeddings"] = get_embeddings(dataframe["prompt"])
357
 
358
  progress(0.5, desc="Creating dataset")
359
- rg_dataset = client.datasets(name=dataset_name, workspace=rg_user.username)
360
  if rg_dataset is None:
361
  rg_dataset = rg.Dataset(
362
  name=dataset_name,
363
- workspace=rg_user.username,
364
  settings=settings,
365
  client=client,
366
  )
@@ -386,6 +374,16 @@ def validate_argilla_dataset_name(
386
  client = get_argilla_client()
387
  if dataset_name is None or dataset_name == "":
388
  raise gr.Error("Dataset name is required")
 
 
 
 
 
 
 
 
 
 
389
  dataset = client.datasets(name=dataset_name, workspace=hf_user)
390
  if dataset and not add_to_existing_dataset:
391
  raise gr.Error(f"Dataset {dataset_name} already exists")
 
247
  progress(0.1, desc="Setting up user and workspace")
248
  client = get_argilla_client()
249
  hf_user = HfApi().whoami(token=oauth_token.token)["name"]
 
 
 
 
 
 
 
 
 
 
 
 
250
  if "messages" in dataframe.columns:
251
  settings = rg.Settings(
252
  fields=[
 
344
  dataframe["prompt_embeddings"] = get_embeddings(dataframe["prompt"])
345
 
346
  progress(0.5, desc="Creating dataset")
347
+ rg_dataset = client.datasets(name=dataset_name, workspace=hf_user)
348
  if rg_dataset is None:
349
  rg_dataset = rg.Dataset(
350
  name=dataset_name,
351
+ workspace=hf_user,
352
  settings=settings,
353
  client=client,
354
  )
 
374
  client = get_argilla_client()
375
  if dataset_name is None or dataset_name == "":
376
  raise gr.Error("Dataset name is required")
377
+ # Create user if it doesn't exist
378
+ rg_user = client.users(username=hf_user)
379
+ if rg_user is None:
380
+ rg_user = client.users.add(rg.User(username=hf_user, role="admin"))
381
+ # Create workspace if it doesn't exist
382
+ workspace = client.workspaces(name=hf_user)
383
+ if workspace is None:
384
+ workspace = client.workspaces.add(rg.Workspace(name=hf_user))
385
+ workspace.add_user(hf_user)
386
+ # Check if dataset exists
387
  dataset = client.datasets(name=dataset_name, workspace=hf_user)
388
  if dataset and not add_to_existing_dataset:
389
  raise gr.Error(f"Dataset {dataset_name} already exists")