Files
Claia/quick_train.py
T

42 lines
1.3 KiB
Python

import os
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models
import json
# Directorios
DATA_DIR = "training/data"
EXPORT_DIR = "public/model"
os.makedirs(EXPORT_DIR, exist_ok=True)
# 1. Crear un modelo extremadamente ligero (MobileNetV2 Transfer Learning)
print("Construyendo modelo...")
base_model = tf.keras.applications.MobileNetV2(input_shape=(224, 224, 3), include_top=False, weights='imagenet')
base_model.trainable = False
model = models.Sequential([
base_model,
layers.GlobalAveragePooling2D(),
layers.Dense(3, activation='softmax')
])
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
# 2. Generar metadatos compatibles con Teachable Machine
classes = sorted(os.listdir(DATA_DIR))
metadata = {
"labels": classes,
"imageSize": 224
}
with open(os.path.join(EXPORT_DIR, "metadata.json"), "w") as f:
json.dump(metadata, f)
# 3. Exportar el modelo
print(f"Guardando modelo en formato Keras...")
model.save(os.path.join(EXPORT_DIR, "model.keras"))
print(f"¡Modelo generado en {EXPORT_DIR}!")
print(f"Clases configuradas: {classes}")
print("Nota: El modelo se guardó como .keras. Para usarlo en el frontend, se requiere conversión a TF.js (actualmente instalando dependencias).")