Spaces:
Sleeping
Sleeping
| import matplotlib | |
| matplotlib.use('Agg') | |
| import gradio as gr | |
| import gymnasium as gym | |
| from stable_baselines3 import SAC | |
| from stable_baselines3.common.vec_env import VecFrameStack, DummyVecEnv | |
| import os | |
| from huggingface_hub import hf_hub_download | |
| import gym_laser # Registers env name for gym.make() | |
| # Pre-trained model configurations (TODO: add models by hosting them on huggingface) | |
| PRETRAINED_MODELS = { | |
| "Random Policy": None, | |
| "Upload Custom Model": "upload", | |
| "SAC-UDR(1.5,2.5)": "sac-udr-narrow", | |
| "SAC-UDR(1.0,9.0)": "sac-udr-wide-extra", | |
| } | |
| MAX_STEPS = 100_000 # large number for continuous simulation | |
| def get_model_path(model_id): | |
| """Get the path to a pre-trained model.""" | |
| return f"pretrained-policies/{model_id}.zip" | |
| def load_pretrained_model(model_id): | |
| """Load a pre-trained model.""" | |
| model = hf_hub_download( | |
| repo_id=f"fracapuano/{model_id}", filename=f"{model_id}.zip" | |
| ) | |
| return SAC.load(model) | |
| def make_env_fn(): | |
| """Helper function to create a single environment instance.""" | |
| return gym.make("LaserEnv", render_mode="rgb_array") | |
| def initialize_environment(): | |
| """Initializes the environment on app load.""" | |
| try: | |
| env = DummyVecEnv([make_env_fn]) | |
| env = VecFrameStack(env, n_stack=5) | |
| obs = env.reset() | |
| state = { | |
| "env": env, | |
| "obs": obs, | |
| "model": None, | |
| "step_num": 0, | |
| "current_b_integral": 2.0, # Store current B-integral in state | |
| "model_filename": "Random Policy" # Default model name | |
| } | |
| return state | |
| except Exception as e: | |
| return None, f"Error: {e}" | |
| def load_selected_model(state, model_selection, uploaded_file): | |
| """Loads a model based on selection (pre-trained or uploaded).""" | |
| if state is None: | |
| return state, gr.update() | |
| try: | |
| if model_selection == "Random Policy": | |
| state["model"] = None | |
| state["model_filename"] = "Random Policy" | |
| state["obs"] = state["env"].reset() | |
| state["step_num"] = 0 | |
| return state, gr.update() | |
| elif model_selection == "Upload Custom Model": | |
| if uploaded_file is None: | |
| return state, "Please upload a model file.", gr.update() | |
| model_filename = uploaded_file.name.split('/')[-1] | |
| state["model"] = SAC.load(uploaded_file.name) | |
| state["model_filename"] = model_filename | |
| state["obs"] = state["env"].reset() | |
| state["step_num"] = 0 | |
| return state, gr.update() | |
| else: | |
| model_id = PRETRAINED_MODELS[model_selection] | |
| model = load_pretrained_model(model_id) | |
| state["model"] = model | |
| state["model_filename"] = model_selection | |
| state["obs"] = state["env"].reset() | |
| state["step_num"] = 0 | |
| return state, gr.update() | |
| except Exception as e: | |
| return state, f"Error loading model: {e}", gr.update() | |
| def update_b_integral(state, b_integral): | |
| """Updates the B-integral value in the state without restarting simulation.""" | |
| if state is not None: | |
| state["current_b_integral"] = b_integral | |
| return state | |
| def run_continuous_simulation(state): | |
| """Runs the simulation continuously, using the current B-integral from state.""" | |
| if not state or "env" not in state: | |
| yield state, None, "Environment not ready." | |
| return | |
| env = state["env"] | |
| obs = state["obs"] | |
| step_num = state.get("step_num", 0) | |
| # Run for a large number of steps to simulate "always-on" | |
| for i in range(MAX_STEPS): | |
| model = state.get("model") | |
| model_filename = state.get("model_filename", "Random Policy") | |
| current_b = state.get("current_b_integral", 2.0) | |
| # Apply the current B-integral value from state | |
| env.envs[0].unwrapped.laser.B = float(current_b) | |
| if model: | |
| action, _ = model.predict(obs, deterministic=True) | |
| else: | |
| action = env.action_space.sample().reshape(1, -1) | |
| obs, _, done, _ = env.step(action) | |
| frame = env.render() | |
| if done[0]: | |
| obs = env.reset() | |
| step_num = 0 | |
| else: | |
| step_num += 1 | |
| state["obs"] = obs | |
| state["step_num"] = step_num | |
| yield state, frame | |
| with gr.Blocks(css="body {zoom: 90%}") as demo: | |
| gr.Markdown("# Shaping Laser Pulses with Reinforcement Learning") | |
| with gr.Tab("Demo"): | |
| sim_state = gr.State() | |
| with gr.Row(): | |
| b_slider = gr.Slider( | |
| minimum=0, | |
| maximum=10, | |
| step=0.5, | |
| value=2.0, | |
| label="B-integral", | |
| info="Adjust nonlinearity live during simulation.", | |
| ) | |
| with gr.Row(): | |
| image_display = gr.Image(label="Environment Render", interactive=False, height=360) | |
| with gr.Row(): | |
| with gr.Column(): | |
| model_selector = gr.Dropdown( | |
| choices=list(PRETRAINED_MODELS.keys()), | |
| value="Random Policy", | |
| label="Model Selection", | |
| info="Choose a pre-trained model or upload your own" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| model_uploader = gr.UploadButton( | |
| "Upload Model (.zip)", | |
| file_types=['.zip'], | |
| elem_id="model-upload", | |
| visible=False # Initially hidden | |
| ) | |
| # Show/hide upload button based on selection | |
| def update_upload_visibility(selection): | |
| return gr.update(visible=(selection == "Upload Custom Model")) | |
| model_selector.change( | |
| fn=update_upload_visibility, | |
| inputs=[model_selector], | |
| outputs=[model_uploader] | |
| ) | |
| # On page load, initialize and start the continuous simulation | |
| init_event = demo.load( | |
| fn=initialize_environment, | |
| inputs=None, | |
| outputs=[sim_state] | |
| ) | |
| continuous_event = init_event.then( | |
| fn=run_continuous_simulation, | |
| inputs=[sim_state], | |
| outputs=[sim_state, image_display] | |
| ) | |
| # When model selection changes, load the selected model | |
| model_change_event = model_selector.change( | |
| fn=load_selected_model, | |
| inputs=[sim_state, model_selector, model_uploader], | |
| outputs=[sim_state, model_uploader], | |
| cancels=[continuous_event] | |
| ).then( | |
| fn=run_continuous_simulation, | |
| inputs=[sim_state], | |
| outputs=[sim_state, image_display] | |
| ) | |
| # When a custom model is uploaded, load it | |
| model_upload_event = model_uploader.upload( | |
| fn=load_selected_model, | |
| inputs=[sim_state, model_selector, model_uploader], | |
| outputs=[sim_state, model_uploader], | |
| cancels=[continuous_event] | |
| ).then( | |
| fn=run_continuous_simulation, | |
| inputs=[sim_state], | |
| outputs=[sim_state, image_display] | |
| ) | |
| # When B-integral slider changes, just update the value in state (no restart needed) | |
| b_slider.change( | |
| fn=update_b_integral, | |
| inputs=[sim_state, b_slider], | |
| outputs=[sim_state] | |
| ) | |
| with gr.Tab("About"): | |
| with open("copy.md", "r") as f: | |
| gr.Markdown(f.read()) | |
| demo.launch() |