94 lines
4.1 KiB
Python
94 lines
4.1 KiB
Python
# This file is part of transport-accessibility.
|
|
# Copyright (C) 2024 Janek Kiljanski, Johannes Randerath
|
|
#
|
|
# transport-accessibility is free software: you can redistribute it and/or modify it under the terms of the
|
|
# GNU General Public License as published by the Free Software Foundation, either version 3
|
|
# of the License, or (at your option) any later version.
|
|
#
|
|
# transport-accessibility is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
|
|
# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
|
# See the GNU General Public License for more details.
|
|
#
|
|
# You should have received a copy of the GNU General Public License along with transport-accessibility.
|
|
# If not, see <https://www.gnu.org/licenses/>.
|
|
|
|
import csv
|
|
import io
|
|
import os
|
|
from pt_map.models import *
|
|
import pt_map
|
|
import pt_map.class_names
|
|
from django.db import models
|
|
import inspect
|
|
import datetime
|
|
import sys
|
|
|
|
def csv_queryset(q: models.query.QuerySet) -> str:
|
|
fields = [field.name for field in q.model._meta.fields]
|
|
result = io.StringIO()
|
|
csv_writer = csv.DictWriter(result, fields)
|
|
csv_writer.writeheader()
|
|
|
|
for model in q:
|
|
row = {}
|
|
for field_name in fields:
|
|
field = model._meta.get_field(field_name)
|
|
value = model.__getattribute__(field.name)
|
|
if value is None:
|
|
continue
|
|
if isinstance(field, models.DateField):
|
|
row[field.name] = value.strftime("%Y%m%d")
|
|
elif isinstance(field, models.ForeignKey):
|
|
row[field.name] = value.pk if not model == Shape else value.shape_id
|
|
elif isinstance(field, models.ManyToManyField):
|
|
row[field.name] = value.all().first().__getattribute__(field.name)
|
|
elif isinstance(field, models.BooleanField):
|
|
row[field.name] = int(value)
|
|
else:
|
|
row[field.name] = value
|
|
csv_writer.writerow(row)
|
|
return result.getvalue()
|
|
|
|
def models_csv(path: str) -> list[models.Model]:
|
|
assume_compliance = True
|
|
if assume_compliance:
|
|
os.chdir(path)
|
|
feed = None
|
|
order = []
|
|
for m in [*pt_map.class_names.fks.values(), *pt_map.class_names.mtm.values(), *[m for _,m in inspect.getmembers(pt_map.models, inspect.isclass)]]:
|
|
if m not in order:
|
|
order.append(m)
|
|
for m in order:
|
|
if os.path.exists(pt_map.class_names.file_names[m]):
|
|
with open(pt_map.class_names.file_names[m], 'r') as f:
|
|
csvreader = csv.DictReader(f)
|
|
mtm = {}
|
|
for row in csvreader:
|
|
for field in [field for field in m._meta.fields if field.name in csvreader.fieldnames]:
|
|
if not row[field.name]:
|
|
del row[field.name]
|
|
continue
|
|
if isinstance(field, models.ForeignKey):
|
|
row[field.name] = pt_map.class_names.fks[field.name].objects.get(pk=f"{feed.pk}_{row[field.name]}")
|
|
elif isinstance(field, models.DateField):
|
|
row[field.name] = datetime.datetime.fromisoformat(row[field.name])
|
|
elif (field.primary_key and feed) or field.name == 'service_id':
|
|
row[field.name] = f"{feed.pk}_{row[field.name]}"
|
|
for field in m._meta.many_to_many:
|
|
mtm[field.name] = pt_map.class_names.mtm[field.name].objects.filter(**{field.name: row[field.name]})
|
|
del row[field.name]
|
|
if feed:
|
|
row['feed_info_id'] = feed
|
|
if m == pt_map.models.Shape:
|
|
row['shape_id'] = f"{feed.pk}_{row['shape_id']}"
|
|
obj = m.objects.create(**row)
|
|
for name, value in mtm.items():
|
|
getattr(obj, name).set(value)
|
|
if not feed:
|
|
feed = obj
|
|
|
|
|
|
|
|
|
|
|