kenken999's picture
fda
0f43f8a
raw
history blame
735 Bytes
import tensorflow as tf
from sklearn.model_selection import train_test_split
def train_model(processed_data):
# Split data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(processed_data.drop("target", axis=1), processed_data["target"], test_size=0.2, random_state=42)
# Define and train model
model = tf.keras.models.Sequential([
tf.keras.layers.Dense(64, activation="relu", input_shape=(X_train.shape[1],)),
tf.keras.layers.Dense(64, activation="relu"),
tf.keras.layers.Dense(1)
])
model.compile(optimizer="adam", loss="mean_squared_error")
model.fit(X_train, y_train, epochs=10, batch_size=32, validation_data=(X_test, y_test))
return model