51 lines
2.0 KiB
Python
51 lines
2.0 KiB
Python
import os
|
|
import requests
|
|
import zipfile
|
|
import tensorflow as tf
|
|
from tensorflow.keras.preprocessing.image import ImageDataGenerator
|
|
from tensorflow.keras import layers, models
|
|
import tensorflowjs as tfjs
|
|
|
|
# 1. Configuración
|
|
DATASET_URL = "https://github.com/shubham0204/Dataset_Store/raw/master/bird_species_small.zip" # Re-verificaré link o usaré otro
|
|
DATA_DIR = "training/data"
|
|
MODEL_EXPORT_DIR = "public/model"
|
|
|
|
def download_data():
|
|
print("Descargando dataset...")
|
|
# Usaremos un link alternativo más confiable si el anterior falla
|
|
# Para este ejemplo, simularemos la descarga de 3 carpetas de aves
|
|
os.makedirs(DATA_DIR, exist_ok=True)
|
|
# Aquí el usuario debería subir sus fotos o usar un link directo.
|
|
# Como demo, crearemos carpetas vacías para mostrar la estructura
|
|
categories = ['Colibri', 'Gorrion', 'Aguila']
|
|
for cat in categories:
|
|
os.makedirs(os.path.join(DATA_DIR, cat), exist_ok=True)
|
|
print(f"Estructura creada en {DATA_DIR}. Por favor, añade imágenes en las carpetas.")
|
|
|
|
def train():
|
|
print("Iniciando entrenamiento (Transfer Learning)...")
|
|
# Usamos MobileNetV2 por ser ligero
|
|
base_model = tf.keras.applications.MobileNetV2(input_shape=(224, 224, 3), include_top=False, weights='imagenet')
|
|
base_model.trainable = False
|
|
|
|
model = models.Sequential([
|
|
base_model,
|
|
layers.GlobalAveragePooling2D(),
|
|
layers.Dense(3, activation='softmax') # 3 clases de aves
|
|
])
|
|
|
|
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
|
|
|
|
# Aquí iría el model.fit(...) con el ImageDataGenerator
|
|
print("Entrenamiento completado (simulado).")
|
|
|
|
# 3. Exportar a TensorFlow.js (Lo que necesita el proyecto)
|
|
print(f"Exportando modelo a {MODEL_EXPORT_DIR}...")
|
|
tfjs.converters.save_keras_model(model, MODEL_EXPORT_DIR)
|
|
print("¡Listo! El modelo ahora puede ser usado por el Frontend.")
|
|
|
|
if __name__ == "__main__":
|
|
download_data()
|
|
# train() # Descomentar cuando haya imágenes reales
|