Spaces:
Runtime error
Runtime error
import hashlib | |
import json | |
import tempfile | |
from copy import deepcopy | |
from pathlib import Path | |
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union | |
from langchain.callbacks.base import BaseCallbackHandler | |
from langchain.schema import AgentAction, AgentFinish, LLMResult | |
def import_wandb() -> Any: | |
try: | |
import wandb # noqa: F401 | |
except ImportError: | |
raise ImportError( | |
"To use the wandb callback manager you need to have the `wandb` python " | |
"package installed. Please install it with `pip install wandb`" | |
) | |
return wandb | |
def import_spacy() -> Any: | |
try: | |
import spacy # noqa: F401 | |
except ImportError: | |
raise ImportError( | |
"To use the wandb callback manager you need to have the `spacy` python " | |
"package installed. Please install it with `pip install spacy`" | |
) | |
return spacy | |
def import_pandas() -> Any: | |
try: | |
import pandas # noqa: F401 | |
except ImportError: | |
raise ImportError( | |
"To use the wandb callback manager you need to have the `pandas` python " | |
"package installed. Please install it with `pip install pandas`" | |
) | |
return pandas | |
def import_textstat() -> Any: | |
try: | |
import textstat # noqa: F401 | |
except ImportError: | |
raise ImportError( | |
"To use the wandb callback manager you need to have the `textstat` python " | |
"package installed. Please install it with `pip install textstat`" | |
) | |
return textstat | |
def _flatten_dict( | |
nested_dict: Dict[str, Any], parent_key: str = "", sep: str = "_" | |
) -> Iterable[Tuple[str, Any]]: | |
""" | |
Generator that yields flattened items from a nested dictionary for a flat dict. | |
Parameters: | |
nested_dict (dict): The nested dictionary to flatten. | |
parent_key (str): The prefix to prepend to the keys of the flattened dict. | |
sep (str): The separator to use between the parent key and the key of the | |
flattened dictionary. | |
Yields: | |
(str, any): A key-value pair from the flattened dictionary. | |
""" | |
for key, value in nested_dict.items(): | |
new_key = parent_key + sep + key if parent_key else key | |
if isinstance(value, dict): | |
yield from _flatten_dict(value, new_key, sep) | |
else: | |
yield new_key, value | |
def flatten_dict( | |
nested_dict: Dict[str, Any], parent_key: str = "", sep: str = "_" | |
) -> Dict[str, Any]: | |
"""Flattens a nested dictionary into a flat dictionary. | |
Parameters: | |
nested_dict (dict): The nested dictionary to flatten. | |
parent_key (str): The prefix to prepend to the keys of the flattened dict. | |
sep (str): The separator to use between the parent key and the key of the | |
flattened dictionary. | |
Returns: | |
(dict): A flat dictionary. | |
""" | |
flat_dict = {k: v for k, v in _flatten_dict(nested_dict, parent_key, sep)} | |
return flat_dict | |
def hash_string(s: str) -> str: | |
"""Hash a string using sha1. | |
Parameters: | |
s (str): The string to hash. | |
Returns: | |
(str): The hashed string. | |
""" | |
return hashlib.sha1(s.encode("utf-8")).hexdigest() | |
def load_json_to_dict(json_path: Union[str, Path]) -> dict: | |
"""Load json file to a dictionary. | |
Parameters: | |
json_path (str): The path to the json file. | |
Returns: | |
(dict): The dictionary representation of the json file. | |
""" | |
with open(json_path, "r") as f: | |
data = json.load(f) | |
return data | |
def analyze_text( | |
text: str, | |
complexity_metrics: bool = True, | |
visualize: bool = True, | |
nlp: Any = None, | |
output_dir: Optional[Union[str, Path]] = None, | |
) -> dict: | |
"""Analyze text using textstat and spacy. | |
Parameters: | |
text (str): The text to analyze. | |
complexity_metrics (bool): Whether to compute complexity metrics. | |
visualize (bool): Whether to visualize the text. | |
nlp (spacy.lang): The spacy language model to use for visualization. | |
output_dir (str): The directory to save the visualization files to. | |
Returns: | |
(dict): A dictionary containing the complexity metrics and visualization | |
files serialized in a wandb.Html element. | |
""" | |
resp = {} | |
textstat = import_textstat() | |
wandb = import_wandb() | |
spacy = import_spacy() | |
if complexity_metrics: | |
text_complexity_metrics = { | |
"flesch_reading_ease": textstat.flesch_reading_ease(text), | |
"flesch_kincaid_grade": textstat.flesch_kincaid_grade(text), | |
"smog_index": textstat.smog_index(text), | |
"coleman_liau_index": textstat.coleman_liau_index(text), | |
"automated_readability_index": textstat.automated_readability_index(text), | |
"dale_chall_readability_score": textstat.dale_chall_readability_score(text), | |
"difficult_words": textstat.difficult_words(text), | |
"linsear_write_formula": textstat.linsear_write_formula(text), | |
"gunning_fog": textstat.gunning_fog(text), | |
"text_standard": textstat.text_standard(text), | |
"fernandez_huerta": textstat.fernandez_huerta(text), | |
"szigriszt_pazos": textstat.szigriszt_pazos(text), | |
"gutierrez_polini": textstat.gutierrez_polini(text), | |
"crawford": textstat.crawford(text), | |
"gulpease_index": textstat.gulpease_index(text), | |
"osman": textstat.osman(text), | |
} | |
resp.update(text_complexity_metrics) | |
if visualize and nlp and output_dir is not None: | |
doc = nlp(text) | |
dep_out = spacy.displacy.render( # type: ignore | |
doc, style="dep", jupyter=False, page=True | |
) | |
dep_output_path = Path(output_dir, hash_string(f"dep-{text}") + ".html") | |
dep_output_path.open("w", encoding="utf-8").write(dep_out) | |
ent_out = spacy.displacy.render( # type: ignore | |
doc, style="ent", jupyter=False, page=True | |
) | |
ent_output_path = Path(output_dir, hash_string(f"ent-{text}") + ".html") | |
ent_output_path.open("w", encoding="utf-8").write(ent_out) | |
text_visualizations = { | |
"dependency_tree": wandb.Html(str(dep_output_path)), | |
"entities": wandb.Html(str(ent_output_path)), | |
} | |
resp.update(text_visualizations) | |
return resp | |
def construct_html_from_prompt_and_generation(prompt: str, generation: str) -> Any: | |
"""Construct an html element from a prompt and a generation. | |
Parameters: | |
prompt (str): The prompt. | |
generation (str): The generation. | |
Returns: | |
(wandb.Html): The html element.""" | |
wandb = import_wandb() | |
formatted_prompt = prompt.replace("\n", "<br>") | |
formatted_generation = generation.replace("\n", "<br>") | |
return wandb.Html( | |
f""" | |
<p style="color:black;">{formatted_prompt}:</p> | |
<blockquote> | |
<p style="color:green;"> | |
{formatted_generation} | |
</p> | |
</blockquote> | |
""", | |
inject=False, | |
) | |
class BaseMetadataCallbackHandler: | |
"""This class handles the metadata and associated function states for callbacks. | |
Attributes: | |
step (int): The current step. | |
starts (int): The number of times the start method has been called. | |
ends (int): The number of times the end method has been called. | |
errors (int): The number of times the error method has been called. | |
text_ctr (int): The number of times the text method has been called. | |
ignore_llm_ (bool): Whether to ignore llm callbacks. | |
ignore_chain_ (bool): Whether to ignore chain callbacks. | |
ignore_agent_ (bool): Whether to ignore agent callbacks. | |
always_verbose_ (bool): Whether to always be verbose. | |
chain_starts (int): The number of times the chain start method has been called. | |
chain_ends (int): The number of times the chain end method has been called. | |
llm_starts (int): The number of times the llm start method has been called. | |
llm_ends (int): The number of times the llm end method has been called. | |
llm_streams (int): The number of times the text method has been called. | |
tool_starts (int): The number of times the tool start method has been called. | |
tool_ends (int): The number of times the tool end method has been called. | |
agent_ends (int): The number of times the agent end method has been called. | |
on_llm_start_records (list): A list of records of the on_llm_start method. | |
on_llm_token_records (list): A list of records of the on_llm_token method. | |
on_llm_end_records (list): A list of records of the on_llm_end method. | |
on_chain_start_records (list): A list of records of the on_chain_start method. | |
on_chain_end_records (list): A list of records of the on_chain_end method. | |
on_tool_start_records (list): A list of records of the on_tool_start method. | |
on_tool_end_records (list): A list of records of the on_tool_end method. | |
on_agent_finish_records (list): A list of records of the on_agent_end method. | |
""" | |
def __init__(self) -> None: | |
self.step = 0 | |
self.starts = 0 | |
self.ends = 0 | |
self.errors = 0 | |
self.text_ctr = 0 | |
self.ignore_llm_ = False | |
self.ignore_chain_ = False | |
self.ignore_agent_ = False | |
self.always_verbose_ = False | |
self.chain_starts = 0 | |
self.chain_ends = 0 | |
self.llm_starts = 0 | |
self.llm_ends = 0 | |
self.llm_streams = 0 | |
self.tool_starts = 0 | |
self.tool_ends = 0 | |
self.agent_ends = 0 | |
self.on_llm_start_records: list = [] | |
self.on_llm_token_records: list = [] | |
self.on_llm_end_records: list = [] | |
self.on_chain_start_records: list = [] | |
self.on_chain_end_records: list = [] | |
self.on_tool_start_records: list = [] | |
self.on_tool_end_records: list = [] | |
self.on_text_records: list = [] | |
self.on_agent_finish_records: list = [] | |
self.on_agent_action_records: list = [] | |
def always_verbose(self) -> bool: | |
"""Whether to call verbose callbacks even if verbose is False.""" | |
return self.always_verbose_ | |
def ignore_llm(self) -> bool: | |
"""Whether to ignore LLM callbacks.""" | |
return self.ignore_llm_ | |
def ignore_chain(self) -> bool: | |
"""Whether to ignore chain callbacks.""" | |
return self.ignore_chain_ | |
def ignore_agent(self) -> bool: | |
"""Whether to ignore agent callbacks.""" | |
return self.ignore_agent_ | |
def get_custom_callback_meta(self) -> Dict[str, Any]: | |
return { | |
"step": self.step, | |
"starts": self.starts, | |
"ends": self.ends, | |
"errors": self.errors, | |
"text_ctr": self.text_ctr, | |
"chain_starts": self.chain_starts, | |
"chain_ends": self.chain_ends, | |
"llm_starts": self.llm_starts, | |
"llm_ends": self.llm_ends, | |
"llm_streams": self.llm_streams, | |
"tool_starts": self.tool_starts, | |
"tool_ends": self.tool_ends, | |
"agent_ends": self.agent_ends, | |
} | |
def reset_callback_meta(self) -> None: | |
"""Reset the callback metadata.""" | |
self.step = 0 | |
self.starts = 0 | |
self.ends = 0 | |
self.errors = 0 | |
self.text_ctr = 0 | |
self.ignore_llm_ = False | |
self.ignore_chain_ = False | |
self.ignore_agent_ = False | |
self.always_verbose_ = False | |
self.chain_starts = 0 | |
self.chain_ends = 0 | |
self.llm_starts = 0 | |
self.llm_ends = 0 | |
self.llm_streams = 0 | |
self.tool_starts = 0 | |
self.tool_ends = 0 | |
self.agent_ends = 0 | |
self.on_llm_start_records = [] | |
self.on_llm_token_records = [] | |
self.on_llm_end_records = [] | |
self.on_chain_start_records = [] | |
self.on_chain_end_records = [] | |
self.on_tool_start_records = [] | |
self.on_tool_end_records = [] | |
self.on_text_records = [] | |
self.on_agent_finish_records = [] | |
self.on_agent_action_records = [] | |
return None | |
class WandbCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): | |
"""Callback Handler that logs to Weights and Biases. | |
Parameters: | |
job_type (str): The type of job. | |
project (str): The project to log to. | |
entity (str): The entity to log to. | |
tags (list): The tags to log. | |
group (str): The group to log to. | |
name (str): The name of the run. | |
notes (str): The notes to log. | |
visualize (bool): Whether to visualize the run. | |
complexity_metrics (bool): Whether to log complexity metrics. | |
stream_logs (bool): Whether to stream callback actions to W&B | |
This handler will utilize the associated callback method called and formats | |
the input of each callback function with metadata regarding the state of LLM run, | |
and adds the response to the list of records for both the {method}_records and | |
action. It then logs the response using the run.log() method to Weights and Biases. | |
""" | |
def __init__( | |
self, | |
job_type: Optional[str] = None, | |
project: Optional[str] = "langchain_callback_demo", | |
entity: Optional[str] = None, | |
tags: Optional[Sequence] = None, | |
group: Optional[str] = None, | |
name: Optional[str] = None, | |
notes: Optional[str] = None, | |
visualize: bool = False, | |
complexity_metrics: bool = False, | |
stream_logs: bool = False, | |
) -> None: | |
"""Initialize callback handler.""" | |
wandb = import_wandb() | |
import_pandas() | |
import_textstat() | |
spacy = import_spacy() | |
super().__init__() | |
self.job_type = job_type | |
self.project = project | |
self.entity = entity | |
self.tags = tags | |
self.group = group | |
self.name = name | |
self.notes = notes | |
self.visualize = visualize | |
self.complexity_metrics = complexity_metrics | |
self.stream_logs = stream_logs | |
self.temp_dir = tempfile.TemporaryDirectory() | |
self.run: wandb.sdk.wandb_run.Run = wandb.init( # type: ignore | |
job_type=self.job_type, | |
project=self.project, | |
entity=self.entity, | |
tags=self.tags, | |
group=self.group, | |
name=self.name, | |
notes=self.notes, | |
) | |
warning = ( | |
"The wandb callback is currently in beta and is subject to change " | |
"based on updates to `langchain`. Please report any issues to " | |
"https://github.com/wandb/wandb/issues with the tag `langchain`." | |
) | |
wandb.termwarn( | |
warning, | |
repeat=False, | |
) | |
self.callback_columns: list = [] | |
self.action_records: list = [] | |
self.complexity_metrics = complexity_metrics | |
self.visualize = visualize | |
self.nlp = spacy.load("en_core_web_sm") | |
def _init_resp(self) -> Dict: | |
return {k: None for k in self.callback_columns} | |
def on_llm_start( | |
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any | |
) -> None: | |
"""Run when LLM starts.""" | |
self.step += 1 | |
self.llm_starts += 1 | |
self.starts += 1 | |
resp = self._init_resp() | |
resp.update({"action": "on_llm_start"}) | |
resp.update(flatten_dict(serialized)) | |
resp.update(self.get_custom_callback_meta()) | |
for prompt in prompts: | |
prompt_resp = deepcopy(resp) | |
prompt_resp["prompts"] = prompt | |
self.on_llm_start_records.append(prompt_resp) | |
self.action_records.append(prompt_resp) | |
if self.stream_logs: | |
self.run.log(prompt_resp) | |
def on_llm_new_token(self, token: str, **kwargs: Any) -> None: | |
"""Run when LLM generates a new token.""" | |
self.step += 1 | |
self.llm_streams += 1 | |
resp = self._init_resp() | |
resp.update({"action": "on_llm_new_token", "token": token}) | |
resp.update(self.get_custom_callback_meta()) | |
self.on_llm_token_records.append(resp) | |
self.action_records.append(resp) | |
if self.stream_logs: | |
self.run.log(resp) | |
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: | |
"""Run when LLM ends running.""" | |
self.step += 1 | |
self.llm_ends += 1 | |
self.ends += 1 | |
resp = self._init_resp() | |
resp.update({"action": "on_llm_end"}) | |
resp.update(flatten_dict(response.llm_output or {})) | |
resp.update(self.get_custom_callback_meta()) | |
for generations in response.generations: | |
for generation in generations: | |
generation_resp = deepcopy(resp) | |
generation_resp.update(flatten_dict(generation.dict())) | |
generation_resp.update( | |
analyze_text( | |
generation.text, | |
complexity_metrics=self.complexity_metrics, | |
visualize=self.visualize, | |
nlp=self.nlp, | |
output_dir=self.temp_dir.name, | |
) | |
) | |
self.on_llm_end_records.append(generation_resp) | |
self.action_records.append(generation_resp) | |
if self.stream_logs: | |
self.run.log(generation_resp) | |
def on_llm_error( | |
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any | |
) -> None: | |
"""Run when LLM errors.""" | |
self.step += 1 | |
self.errors += 1 | |
def on_chain_start( | |
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any | |
) -> None: | |
"""Run when chain starts running.""" | |
self.step += 1 | |
self.chain_starts += 1 | |
self.starts += 1 | |
resp = self._init_resp() | |
resp.update({"action": "on_chain_start"}) | |
resp.update(flatten_dict(serialized)) | |
resp.update(self.get_custom_callback_meta()) | |
chain_input = inputs["input"] | |
if isinstance(chain_input, str): | |
input_resp = deepcopy(resp) | |
input_resp["input"] = chain_input | |
self.on_chain_start_records.append(input_resp) | |
self.action_records.append(input_resp) | |
if self.stream_logs: | |
self.run.log(input_resp) | |
elif isinstance(chain_input, list): | |
for inp in chain_input: | |
input_resp = deepcopy(resp) | |
input_resp.update(inp) | |
self.on_chain_start_records.append(input_resp) | |
self.action_records.append(input_resp) | |
if self.stream_logs: | |
self.run.log(input_resp) | |
else: | |
raise ValueError("Unexpected data format provided!") | |
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: | |
"""Run when chain ends running.""" | |
self.step += 1 | |
self.chain_ends += 1 | |
self.ends += 1 | |
resp = self._init_resp() | |
resp.update({"action": "on_chain_end", "outputs": outputs["output"]}) | |
resp.update(self.get_custom_callback_meta()) | |
self.on_chain_end_records.append(resp) | |
self.action_records.append(resp) | |
if self.stream_logs: | |
self.run.log(resp) | |
def on_chain_error( | |
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any | |
) -> None: | |
"""Run when chain errors.""" | |
self.step += 1 | |
self.errors += 1 | |
def on_tool_start( | |
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any | |
) -> None: | |
"""Run when tool starts running.""" | |
self.step += 1 | |
self.tool_starts += 1 | |
self.starts += 1 | |
resp = self._init_resp() | |
resp.update({"action": "on_tool_start", "input_str": input_str}) | |
resp.update(flatten_dict(serialized)) | |
resp.update(self.get_custom_callback_meta()) | |
self.on_tool_start_records.append(resp) | |
self.action_records.append(resp) | |
if self.stream_logs: | |
self.run.log(resp) | |
def on_tool_end(self, output: str, **kwargs: Any) -> None: | |
"""Run when tool ends running.""" | |
self.step += 1 | |
self.tool_ends += 1 | |
self.ends += 1 | |
resp = self._init_resp() | |
resp.update({"action": "on_tool_end", "output": output}) | |
resp.update(self.get_custom_callback_meta()) | |
self.on_tool_end_records.append(resp) | |
self.action_records.append(resp) | |
if self.stream_logs: | |
self.run.log(resp) | |
def on_tool_error( | |
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any | |
) -> None: | |
"""Run when tool errors.""" | |
self.step += 1 | |
self.errors += 1 | |
def on_text(self, text: str, **kwargs: Any) -> None: | |
""" | |
Run when agent is ending. | |
""" | |
self.step += 1 | |
self.text_ctr += 1 | |
resp = self._init_resp() | |
resp.update({"action": "on_text", "text": text}) | |
resp.update(self.get_custom_callback_meta()) | |
self.on_text_records.append(resp) | |
self.action_records.append(resp) | |
if self.stream_logs: | |
self.run.log(resp) | |
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None: | |
"""Run when agent ends running.""" | |
self.step += 1 | |
self.agent_ends += 1 | |
self.ends += 1 | |
resp = self._init_resp() | |
resp.update( | |
{ | |
"action": "on_agent_finish", | |
"output": finish.return_values["output"], | |
"log": finish.log, | |
} | |
) | |
resp.update(self.get_custom_callback_meta()) | |
self.on_agent_finish_records.append(resp) | |
self.action_records.append(resp) | |
if self.stream_logs: | |
self.run.log(resp) | |
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: | |
"""Run on agent action.""" | |
self.step += 1 | |
self.tool_starts += 1 | |
self.starts += 1 | |
resp = self._init_resp() | |
resp.update( | |
{ | |
"action": "on_agent_action", | |
"tool": action.tool, | |
"tool_input": action.tool_input, | |
"log": action.log, | |
} | |
) | |
resp.update(self.get_custom_callback_meta()) | |
self.on_agent_action_records.append(resp) | |
self.action_records.append(resp) | |
if self.stream_logs: | |
self.run.log(resp) | |
def _create_session_analysis_df(self) -> Any: | |
"""Create a dataframe with all the information from the session.""" | |
pd = import_pandas() | |
on_llm_start_records_df = pd.DataFrame(self.on_llm_start_records) | |
on_llm_end_records_df = pd.DataFrame(self.on_llm_end_records) | |
selected_columns = ["step", "prompts", "name"] | |
on_llm_start_records_selected_df = on_llm_start_records_df[selected_columns] | |
llm_input_prompts_df = ( | |
on_llm_start_records_selected_df | |
.dropna(axis=1) | |
.rename({"step": "prompt_step"}, axis=1) | |
) | |
complexity_metrics_columns = [] | |
visualizations_columns = [] | |
if self.complexity_metrics: | |
complexity_metrics_columns = [ | |
"flesch_reading_ease", | |
"flesch_kincaid_grade", | |
"smog_index", | |
"coleman_liau_index", | |
"automated_readability_index", | |
"dale_chall_readability_score", | |
"difficult_words", | |
"linsear_write_formula", | |
"gunning_fog", | |
"text_standard", | |
"fernandez_huerta", | |
"szigriszt_pazos", | |
"gutierrez_polini", | |
"crawford", | |
"gulpease_index", | |
"osman", | |
] | |
if self.visualize: | |
visualizations_columns = ["dependency_tree", "entities"] | |
llm_outputs_df = ( | |
on_llm_end_records_df[ | |
[ | |
"step", | |
"text", | |
"token_usage_total_tokens", | |
"token_usage_prompt_tokens", | |
"token_usage_completion_tokens", | |
] | |
+ complexity_metrics_columns | |
+ visualizations_columns | |
] | |
.dropna(axis=1) | |
.rename({"step": "output_step", "text": "output"}, axis=1) | |
) | |
session_analysis_df = pd.concat([llm_input_prompts_df, llm_outputs_df], axis=1) | |
session_analysis_df["chat_html"] = session_analysis_df[ | |
["prompts", "output"] | |
].apply( | |
lambda row: construct_html_from_prompt_and_generation( | |
row["prompts"], row["output"] | |
), | |
axis=1, | |
) | |
return session_analysis_df | |
def flush_tracker( | |
self, | |
langchain_asset: Any = None, | |
reset: bool = True, | |
finish: bool = False, | |
job_type: Optional[str] = None, | |
project: Optional[str] = None, | |
entity: Optional[str] = None, | |
tags: Optional[Sequence] = None, | |
group: Optional[str] = None, | |
name: Optional[str] = None, | |
notes: Optional[str] = None, | |
visualize: Optional[bool] = None, | |
complexity_metrics: Optional[bool] = None, | |
) -> None: | |
"""Flush the tracker and reset the session. | |
Args: | |
langchain_asset: The langchain asset to save. | |
reset: Whether to reset the session. | |
finish: Whether to finish the run. | |
job_type: The job type. | |
project: The project. | |
entity: The entity. | |
tags: The tags. | |
group: The group. | |
name: The name. | |
notes: The notes. | |
visualize: Whether to visualize. | |
complexity_metrics: Whether to compute complexity metrics. | |
Returns: | |
None | |
""" | |
pd = import_pandas() | |
wandb = import_wandb() | |
action_records_table = wandb.Table(dataframe=pd.DataFrame(self.action_records)) | |
session_analysis_dataframe = self._create_session_analysis_df() | |
session_analysis_table = wandb.Table( | |
dataframe=session_analysis_dataframe | |
) | |
self.run.log( | |
{ | |
"action_records": action_records_table, | |
"session_analysis": session_analysis_table, | |
} | |
) | |
if langchain_asset: | |
langchain_asset_path = Path(self.temp_dir.name, "model.json") | |
model_artifact = wandb.Artifact(name="model", type="model") | |
model_artifact.add(action_records_table, name="action_records") | |
model_artifact.add(session_analysis_table, name="session_analysis") | |
try: | |
langchain_asset.save_agent(langchain_asset_path) | |
model_artifact.add_file(str(langchain_asset_path)) | |
model_artifact.metadata = load_json_to_dict(langchain_asset_path) | |
except ValueError: | |
langchain_asset.save_agent(langchain_asset_path) | |
model_artifact.add_file(str(langchain_asset_path)) | |
model_artifact.metadata = load_json_to_dict(langchain_asset_path) | |
except NotImplementedError as e: | |
print("Could not save model.") | |
print(repr(e)) | |
pass | |
self.run.log_artifact(model_artifact) | |
if finish or reset: | |
self.run.finish() | |
self.temp_dir.cleanup() | |
self.reset_callback_meta() | |
if reset: | |
self.__init__( # type: ignore | |
job_type=job_type if job_type else self.job_type, | |
project=project if project else self.project, | |
entity=entity if entity else self.entity, | |
tags=tags if tags else self.tags, | |
group=group if group else self.group, | |
name=name if name else self.name, | |
notes=notes if notes else self.notes, | |
visualize=visualize if visualize else self.visualize, | |
complexity_metrics=complexity_metrics | |
if complexity_metrics | |
else self.complexity_metrics, | |
) | |