transport-accessibility/transport_accessibility/api/io.py
Johannes Randerath 3853d25c1e Added LICENSE
- Code uses AGPL
- Docs use GNU FDL
2024-07-08 22:10:53 +02:00

94 lines
4.1 KiB
Python

# This file is part of transport-accessibility.
# Copyright (C) 2024 Janek Kiljanski, Johannes Randerath
#
# transport-accessibility is free software: you can redistribute it and/or modify it under the terms of the
# GNU General Public License as published by the Free Software Foundation, either version 3
# of the License, or (at your option) any later version.
#
# transport-accessibility is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along with transport-accessibility.
# If not, see <https://www.gnu.org/licenses/>.
import csv
import io
import os
from pt_map.models import *
import pt_map
import pt_map.class_names
from django.db import models
import inspect
import datetime
import sys
def csv_queryset(q: models.query.QuerySet) -> str:
fields = [field.name for field in q.model._meta.fields]
result = io.StringIO()
csv_writer = csv.DictWriter(result, fields)
csv_writer.writeheader()
for model in q:
row = {}
for field_name in fields:
field = model._meta.get_field(field_name)
value = model.__getattribute__(field.name)
if value is None:
continue
if isinstance(field, models.DateField):
row[field.name] = value.strftime("%Y%m%d")
elif isinstance(field, models.ForeignKey):
row[field.name] = value.pk if not model == Shape else value.shape_id
elif isinstance(field, models.ManyToManyField):
row[field.name] = value.all().first().__getattribute__(field.name)
elif isinstance(field, models.BooleanField):
row[field.name] = int(value)
else:
row[field.name] = value
csv_writer.writerow(row)
return result.getvalue()
def models_csv(path: str) -> list[models.Model]:
assume_compliance = True
if assume_compliance:
os.chdir(path)
feed = None
order = []
for m in [*pt_map.class_names.fks.values(), *pt_map.class_names.mtm.values(), *[m for _,m in inspect.getmembers(pt_map.models, inspect.isclass)]]:
if m not in order:
order.append(m)
for m in order:
if os.path.exists(pt_map.class_names.file_names[m]):
with open(pt_map.class_names.file_names[m], 'r') as f:
csvreader = csv.DictReader(f)
mtm = {}
for row in csvreader:
for field in [field for field in m._meta.fields if field.name in csvreader.fieldnames]:
if not row[field.name]:
del row[field.name]
continue
if isinstance(field, models.ForeignKey):
row[field.name] = pt_map.class_names.fks[field.name].objects.get(pk=f"{feed.pk}_{row[field.name]}")
elif isinstance(field, models.DateField):
row[field.name] = datetime.datetime.fromisoformat(row[field.name])
elif (field.primary_key and feed) or field.name == 'service_id':
row[field.name] = f"{feed.pk}_{row[field.name]}"
for field in m._meta.many_to_many:
mtm[field.name] = pt_map.class_names.mtm[field.name].objects.filter(**{field.name: row[field.name]})
del row[field.name]
if feed:
row['feed_info_id'] = feed
if m == pt_map.models.Shape:
row['shape_id'] = f"{feed.pk}_{row['shape_id']}"
obj = m.objects.create(**row)
for name, value in mtm.items():
getattr(obj, name).set(value)
if not feed:
feed = obj