| def cleanup_code( | |
| code: str, | |
| language_type: str = None, | |
| dataset: str = None, | |
| issft: bool = False, | |
| stop_words = [] | |
| ): | |
| """ | |
| Cleans up the generated code. | |
| """ | |
| if language_type.lower() == "python": | |
| if issft: | |
| code = _clean_python_code_for_sft(code) | |
| stop_words = ["\ndef", "\nclass", "\nif", "\n#", "\nprint"] | |
| code = _truncate_code_at_stopwords(code, stop_words) | |
| elif language_type.lower() == "ts": | |
| code = _truncate_code_at_stopwords(code, stop_words + ["\nexport", "\nimport", "\nexport default", "\nimport default", "\nconsole.log"]) | |
| else: | |
| code = _truncate_code_at_stopwords(code, stop_words) | |
| return code | |
| def _clean_python_code_for_sft(code): | |
| code = code.replace("\r", "") | |
| if "```python" in code: | |
| code_start_idx = code.index("```python") | |
| code = code[code_start_idx:].replace("```python", "").strip() | |
| end_idx = code.find("```") if "```" in code else len(code) | |
| code = code[:end_idx].strip() | |
| return code | |
| def _truncate_code_at_stopwords(code, stop_words): | |
| min_stop_idx = len(code) | |
| for stop_word in stop_words: | |
| stop_index = code.find(stop_word) | |
| if 0 <= stop_index < min_stop_idx: | |
| min_stop_idx = stop_index | |
| return code[:min_stop_idx] |